"""LLM client adapter using LiteLLM for unified provider access."""
import logging
import re
from typing import cast
import openai
from litellm import acompletion, completion_cost
from pydantic import BaseModel, ValidationError
from tenacity import retry, retry_if_exception_type, stop_after_delay, wait_exponential
from qfa.domain import AnalysisError, FeedbackTooLargeError
from qfa.domain.errors import LLMError, LLMRateLimitError, LLMTimeoutError
from qfa.domain.models import LLMResponse, T_Response
from qfa.domain.ports import LLMPort
logger = logging.getLogger(__name__)
[docs]
class LiteLLMClient(LLMPort):
"""LLM adapter satisfying LLMPort via LiteLLM.
Routes to any LLM provider based on the model string prefix
(e.g. ``"azure/gpt-4"``, ``"azure_ai/mistral-large-2411"``).
Calculates per-call cost using LiteLLM's built-in cost map
or custom pricing registered via ``litellm.register_model()``.
Parameters
----------
model : str
LiteLLM model identifier (e.g. ``"azure_ai/mistral-large-2411"``).
api_key : str
API key for the provider.
api_base : str
Base URL for the provider endpoint. Empty string if not needed.
api_version : str
API version string. Empty string if not needed.
"""
def __init__(
self,
model: str,
api_key: str,
api_base: str,
api_version: str,
chars_per_token: int,
max_total_tokens: int,
) -> None:
self._model = model
self._api_key = api_key
self._api_base = api_base
self._api_version = api_version
self._chars_per_token = chars_per_token
self._max_total_tokens = max_total_tokens
def _check_injection(self, user_message: str) -> None:
"""Scan user_message for known prompt injection strings.
Parameters
----------
user_message : str
The prompt.
Raises
------
AnalysisError
When a document matches an injection pattern.
"""
_INJECTION_PATTERNS: list[tuple[str, re.Pattern[str]]] = [
(
"role_prefix",
re.compile(r"^\s*(SYSTEM|ASSISTANT|USER)\s*:", re.IGNORECASE),
),
("null_byte", re.compile(r"\x00")),
("repeated_chars", re.compile(r"(.)\1{199,}")),
]
for pattern_name, pattern in _INJECTION_PATTERNS:
if pattern.search(user_message):
logger.warning(
"Prompt injection detected: pattern=%s",
pattern_name,
)
msg = f"Prompt injection detected pattern={pattern_name}"
raise AnalysisError(msg)
def _check_token_limit(self, system_message: str, user_message: str) -> None:
"""Estimate total tokens and raise if over the limit.
Parameters
----------
system_message : str
The assembled system message.
user_message : str
The assembled user message containing the feedback records.
Raises
------
FeedbackTooLargeError
When estimated tokens exceed the configured limit.
"""
assembled_text = system_message + user_message
estimated_tokens = len(assembled_text) // self._chars_per_token
if estimated_tokens > self._max_total_tokens:
msg = (
f"Estimated tokens ({estimated_tokens}) exceed limit "
f"({self._max_total_tokens})"
)
raise FeedbackTooLargeError(
msg,
estimated_tokens=estimated_tokens,
limit=self._max_total_tokens,
)
[docs]
@retry(
wait=wait_exponential(multiplier=1, max=10),
stop=stop_after_delay(60),
retry=retry_if_exception_type((LLMTimeoutError, LLMRateLimitError)),
)
async def complete(
self,
system_message: str,
user_message: str,
tenant_id: str,
response_model: type[T_Response],
timeout: float = 20.0,
) -> LLMResponse[T_Response]:
"""Send a completion request via LiteLLM.
Parameters
----------
system_message : str
The system-level instruction for the model.
user_message : str
The user-level message to complete.
timeout : float
Maximum time in seconds to wait for a response.
tenant_id : str
Tenant identifier passed as ``user`` for audit trail.
Returns
-------
LLMResponse
The model's response including token usage and cost.
Raises
------
LLMTimeoutError
When the provider does not respond in time.
LLMRateLimitError
When the provider returns a rate-limit response.
LLMError
For any other provider error or empty response.
"""
self._check_injection(user_message)
self._check_token_limit(system_message, user_message)
try:
response = await acompletion(
model=self._model,
messages=[
{"role": "system", "content": system_message},
{"role": "user", "content": user_message},
],
api_key=self._api_key,
api_base=self._api_base or None,
api_version=self._api_version or None,
user=tenant_id,
timeout=timeout,
response_format=response_model
if issubclass(response_model, BaseModel)
else None,
)
except openai.APITimeoutError as exc:
logger.error(exc)
raise LLMTimeoutError(str(exc)) from exc
except openai.RateLimitError as exc:
logger.error(exc)
raise LLMRateLimitError(str(exc)) from exc
except openai.APIError as exc:
logger.error(exc)
raise LLMError(str(exc)) from exc
content = response.choices[0].message.content
if content is None:
raise LLMError("LLM response missing content")
if not isinstance(content, str):
msg = f"LLM response content must be a string, got {type(content).__name__}"
raise LLMError(msg)
usage = response.usage
if usage is None:
raise LLMError("LLM response missing usage data")
try:
cost = completion_cost(completion_response=response)
except Exception:
logger.error("No pricing data for model %s", self._model)
cost = float("nan")
if issubclass(response_model, BaseModel):
try:
parsed_data: T_Response = cast(
T_Response, response_model.model_validate_json(content)
)
except ValidationError as exc:
raise LLMError(
f"LLM response validation failed for {response_model.__name__}: {exc}"
) from exc
elif issubclass(response_model, str):
parsed_data = content
else:
raise ValueError(
"The `response_model` is not a string or BaseModel subclass."
)
return LLMResponse[T_Response](
structured=parsed_data,
model=response.model,
prompt_tokens=usage.prompt_tokens,
completion_tokens=usage.completion_tokens,
cost=cost,
)