diff --git a/pyproject.toml b/pyproject.toml index ef64c25..797263f 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [project] name = "servicekit" -version = "0.9.0" +version = "0.10.0" description = "Async SQLAlchemy framework with FastAPI integration - reusable foundation for building data services" readme = "README.md" authors = [{ name = "Morten Hansen", email = "morten@winterop.com" }] diff --git a/src/servicekit/api/service_builder.py b/src/servicekit/api/service_builder.py index d0b7637..0cf8c9a 100644 --- a/src/servicekit/api/service_builder.py +++ b/src/servicekit/api/service_builder.py @@ -2,6 +2,8 @@ from __future__ import annotations +import asyncio +import os import re from contextlib import asynccontextmanager from dataclasses import dataclass @@ -587,6 +589,7 @@ def _build_lifespan(self) -> LifespanFactory: job_options = self._job_options include_logging = self._include_logging registration_options = self._registration_options + health_path = self._health_options.prefix if self._health_options else None info = self.info startup_hooks = list(self._startup_hooks) shutdown_hooks = list(self._shutdown_hooks) @@ -663,70 +666,39 @@ async def lifespan(app: FastAPI) -> AsyncIterator[None]: for hook in startup_hooks: await hook(app) - # Register with orchestrator if enabled - registration_info = None + # Deferred registration: always wait for app to be ready before registering. + # The task is created now and runs after yield (once uvicorn is serving). + # For fail_on_error=True we await it immediately after yield so that + # exceptions still propagate (the app shuts down on the first request cycle). + registration_task: asyncio.Task[None] | None = None + if registration_options is not None: - from .registration import register_service, start_keepalive - - registration_info = await register_service( - orchestrator_url=registration_options.orchestrator_url, - host=registration_options.host, - port=registration_options.port, - info=info, - orchestrator_url_env=registration_options.orchestrator_url_env, - host_env=registration_options.host_env, - port_env=registration_options.port_env, - max_retries=registration_options.max_retries, - retry_delay=registration_options.retry_delay, - fail_on_error=registration_options.fail_on_error, - timeout=registration_options.timeout, - service_key=registration_options.service_key, - service_key_env=registration_options.service_key_env, + registration_task = asyncio.create_task( + _register_after_ready(registration_options, info, app, health_path), + name="servicekit-deferred-registration", ) - # Start keepalive if registration succeeded and enabled - if registration_info and registration_options.enable_keepalive: - ping_url = registration_info.get("ping_url") - if ping_url: - from .registration import RegistrationConfig - - registration_config = RegistrationConfig( - orchestrator_url=registration_options.orchestrator_url, - host=registration_options.host, - port=registration_options.port, - info=info, - orchestrator_url_env=registration_options.orchestrator_url_env, - host_env=registration_options.host_env, - port_env=registration_options.port_env, - max_retries=registration_options.max_retries, - retry_delay=registration_options.retry_delay, - fail_on_error=False, - timeout=registration_options.timeout, - service_key=registration_options.service_key, - service_key_env=registration_options.service_key_env, - ) - await start_keepalive( - ping_url=ping_url, - interval=registration_options.keepalive_interval, - timeout=registration_options.timeout, - service_key=registration_options.service_key, - service_key_env=registration_options.service_key_env, - registration_config=registration_config, - re_register_grace_period=registration_options.re_register_grace_period, - ) - try: yield finally: - # Stop keepalive and deregister service if enabled + if registration_task is not None: + if not registration_task.done(): + registration_task.cancel() + try: + await registration_task + except asyncio.CancelledError: + pass + + # Read registration info stored by the background task + registration_info: dict[str, Any] | None = getattr(app.state, "registration_info", None) + + # Stop keepalive and deregister service if registration completed if registration_options is not None and registration_info: from .registration import deregister_service, stop_keepalive - # Stop keepalive task if registration_options.enable_keepalive: await stop_keepalive() - # Deregister from orchestrator if registration_options.auto_deregister: service_id = registration_info.get("service_id") orchestrator_url = registration_info.get("orchestrator_url") @@ -830,3 +802,128 @@ async def get_info() -> ServiceInfo: def create(cls, *, info: ServiceInfo, **kwargs: Any) -> FastAPI: """Create and build a FastAPI application in one call.""" return cls(info=info, **kwargs).build() + + +def _resolve_port(options: _RegistrationOptions) -> int: + """Resolve port from options or environment, matching register_service logic.""" + if options.port is not None: + return options.port + port_str = os.getenv(options.port_env) + if port_str: + try: + return int(port_str) + except ValueError: + return 8000 + return 8000 + + +async def _wait_until_ready( + port: int, *, health_path: str | None = None, poll_interval: float = 0.5, timeout: float = 30.0 +) -> bool: + """Poll the local app until it is serving requests.""" + import httpx + + if health_path: + url = f"http://127.0.0.1:{port}{health_path}" + check_kind = "health" + else: + url = f"http://127.0.0.1:{port}/" + check_kind = "tcp" + + deadline = asyncio.get_event_loop().time() + timeout + while asyncio.get_event_loop().time() < deadline: + try: + async with httpx.AsyncClient() as client: + response = await client.get(url, timeout=2.0) + if health_path: + if response.status_code == 200: + return True + else: + # Any response means the server is accepting connections + return True + except Exception: + pass + await asyncio.sleep(poll_interval) + + logger.warning("registration.readiness_timeout", port=port, check=check_kind, timeout=timeout) + return False + + +async def _register_and_start_keepalive(options: _RegistrationOptions, info: BaseModel) -> dict[str, Any] | None: + """Register with orchestrator and start keepalive. Returns registration info.""" + from .registration import RegistrationConfig, register_service, start_keepalive + + registration_info = await register_service( + orchestrator_url=options.orchestrator_url, + host=options.host, + port=options.port, + info=info, + orchestrator_url_env=options.orchestrator_url_env, + host_env=options.host_env, + port_env=options.port_env, + max_retries=options.max_retries, + retry_delay=options.retry_delay, + fail_on_error=options.fail_on_error, + timeout=options.timeout, + service_key=options.service_key, + service_key_env=options.service_key_env, + ) + + if registration_info and options.enable_keepalive: + ping_url = registration_info.get("ping_url") + if ping_url: + registration_config = RegistrationConfig( + orchestrator_url=options.orchestrator_url, + host=options.host, + port=options.port, + info=info, + orchestrator_url_env=options.orchestrator_url_env, + host_env=options.host_env, + port_env=options.port_env, + max_retries=options.max_retries, + retry_delay=options.retry_delay, + fail_on_error=False, + timeout=options.timeout, + service_key=options.service_key, + service_key_env=options.service_key_env, + ) + await start_keepalive( + ping_url=ping_url, + interval=options.keepalive_interval, + timeout=options.timeout, + service_key=options.service_key, + service_key_env=options.service_key_env, + registration_config=registration_config, + re_register_grace_period=options.re_register_grace_period, + ) + + return registration_info + + +async def _register_after_ready( + options: _RegistrationOptions, info: BaseModel, app: FastAPI, health_path: str | None +) -> None: + """Wait for the app to be ready, then register with the orchestrator.""" + port = _resolve_port(options) + ready = await _wait_until_ready(port, health_path=health_path) + if not ready: + logger.error( + "registration.aborted", + port=port, + message="App never became ready, skipping registration", + ) + return + + # Shield registration from cancellation so that if the POST succeeds, + # app.state.registration_info is always written before the task exits. + # This prevents a leaked registration when shutdown cancels the task + # between the successful POST and the state assignment. + shielded = asyncio.ensure_future(_register_and_start_keepalive(options, info)) + try: + registration_info = await asyncio.shield(shielded) + except asyncio.CancelledError: + # Shutdown cancelled us mid-registration. Wait for the shielded + # coroutine to finish so we can still store the result for cleanup. + registration_info = await shielded + if registration_info: + app.state.registration_info = registration_info diff --git a/tests/test_deferred_registration.py b/tests/test_deferred_registration.py new file mode 100644 index 0000000..2716e3f --- /dev/null +++ b/tests/test_deferred_registration.py @@ -0,0 +1,473 @@ +"""Tests for deferred service registration lifecycle.""" + +import asyncio +from typing import Any +from unittest.mock import AsyncMock, MagicMock, patch + +import pytest +from fastapi import FastAPI + +from servicekit.api.service_builder import ( + BaseServiceBuilder, + ServiceInfo, + _register_after_ready, + _register_and_start_keepalive, + _RegistrationOptions, + _resolve_port, + _wait_until_ready, +) + + +def _make_options() -> _RegistrationOptions: + """Create _RegistrationOptions with sensible defaults.""" + return _RegistrationOptions( + orchestrator_url="http://orchestrator:9000/services/$register", + host="test-host", + port=9999, + orchestrator_url_env="SERVICEKIT_ORCHESTRATOR_URL", + host_env="SERVICEKIT_HOST", + port_env="SERVICEKIT_PORT", + max_retries=1, + retry_delay=0.0, + fail_on_error=False, + timeout=2.0, + enable_keepalive=False, + keepalive_interval=10.0, + auto_deregister=True, + service_key=None, + service_key_env="SERVICEKIT_REGISTRATION_KEY", + re_register_grace_period=30.0, + ) + + +def _make_info() -> ServiceInfo: + """Create a minimal ServiceInfo.""" + return ServiceInfo(id="test-svc", display_name="Test Service") + + +# --------------------------------------------------------------------------- +# _wait_until_ready +# --------------------------------------------------------------------------- + + +@pytest.mark.asyncio +async def test_wait_until_ready_health_success(): + """Return True when health endpoint responds 200.""" + mock_response = MagicMock(status_code=200) + mock_client = AsyncMock() + mock_client.get = AsyncMock(return_value=mock_response) + + with patch("httpx.AsyncClient") as mock_cls: + mock_cls.return_value.__aenter__ = AsyncMock(return_value=mock_client) + mock_cls.return_value.__aexit__ = AsyncMock(return_value=False) + + result = await _wait_until_ready(9999, health_path="/health", timeout=2.0) + + assert result is True + mock_client.get.assert_called_once() + assert "/health" in str(mock_client.get.call_args) + + +@pytest.mark.asyncio +async def test_wait_until_ready_tcp_fallback_any_status(): + """Return True on any HTTP response when no health_path (TCP mode).""" + mock_response = MagicMock(status_code=404) + mock_client = AsyncMock() + mock_client.get = AsyncMock(return_value=mock_response) + + with patch("httpx.AsyncClient") as mock_cls: + mock_cls.return_value.__aenter__ = AsyncMock(return_value=mock_client) + mock_cls.return_value.__aexit__ = AsyncMock(return_value=False) + + result = await _wait_until_ready(9999, health_path=None, timeout=2.0) + + assert result is True + + +@pytest.mark.asyncio +async def test_wait_until_ready_timeout(): + """Return False when the endpoint never responds within timeout.""" + mock_client = AsyncMock() + mock_client.get = AsyncMock(side_effect=ConnectionError("refused")) + + with patch("httpx.AsyncClient") as mock_cls: + mock_cls.return_value.__aenter__ = AsyncMock(return_value=mock_client) + mock_cls.return_value.__aexit__ = AsyncMock(return_value=False) + + result = await _wait_until_ready(9999, health_path="/health", poll_interval=0.05, timeout=0.15) + + assert result is False + + +@pytest.mark.asyncio +async def test_wait_until_ready_custom_health_path(): + """Use the custom health path, not hardcoded /health.""" + mock_response = MagicMock(status_code=200) + mock_client = AsyncMock() + mock_client.get = AsyncMock(return_value=mock_response) + + with patch("httpx.AsyncClient") as mock_cls: + mock_cls.return_value.__aenter__ = AsyncMock(return_value=mock_client) + mock_cls.return_value.__aexit__ = AsyncMock(return_value=False) + + result = await _wait_until_ready(9999, health_path="/status", timeout=2.0) + + assert result is True + url_called = str(mock_client.get.call_args) + assert "/status" in url_called + assert "/health" not in url_called + + +# --------------------------------------------------------------------------- +# _register_after_ready +# --------------------------------------------------------------------------- + + +@pytest.mark.asyncio +async def test_register_after_ready_skips_when_not_ready(): + """Do not register when readiness check times out.""" + options = _make_options() + info = _make_info() + app = FastAPI() + + with ( + patch( + "servicekit.api.service_builder._wait_until_ready", + new_callable=AsyncMock, + return_value=False, + ), + patch( + "servicekit.api.service_builder._register_and_start_keepalive", + new_callable=AsyncMock, + ) as mock_register, + ): + await _register_after_ready(options, info, app, "/health") + + mock_register.assert_not_called() + assert not hasattr(app.state, "registration_info") + + +@pytest.mark.asyncio +async def test_register_after_ready_registers_when_ready(): + """Register and store info on app.state when readiness succeeds.""" + options = _make_options() + info = _make_info() + app = FastAPI() + reg_info = {"service_id": "svc-1", "orchestrator_url": "http://orch", "ping_url": None} + + with ( + patch( + "servicekit.api.service_builder._wait_until_ready", + new_callable=AsyncMock, + return_value=True, + ), + patch( + "servicekit.api.service_builder._register_and_start_keepalive", + new_callable=AsyncMock, + return_value=reg_info, + ), + ): + await _register_after_ready(options, info, app, "/health") + + assert app.state.registration_info == reg_info + + +@pytest.mark.asyncio +async def test_register_after_ready_stores_state_on_cancellation(): + """Registration info is stored even if task is cancelled mid-registration.""" + options = _make_options() + info = _make_info() + app = FastAPI() + reg_info = {"service_id": "svc-1", "orchestrator_url": "http://orch", "ping_url": None} + + async def slow_register(*_args: object, **_kwargs: object) -> dict[str, str | None]: + """Simulate a registration that takes some time.""" + await asyncio.sleep(0.1) + return reg_info + + with ( + patch( + "servicekit.api.service_builder._wait_until_ready", + new_callable=AsyncMock, + return_value=True, + ), + patch( + "servicekit.api.service_builder._register_and_start_keepalive", + side_effect=slow_register, + ), + ): + task = asyncio.create_task(_register_after_ready(options, info, app, "/health")) + # Let the task get past _wait_until_ready and into _register_and_start_keepalive + await asyncio.sleep(0.01) + task.cancel() + # The shielded inner coroutine should still complete + try: + await task + except asyncio.CancelledError: + pass + + assert app.state.registration_info == reg_info + + +# --------------------------------------------------------------------------- +# _resolve_port +# --------------------------------------------------------------------------- + + +def test_resolve_port_from_options(): + """Use port from options when explicitly set.""" + options = _make_options() + assert _resolve_port(options) == 9999 + + +def test_resolve_port_from_env(monkeypatch: pytest.MonkeyPatch): + """Fall back to environment variable when options.port is None.""" + options = _RegistrationOptions( + orchestrator_url="http://orch:9000/services/$register", + host="h", + port=None, + orchestrator_url_env="SERVICEKIT_ORCHESTRATOR_URL", + host_env="SERVICEKIT_HOST", + port_env="SERVICEKIT_PORT", + max_retries=1, + retry_delay=0.0, + fail_on_error=False, + timeout=2.0, + enable_keepalive=False, + keepalive_interval=10.0, + auto_deregister=True, + service_key=None, + service_key_env="SERVICEKIT_REGISTRATION_KEY", + re_register_grace_period=30.0, + ) + monkeypatch.setenv("SERVICEKIT_PORT", "7777") + assert _resolve_port(options) == 7777 + + +def test_resolve_port_default(monkeypatch: pytest.MonkeyPatch): + """Default to 8000 when port not set anywhere.""" + options = _RegistrationOptions( + orchestrator_url="http://orch:9000/services/$register", + host="h", + port=None, + orchestrator_url_env="SERVICEKIT_ORCHESTRATOR_URL", + host_env="SERVICEKIT_HOST", + port_env="SERVICEKIT_PORT", + max_retries=1, + retry_delay=0.0, + fail_on_error=False, + timeout=2.0, + enable_keepalive=False, + keepalive_interval=10.0, + auto_deregister=True, + service_key=None, + service_key_env="SERVICEKIT_REGISTRATION_KEY", + re_register_grace_period=30.0, + ) + monkeypatch.delenv("SERVICEKIT_PORT", raising=False) + assert _resolve_port(options) == 8000 + + +def test_resolve_port_invalid_env(monkeypatch: pytest.MonkeyPatch): + """Fall back to 8000 when env var is not a valid integer.""" + options = _RegistrationOptions( + orchestrator_url="http://orch:9000/services/$register", + host="h", + port=None, + orchestrator_url_env="SERVICEKIT_ORCHESTRATOR_URL", + host_env="SERVICEKIT_HOST", + port_env="SERVICEKIT_PORT", + max_retries=1, + retry_delay=0.0, + fail_on_error=False, + timeout=2.0, + enable_keepalive=False, + keepalive_interval=10.0, + auto_deregister=True, + service_key=None, + service_key_env="SERVICEKIT_REGISTRATION_KEY", + re_register_grace_period=30.0, + ) + monkeypatch.setenv("SERVICEKIT_PORT", "not-a-number") + assert _resolve_port(options) == 8000 + + +# --------------------------------------------------------------------------- +# _register_and_start_keepalive +# --------------------------------------------------------------------------- + + +@pytest.mark.asyncio +async def test_register_and_start_keepalive_success(): + """Calls register_service and returns registration info.""" + options = _make_options() + info = _make_info() + reg_info: dict[str, Any] = { + "service_id": "svc-1", + "service_url": "http://test-host:9999", + "orchestrator_url": "http://orchestrator:9000/services/$register", + "ttl_seconds": 60, + "ping_url": None, + } + + with patch( + "servicekit.api.registration.register_service", + new_callable=AsyncMock, + return_value=reg_info, + ): + result = await _register_and_start_keepalive(options, info) + + assert result == reg_info + + +@pytest.mark.asyncio +async def test_register_and_start_keepalive_with_keepalive(): + """Starts keepalive when registration returns a ping_url.""" + options = _RegistrationOptions( + orchestrator_url="http://orchestrator:9000/services/$register", + host="test-host", + port=9999, + orchestrator_url_env="SERVICEKIT_ORCHESTRATOR_URL", + host_env="SERVICEKIT_HOST", + port_env="SERVICEKIT_PORT", + max_retries=1, + retry_delay=0.0, + fail_on_error=False, + timeout=2.0, + enable_keepalive=True, + keepalive_interval=10.0, + auto_deregister=True, + service_key=None, + service_key_env="SERVICEKIT_REGISTRATION_KEY", + re_register_grace_period=30.0, + ) + info = _make_info() + reg_info: dict[str, Any] = { + "service_id": "svc-1", + "service_url": "http://test-host:9999", + "orchestrator_url": "http://orchestrator:9000/services/$register", + "ttl_seconds": 60, + "ping_url": "http://orchestrator:9000/services/svc-1/$ping", + } + + with ( + patch( + "servicekit.api.registration.register_service", + new_callable=AsyncMock, + return_value=reg_info, + ), + patch( + "servicekit.api.registration.start_keepalive", + new_callable=AsyncMock, + ) as mock_keepalive, + ): + result = await _register_and_start_keepalive(options, info) + + assert result == reg_info + mock_keepalive.assert_called_once() + + +@pytest.mark.asyncio +async def test_register_and_start_keepalive_failure(): + """Returns None when registration fails.""" + options = _make_options() + info = _make_info() + + with patch( + "servicekit.api.registration.register_service", + new_callable=AsyncMock, + return_value=None, + ): + result = await _register_and_start_keepalive(options, info) + + assert result is None + + +# --------------------------------------------------------------------------- +# Lifespan integration: builder creates deferred registration task +# --------------------------------------------------------------------------- + + +@pytest.mark.asyncio +async def test_lifespan_deferred_registration_and_deregistration(): + """Builder with registration creates a deferred task; shutdown deregisters.""" + reg_info: dict[str, Any] = { + "service_id": "svc-1", + "service_url": "http://test-host:9999", + "orchestrator_url": "http://orchestrator:9000/services/$register", + "ttl_seconds": 60, + "ping_url": None, + } + + builder = ( + BaseServiceBuilder(info=ServiceInfo(id="test-svc", display_name="Test")) + .with_health() + .with_registration( + orchestrator_url="http://orchestrator:9000/services/$register", + host="test-host", + port=9999, + enable_keepalive=False, + ) + ) + app = builder.build() + + with ( + patch( + "servicekit.api.service_builder._wait_until_ready", + new_callable=AsyncMock, + return_value=True, + ), + patch( + "servicekit.api.registration.register_service", + new_callable=AsyncMock, + return_value=reg_info, + ), + patch( + "servicekit.api.registration.deregister_service", + new_callable=AsyncMock, + ) as mock_deregister, + ): + async with app.router.lifespan_context(app): + # Let the background task complete + await asyncio.sleep(0.05) + assert app.state.registration_info == reg_info + + # After lifespan exit, deregister should have been called + mock_deregister.assert_called_once() + + +@pytest.mark.asyncio +async def test_lifespan_no_registration_when_not_ready(): + """Registration is skipped when readiness check fails.""" + builder = ( + BaseServiceBuilder(info=ServiceInfo(id="test-svc", display_name="Test")) + .with_health() + .with_registration( + orchestrator_url="http://orchestrator:9000/services/$register", + host="test-host", + port=9999, + enable_keepalive=False, + ) + ) + app = builder.build() + + with ( + patch( + "servicekit.api.service_builder._wait_until_ready", + new_callable=AsyncMock, + return_value=False, + ), + patch( + "servicekit.api.registration.register_service", + new_callable=AsyncMock, + ) as mock_register, + patch( + "servicekit.api.registration.deregister_service", + new_callable=AsyncMock, + ) as mock_deregister, + ): + async with app.router.lifespan_context(app): + await asyncio.sleep(0.05) + + mock_register.assert_not_called() + mock_deregister.assert_not_called() diff --git a/uv.lock b/uv.lock index 4580a06..ecadcc2 100644 --- a/uv.lock +++ b/uv.lock @@ -1149,7 +1149,7 @@ wheels = [ [[package]] name = "servicekit" -version = "0.9.0" +version = "0.10.0" source = { editable = "." } dependencies = [ { name = "aiosqlite" },