diff --git a/litellm/__init__.py b/litellm/__init__.py index cffdbacf597..275b60dcf83 100644 --- a/litellm/__init__.py +++ b/litellm/__init__.py @@ -162,6 +162,7 @@ def _dev_env_hot_reload_enabled() -> bool: "levo", "compression_interception", "newrelic", + "asqav", ] cold_storage_custom_logger: Optional[_custom_logger_compatible_callbacks_literal] = None logged_real_time_event_types: Optional[Union[List[str], Literal["*"]]] = None diff --git a/litellm/integrations/asqav/__init__.py b/litellm/integrations/asqav/__init__.py new file mode 100644 index 00000000000..d16aaec6b7d --- /dev/null +++ b/litellm/integrations/asqav/__init__.py @@ -0,0 +1,3 @@ +from litellm.integrations.asqav.asqav import AsqavLogger + +__all__ = ["AsqavLogger"] diff --git a/litellm/integrations/asqav/asqav.py b/litellm/integrations/asqav/asqav.py new file mode 100644 index 00000000000..975976efd39 --- /dev/null +++ b/litellm/integrations/asqav/asqav.py @@ -0,0 +1,458 @@ +"""Asqav local-first audit-log callback for LiteLLM. + +Each LLM call appends one record to a local JSONL file. Every record carries +a SHA-256 chain hash over its own canonical fields plus the previous record's +hash, giving a tamper-evident sequence that can be verified entirely offline +with stdlib tools. + +Design goals (matching the on-device ask from litellm#25329): +- Zero runtime dependencies beyond Python stdlib + litellm itself. +- Never breaks an LLM call: every code path is wrapped fail-soft. +- Does not log message content by default; logs content digests so + auditors can prove a payload was present without reconstructing it. +""" + +from __future__ import annotations + +import asyncio +import hashlib +import json +import os +import threading +import time +import traceback +from datetime import datetime, timezone +from typing import Any, BinaryIO, Optional + +from litellm._logging import verbose_logger +from litellm.integrations.custom_logger import CustomLogger + +__all__ = ["AsqavLogger"] + +# Sentinel used as the genesis prev_hash (no predecessor). +_GENESIS_HASH = "0" * 64 + +# Default log path; can be overridden via ASQAV_LOG_PATH. +_DEFAULT_LOG_PATH = os.path.join(os.path.expanduser("~"), ".litellm_asqav_audit.jsonl") + + +def _sha256_hex(data: bytes) -> str: + return hashlib.sha256(data).hexdigest() + + +def _canonical_bytes(record: dict[str, Any]) -> bytes: + """Stable canonical serialisation for hashing. + + We sort keys and use separators=(',', ':') so the byte sequence is + deterministic across Python versions and platforms. + """ + return json.dumps(record, sort_keys=True, separators=(",", ":")).encode("utf-8") + + +def _read_tail(fh: BinaryIO, size: int) -> bytes: + """Read backwards from the end of fh until the buffer contains the entire + last line, doubling the window each pass so records of any length survive + a restart.""" + chunk_size = 4096 + while True: + read_size = min(chunk_size, size) + fh.seek(size - read_size) + tail = fh.read(read_size) + if read_size == size or b"\n" in tail.rstrip(b"\n"): + return tail + chunk_size *= 2 + + +def _content_digest(value: Any) -> Optional[str]: + """Return a SHA-256 hex digest of a content value, or None if empty.""" + if value is None: + return None + raw = json.dumps(value, sort_keys=True, separators=(",", ":")).encode("utf-8") + return _sha256_hex(raw) + + +def _extract_loggable( + kwargs: dict[str, Any], + response_obj: Any, + start_time: Any, + end_time: Any, + status: str, +) -> dict[str, Any]: + """Pull metadata + digests out of a callback invocation. + + Message content and response text are never stored in the clear; only their + SHA-256 digests appear in the log so callers can prove a payload existed + without reconstructing it. + """ + model: str = kwargs.get("model", "") + messages: Any = kwargs.get("messages") + + # Root metadata (user-supplied tags, etc.) + metadata: Any = dict(kwargs.get("metadata") or kwargs.get("litellm_metadata") or {}) + + # Merge proxy identity fields from litellm_params.metadata. Sensitive + # header/key values are filtered so raw auth tokens never reach the log. + _SENSITIVE_KEYS = frozenset( + { + "user_api_key", + "Authorization", + "authorization", + "token", + "api_key", + } + ) + _PROXY_IDENTITY_KEYS = frozenset( + { + "user_api_key_user_id", + "user_api_key_team_id", + "user_api_key_org_id", + "user_api_key_alias", + "user_id", + "team_id", + "org_id", + } + ) + try: + lp_meta: Any = (kwargs.get("litellm_params") or {}).get("metadata") or {} + for k, v in lp_meta.items(): + if k in _SENSITIVE_KEYS: + continue + # Always include explicit proxy identity keys; skip other + # litellm_params.metadata keys to avoid unexpected bleed. + if k in _PROXY_IDENTITY_KEYS: + metadata.setdefault(k, v) + except Exception: + pass + + # Timing + latency_ms: Optional[int] = None + try: + if start_time is not None and end_time is not None: + latency_ms = int((end_time - start_time).total_seconds() * 1000) + except Exception: + pass + + # Usage + prompt_tokens: Optional[int] = None + completion_tokens: Optional[int] = None + total_tokens: Optional[int] = None + finish_reason: Optional[str] = None + provider_request_id: Optional[str] = None + try: + if hasattr(response_obj, "usage") and response_obj.usage: + prompt_tokens = response_obj.usage.prompt_tokens + completion_tokens = response_obj.usage.completion_tokens + total_tokens = response_obj.usage.total_tokens + if hasattr(response_obj, "choices") and response_obj.choices: + finish_reason = response_obj.choices[0].finish_reason + if hasattr(response_obj, "_hidden_params"): + provider_request_id = response_obj._hidden_params.get( + "x-request-id" + ) or response_obj._hidden_params.get("cf-ray") + except Exception: + pass + + # Content digests (not content itself) + messages_digest: Optional[str] = _content_digest(messages) + + response_content_digest: Optional[str] = None + try: + if hasattr(response_obj, "choices") and response_obj.choices: + content = response_obj.choices[0].message.content + response_content_digest = _content_digest(content) + except Exception: + pass + + # Standard logging payload may carry call_id / litellm_call_id + call_id: Optional[str] = None + try: + slp: Any = kwargs.get("standard_logging_object") + if slp and isinstance(slp, dict): + call_id = slp.get("id") or slp.get("litellm_call_id") + except Exception: + pass + if not call_id: + call_id = kwargs.get("litellm_call_id") or kwargs.get( + "id", str(int(time.time() * 1e6)) + ) + + return { + "call_id": call_id, + "model": model, + "status": status, + "latency_ms": latency_ms, + "prompt_tokens": prompt_tokens, + "completion_tokens": completion_tokens, + "total_tokens": total_tokens, + "finish_reason": finish_reason, + "provider_request_id": provider_request_id, + "messages_digest": messages_digest, + "response_content_digest": response_content_digest, + "metadata": {k: v for k, v in (metadata or {}).items() if isinstance(k, str)}, + } + + +class AsqavLogger(CustomLogger): + """Tamper-evident local-first audit-log callback for LiteLLM. + + Configuration (all via environment variables): + + ASQAV_LOG_PATH + Path to the JSONL audit log. Defaults to ~/.litellm_asqav_audit.jsonl. + + ASQAV_REDACT_CONTENT + Set to "false" to store message/response content in the clear instead + of as SHA-256 digests. Defaults to "true" (digest only). + + Multi-worker limitation: this logger is designed for a single audit writer + per log file. The threading.Lock serializes concurrent threads within one + process; it does NOT serialize across OS processes. In a multi-worker + proxy deployment, run a single dedicated audit-writer process, or use a + shared filesystem with OS-level exclusive locking (fcntl.flock) via a + custom wrapper. Without this, multiple workers will produce records with + duplicate seq/prev_hash values and verify_chain will report a break. + """ + + def __init__( + self, + log_path: Optional[str] = None, + redact_content: bool = True, + ) -> None: + super().__init__() + + self._log_path: str = log_path or os.environ.get( + "ASQAV_LOG_PATH", _DEFAULT_LOG_PATH + ) + self._redact_content: bool = ( + os.environ.get("ASQAV_REDACT_CONTENT", "true").lower() != "false" + if log_path is None + else redact_content + ) + + self._lock: threading.Lock = threading.Lock() + self._call_count: int = 0 + self._prev_hash: str = _GENESIS_HASH + + # Load chain state from an existing log file so we chain correctly + # across process restarts. + self._load_chain_tail() + + def __repr__(self) -> str: + return ( + f"AsqavLogger(log_path={self._log_path!r}," + f" redact_content={self._redact_content})" + ) + + # ------------------------------------------------------------------ + # Chain state persistence + # ------------------------------------------------------------------ + + def _load_chain_tail(self) -> None: + """Read the last line of an existing log file to resume the chain.""" + try: + if not os.path.exists(self._log_path): + return + with open(self._log_path, "rb") as fh: + fh.seek(0, 2) + size = fh.tell() + if size == 0: + return + tail = _read_tail(fh, size) + lines = [ln for ln in tail.split(b"\n") if ln.strip()] + if not lines: + return + last_record = json.loads(lines[-1].decode("utf-8")) + self._prev_hash = last_record.get("record_hash", _GENESIS_HASH) + self._call_count = last_record.get("seq", -1) + 1 + except Exception: + verbose_logger.debug( + f"[AsqavLogger] Could not load chain tail: {traceback.format_exc()}" + ) + + # ------------------------------------------------------------------ + # Core record append + # ------------------------------------------------------------------ + + def _build_and_append( + self, + kwargs: dict[str, Any], + response_obj: Any, + start_time: Any, + end_time: Any, + status: str, + ) -> None: + """Build one audit record and append it to the JSONL log. + + seq/prev_hash assignment and the file write happen under _lock so the + on-disk order always matches the chain order. + """ + try: + loggable = _extract_loggable( + kwargs, response_obj, start_time, end_time, status + ) + + if not self._redact_content: + # Store content in the clear when the operator explicitly opts in. + loggable["messages"] = kwargs.get("messages") + try: + if hasattr(response_obj, "choices") and response_obj.choices: + loggable["response_content"] = response_obj.choices[ + 0 + ].message.content + except Exception: + pass + + # The file write happens under the same lock that assigns seq and + # prev_hash, so records always land on disk in chain order even + # when callbacks fire concurrently. Chain state only advances + # after a successful write; a failed write drops the record and + # the chain continues from the last record actually on disk. + with self._lock: + seq = self._call_count + + # The fields that enter the hash are fixed and canonical so that + # an auditor can reproduce the digest from the log alone. + hashable: dict[str, Any] = { + "seq": seq, + "ts": datetime.now(tz=timezone.utc).isoformat(), + "prev_hash": self._prev_hash, + **loggable, + } + record_hash = _sha256_hex(_canonical_bytes(hashable)) + + if not self._write_record({**hashable, "record_hash": record_hash}): + return + + self._prev_hash = record_hash + self._call_count += 1 + + except Exception: + verbose_logger.debug( + f"[AsqavLogger] Unhandled error in _build_and_append: {traceback.format_exc()}" + ) + + def _write_record(self, record: dict[str, Any]) -> bool: + """Append one record to the log file. Returns False if the write failed.""" + try: + parent = os.path.dirname(self._log_path) + if parent: + os.makedirs(parent, exist_ok=True) + # Tighten permissions on an existing file before appending so a + # file created by a previous run with a permissive umask is locked + # down. Create new files via os.open with 0o600 to skip umask. + if os.path.exists(self._log_path): + os.chmod(self._log_path, 0o600) + fd = os.open( + self._log_path, + os.O_CREAT | os.O_WRONLY | os.O_APPEND, + 0o600, + ) + try: + with os.fdopen(fd, "a", encoding="utf-8", closefd=True) as fh: + fh.write(json.dumps(record, separators=(",", ":")) + "\n") + except Exception: + # fdopen owns fd; if it raises before returning the context + # manager the fd may still be open - close defensively. + try: + os.close(fd) + except OSError: + pass + raise + return True + except Exception: + verbose_logger.warning( + f"[AsqavLogger] Failed to write audit record: {traceback.format_exc()}" + ) + return False + + # ------------------------------------------------------------------ + # CustomLogger hooks + # ------------------------------------------------------------------ + + def log_success_event( + self, kwargs: dict[str, Any], response_obj: Any, start_time: Any, end_time: Any + ) -> None: + self._build_and_append(kwargs, response_obj, start_time, end_time, "success") + + def log_failure_event( + self, kwargs: dict[str, Any], response_obj: Any, start_time: Any, end_time: Any + ) -> None: + self._build_and_append(kwargs, response_obj, start_time, end_time, "failure") + + async def async_log_success_event( + self, kwargs: dict[str, Any], response_obj: Any, start_time: Any, end_time: Any + ) -> None: + await asyncio.to_thread( + self._build_and_append, + kwargs, + response_obj, + start_time, + end_time, + "success", + ) + + async def async_log_failure_event( + self, kwargs: dict[str, Any], response_obj: Any, start_time: Any, end_time: Any + ) -> None: + await asyncio.to_thread( + self._build_and_append, + kwargs, + response_obj, + start_time, + end_time, + "failure", + ) + + # ------------------------------------------------------------------ + # Chain verification (utility; not called on the hot path) + # ------------------------------------------------------------------ + + def verify_chain(self, log_path: Optional[str] = None) -> tuple[bool, str]: + """Verify the integrity of the audit log at log_path. + + Returns (True, "ok") when every record's hash matches its content and + its prev_hash matches the previous record's hash. Returns + (False, reason) on the first violation found. + + This method is intentionally a pure stdlib utility so auditors can + paste it anywhere. + """ + path = log_path or self._log_path + try: + prev_hash = _GENESIS_HASH + with open(path, encoding="utf-8") as fh: + for lineno, line in enumerate(fh, start=1): + line = line.strip() + if not line: + continue + record = json.loads(line) + + stored_hash = record.get("record_hash", "") + # Recompute hash over all fields except record_hash itself. + hashable = {k: v for k, v in record.items() if k != "record_hash"} + computed_hash = _sha256_hex(_canonical_bytes(hashable)) + + if computed_hash != stored_hash: + return ( + False, + f"line {lineno}: hash mismatch" + f" (stored={stored_hash[:12]}," + f" computed={computed_hash[:12]})", + ) + + rec_prev = record.get("prev_hash", _GENESIS_HASH) + if rec_prev != prev_hash: + return ( + False, + f"line {lineno}: prev_hash chain break" + f" (expected={prev_hash[:12]}," + f" got={rec_prev[:12]})", + ) + + prev_hash = stored_hash + + return True, "ok" + except FileNotFoundError: + return False, f"log file not found: {path}" + except Exception as exc: + return False, f"verification error: {exc}" diff --git a/litellm/litellm_core_utils/custom_logger_registry.py b/litellm/litellm_core_utils/custom_logger_registry.py index a7fae104c92..899b8d00a35 100644 --- a/litellm/litellm_core_utils/custom_logger_registry.py +++ b/litellm/litellm_core_utils/custom_logger_registry.py @@ -12,6 +12,7 @@ from litellm import _custom_logger_compatible_callbacks_literal from litellm.integrations.agentops import AgentOps +from litellm.integrations.asqav import AsqavLogger from litellm.integrations.anthropic_cache_control_hook import AnthropicCacheControlHook from litellm.integrations.argilla import ArgillaLogger from litellm.integrations.azure_sentinel.azure_sentinel import AzureSentinelLogger @@ -108,6 +109,7 @@ class CustomLoggerRegistry: "vantage": VantageLogger, "posthog": PostHogLogger, "newrelic": NewRelicLogger, + "asqav": AsqavLogger, } try: diff --git a/litellm/litellm_core_utils/litellm_logging.py b/litellm/litellm_core_utils/litellm_logging.py index d750a509054..00cc74ab65e 100644 --- a/litellm/litellm_core_utils/litellm_logging.py +++ b/litellm/litellm_core_utils/litellm_logging.py @@ -4521,6 +4521,16 @@ def _init_custom_logger_compatible_class( newrelic_logger = NewRelicLogger() _in_memory_loggers.append(newrelic_logger) return newrelic_logger # type: ignore + elif logging_integration == "asqav": + from litellm.integrations.asqav import AsqavLogger + + for callback in _in_memory_loggers: + if isinstance(callback, AsqavLogger): + return callback # type: ignore + + asqav_logger = AsqavLogger() + _in_memory_loggers.append(asqav_logger) + return asqav_logger # type: ignore return None except Exception as e: verbose_logger.exception( diff --git a/tests/test_litellm/integrations/asqav/__init__.py b/tests/test_litellm/integrations/asqav/__init__.py new file mode 100644 index 00000000000..e69de29bb2d diff --git a/tests/test_litellm/integrations/asqav/conftest.py b/tests/test_litellm/integrations/asqav/conftest.py new file mode 100644 index 00000000000..d06e29e48ab --- /dev/null +++ b/tests/test_litellm/integrations/asqav/conftest.py @@ -0,0 +1,11 @@ +# Minimal conftest for the asqav standalone tests. +# These tests do not import litellm at module scope and do not need the +# parent conftest fixtures (which pull in the full litellm stack). +import os +import sys + +# Ensure the repo root is on sys.path so "litellm.integrations.asqav" resolves. +_REPO_ROOT = os.path.abspath( + os.path.join(__file__, "..", "..", "..", "..", "..", "..") +) +sys.path.insert(0, _REPO_ROOT) diff --git a/tests/test_litellm/integrations/asqav/test_asqav.py b/tests/test_litellm/integrations/asqav/test_asqav.py new file mode 100644 index 00000000000..df2445d8910 --- /dev/null +++ b/tests/test_litellm/integrations/asqav/test_asqav.py @@ -0,0 +1,679 @@ +"""Tests for the Asqav local-first audit-log callback. + +All tests are self-contained: they use only stdlib + the integration module +itself. No LLM API calls, no network, no external services. + +Chain property tests verify: +- Appending N records produces a valid chain (every hash links to its predecessor). +- Mutating one byte in any record causes verify_chain to detect the break. +- A chain survives a process restart (state loaded from the tail of the file). +""" + +from __future__ import annotations + +import asyncio +import json +import os +import sys +import threading +from datetime import datetime, timezone +from unittest.mock import MagicMock + +import importlib +import importlib.util +import types + +# --------------------------------------------------------------------------- +# Bootstrap: load litellm.integrations.asqav.asqav without triggering +# litellm/__init__.py (which needs tokenizers, openai, etc.). We stub the +# minimal litellm sub-modules the integration actually imports at the top of +# its file, then load the module via importlib. +# --------------------------------------------------------------------------- +_REPO_ROOT = os.path.abspath( + os.path.join(os.path.dirname(__file__), "..", "..", "..", "..") +) +sys.path.insert(0, _REPO_ROOT) + + +def _stub_litellm_deps() -> None: + """Install minimal stubs for litellm sub-modules imported by asqav.py.""" + if "litellm" in sys.modules: + return # already loaded (e.g. in the full test suite) + + # litellm package stub + pkg = types.ModuleType("litellm") + sys.modules["litellm"] = pkg + + # litellm._logging stub + class _VL: + def debug(self, *a: object, **k: object) -> None: + pass + + def warning(self, *a: object, **k: object) -> None: + pass + + log_mod = types.ModuleType("litellm._logging") + log_mod.verbose_logger = _VL() # type: ignore[attr-defined] + sys.modules["litellm._logging"] = log_mod + + # litellm.integrations package + custom_logger stub + integrations_pkg = types.ModuleType("litellm.integrations") + sys.modules["litellm.integrations"] = integrations_pkg + pkg.integrations = integrations_pkg # type: ignore[attr-defined] + + class _CustomLogger: + def __init__(self, **kw: object) -> None: + pass + + cl_mod = types.ModuleType("litellm.integrations.custom_logger") + cl_mod.CustomLogger = _CustomLogger # type: ignore[attr-defined] + sys.modules["litellm.integrations.custom_logger"] = cl_mod + + # litellm.types stubs (imported in type annotations only) + types_pkg = types.ModuleType("litellm.types") + sys.modules["litellm.types"] = types_pkg + utils_mod = types.ModuleType("litellm.types.utils") + sys.modules["litellm.types.utils"] = utils_mod + + # litellm.llms stubs (kept minimal; asqav.py no longer imports from + # litellm.llms.custom_httpx.http_handler, but other litellm internals may + # still need the package hierarchy present during import resolution). + llms_pkg = types.ModuleType("litellm.llms") + sys.modules["litellm.llms"] = llms_pkg + custom_httpx_pkg = types.ModuleType("litellm.llms.custom_httpx") + sys.modules["litellm.llms.custom_httpx"] = custom_httpx_pkg + http_handler_mod = types.ModuleType("litellm.llms.custom_httpx.http_handler") + sys.modules["litellm.llms.custom_httpx.http_handler"] = http_handler_mod + + # litellm.types.llms stub + types_llms_pkg = types.ModuleType("litellm.types.llms") + sys.modules["litellm.types.llms"] = types_llms_pkg + custom_http_mod = types.ModuleType("litellm.types.llms.custom_http") + sys.modules["litellm.types.llms.custom_http"] = custom_http_mod + + +_stub_litellm_deps() + +# Now load the integration module directly. +_asqav_path = os.path.join(_REPO_ROOT, "litellm", "integrations", "asqav", "asqav.py") +_spec = importlib.util.spec_from_file_location( + "litellm.integrations.asqav.asqav", _asqav_path +) +assert _spec and _spec.loader +_asqav_module = importlib.util.module_from_spec(_spec) +sys.modules["litellm.integrations.asqav.asqav"] = _asqav_module +_spec.loader.exec_module(_asqav_module) # type: ignore[union-attr] + +AsqavLogger = _asqav_module.AsqavLogger +_GENESIS_HASH = _asqav_module._GENESIS_HASH +_canonical_bytes = _asqav_module._canonical_bytes +_content_digest = _asqav_module._content_digest +_sha256_hex = _asqav_module._sha256_hex + + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- + + +def _make_kwargs(model: str = "gpt-4o", content: str = "hello") -> dict: + return { + "model": model, + "messages": [{"role": "user", "content": content}], + "litellm_call_id": f"test-{content[:8]}", + } + + +def _make_response(content: str = "world") -> MagicMock: + choice = MagicMock() + choice.message.content = content + choice.finish_reason = "stop" + resp = MagicMock() + resp.choices = [choice] + resp.usage.prompt_tokens = 10 + resp.usage.completion_tokens = 5 + resp.usage.total_tokens = 15 + resp._hidden_params = {} + return resp + + +def _make_times() -> tuple: + start = datetime(2026, 1, 1, 0, 0, 0, tzinfo=timezone.utc) + end = datetime(2026, 1, 1, 0, 0, 1, tzinfo=timezone.utc) + return start, end + + +def _logger_at(path: str) -> AsqavLogger: + return AsqavLogger(log_path=path, redact_content=True) + + +def _append_n(logger: AsqavLogger, n: int) -> None: + start, end = _make_times() + for i in range(n): + logger.log_success_event( + kwargs=_make_kwargs(content=f"msg-{i}"), + response_obj=_make_response(content=f"resp-{i}"), + start_time=start, + end_time=end, + ) + + +def _read_records(path: str) -> list: + records = [] + with open(path, encoding="utf-8") as fh: + for line in fh: + line = line.strip() + if line: + records.append(json.loads(line)) + return records + + +# --------------------------------------------------------------------------- +# Unit tests: helpers +# --------------------------------------------------------------------------- + + +def test_sha256_hex_is_64_chars() -> None: + h = _sha256_hex(b"hello") + assert len(h) == 64 + assert h == "2cf24dba5fb0a30e26e83b2ac5b9e29e1b161e5c1fa7425e73043362938b9824" + + +def test_canonical_bytes_is_deterministic() -> None: + d = {"b": 2, "a": 1, "c": [3, 4]} + b1 = _canonical_bytes(d) + b2 = _canonical_bytes({"c": [3, 4], "a": 1, "b": 2}) + assert b1 == b2 + + +def test_content_digest_returns_none_for_none() -> None: + assert _content_digest(None) is None + + +def test_content_digest_is_stable() -> None: + d1 = _content_digest("hello world") + d2 = _content_digest("hello world") + assert d1 == d2 + assert d1 is not None and len(d1) == 64 + + +# --------------------------------------------------------------------------- +# Chain property tests +# --------------------------------------------------------------------------- + + +def test_single_record_genesis_chain(tmp_path) -> None: + path = str(tmp_path / "audit.jsonl") + logger = _logger_at(path) + start, end = _make_times() + + logger.log_success_event( + kwargs=_make_kwargs(), + response_obj=_make_response(), + start_time=start, + end_time=end, + ) + + records = _read_records(path) + assert len(records) == 1 + r = records[0] + assert r["seq"] == 0 + assert r["prev_hash"] == _GENESIS_HASH + assert r["status"] == "success" + assert "record_hash" in r + assert len(r["record_hash"]) == 64 + + +def test_chain_links_correctly_for_n_records(tmp_path) -> None: + path = str(tmp_path / "audit.jsonl") + logger = _logger_at(path) + n = 10 + _append_n(logger, n) + + records = _read_records(path) + assert len(records) == n + + # First record's prev_hash is the genesis sentinel. + assert records[0]["prev_hash"] == _GENESIS_HASH + + # Each subsequent record's prev_hash equals the hash of the prior record. + for i in range(1, n): + assert ( + records[i]["prev_hash"] == records[i - 1]["record_hash"] + ), f"Chain broken between records {i-1} and {i}" + + +def test_verify_chain_passes_on_valid_log(tmp_path) -> None: + path = str(tmp_path / "audit.jsonl") + logger = _logger_at(path) + _append_n(logger, 5) + + ok, msg = logger.verify_chain(path) + assert ok is True, f"Expected valid chain but got: {msg}" + assert msg == "ok" + + +def test_verify_chain_detects_record_hash_tampering(tmp_path) -> None: + path = str(tmp_path / "audit.jsonl") + logger = _logger_at(path) + _append_n(logger, 5) + + records = _read_records(path) + # Corrupt the model field of record 2 (middle of chain). + records[2]["model"] = "tampered-model" + + with open(path, "w", encoding="utf-8") as fh: + for r in records: + fh.write(json.dumps(r, separators=(",", ":")) + "\n") + + ok, msg = logger.verify_chain(path) + assert ok is False + assert "hash mismatch" in msg + + +def test_verify_chain_detects_prev_hash_tampering(tmp_path) -> None: + path = str(tmp_path / "audit.jsonl") + logger = _logger_at(path) + _append_n(logger, 5) + + records = _read_records(path) + # Recompute record_hash after tampering prev_hash to bypass the first check. + records[3]["prev_hash"] = "a" * 64 + hashable = {k: v for k, v in records[3].items() if k != "record_hash"} + records[3]["record_hash"] = _sha256_hex(_canonical_bytes(hashable)) + + with open(path, "w", encoding="utf-8") as fh: + for r in records: + fh.write(json.dumps(r, separators=(",", ":")) + "\n") + + ok, msg = logger.verify_chain(path) + assert ok is False + assert "chain break" in msg + + +def test_verify_chain_detects_deleted_record(tmp_path) -> None: + path = str(tmp_path / "audit.jsonl") + logger = _logger_at(path) + _append_n(logger, 5) + + records = _read_records(path) + # Remove record 2 - record 3's prev_hash will no longer match record 1's hash. + del records[2] + + with open(path, "w", encoding="utf-8") as fh: + for r in records: + fh.write(json.dumps(r, separators=(",", ":")) + "\n") + + ok, msg = logger.verify_chain(path) + assert ok is False + + +def test_chain_resumes_after_process_restart(tmp_path) -> None: + path = str(tmp_path / "audit.jsonl") + + # First "process": write 3 records. + logger1 = _logger_at(path) + _append_n(logger1, 3) + hash_after_first = logger1._prev_hash + + # Second "process": new logger instance reads the tail and continues. + logger2 = _logger_at(path) + assert ( + logger2._prev_hash == hash_after_first + ), "Second logger did not resume chain from tail of existing log" + _append_n(logger2, 3) + + # Full 6-record chain should verify clean. + ok, msg = logger2.verify_chain(path) + assert ok is True, f"Chain broken across restart: {msg}" + + +def test_seq_counter_restored_after_restart(tmp_path) -> None: + """P1 regression: _call_count (and thus seq) must resume from the last + persisted record's seq field after a process restart. + + Before the fix, _load_chain_tail restored _prev_hash but left _call_count + at 0, so the second "process" would emit seq=0 again instead of continuing + from where the first process stopped. + """ + path = str(tmp_path / "audit.jsonl") + + # First "process": write 5 records (seq 0..4). + logger1 = _logger_at(path) + _append_n(logger1, 5) + records_after_first = _read_records(path) + assert ( + records_after_first[-1]["seq"] == 4 + ), "sanity: last seq from first process is 4" + + # Second "process": new instance reads the tail. + logger2 = _logger_at(path) + assert ( + logger2._call_count == 5 + ), f"_call_count not restored: expected 5, got {logger2._call_count}" + + # Writing one more record must produce seq=5, not seq=0. + _append_n(logger2, 1) + records = _read_records(path) + assert len(records) == 6 + assert ( + records[5]["seq"] == 5 + ), f"seq reset after restart: expected 5, got {records[5]['seq']}" + + # The full chain must also pass integrity verification. + ok, msg = logger2.verify_chain(path) + assert ok is True, f"Chain broken after restart: {msg}" + + +def test_failure_event_is_logged(tmp_path) -> None: + path = str(tmp_path / "audit.jsonl") + logger = _logger_at(path) + start, end = _make_times() + + logger.log_failure_event( + kwargs=_make_kwargs(), + response_obj=_make_response(), + start_time=start, + end_time=end, + ) + + records = _read_records(path) + assert len(records) == 1 + assert records[0]["status"] == "failure" + + +def test_logger_does_not_raise_on_malformed_response(tmp_path) -> None: + path = str(tmp_path / "audit.jsonl") + logger = _logger_at(path) + start, end = _make_times() + + bad_response = object() # not a ModelResponse at all + + # Should not raise; logger is fail-soft. + logger.log_success_event( + kwargs=_make_kwargs(), + response_obj=bad_response, + start_time=start, + end_time=end, + ) + + records = _read_records(path) + assert len(records) == 1 + assert records[0]["status"] == "success" + + +def test_content_digest_stored_not_plaintext_by_default(tmp_path) -> None: + path = str(tmp_path / "audit.jsonl") + logger = _logger_at(path) + start, end = _make_times() + + logger.log_success_event( + kwargs=_make_kwargs(content="my secret prompt"), + response_obj=_make_response(content="my secret response"), + start_time=start, + end_time=end, + ) + + records = _read_records(path) + raw = json.dumps(records[0]) + assert "my secret prompt" not in raw + assert "my secret response" not in raw + assert "messages_digest" in records[0] + assert "response_content_digest" in records[0] + + +def test_seq_increments_across_calls(tmp_path) -> None: + path = str(tmp_path / "audit.jsonl") + logger = _logger_at(path) + _append_n(logger, 4) + + records = _read_records(path) + seqs = [r["seq"] for r in records] + assert seqs == [0, 1, 2, 3] + + +def test_latency_ms_is_computed(tmp_path) -> None: + path = str(tmp_path / "audit.jsonl") + logger = _logger_at(path) + start, end = _make_times() # 1-second gap + + logger.log_success_event( + kwargs=_make_kwargs(), + response_obj=_make_response(), + start_time=start, + end_time=end, + ) + + records = _read_records(path) + assert records[0]["latency_ms"] == 1000 + + +def test_no_background_threads_spawned(tmp_path) -> None: + """Local-only logger must not spawn background threads.""" + path = str(tmp_path / "audit.jsonl") + logger = AsqavLogger(log_path=path, redact_content=True) + start, end = _make_times() + + before = threading.active_count() + logger.log_success_event( + kwargs=_make_kwargs(), + response_obj=_make_response(), + start_time=start, + end_time=end, + ) + after = threading.active_count() + assert after <= before + + +# --------------------------------------------------------------------------- +# Concurrency, restart with large records, repr +# --------------------------------------------------------------------------- + + +def test_concurrent_callbacks_keep_chain_ordered(tmp_path) -> None: + """Records from concurrent threads land on disk in seq order. + + Regression test for the out-of-order-write race: seq/prev_hash assignment + and the file write must happen under the same lock, otherwise two threads + can write their records in reversed order and break verify_chain. + """ + path = str(tmp_path / "audit.jsonl") + logger = _logger_at(path) + start, end = _make_times() + n_threads = 8 + per_thread = 25 + barrier = threading.Barrier(n_threads) + + def _worker(tid: int) -> None: + barrier.wait() + for i in range(per_thread): + logger.log_success_event( + kwargs=_make_kwargs(content=f"t{tid}-m{i}"), + response_obj=_make_response(content=f"t{tid}-r{i}"), + start_time=start, + end_time=end, + ) + + threads = [ + threading.Thread(target=_worker, args=(tid,)) for tid in range(n_threads) + ] + for t in threads: + t.start() + for t in threads: + t.join() + + records = _read_records(path) + assert len(records) == n_threads * per_thread + assert [r["seq"] for r in records] == list(range(n_threads * per_thread)) + ok, msg = logger.verify_chain(path) + assert ok is True, f"Chain broken under concurrent writes: {msg}" + + +def test_restart_resumes_chain_when_last_record_exceeds_4kb(tmp_path) -> None: + """A record larger than the old 4 KB tail buffer survives a restart. + + Regression test for the silent chain reset: the tail read must widen until + it contains the whole last line instead of truncating it. + """ + path = str(tmp_path / "audit.jsonl") + logger1 = _logger_at(path) + start, end = _make_times() + + big_kwargs = _make_kwargs(content="big") + big_kwargs["metadata"] = {"blob": "x" * 10_000} + logger1.log_success_event( + kwargs=big_kwargs, + response_obj=_make_response(), + start_time=start, + end_time=end, + ) + assert os.path.getsize(path) > 4096 + hash_after_first = logger1._prev_hash + + logger2 = _logger_at(path) + assert ( + logger2._prev_hash == hash_after_first + ), "Restart did not resume the chain from a record larger than 4 KB" + _append_n(logger2, 2) + + ok, msg = logger2.verify_chain(path) + assert ok is True, f"Chain broken across restart with large record: {msg}" + + +def test_repr_shows_log_path_and_redact(tmp_path) -> None: + path = str(tmp_path / "audit.jsonl") + logger = AsqavLogger(log_path=path, redact_content=True) + + r = repr(logger) + assert "AsqavLogger" in r + assert path in r + + +def test_async_hooks_write_records(tmp_path) -> None: + path = str(tmp_path / "audit.jsonl") + logger = _logger_at(path) + start, end = _make_times() + + async def _run() -> None: + await logger.async_log_success_event( + kwargs=_make_kwargs(content="async-ok"), + response_obj=_make_response(), + start_time=start, + end_time=end, + ) + await logger.async_log_failure_event( + kwargs=_make_kwargs(content="async-fail"), + response_obj=_make_response(), + start_time=start, + end_time=end, + ) + + asyncio.run(_run()) + + records = _read_records(path) + assert [r["status"] for r in records] == ["success", "failure"] + ok, msg = logger.verify_chain(path) + assert ok is True, msg + + +def test_redact_content_false_stores_plaintext(tmp_path) -> None: + path = str(tmp_path / "audit.jsonl") + logger = AsqavLogger(log_path=path, redact_content=False) + start, end = _make_times() + + logger.log_success_event( + kwargs=_make_kwargs(content="visible prompt"), + response_obj=_make_response(content="visible response"), + start_time=start, + end_time=end, + ) + + records = _read_records(path) + assert records[0]["messages"][0]["content"] == "visible prompt" + assert records[0]["response_content"] == "visible response" + + +# --------------------------------------------------------------------------- +# New regression tests (must FAIL before the corresponding fix is applied) +# --------------------------------------------------------------------------- + + +def test_audit_log_file_created_with_0600_perms(tmp_path) -> None: + """Veria Medium: audit log must be created with mode 0600. + + With a standard 022 umask, plain open(..., 'a') produces 0644, which lets + other local users read the log. The fix creates via os.open with 0o600 and + chmods an existing file to 0600 before appending. + """ + path = str(tmp_path / "audit.jsonl") + logger = _logger_at(path) + start, end = _make_times() + + logger.log_success_event( + kwargs=_make_kwargs(), + response_obj=_make_response(), + start_time=start, + end_time=end, + ) + + assert os.path.exists(path), "audit log was not created" + mode_octal = oct(os.stat(path).st_mode)[-3:] + assert mode_octal == "600", f"audit log has mode {mode_octal}, expected 600" + + +def test_proxy_identity_metadata_attributed_in_record(tmp_path) -> None: + """Veria Medium: proxy identity fields from litellm_params.metadata must + appear in the logged record's metadata. + + user_api_key_user_id, team_id, org_id, and key_alias live under + kwargs['litellm_params']['metadata'], not at kwargs root. Records written + without reading that sub-dict have no proxy attribution. + """ + path = str(tmp_path / "audit.jsonl") + logger = _logger_at(path) + start, end = _make_times() + + kwargs = _make_kwargs() + kwargs["litellm_params"] = { + "metadata": { + "user_api_key_user_id": "user-abc", + "user_api_key_team_id": "team-xyz", + "user_api_key_org_id": "org-123", + "user_api_key_alias": "my-key", + # sensitive values that must NOT be persisted + "user_api_key": "sk-secret-12345", + "Authorization": "Bearer sk-secret-12345", + } + } + + logger.log_success_event( + kwargs=kwargs, + response_obj=_make_response(), + start_time=start, + end_time=end, + ) + + records = _read_records(path) + meta = records[0]["metadata"] + assert ( + meta.get("user_api_key_user_id") == "user-abc" + ), "user_api_key_user_id not attributed in record" + assert ( + meta.get("user_api_key_team_id") == "team-xyz" + ), "user_api_key_team_id not attributed in record" + # Sensitive fields must be filtered out + assert "user_api_key" not in meta, "raw api key leaked into record metadata" + assert "Authorization" not in meta, "auth header leaked into record metadata" + + +def test_multiworker_flock_guard_documented_or_implemented(tmp_path) -> None: + """Veria Medium: the multi-worker limitation must be documented in the + class docstring (or, if fcntl is used, the lock must serialize cross-process + writes). This test checks for the docstring acknowledgement. + """ + import inspect + + doc = inspect.getdoc(AsqavLogger) or "" + assert ( + "single" in doc.lower() or "flock" in doc.lower() or "worker" in doc.lower() + ), "AsqavLogger docstring must document the single-writer / multi-worker limitation" diff --git a/tests/test_litellm/litellm_core_utils/test_litellm_logging.py b/tests/test_litellm/litellm_core_utils/test_litellm_logging.py index f0db0409bd7..59bf976f83f 100644 --- a/tests/test_litellm/litellm_core_utils/test_litellm_logging.py +++ b/tests/test_litellm/litellm_core_utils/test_litellm_logging.py @@ -3449,3 +3449,32 @@ def test_failure_handler_zeroes_spend_without_recovered_usage(logging_obj): assert payload["status"] == "failure" assert payload["response_cost"] == 0 assert payload["total_tokens"] == 0 + + +def test_init_custom_logger_compatible_class_asqav_singleton(monkeypatch, tmp_path): + """callbacks=["asqav"] constructs one AsqavLogger and reuses it on re-init.""" + monkeypatch.setenv("ASQAV_LOG_PATH", str(tmp_path / "audit.jsonl")) + + from litellm.integrations.asqav import AsqavLogger + from litellm.litellm_core_utils import litellm_logging as logging_module + + logging_module._in_memory_loggers.clear() + try: + first = logging_module._init_custom_logger_compatible_class( + logging_integration="asqav", + internal_usage_cache=None, + llm_router=None, + ) + second = logging_module._init_custom_logger_compatible_class( + logging_integration="asqav", + internal_usage_cache=None, + llm_router=None, + ) + + assert type(first) is AsqavLogger + assert second is first + assert any( + isinstance(cb, AsqavLogger) for cb in logging_module._in_memory_loggers + ) + finally: + logging_module._in_memory_loggers.clear()