Source code for qfa.api.app

"""Application factory and composition root."""

import importlib.resources
import logging
import secrets
from collections.abc import AsyncGenerator, Callable
from contextlib import asynccontextmanager
from datetime import UTC, datetime
from typing import Any

import litellm
import yaml
from fastapi import FastAPI, HTTPException, Request
from fastapi.exceptions import RequestValidationError
from fastapi.responses import JSONResponse
from starlette.types import ASGIApp, Message, Receive, Scope, Send

import qfa
from qfa.adapters.llm_client import LiteLLMClient
from qfa.adapters.presidio_anonymizer import PresidioAnonymizer
from qfa.api.routes import router
from qfa.api.routes_usage import router as usage_router
from qfa.api.schemas import (
    ApiErrorDetail,
    ApiErrorFieldDetail,
    ApiErrorResponse,
)
from qfa.auth import validate_api_key
from qfa.domain.errors import (
    AnalysisError,
    AnalysisTimeoutError,
    AuthenticationError,
    AuthorizationError,
    FeedbackTooLargeError,
    LLMError,
    UsageRepositoryUnavailableError,
)
from qfa.domain.ports import LLMPort
from qfa.services.orchestrator import Orchestrator
from qfa.settings import AppSettings, LLMSettings
from qfa.utils import setup_logging

logger = logging.getLogger(__name__)


