"""Application factory and composition root."""
import importlib.resources
import logging
import secrets
from collections.abc import AsyncGenerator, Callable
from contextlib import asynccontextmanager
from datetime import UTC, datetime
from typing import Any
import litellm
import yaml
from fastapi import FastAPI, HTTPException, Request
from fastapi.exceptions import RequestValidationError
from fastapi.responses import JSONResponse
from starlette.types import ASGIApp, Message, Receive, Scope, Send
import qfa
from qfa.adapters.llm_client import LiteLLMClient
from qfa.adapters.presidio_anonymizer import PresidioAnonymizer
from qfa.api.routes import router
from qfa.api.routes_usage import router as usage_router
from qfa.api.schemas import (
ApiErrorDetail,
ApiErrorFieldDetail,
ApiErrorResponse,
)
from qfa.auth import validate_api_key
from qfa.domain.errors import (
AnalysisError,
AnalysisTimeoutError,
AuthenticationError,
AuthorizationError,
FeedbackTooLargeError,
LLMError,
UsageRepositoryUnavailableError,
)
from qfa.domain.ports import LLMPort
from qfa.services.orchestrator import Orchestrator
from qfa.settings import AppSettings, LLMSettings
from qfa.utils import setup_logging
logger = logging.getLogger(__name__)
[docs]
class RequestIdMiddleware:
"""Pure ASGI middleware that assigns a unique request ID to every request.
Stores ``request_id`` and ``start_utc`` on ``scope["state"]`` and
adds an ``X-Request-ID`` header to every response.
Parameters
----------
app : ASGIApp
The wrapped ASGI application.
"""
def __init__(self, app: ASGIApp) -> None:
self.app = app
async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
"""Process an ASGI request.
Assigns a unique request ID, adds it to the response headers,
and catches any unhandled exceptions to return a 500 JSON response.
Parameters
----------
scope : Scope
The ASGI connection scope.
receive : Receive
The ASGI receive callable.
send : Send
The ASGI send callable.
"""
if scope["type"] not in ("http", "websocket"):
await self.app(scope, receive, send)
return
request_id = "req_" + secrets.token_urlsafe(16)
scope.setdefault("state", {})
scope["state"]["request_id"] = request_id
scope["state"]["start_utc"] = datetime.now(UTC)
response_started = False
async def send_with_request_id(message: Message) -> None:
nonlocal response_started
if message["type"] == "http.response.start":
response_started = True
headers: list[Any] = list(message.get("headers", []))
headers.append([b"x-request-id", request_id.encode()])
message["headers"] = headers
await send(message)
try:
await self.app(scope, receive, send_with_request_id)
except Exception:
if response_started:
raise
logger.exception("Unhandled exception for request %s", request_id)
body = ApiErrorResponse(
error=ApiErrorDetail(
code="internal_error",
message="An unexpected error occurred",
request_id=request_id,
)
)
response = JSONResponse(status_code=500, content=body.model_dump())
response.headers["X-Request-ID"] = request_id
await response(scope, receive, send)
[docs]
class RequestLoggingMiddleware:
"""Pure ASGI middleware that logs every HTTP request.
Logs method, path, status code, duration, request ID, and tenant name
(when available). Never logs API keys or request bodies.
"""
def __init__(self, app: ASGIApp) -> None:
self.app = app
async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
"""Log method, path, status, duration, request ID, and tenant."""
if scope["type"] != "http":
await self.app(scope, receive, send)
return
state = scope.get("state", {})
request_id = state.get("request_id", "unknown")
start = state.get("start_utc") or datetime.now(UTC)
method = scope.get("method", "?")
path = scope.get("path", "?")
status_code: int | None = None
async def capture_status(message: Message) -> None:
nonlocal status_code
if message["type"] == "http.response.start":
status_code = message.get("status", 0)
await send(message)
try:
await self.app(scope, receive, capture_status)
finally:
duration_ms = (datetime.now(UTC) - start).total_seconds() * 1000
tenant_name = self._resolve_tenant(scope)
logger.info(
"%s %s status=%s duration=%.0fms request_id=%s tenant=%s",
method,
path,
status_code,
duration_ms,
request_id,
tenant_name,
)
@staticmethod
def _resolve_tenant(scope: Scope) -> str:
"""Extract tenant name from the Authorization header if possible.
Never logs the API key itself. Returns ``"anonymous"`` when the
tenant cannot be determined.
"""
headers: list[tuple[bytes, bytes]] = scope.get("headers", [])
token: str | None = None
for name, value in headers:
if name.lower() == b"authorization":
decoded = value.decode("latin-1", errors="replace")
if decoded.lower().startswith("bearer "):
token = decoded[7:]
break
if token is None:
return "anonymous"
app = scope.get("app")
if app is None:
return "anonymous"
api_keys = getattr(getattr(app, "state", None), "api_keys", None)
if not api_keys:
return "anonymous"
try:
tenant = validate_api_key(token, api_keys)
return tenant.name
except Exception:
return "invalid"
def _get_request_id(request: Request) -> str:
"""Extract request_id from request state, with a fallback.
Parameters
----------
request : Request
The incoming HTTP request.
Returns
-------
str
The request ID string.
"""
return getattr(request.state, "request_id", "unknown")
async def _handle_authentication_error(
request: Request, exc: AuthenticationError
) -> JSONResponse:
"""Handle AuthenticationError exceptions.
Parameters
----------
request : Request
The incoming HTTP request.
exc : AuthenticationError
The authentication error.
Returns
-------
JSONResponse
A 401 JSON response.
"""
body = ApiErrorResponse(
error=ApiErrorDetail(
code="authentication_required",
message=str(exc),
request_id=_get_request_id(request),
)
)
return JSONResponse(status_code=401, content=body.model_dump())
async def _handle_authorization_error(
request: Request, exc: AuthorizationError
) -> JSONResponse:
"""Handle AuthorizationError exceptions.
Parameters
----------
request : Request
The incoming HTTP request.
exc : AuthorizationError
The authorization error.
Returns
-------
JSONResponse
A 403 JSON response.
"""
body = ApiErrorResponse(
error=ApiErrorDetail(
code="forbidden",
message=str(exc),
request_id=_get_request_id(request),
)
)
return JSONResponse(status_code=403, content=body.model_dump())
async def _handle_validation_error(
request: Request, exc: RequestValidationError
) -> JSONResponse:
"""Handle Pydantic RequestValidationError exceptions.
Parameters
----------
request : Request
The incoming HTTP request.
exc : RequestValidationError
The validation error.
Returns
-------
JSONResponse
A 422 JSON response with per-field details.
"""
fields = []
for err in exc.errors():
loc_parts = [str(part) for part in err.get("loc", [])]
field_name = ".".join(loc_parts) if loc_parts else "unknown"
fields.append(ApiErrorFieldDetail(field=field_name, issue=err.get("msg", "")))
body = ApiErrorResponse(
error=ApiErrorDetail(
code="validation_error",
message="Request validation failed",
request_id=_get_request_id(request),
fields=fields,
)
)
return JSONResponse(status_code=422, content=body.model_dump())
async def _handle_feedback_too_large(
request: Request, exc: FeedbackTooLargeError
) -> JSONResponse:
"""Handle FeedbackTooLargeError exceptions.
Parameters
----------
request : Request
The incoming HTTP request.
exc : FeedbackTooLargeError
The feedback-too-large error.
Returns
-------
JSONResponse
A 413 JSON response.
"""
body = ApiErrorResponse(
error=ApiErrorDetail(
code="payload_too_large",
message=str(exc),
request_id=_get_request_id(request),
)
)
return JSONResponse(status_code=413, content=body.model_dump())
async def _handle_analysis_timeout(
request: Request, exc: AnalysisTimeoutError
) -> JSONResponse:
"""Handle AnalysisTimeoutError exceptions.
Parameters
----------
request : Request
The incoming HTTP request.
exc : AnalysisTimeoutError
The analysis timeout error.
Returns
-------
JSONResponse
A 504 JSON response.
"""
body = ApiErrorResponse(
error=ApiErrorDetail(
code="analysis_timeout",
message=str(exc),
request_id=_get_request_id(request),
)
)
return JSONResponse(status_code=504, content=body.model_dump())
async def _handle_analysis_error(request: Request, exc: AnalysisError) -> JSONResponse:
"""Handle AnalysisError exceptions.
If the error message contains "injection", returns 422 instead of 502
to signal that the input was rejected.
Parameters
----------
request : Request
The incoming HTTP request.
exc : AnalysisError
The analysis error.
Returns
-------
JSONResponse
A 502 or 422 JSON response depending on the error cause.
"""
logger.debug("Analysis error: %s", exc, exc_info=True)
if "injection" in str(exc).lower():
body = ApiErrorResponse(
error=ApiErrorDetail(
code="validation_error",
message=str(exc),
request_id=_get_request_id(request),
)
)
return JSONResponse(status_code=422, content=body.model_dump())
body = ApiErrorResponse(
error=ApiErrorDetail(
code="analysis_unavailable",
message=str(exc),
request_id=_get_request_id(request),
)
)
return JSONResponse(status_code=502, content=body.model_dump())
async def _handle_llm_error(request: Request, exc: LLMError) -> JSONResponse:
"""Map an LLM provider failure to 502 bad_gateway.
LLMError signals that an upstream LLM provider call failed in a way
the orchestrator did not recover from. From the API consumer's
perspective this is a bad gateway, distinct from a 504 timeout
(AnalysisTimeoutError) or a 502 analysis failure (AnalysisError).
"""
logger.warning("LLM provider error: %s", exc, exc_info=True)
body = ApiErrorResponse(
error=ApiErrorDetail(
code="llm_error",
message=str(exc),
request_id=_get_request_id(request),
)
)
return JSONResponse(status_code=502, content=body.model_dump())
async def _handle_usage_repository_unavailable(
request: Request, exc: UsageRepositoryUnavailableError
) -> JSONResponse:
"""Map a usage-repository unavailability to 503 with a machine-readable code.
Distinct from ``usage_tracking_disabled`` (raised by ``get_usage_repo``
when the feature flag is off): this signals that the feature is on but
the backing store is transiently unreachable. Consumers can use the
code to drive retry/backoff decisions instead of treating both as the
same opaque 503.
"""
logger.warning("Usage repository unavailable: %s", exc)
body = ApiErrorResponse(
error=ApiErrorDetail(
code="usage_backend_unavailable",
message="Usage backend is temporarily unavailable",
request_id=_get_request_id(request),
)
)
return JSONResponse(status_code=503, content=body.model_dump())
async def _handle_http_exception(request: Request, exc: HTTPException) -> JSONResponse:
"""Wrap HTTPException with the standard error envelope.
When ``detail`` is a dict with ``code``/``message`` keys, those are
surfaced. Otherwise the detail string becomes the message and a
generic ``http_error`` code is used.
"""
detail = exc.detail
if isinstance(detail, dict):
body = ApiErrorResponse(
error=ApiErrorDetail(
code=str(detail.get("code", "http_error")),
message=str(detail.get("message", "")),
request_id=_get_request_id(request),
)
)
else:
body = ApiErrorResponse(
error=ApiErrorDetail(
code="http_error",
message=str(detail) if detail is not None else "",
request_id=_get_request_id(request),
)
)
return JSONResponse(status_code=exc.status_code, content=body.model_dump())
async def _handle_unhandled_exception(request: Request, exc: Exception) -> JSONResponse:
"""Handle unexpected exceptions.
Parameters
----------
request : Request
The incoming HTTP request.
exc : Exception
The unhandled exception.
Returns
-------
JSONResponse
A 500 JSON response.
"""
logger.exception("Unhandled exception: %s", exc)
body = ApiErrorResponse(
error=ApiErrorDetail(
code="internal_error",
message="An unexpected error occurred",
request_id=_get_request_id(request),
)
)
return JSONResponse(status_code=500, content=body.model_dump())
[docs]
def build_llm_client(settings: LLMSettings) -> LiteLLMClient:
"""Build an LLM client from the provided settings.
Parameters
----------
settings : LLMSettings
The LLM configuration settings.
Returns
-------
LiteLLMClient
A configured LLM client instance.
"""
return LiteLLMClient(
model=settings.model,
api_key=settings.api_key.get_secret_value(),
api_base=settings.api_base,
api_version=settings.api_version,
chars_per_token=settings.chars_per_token,
max_total_tokens=settings.max_total_tokens,
)
def _register_custom_model_prices() -> None:
"""Load custom model pricing from the bundled YAML resource.
Registers models with LiteLLM so that ``completion_cost()`` works
for models not in the built-in cost map.
"""
prices_path = importlib.resources.files("qfa.resources").joinpath(
"model_prices.yaml"
)
with importlib.resources.as_file(prices_path) as f:
custom_prices = yaml.safe_load(f.read_text())
if custom_prices and custom_prices.get("models"):
litellm.register_model(custom_prices["models"])
logger.info(
"Registered %d custom model price(s) for %s",
len(custom_prices["models"]),
list(custom_prices["models"].keys()),
)
LLMFactory = Callable[[LLMSettings], LLMPort]
"""Factory that builds an ``LLMPort`` from settings.
The default is ``build_llm_client`` (real LiteLLM client). Tests can pass
their own factory to ``create_app`` to inject a fake without monkeypatching.
"""
def _make_lifespan(llm_factory: LLMFactory):
"""Build a FastAPI lifespan context manager that closes over ``llm_factory``.
FastAPI's ``lifespan=`` parameter accepts a single async context
manager whose signature is fixed at ``(app: FastAPI) -> ...``. There
is no built-in way to thread extra construction-time dependencies
(like which ``LLMPort`` factory to use) through that signature
without resorting to module-level globals or monkeypatching.
This factory closes over ``llm_factory`` and returns the resulting
lifespan, so ``create_app`` can pass a fake factory in tests and the
lifespan picks it up at startup — wiring the same composition path
(``llm_factory(settings.llm)`` → optional ``TrackingLLMAdapter``
wrap → ``StandardOrchestrator``) regardless of whether the LLM
client is real or stubbed. Production simply omits the override and
gets the default ``build_llm_client``.
Parameters
----------
llm_factory : LLMFactory
Factory invoked at startup to construct the inner ``LLMPort``.
Returns
-------
Callable[[FastAPI], AsyncContextManager[None]]
A lifespan suitable for ``FastAPI(lifespan=...)``.
"""
@asynccontextmanager
async def lifespan(app: FastAPI) -> AsyncGenerator[None, None]:
"""Compose the application graph at startup; tear it down on shutdown.
This is the runtime composition root: it loads settings, builds
every dependency that routes consume, and attaches the results
to ``app.state`` so request handlers can read them without
importing modules directly. Doing this in the lifespan (rather
than at import time) ensures settings/env-vars are read once per
process boot and the DB engine is created on the running event
loop.
Schema migrations are NOT run from the lifespan. They run as a
pre-start step in ``entrypoint.sh`` (``python -m
qfa.cli.migrate``) before this process binds the port, so the
app boots against an already-current schema.
Startup order is significant:
1. Load ``AppSettings`` and configure logging — must happen
before anything that might log.
2. Register custom LiteLLM model prices — required for
``completion_cost()`` to value any model not in LiteLLM's
built-in cost map.
3. Build the base ``LLMPort`` via the closed-over factory.
4. If ``settings.db.track_usage`` is set: create the async DB
engine and wrap the base LLM in ``TrackingLLMAdapter`` so
every call attempt is recorded.
5. Construct the ``StandardOrchestrator`` over whichever LLM
variant was chosen above.
6. Publish ``orchestrator``, ``api_keys``, ``settings``, and
``usage_repo`` on ``app.state`` for routes/middleware to read.
On shutdown the only resource that needs explicit cleanup is the
DB engine's connection pool; everything else is plain Python
objects that the GC handles.
Parameters
----------
app : FastAPI
The application instance whose ``state`` will be populated.
"""
settings = AppSettings()
setup_logging(settings.log)
_register_custom_model_prices()
anonymizer = PresidioAnonymizer()
api_keys = settings.auth.api_keys
engine = None
usage_repo = None
base_llm = llm_factory(settings.llm)
llm_for_orch: LLMPort = base_llm
if settings.db.track_usage:
from qfa.adapters.db import (
SqlAlchemyUsageRepository,
create_async_engine_from_settings,
create_session_factory,
)
from qfa.adapters.tracking_llm import TrackingLLMAdapter
engine = create_async_engine_from_settings(settings.db)
session_factory = create_session_factory(engine)
usage_repo = SqlAlchemyUsageRepository(session_factory)
llm_for_orch = TrackingLLMAdapter(inner=base_llm, usage_repo=usage_repo)
logger.info("Usage tracking enabled (per-attempt, per-operation)")
orchestrator = Orchestrator(
llm=llm_for_orch,
anonymizer=anonymizer,
settings=settings.orchestrator,
llm_timeout_seconds=settings.llm.timeout_seconds,
max_total_tokens=settings.llm.max_total_tokens,
)
app.state.orchestrator = orchestrator
app.state.api_keys = api_keys
app.state.settings = settings
app.state.usage_repo = usage_repo
yield
if engine is not None:
await engine.dispose()
return lifespan
[docs]
def register_exception_handlers(app: FastAPI) -> None:
"""Register all exception handlers on the application.
Parameters
----------
app : FastAPI
The FastAPI application instance.
"""
app.add_exception_handler(AuthorizationError, _handle_authorization_error) # ty: ignore[invalid-argument-type]
app.add_exception_handler(AuthenticationError, _handle_authentication_error) # ty: ignore[invalid-argument-type]
app.add_exception_handler(RequestValidationError, _handle_validation_error) # ty: ignore[invalid-argument-type]
app.add_exception_handler(FeedbackTooLargeError, _handle_feedback_too_large) # ty: ignore[invalid-argument-type]
app.add_exception_handler(AnalysisTimeoutError, _handle_analysis_timeout) # ty: ignore[invalid-argument-type]
app.add_exception_handler(AnalysisError, _handle_analysis_error) # ty: ignore[invalid-argument-type]
app.add_exception_handler(LLMError, _handle_llm_error) # ty: ignore[invalid-argument-type]
app.add_exception_handler(
UsageRepositoryUnavailableError,
_handle_usage_repository_unavailable, # ty: ignore[invalid-argument-type]
)
app.add_exception_handler(HTTPException, _handle_http_exception) # ty: ignore[invalid-argument-type]
app.add_exception_handler(Exception, _handle_unhandled_exception)
[docs]
def create_app(*, llm_factory: LLMFactory | None = None) -> FastAPI:
"""Create and configure the FastAPI application.
Parameters
----------
llm_factory : LLMFactory | None
Optional override for the LLM-port factory. Defaults to
``build_llm_client`` (the real LiteLLM client). Tests pass a fake
factory here to inject a stubbed ``LLMPort`` without monkeypatching
— the lifespan still wraps it in ``TrackingLLMAdapter`` exactly as
it would the real client.
Returns
-------
FastAPI
The fully configured application instance.
"""
factory: LLMFactory = llm_factory if llm_factory is not None else build_llm_client
app = FastAPI(
title="Feedback Analysis Backend",
lifespan=_make_lifespan(factory),
version=qfa.__version__,
)
app.add_middleware(RequestLoggingMiddleware)
app.add_middleware(RequestIdMiddleware)
app.include_router(router)
app.include_router(usage_router)
register_exception_handlers(app)
return app