Source code for qfa.adapters.tracking_llm

"""LLM port decorator that records every call attempt for usage tracking."""

import logging
import time
from datetime import UTC, datetime
from decimal import Decimal

from sqlalchemy.exc import InterfaceError, OperationalError
from tenacity import (
    retry,
    retry_if_exception_type,
    stop_after_attempt,
    wait_exponential,
)

from qfa.domain.errors import MissingCallScopeError
from qfa.domain.models import CallStatus, LLMCallRecord, LLMResponse, T_Response
from qfa.domain.ports import LLMPort, UsageRepositoryPort
from qfa.services.call_context import current_call_context

logger = logging.getLogger(__name__)


[docs] class TrackingLLMAdapter(LLMPort): """Decorator over an inner ``LLMPort`` that records every call attempt. Reads tenant + operation from ``current_call_context``. Persists one ``LLMCallRecord`` per attempt (success or failure). Recording errors are logged but never raised, so a misbehaving usage repository never breaks an analysis. Connection-class transient errors (``OperationalError``, ``InterfaceError``) are retried up to 3 times with exponential backoff capped at 0.5s per wait — worst-case added latency under a sustained DB outage is ~0.3s of waits plus 3 fast-failing connection attempts (typically <1s total). Non-transient errors (``IntegrityError``, ``ProgrammingError``, etc.) skip the retry path and are logged immediately. Parameters ---------- inner : LLMPort The wrapped LLM adapter. usage_repo : UsageRepositoryPort Repository used to persist call records. """ def __init__(self, inner: LLMPort, usage_repo: UsageRepositoryPort) -> None: self._inner = inner self._usage_repo = usage_repo
[docs] async def complete( self, system_message: str, user_message: str, tenant_id: str, response_model: type[T_Response], timeout: float = 20.0, ) -> LLMResponse[T_Response]: """Run the inner ``complete`` and record the attempt. Raises ------ MissingCallScopeError When ``current_call_context`` is unset; indicates a wiring bug (the orchestrator forgot to enter ``call_scope``). """ ctx = current_call_context.get() if ctx is None: raise MissingCallScopeError( "TrackingLLMAdapter.complete called outside an active call_scope; " "the orchestrator must enter call_scope(...) at each public-method entry." ) started_at = datetime.now(UTC) start_monotonic = time.monotonic() try: response = await self._inner.complete( system_message=system_message, user_message=user_message, tenant_id=tenant_id, response_model=response_model, timeout=timeout, ) except Exception as exc: duration_ms = int((time.monotonic() - start_monotonic) * 1000) await self._record_safely( LLMCallRecord( tenant_id=ctx.tenant_id, operation=ctx.operation, timestamp=started_at, call_duration_ms=duration_ms, model="", input_tokens=0, output_tokens=0, cost_usd=Decimal("0"), status=CallStatus.ERROR, error_class=type(exc).__name__, ) ) raise duration_ms = int((time.monotonic() - start_monotonic) * 1000) await self._record_safely( LLMCallRecord( tenant_id=ctx.tenant_id, operation=ctx.operation, timestamp=started_at, call_duration_ms=duration_ms, model=response.model, input_tokens=response.prompt_tokens, output_tokens=response.completion_tokens, cost_usd=_to_decimal(response.cost), status=CallStatus.OK, error_class=None, ) ) return response
@retry( retry=retry_if_exception_type((OperationalError, InterfaceError)), wait=wait_exponential(multiplier=0.05, min=0.05, max=0.5), stop=stop_after_attempt(3), reraise=True, ) async def _record_with_retry(self, record: LLMCallRecord) -> None: """Persist with bounded retries on connection-class transient errors.""" await self._usage_repo.record_call(record) async def _record_safely(self, record: LLMCallRecord) -> None: try: await self._record_with_retry(record) except Exception: logger.exception( "Failed to record LLM call for tenant=%s operation=%s", record.tenant_id, record.operation, )
def _to_decimal(cost: float | None) -> Decimal: """Convert a float cost to a non-negative Decimal; coerce NaN/None to 0.""" if cost is None: return Decimal("0") if cost != cost: return Decimal("0") if cost < 0: return Decimal("0") return Decimal(repr(cost))