[docs] class RequestIdMiddleware: """Pure ASGI middleware that assigns a unique request ID to every request. Stores ``request_id`` and ``start_utc`` on ``scope["state"]`` and adds an ``X-Request-ID`` header to every response. Parameters ---------- app : ASGIApp The wrapped ASGI application. """ def __init__(self, app: ASGIApp) -> None: self.app = app async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None: """Process an ASGI request. Assigns a unique request ID, adds it to the response headers, and catches any unhandled exceptions to return a 500 JSON response. Parameters ---------- scope : Scope The ASGI connection scope. receive : Receive The ASGI receive callable. send : Send The ASGI send callable. """ if scope["type"] not in ("http", "websocket"): await self.app(scope, receive, send) return request_id = "req_" + secrets.token_urlsafe(16) scope.setdefault("state", {}) scope["state"]["request_id"] = request_id scope["state"]["start_utc"] = datetime.now(UTC) response_started = False async def send_with_request_id(message: Message) -> None: nonlocal response_started if message["type"] == "http.response.start": response_started = True headers: list[Any] = list(message.get("headers", [])) headers.append([b"x-request-id", request_id.encode()]) message["headers"] = headers await send(message) try: await self.app(scope, receive, send_with_request_id) except Exception: if response_started: raise logger.exception("Unhandled exception for request %s", request_id) body = ApiErrorResponse( error=ApiErrorDetail( code="internal_error", message="An unexpected error occurred", request_id=request_id, ) ) response = JSONResponse(status_code=500, content=body.model_dump()) response.headers["X-Request-ID"] = request_id await response(scope, receive, send)
[docs] class RequestLoggingMiddleware: """Pure ASGI middleware that logs every HTTP request. Logs method, path, status code, duration, request ID, and tenant name (when available). Never logs API keys or request bodies. """ def __init__(self, app: ASGIApp) -> None: self.app = app async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None: """Log method, path, status, duration, request ID, and tenant.""" if scope["type"] != "http": await self.app(scope, receive, send) return state = scope.get("state", {}) request_id = state.get("request_id", "unknown") start = state.get("start_utc") or datetime.now(UTC) method = scope.get("method", "?") path = scope.get("path", "?") status_code: int | None = None async def capture_status(message: Message) -> None: nonlocal status_code if message["type"] == "http.response.start": status_code = message.get("status", 0) await send(message) try: await self.app(scope, receive, capture_status) finally: duration_ms = (datetime.now(UTC) - start).total_seconds() * 1000 tenant_name = self._resolve_tenant(scope) logger.info( "%s %s status=%s duration=%.0fms request_id=%s tenant=%s", method, path, status_code, duration_ms, request_id, tenant_name, ) @staticmethod def _resolve_tenant(scope: Scope) -> str: """Extract tenant name from the Authorization header if possible. Never logs the API key itself. Returns ``"anonymous"`` when the tenant cannot be determined. """ headers: list[tuple[bytes, bytes]] = scope.get("headers", []) token: str | None = None for name, value in headers: if name.lower() == b"authorization": decoded = value.decode("latin-1", errors="replace") if decoded.lower().startswith("bearer "): token = decoded[7:] break if token is None: return "anonymous" app = scope.get("app") if app is None: return "anonymous" api_keys = getattr(getattr(app, "state", None), "api_keys", None) if not api_keys: return "anonymous" try: tenant = validate_api_key(token, api_keys) return tenant.name except Exception: return "invalid"
def _get_request_id(request: Request) -> str: """Extract request_id from request state, with a fallback. Parameters ---------- request : Request The incoming HTTP request. Returns ------- str The request ID string. """ return getattr(request.state, "request_id", "unknown") async def _handle_authentication_error( request: Request, exc: AuthenticationError ) -> JSONResponse: """Handle AuthenticationError exceptions. Parameters ---------- request : Request The incoming HTTP request. exc : AuthenticationError The authentication error. Returns ------- JSONResponse A 401 JSON response. """ body = ApiErrorResponse( error=ApiErrorDetail( code="authentication_required", message=str(exc), request_id=_get_request_id(request), ) ) return JSONResponse(status_code=401, content=body.model_dump()) async def _handle_authorization_error( request: Request, exc: AuthorizationError ) -> JSONResponse: """Handle AuthorizationError exceptions. Parameters ---------- request : Request The incoming HTTP request. exc : AuthorizationError The authorization error. Returns ------- JSONResponse A 403 JSON response. """ body = ApiErrorResponse( error=ApiErrorDetail( code="forbidden", message=str(exc), request_id=_get_request_id(request), ) ) return JSONResponse(status_code=403, content=body.model_dump()) async def _handle_validation_error( request: Request, exc: RequestValidationError ) -> JSONResponse: """Handle Pydantic RequestValidationError exceptions. Parameters ---------- request : Request The incoming HTTP request. exc : RequestValidationError The validation error. Returns ------- JSONResponse A 422 JSON response with per-field details. """ fields = [] for err in exc.errors(): loc_parts = [str(part) for part in err.get("loc", [])] field_name = ".".join(loc_parts) if loc_parts else "unknown" fields.append(ApiErrorFieldDetail(field=field_name, issue=err.get("msg", ""))) body = ApiErrorResponse( error=ApiErrorDetail( code="validation_error", message="Request validation failed", request_id=_get_request_id(request), fields=fields, ) ) return JSONResponse(status_code=422, content=body.model_dump()) async def _handle_feedback_too_large( request: Request, exc: FeedbackTooLargeError ) -> JSONResponse: """Handle FeedbackTooLargeError exceptions. Parameters ---------- request : Request The incoming HTTP request. exc : FeedbackTooLargeError The feedback-too-large error. Returns ------- JSONResponse A 413 JSON response. """ body = ApiErrorResponse( error=ApiErrorDetail( code="payload_too_large", message=str(exc), request_id=_get_request_id(request), ) ) return JSONResponse(status_code=413, content=body.model_dump()) async def _handle_analysis_timeout( request: Request, exc: AnalysisTimeoutError ) -> JSONResponse: """Handle AnalysisTimeoutError exceptions. Parameters ---------- request : Request The incoming HTTP request. exc : AnalysisTimeoutError The analysis timeout error. Returns ------- JSONResponse A 504 JSON response. """ body = ApiErrorResponse( error=ApiErrorDetail( code="analysis_timeout", message=str(exc), request_id=_get_request_id(request), ) ) return JSONResponse(status_code=504, content=body.model_dump()) async def _handle_analysis_error(request: Request, exc: AnalysisError) -> JSONResponse: """Handle AnalysisError exceptions. If the error message contains "injection", returns 422 instead of 502 to signal that the input was rejected. Parameters ---------- request : Request The incoming HTTP request. exc : AnalysisError The analysis error. Returns ------- JSONResponse A 502 or 422 JSON response depending on the error cause. """ logger.debug("Analysis error: %s", exc, exc_info=True) if "injection" in str(exc).lower(): body = ApiErrorResponse( error=ApiErrorDetail( code="validation_error", message=str(exc), request_id=_get_request_id(request), ) ) return JSONResponse(status_code=422, content=body.model_dump()) body = ApiErrorResponse( error=ApiErrorDetail( code="analysis_unavailable", message=str(exc), request_id=_get_request_id(request), ) ) return JSONResponse(status_code=502, content=body.model_dump()) async def _handle_llm_error(request: Request, exc: LLMError) -> JSONResponse: """Map an LLM provider failure to 502 bad_gateway. LLMError signals that an upstream LLM provider call failed in a way the orchestrator did not recover from. From the API consumer's perspective this is a bad gateway, distinct from a 504 timeout (AnalysisTimeoutError) or a 502 analysis failure (AnalysisError). """ logger.warning("LLM provider error: %s", exc, exc_info=True) body = ApiErrorResponse( error=ApiErrorDetail( code="llm_error", message=str(exc), request_id=_get_request_id(request), ) ) return JSONResponse(status_code=502, content=body.model_dump()) async def _handle_usage_repository_unavailable( request: Request, exc: UsageRepositoryUnavailableError ) -> JSONResponse: """Map a usage-repository unavailability to 503 with a machine-readable code. Distinct from ``usage_tracking_disabled`` (raised by ``get_usage_repo`` when the feature flag is off): this signals that the feature is on but the backing store is transiently unreachable. Consumers can use the code to drive retry/backoff decisions instead of treating both as the same opaque 503. """ logger.warning("Usage repository unavailable: %s", exc) body = ApiErrorResponse( error=ApiErrorDetail( code="usage_backend_unavailable", message="Usage backend is temporarily unavailable", request_id=_get_request_id(request), ) ) return JSONResponse(status_code=503, content=body.model_dump()) async def _handle_http_exception(request: Request, exc: HTTPException) -> JSONResponse: """Wrap HTTPException with the standard error envelope. When ``detail`` is a dict with ``code``/``message`` keys, those are surfaced. Otherwise the detail string becomes the message and a generic ``http_error`` code is used. """ detail = exc.detail if isinstance(detail, dict): body = ApiErrorResponse( error=ApiErrorDetail( code=str(detail.get("code", "http_error")), message=str(detail.get("message", "")), request_id=_get_request_id(request), ) ) else: body = ApiErrorResponse( error=ApiErrorDetail( code="http_error", message=str(detail) if detail is not None else "", request_id=_get_request_id(request), ) ) return JSONResponse(status_code=exc.status_code, content=body.model_dump()) async def _handle_unhandled_exception(request: Request, exc: Exception) -> JSONResponse: """Handle unexpected exceptions. Parameters ---------- request : Request The incoming HTTP request. exc : Exception The unhandled exception. Returns ------- JSONResponse A 500 JSON response. """ logger.exception("Unhandled exception: %s", exc) body = ApiErrorResponse( error=ApiErrorDetail( code="internal_error", message="An unexpected error occurred", request_id=_get_request_id(request), ) ) return JSONResponse(status_code=500, content=body.model_dump())
[docs] def build_llm_client(settings: LLMSettings) -> LiteLLMClient: """Build an LLM client from the provided settings. Parameters ---------- settings : LLMSettings The LLM configuration settings. Returns ------- LiteLLMClient A configured LLM client instance. """ return LiteLLMClient( model=settings.model, api_key=settings.api_key.get_secret_value(), api_base=settings.api_base, api_version=settings.api_version, chars_per_token=settings.chars_per_token, max_total_tokens=settings.max_total_tokens, )
def _register_custom_model_prices() -> None: """Load custom model pricing from the bundled YAML resource. Registers models with LiteLLM so that ``completion_cost()`` works for models not in the built-in cost map. """ prices_path = importlib.resources.files("qfa.resources").joinpath( "model_prices.yaml" ) with importlib.resources.as_file(prices_path) as f: custom_prices = yaml.safe_load(f.read_text()) if custom_prices and custom_prices.get("models"): litellm.register_model(custom_prices["models"]) logger.info( "Registered %d custom model price(s) for %s", len(custom_prices["models"]), list(custom_prices["models"].keys()), ) LLMFactory = Callable[[LLMSettings], LLMPort] """Factory that builds an ``LLMPort`` from settings. The default is ``build_llm_client`` (real LiteLLM client). Tests can pass their own factory to ``create_app`` to inject a fake without monkeypatching. """ def _make_lifespan(llm_factory: LLMFactory): """Build a FastAPI lifespan context manager that closes over ``llm_factory``. FastAPI's ``lifespan=`` parameter accepts a single async context manager whose signature is fixed at ``(app: FastAPI) -> ...``. There is no built-in way to thread extra construction-time dependencies (like which ``LLMPort`` factory to use) through that signature without resorting to module-level globals or monkeypatching. This factory closes over ``llm_factory`` and returns the resulting lifespan, so ``create_app`` can pass a fake factory in tests and the lifespan picks it up at startup — wiring the same composition path (``llm_factory(settings.llm)`` → optional ``TrackingLLMAdapter`` wrap → ``StandardOrchestrator``) regardless of whether the LLM client is real or stubbed. Production simply omits the override and gets the default ``build_llm_client``. Parameters ---------- llm_factory : LLMFactory Factory invoked at startup to construct the inner ``LLMPort``. Returns ------- Callable[[FastAPI], AsyncContextManager[None]] A lifespan suitable for ``FastAPI(lifespan=...)``. """ @asynccontextmanager async def lifespan(app: FastAPI) -> AsyncGenerator[None, None]: """Compose the application graph at startup; tear it down on shutdown. This is the runtime composition root: it loads settings, builds every dependency that routes consume, and attaches the results to ``app.state`` so request handlers can read them without importing modules directly. Doing this in the lifespan (rather than at import time) ensures settings/env-vars are read once per process boot and the DB engine is created on the running event loop. Schema migrations are NOT run from the lifespan. They run as a pre-start step in ``entrypoint.sh`` (``python -m qfa.cli.migrate``) before this process binds the port, so the app boots against an already-current schema. Startup order is significant: 1. Load ``AppSettings`` and configure logging — must happen before anything that might log. 2. Register custom LiteLLM model prices — required for ``completion_cost()`` to value any model not in LiteLLM's built-in cost map. 3. Build the base ``LLMPort`` via the closed-over factory. 4. If ``settings.db.track_usage`` is set: create the async DB engine and wrap the base LLM in ``TrackingLLMAdapter`` so every call attempt is recorded. 5. Construct the ``StandardOrchestrator`` over whichever LLM variant was chosen above. 6. Publish ``orchestrator``, ``api_keys``, ``settings``, and ``usage_repo`` on ``app.state`` for routes/middleware to read. On shutdown the only resource that needs explicit cleanup is the DB engine's connection pool; everything else is plain Python objects that the GC handles. Parameters ---------- app : FastAPI The application instance whose ``state`` will be populated. """ settings = AppSettings() setup_logging(settings.log) _register_custom_model_prices() anonymizer = PresidioAnonymizer() api_keys = settings.auth.api_keys engine = None usage_repo = None base_llm = llm_factory(settings.llm) llm_for_orch: LLMPort = base_llm if settings.db.track_usage: from qfa.adapters.db import ( SqlAlchemyUsageRepository, create_async_engine_from_settings, create_session_factory, ) from qfa.adapters.tracking_llm import TrackingLLMAdapter engine = create_async_engine_from_settings(settings.db) session_factory = create_session_factory(engine) usage_repo = SqlAlchemyUsageRepository(session_factory) llm_for_orch = TrackingLLMAdapter(inner=base_llm, usage_repo=usage_repo) logger.info("Usage tracking enabled (per-attempt, per-operation)") orchestrator = Orchestrator( llm=llm_for_orch, anonymizer=anonymizer, settings=settings.orchestrator, llm_timeout_seconds=settings.llm.timeout_seconds, max_total_tokens=settings.llm.max_total_tokens, ) app.state.orchestrator = orchestrator app.state.api_keys = api_keys app.state.settings = settings app.state.usage_repo = usage_repo yield if engine is not None: await engine.dispose() return lifespan
[docs] def register_exception_handlers(app: FastAPI) -> None: """Register all exception handlers on the application. Parameters ---------- app : FastAPI The FastAPI application instance. """ app.add_exception_handler(AuthorizationError, _handle_authorization_error) # ty: ignore[invalid-argument-type] app.add_exception_handler(AuthenticationError, _handle_authentication_error) # ty: ignore[invalid-argument-type] app.add_exception_handler(RequestValidationError, _handle_validation_error) # ty: ignore[invalid-argument-type] app.add_exception_handler(FeedbackTooLargeError, _handle_feedback_too_large) # ty: ignore[invalid-argument-type] app.add_exception_handler(AnalysisTimeoutError, _handle_analysis_timeout) # ty: ignore[invalid-argument-type] app.add_exception_handler(AnalysisError, _handle_analysis_error) # ty: ignore[invalid-argument-type] app.add_exception_handler(LLMError, _handle_llm_error) # ty: ignore[invalid-argument-type] app.add_exception_handler( UsageRepositoryUnavailableError, _handle_usage_repository_unavailable, # ty: ignore[invalid-argument-type] ) app.add_exception_handler(HTTPException, _handle_http_exception) # ty: ignore[invalid-argument-type] app.add_exception_handler(Exception, _handle_unhandled_exception)
[docs] def create_app(*, llm_factory: LLMFactory | None = None) -> FastAPI: """Create and configure the FastAPI application. Parameters ---------- llm_factory : LLMFactory | None Optional override for the LLM-port factory. Defaults to ``build_llm_client`` (the real LiteLLM client). Tests pass a fake factory here to inject a stubbed ``LLMPort`` without monkeypatching — the lifespan still wraps it in ``TrackingLLMAdapter`` exactly as it would the real client. Returns ------- FastAPI The fully configured application instance. """ factory: LLMFactory = llm_factory if llm_factory is not None else build_llm_client app = FastAPI( title="Feedback Analysis Backend", lifespan=_make_lifespan(factory), version=qfa.__version__, ) app.add_middleware(RequestLoggingMiddleware) app.add_middleware(RequestIdMiddleware) app.include_router(router) app.include_router(usage_router) register_exception_handlers(app) return app