diff --git a/CHANGELOG.md b/CHANGELOG.md index 142a8738a9..8714c3641b 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -58,6 +58,12 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - Added geometry functionals in `physicsnemo.nn.functional` for `mesh_poisson_disk_sample`, `mesh_to_voxel_fraction`, and `signed_distance_field`. +- Adds embedded OOD guardrail `OODGuard` at + `physicsnemo.experimental.guardrails.embedded`, optionally + wired into `GeoTransolver` via a new `guard_config` constructor argument. + The guard calibrates per-channel global bounds and a geometry-latent + kNN threshold during training, and emits warnings on out-of-distribution + inputs at inference. - In PhysicsNeMo-Mesh, `physicsnemo.mesh.geometry` now publicly exposes `stable_angle_between_vectors` and `compute_triangle_angles` (previously only available via the private `physicsnemo.mesh.curvature._utils`). diff --git a/examples/minimal/guardrails/geometry_validation.py b/examples/minimal/guardrails/geometry_validation.py index 0f80f0c83c..8679e67406 100644 --- a/examples/minimal/guardrails/geometry_validation.py +++ b/examples/minimal/guardrails/geometry_validation.py @@ -30,7 +30,7 @@ import multiprocessing as mp from pathlib import Path -from physicsnemo.experimental.guardrails import GeometryGuardrail +from physicsnemo.experimental.guardrails.geometry import GeometryGuardrail def prepare_datasets( diff --git a/examples/structural_mechanics/crash/README.md b/examples/structural_mechanics/crash/README.md index 514c1c792e..b335bed150 100644 --- a/examples/structural_mechanics/crash/README.md +++ b/examples/structural_mechanics/crash/README.md @@ -211,6 +211,75 @@ torchrun --nproc_per_node= inference.py --config-name=bumper_geotranso Runs are sharded across ranks: rank `r` processes `run_items[r::world_size]`. Predicted meshes are written as .vtp files under `./predicted_vtps/`, and can be opened using ParaView. +## Guardrails (OOD detection) + +GeoTransolver ships with an optional embedded out-of-distribution (OOD) guardrail +that calibrates during training and emits warnings at inference when inputs +drift outside the training distribution. It watches two surfaces: + +- **Global parameters** — per-channel bounding box on the global embedding + (e.g. `velocity_x`, `thickness_scale`). +- **Geometry** — k-nearest-neighbour distance on a pooled geometry latent. + +Enable it through the model config by setting `guard_config` to a mapping +(leave `null` to disable): + +```yaml +# conf/my_experiment.yaml +model: + guard_config: + buffer_size: 121 # FIFO buffer; typically = num_training_samples + knn_k: 10 # k for geometry kNN distance + sensitivity: 1.5 # threshold multiplier on 99th-percentile kNN dist +``` + +What each parameter controls: + +- **`buffer_size`** — Capacity of the FIFO ring buffer that stores pooled + geometry latents during training; the kNN threshold is computed over its + contents. Set it to at least the training-set size so calibration sees every + sample. Under DDP each rank keeps its own buffer and the distributed sampler + shuffles, so after a few epochs each rank's FIFO covers most of the data. + Memory cost is `buffer_size × head_dim × 4` bytes — typically well under 1 MB. +- **`knn_k`** — Number of nearest neighbours used in the kNN distance. Smaller + `k` is more sensitive to isolated training-set outliers; larger `k` is + smoother but can blur multi-modal cluster boundaries. The default of `10` + works for buffer sizes from ~100 up to several thousand. +- **`sensitivity`** — Multiplier on the 99th-percentile training kNN distance + used as the OOD threshold. **Higher = less sensitive** (fewer warnings). + Raise it if known in-distribution validation data triggers warnings; lower + it if known-OOD inputs are being missed. + +Recommended starting points by training-set size: + +| Training samples | `buffer_size` | `knn_k` | `sensitivity` | +|------------------|-------------------------|---------|---------------| +| ~100 | 100–200 | 5–10 | 1.5 | +| ~500 | 500–1000 | 10 | 1.5 | +| 5000+ | = dataset size | 10–15 | 1.5 | + +If validation data trips warnings, raise `sensitivity` toward 2.0–3.0; if +known-OOD inputs slip through, lower it toward 1.0. + +No changes to `train.py` / `inference.py` are required: during training the +guard silently collects calibration statistics, and during inference it emits +warnings of the form `OOD Guard: geometry sample ...` or +`OOD Guard: global_embedding dim ...` to the Python logger whenever a sample +falls outside the calibrated training envelope. Warnings do not halt inference. + +Two inference drivers are provided to synthesise OOD samples for testing the +guard end-to-end: + +```bash +# Scale every global-feature scalar by 1.5x (default): +python inference_ood_global.py --config-name=bumper_geotransolver_oneshot +# Scale the geometry uniformly by 1.10x in raw space (default): +python inference_ood_geometry.py --config-name=bumper_geotransolver_oneshot +``` + +Override the perturbation factor from the CLI, e.g. +`inference.ood_global_scale=2.0` or `inference.ood_geometry_scale=1.05`. + ## Experiments Each experiment is a self-contained YAML file in `conf/`. Each config file includes all defaults and experiment-specific settings. diff --git a/examples/structural_mechanics/crash/conf/bumper_geotransolver_oneshot.yaml b/examples/structural_mechanics/crash/conf/bumper_geotransolver_oneshot.yaml index d0017f7dbb..2d0d7ad9d1 100644 --- a/examples/structural_mechanics/crash/conf/bumper_geotransolver_oneshot.yaml +++ b/examples/structural_mechanics/crash/conf/bumper_geotransolver_oneshot.yaml @@ -38,7 +38,7 @@ defaults: # └───────────────────────────────────────────┘ training: - raw_data_dir: ??? # set in config or via CLI: training.raw_data_dir=/path/to/train + raw_data_dir: ??? # set in config or via CLI: training.raw_data_dir=/path/to/train raw_data_dir_validation: ??? # set in config or via CLI: training.raw_data_dir_validation=/path/to/validation global_features_filepath: ??? # set in config or via CLI: training.global_features_filepath=/path/to/global_features.json optimizer: muon @@ -75,4 +75,12 @@ datapipe: model: functional_dim: 3 # coords (3) out_dim: 250 # (num_time_steps - 1) * 5 = 50 * 5 - global_dim: 3 # must match len(datapipe.global_features) \ No newline at end of file + global_dim: 3 # must match len(datapipe.global_features) + # OOD guard (disabled by default — set `guard_config` to a mapping to enable) + guard_config: null + # Example enabled config: + # guard_config: + # buffer_size: 121 # FIFO buffer size (= num_training_samples) + # knn_k: 10 # k for geometry kNN distance (recommended 5–15) + # sensitivity: 1.5 # threshold = sensitivity * 99th-percentile in-dist + # # kNN distance; >1 is less sensitive, <1 more. diff --git a/physicsnemo/experimental/guardrails/__init__.py b/physicsnemo/experimental/guardrails/__init__.py index e3d8a1a402..72cf7b3bb7 100644 --- a/physicsnemo/experimental/guardrails/__init__.py +++ b/physicsnemo/experimental/guardrails/__init__.py @@ -19,8 +19,13 @@ This package provides utilities for detecting out-of-distribution data and validating inputs to physics-based machine learning models. -""" -from .geometry import GeometryGuardrail +Import individual guardrails from their respective subpackages, e.g.:: + + from physicsnemo.experimental.guardrails.geometry import GeometryGuardrail + from physicsnemo.experimental.guardrails.embedded import OODGuard -__all__ = ["GeometryGuardrail"] +The top-level namespace intentionally does not re-export guardrail classes so +that importing one subpackage does not force-load the others (some carry +optional dependencies like ``pyvista``). +""" diff --git a/physicsnemo/experimental/guardrails/embedded/__init__.py b/physicsnemo/experimental/guardrails/embedded/__init__.py new file mode 100644 index 0000000000..44f1bd536d --- /dev/null +++ b/physicsnemo/experimental/guardrails/embedded/__init__.py @@ -0,0 +1,27 @@ +# SPDX-FileCopyrightText: Copyright (c) 2023 - 2026 NVIDIA CORPORATION & AFFILIATES. +# SPDX-FileCopyrightText: All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Embedded (in-forward) guardrails for PhysicsNemo models. + +This submodule provides guardrails that live inside a model's forward pass, +calibrating during training and checking during inference. Contrast with +``physicsnemo.experimental.guardrails.geometry``, which operates on raw mesh +data offline prior to inference. +""" + +from .ood_guard import OODGuard, OODGuardConfig + +__all__ = ["OODGuard", "OODGuardConfig"] diff --git a/physicsnemo/experimental/guardrails/embedded/ood_guard.py b/physicsnemo/experimental/guardrails/embedded/ood_guard.py new file mode 100644 index 0000000000..0c70fa7e65 --- /dev/null +++ b/physicsnemo/experimental/guardrails/embedded/ood_guard.py @@ -0,0 +1,383 @@ +# SPDX-FileCopyrightText: Copyright (c) 2023 - 2026 NVIDIA CORPORATION & AFFILIATES. +# SPDX-FileCopyrightText: All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""OOD (Out-of-Distribution) Guard for runtime anomaly detection. + +Provides two complementary checks: + +1. **Global parameter bounds** — per-channel bounding box on an arbitrary-rank + global embedding tensor with channel as its last dimension. +2. **Geometry latent kNN** — k-nearest-neighbour distance in a user-provided + fixed-dimensional latent space. + +During training, the guard collects calibration statistics. During inference, +it compares incoming data against those statistics and emits warnings when +inputs fall outside the training distribution. + +The guard is intentionally model-agnostic: callers are responsible for pooling +any higher-rank latent tensor down to ``(B, D)`` before passing it in. +""" + +from __future__ import annotations + +import logging +from collections.abc import Callable +from dataclasses import dataclass + +import torch +import torch.nn as nn + +from physicsnemo.nn.functional import knn + +logger = logging.getLogger(__name__) + + +_RED = "\033[91m" +_RESET = "\033[0m" + + +def _reduce_leading( + x: torch.Tensor, + reducer: Callable[..., torch.Tensor], +) -> torch.Tensor: + """Apply ``reducer`` over all dims except the last (channel) dim.""" + if x.ndim <= 1: + return x + reduce_dims = tuple(range(x.ndim - 1)) + return reducer(x, dim=reduce_dims) + + +@dataclass +class OODGuardConfig: + """User-facing configuration for :class:`OODGuard`. + + Model-derived fields (``global_dim``, ``geometry_embed_dim``) are supplied + by the enclosing model and intentionally omitted here. + + Attributes + ---------- + buffer_size : int + Capacity of the geometry latent FIFO buffer. Typically set to the + training-set size. No default — callers must pick a value. + knn_k : int + Number of nearest neighbours for the geometry kNN distance check. + Default is ``10``. + sensitivity : float + Multiplier on the 99th-percentile kNN distance used as the OOD + threshold. Higher values are less sensitive. Default is ``1.5``. + """ + + buffer_size: int + knn_k: int = 10 + sensitivity: float = 1.5 + + +class OODGuard(nn.Module): + """Out-of-distribution guard using global-parameter bounds and geometry kNN. + + Parameters + ---------- + buffer_size : int + Capacity of the geometry latent FIFO buffer (typically = training set size). + global_dim : int | None + Channel dimension of global embeddings. ``None`` disables the global check. + geometry_embed_dim : int | None + Dimensionality of the pooled geometry latent vector. ``None`` disables + the geometry kNN check. + knn_k : int + Number of nearest neighbours for the geometry distance check. + sensitivity : float + Multiplier on the 99th-percentile kNN distance used as the OOD + threshold. Higher values are less sensitive. Default is ``1.5``. + """ + + def __init__( + self, + buffer_size: int, + global_dim: int | None = None, + geometry_embed_dim: int | None = None, + knn_k: int = 10, + sensitivity: float = 1.5, + ) -> None: + super().__init__() + self.buffer_size = buffer_size + self.sensitivity = sensitivity + + # Global parameter bounds + if global_dim is not None: + self.register_buffer( + "global_min", torch.full((global_dim,), float("inf")) + ) + self.register_buffer( + "global_max", torch.full((global_dim,), float("-inf")) + ) + else: + self.register_buffer("global_min", None) + self.register_buffer("global_max", None) + + # Geometry kNN buffer + if geometry_embed_dim is not None: + self.register_buffer( + "geo_embeddings", torch.zeros(buffer_size, geometry_embed_dim) + ) + # Write index into the FIFO, kept in [0, buffer_size). + self.register_buffer("geo_ptr", torch.zeros(1, dtype=torch.long)) + # Latches True once the FIFO has been filled at least once. + self.register_buffer("geo_full", torch.zeros(1, dtype=torch.bool)) + self.register_buffer("knn_threshold", torch.tensor(float("inf"))) + else: + self.register_buffer("geo_embeddings", None) + self.register_buffer("geo_ptr", None) + self.register_buffer("geo_full", None) + self.register_buffer("knn_threshold", None) + + self.knn_k = knn_k + self._threshold_stale = True + + # ------------------------------------------------------------------ + # Public API + # ------------------------------------------------------------------ + + @torch.no_grad() + def collect( + self, + global_embedding: torch.Tensor | None = None, + geometry_latent: torch.Tensor | None = None, + ) -> None: + """Accumulate calibration data (call during training). + + Parameters + ---------- + global_embedding : Tensor | None + Shape ``(B, ..., C_g)`` — at least one leading (batch) dim; + last dim is channel. + geometry_latent : Tensor | None + Shape ``(B, D)`` — pre-pooled per-sample geometry latent vector. + """ + self._validate_shapes(global_embedding, geometry_latent) + self._collect_global(global_embedding) + self._collect_geometry(geometry_latent) + # Any new geometry sample invalidates the kNN threshold. + if geometry_latent is not None and self.geo_embeddings is not None: + self._threshold_stale = True + + @torch.no_grad() + def check( + self, + global_embedding: torch.Tensor | None = None, + geometry_latent: torch.Tensor | None = None, + ) -> None: + """Run OOD checks and emit warnings (call during inference). + + Parameters + ---------- + global_embedding : Tensor | None + Shape ``(B, ..., C_g)`` — at least one leading (batch) dim; + last dim is channel. + geometry_latent : Tensor | None + Shape ``(B, D)`` — pre-pooled per-sample geometry latent vector. + """ + self._validate_shapes(global_embedding, geometry_latent) + self._check_global(global_embedding) + # Lazy threshold computation on first inference call + if self._threshold_stale: + self.compute_threshold() + self._threshold_stale = False + self._check_geometry(geometry_latent) + + @torch.compiler.disable + @torch.no_grad() + def compute_threshold(self) -> None: + """Compute the kNN threshold from the accumulated geometry buffer.""" + if self.geo_embeddings is None: + return + n_valid = self._n_valid() + if n_valid == 0: + return + store = self.geo_embeddings[:n_valid] + store_norm = store / (store.norm(dim=-1, keepdim=True) + 1e-8) + k = min(self.knn_k, n_valid - 1) + if k <= 0: + return + # Leave-one-out: ask for k+1 neighbours and drop column 0 (each + # point's nearest neighbour is itself, distance 0). + _, dists = knn(store_norm, store_norm, k + 1) + avg_knn_dists = dists[:, 1:].mean(dim=-1) + base = torch.quantile(avg_knn_dists, 0.99) + threshold = base * self.sensitivity + self.knn_threshold.copy_(threshold) + logger.info( + "OOD Guard: computed kNN threshold=%.4f (base_99pct=%.4f, sensitivity=%.2f, k=%d)", + threshold.item(), + base.item(), + self.sensitivity, + self.knn_k, + ) + + # ------------------------------------------------------------------ + # Internal helpers + # ------------------------------------------------------------------ + + def _n_valid(self) -> int: + """Number of populated rows in the geometry FIFO buffer.""" + if self.geo_full.item(): + return self.buffer_size + return self.geo_ptr.item() + + def _validate_shapes( + self, + global_embedding: torch.Tensor | None, + geometry_latent: torch.Tensor | None, + ) -> None: + """Validate caller-supplied tensor shapes against the guard's config. + + Skipped under ``torch.compile`` to avoid graph breaks on shape checks. + """ + if torch.compiler.is_compiling(): + return + if global_embedding is not None and self.global_min is not None: + if global_embedding.ndim < 2: + raise ValueError( + f"global_embedding must have at least 2 dims " + f"(batch + channel); got {global_embedding.ndim}D tensor " + f"with shape {tuple(global_embedding.shape)}. Did you mean " + f"to unsqueeze a batch dim?" + ) + expected = self.global_min.shape[0] + got = global_embedding.shape[-1] + if got != expected: + raise ValueError( + f"global_embedding last-dim mismatch: expected {expected} " + f"(from global_dim), got {got} " + f"(shape {tuple(global_embedding.shape)})" + ) + if geometry_latent is not None and self.geo_embeddings is not None: + if geometry_latent.ndim != 2: + raise ValueError( + f"geometry_latent must be rank-2 (B, D); got " + f"{geometry_latent.ndim}D tensor with shape " + f"{tuple(geometry_latent.shape)}. Pool any higher-rank " + f"latent at the caller before passing it in." + ) + expected = self.geo_embeddings.shape[1] + got = geometry_latent.shape[1] + if got != expected: + raise ValueError( + f"geometry_latent channel dim mismatch: expected " + f"{expected} (from geometry_embed_dim), got {got}" + ) + + def _collect_global(self, global_embedding: torch.Tensor | None) -> None: + if global_embedding is None or self.global_min is None: + return + # Upcast to the buffer dtype so AMP (fp16/bf16) inputs don't mismatch + # the fp32 running min/max. + vals = global_embedding.detach().to(self.global_min.dtype) + batch_min = _reduce_leading(vals, torch.amin) + batch_max = _reduce_leading(vals, torch.amax) + self.global_min.copy_(torch.minimum(self.global_min, batch_min)) + self.global_max.copy_(torch.maximum(self.global_max, batch_max)) + + def _collect_geometry(self, geometry_latent: torch.Tensor | None) -> None: + if geometry_latent is None or self.geo_embeddings is None: + return + # Upcast to the buffer dtype so AMP (fp16/bf16) inputs don't fail the + # dtype-strict indexed assignment into geo_embeddings. + pooled = geometry_latent.detach().to(self.geo_embeddings.dtype) # (B, D) + B = pooled.shape[0] + ptr = self.geo_ptr[0] + indices = (ptr + torch.arange(B, device=pooled.device)) % self.buffer_size + self.geo_embeddings[indices] = pooled + wrapped = ((ptr + B) >= self.buffer_size).view(1) + self.geo_full.logical_or_(wrapped) + self.geo_ptr.fill_((ptr + B) % self.buffer_size) + + @torch.compiler.disable + def _check_global(self, global_embedding: torch.Tensor | None) -> None: + if global_embedding is None or self.global_min is None: + return + if torch.isinf(self.global_min).any(): + return + # Upcast so AMP inputs compare against the fp32 bounds cleanly. + vals = global_embedding.detach().to(self.global_min.dtype) + batch_min = _reduce_leading(vals, torch.amin) + batch_max = _reduce_leading(vals, torch.amax) + below = batch_min < self.global_min + above = batch_max > self.global_max + # Skip host transfer when nothing is violated and DEBUG is off. + if not (bool((below | above).any()) or logger.isEnabledFor(logging.DEBUG)): + return + # Single bulk transfer; then iterate in Python over dims. + bmin_l = batch_min.tolist() + bmax_l = batch_max.tolist() + lo_l = self.global_min.tolist() + hi_l = self.global_max.tolist() + below_l = below.tolist() + above_l = above.tolist() + for d, (bmin, bmax, lo, hi) in enumerate( + zip(bmin_l, bmax_l, lo_l, hi_l) + ): + logger.debug( + "OOD Guard [global] dim %d: val=[%.4f, %.4f] bounds=[%.4f, %.4f]", + d, bmin, bmax, lo, hi, + ) + if below_l[d]: + logger.warning( + f"{_RED}OOD Guard: global_embedding dim {d} value " + f"{bmin:.4f} below training min {lo:.4f}{_RESET}" + ) + if above_l[d]: + logger.warning( + f"{_RED}OOD Guard: global_embedding dim {d} value " + f"{bmax:.4f} above training max {hi:.4f}{_RESET}" + ) + + @torch.compiler.disable + def _check_geometry(self, geometry_latent: torch.Tensor | None) -> None: + if geometry_latent is None or self.geo_embeddings is None: + return + if torch.isinf(self.knn_threshold): + return + # Upcast so ``cdist`` against the fp32 store works under AMP inputs. + pooled = geometry_latent.detach().to(self.geo_embeddings.dtype) # (B, D) + z = pooled / (pooled.norm(dim=-1, keepdim=True) + 1e-8) + n_valid = self._n_valid() + if n_valid == 0: + return + store = self.geo_embeddings[:n_valid] + store_norm = store / (store.norm(dim=-1, keepdim=True) + 1e-8) + # Query is not in the store, so no -1 needed; clamp to buffer size. + k = min(self.knn_k, n_valid) + _, dists = knn(store_norm, z, k) # (B, k) + avg_knn_dists = dists.mean(dim=-1) # (B,) + over = avg_knn_dists > self.knn_threshold + # Skip host transfer when nothing is violated and DEBUG is off. + if not (bool(over.any()) or logger.isEnabledFor(logging.DEBUG)): + return + # Single bulk transfer; then iterate in Python over batch. + dist_l = avg_knn_dists.tolist() + over_l = over.tolist() + threshold = self.knn_threshold.item() + for i, dist_val in enumerate(dist_l): + logger.debug( + "OOD Guard [geometry] sample %d: kNN_dist=%.4f threshold=%.4f", + i, dist_val, threshold, + ) + if over_l[i]: + logger.warning( + f"{_RED}OOD Guard: geometry sample {i} kNN distance " + f"{dist_val:.4f} above threshold {threshold:.4f}{_RESET}" + ) diff --git a/physicsnemo/experimental/guardrails/geometry/README.md b/physicsnemo/experimental/guardrails/geometry/README.md index 040a38ee5b..0af4d2e345 100644 --- a/physicsnemo/experimental/guardrails/geometry/README.md +++ b/physicsnemo/experimental/guardrails/geometry/README.md @@ -49,7 +49,7 @@ to capture non-Gaussian patterns in your geometry distribution. ```python import pyvista as pv from physicsnemo.mesh.io import from_pyvista -from physicsnemo.experimental.guardrails import GeometryGuardrail +from physicsnemo.experimental.guardrails.geometry import GeometryGuardrail # Load or create training meshes train_meshes = [ @@ -83,7 +83,7 @@ automatic parallel processing: ```python from pathlib import Path -from physicsnemo.experimental.guardrails import GeometryGuardrail +from physicsnemo.experimental.guardrails.geometry import GeometryGuardrail # Fit from directory of STL files guardrail = GeometryGuardrail( @@ -126,7 +126,7 @@ guardrail.fit_from_dir( ```python from pathlib import Path -from physicsnemo.experimental.guardrails import GeometryGuardrail +from physicsnemo.experimental.guardrails.geometry import GeometryGuardrail # Save fitted guardrail guardrail.save(Path("guardrail.npz")) @@ -141,7 +141,7 @@ results = loaded_guardrail.query(test_meshes) Both GMM and PCE methods support GPU acceleration via PyTorch: ```python -from physicsnemo.experimental.guardrails import GeometryGuardrail +from physicsnemo.experimental.guardrails.geometry import GeometryGuardrail # Create guardrail with GPU support (requires PyTorch and CUDA) guardrail_gpu = GeometryGuardrail( diff --git a/physicsnemo/experimental/guardrails/geometry/ood_detector.py b/physicsnemo/experimental/guardrails/geometry/ood_detector.py index 34d1fd468e..38a098e26f 100644 --- a/physicsnemo/experimental/guardrails/geometry/ood_detector.py +++ b/physicsnemo/experimental/guardrails/geometry/ood_detector.py @@ -113,7 +113,7 @@ class GeometryGuardrail: >>> import pyvista as pv >>> from pathlib import Path >>> from physicsnemo.mesh.io import from_pyvista - >>> from physicsnemo.experimental.guardrails import GeometryGuardrail + >>> from physicsnemo.experimental.guardrails.geometry import GeometryGuardrail >>> >>> # Create and fit guardrail from training meshes (CPU) >>> train_meshes = [from_pyvista(pv.Cube()) for _ in range(100)] diff --git a/physicsnemo/experimental/models/geotransolver/context_projector.py b/physicsnemo/experimental/models/geotransolver/context_projector.py index 938f11c00e..a57b7ecfc1 100644 --- a/physicsnemo/experimental/models/geotransolver/context_projector.py +++ b/physicsnemo/experimental/models/geotransolver/context_projector.py @@ -643,7 +643,8 @@ class GlobalContextBuilder(nn.Module): Forward ------- This class does not implement a standard ``forward`` method. Instead, use - :meth:`build_context` to construct context and local features. + :meth:`build_context` to construct context, local features, and the + detached geometry context. See Also -------- @@ -664,7 +665,7 @@ class GlobalContextBuilder(nn.Module): >>> local_embeddings = (torch.randn(2, 100, 64),) >>> geometry = torch.randn(2, 100, 3) >>> global_embedding = torch.randn(2, 1, 16) - >>> context, local_feats = builder.build_context( + >>> context, local_feats, geo_ctx = builder.build_context( ... local_embeddings, None, geometry, global_embedding ... ) >>> context.shape @@ -779,6 +780,7 @@ def build_context( ) -> tuple[ Float[torch.Tensor, "batch heads slices context_dim"] | None, list[Float[torch.Tensor, "batch tokens local_features"]] | None, + Float[torch.Tensor, "batch heads slices dim_head"] | None, ]: r"""Build all context and local features. @@ -798,13 +800,17 @@ def build_context( Returns ------- - tuple[torch.Tensor | None, list[torch.Tensor] | None] + tuple[torch.Tensor | None, list[torch.Tensor] | None, torch.Tensor | None] - ``context``: Concatenated context tensor of shape :math:`(B, H, S, D_c)` where :math:`D_c` is the total context dimension, or ``None`` if no context sources are provided. - ``local_features``: List of local feature tensors, one per input type, each of shape :math:`(B, N, D_l)`, or ``None`` if local features are disabled. + - ``geometry_context_detached``: Detached geometry-tokenizer output of shape + :math:`(B, H, S, D)`, intended for downstream observers such as the + embedded OOD guard. ``None`` when geometry tokenization is disabled + or no geometry was provided. Raises ------ @@ -824,6 +830,7 @@ def build_context( context_parts = [] local_features = None + geometry_context_detached: torch.Tensor | None = None if local_positions is None and self.local_extractors is not None: raise ValueError( @@ -850,7 +857,11 @@ def build_context( # Tokenize geometry features if self.geometry_tokenizer is not None and geometry is not None: - context_parts.append(self.geometry_tokenizer(geometry)) + geometry_context = self.geometry_tokenizer(geometry) + # Detach the returned copy so downstream observers (e.g. the OOD + # guard) don't keep the backward graph alive. + geometry_context_detached = geometry_context.detach() + context_parts.append(geometry_context) # Tokenize global embedding if self.global_tokenizer is not None and global_embedding is not None: @@ -859,4 +870,4 @@ def build_context( # Concatenate all context features along the last dimension context = torch.cat(context_parts, dim=-1) if context_parts else None - return context, local_features + return context, local_features, geometry_context_detached diff --git a/physicsnemo/experimental/models/geotransolver/geotransolver.py b/physicsnemo/experimental/models/geotransolver/geotransolver.py index 6460960d08..1da57c8bb3 100644 --- a/physicsnemo/experimental/models/geotransolver/geotransolver.py +++ b/physicsnemo/experimental/models/geotransolver/geotransolver.py @@ -34,6 +34,7 @@ from physicsnemo.core.meta import ModelMetaData from physicsnemo.core.module import Module from physicsnemo.core.version_check import check_version_spec +from physicsnemo.experimental.guardrails.embedded import OODGuard, OODGuardConfig from physicsnemo.models.transolver.transolver import _TransolverMlp from .context_projector import GlobalContextBuilder @@ -204,6 +205,18 @@ class GeoTransolver(Module): Neighbors in radius for the local features. Default is ``[8, 32]``. n_hidden_local : int, optional Hidden dimension for the local features. Default is 32. + guard_config : dict | None, optional + Configuration for the embedded OOD guard + (:class:`~physicsnemo.experimental.guardrails.embedded.OODGuard`). + Pass a plain ``dict`` whose keys match the fields of + :class:`~physicsnemo.experimental.guardrails.embedded.OODGuardConfig` + (``buffer_size`` required; ``knn_k`` and ``sensitivity`` optional), or + ``None`` to disable the guard entirely. A ``dict`` is required (rather + than the dataclass directly) so the model kwargs remain + JSON-serialisable for ``.mdlus`` checkpointing. When set, the guard + accumulates global-parameter bounds and pooled geometry latents during + training, and emits warnings on out-of-distribution inputs during + inference. Default is ``None``. attention_type : str, optional attention_type is used to choose the attention type (GALE or GALE_FA). Default is ``"GALE"``. @@ -335,6 +348,7 @@ def __init__( radii: list[float] | None = None, neighbors_in_radius: list[int] | None = None, n_hidden_local: int = 32, + guard_config: dict | None = None, attention_type: str = "GALE", concrete_dropout: bool = False, ) -> None: @@ -462,6 +476,35 @@ def __init__( nn.Linear(n_hidden, n_hidden), ) + # OOD guard (None when disabled). + if guard_config is None: + self.ood_guard = None + else: + if not isinstance(guard_config, dict): + raise TypeError( + f"guard_config must be a dict or None; got " + f"{type(guard_config).__name__}. If using Hydra, set " + f"_convert_=partial or _convert_=all on the model config " + f"so nested mappings are passed as native dicts." + ) + if global_dim is None and geometry_dim is None: + raise ValueError( + "guard_config is set, but neither global_dim nor " + "geometry_dim is configured; the OOD guard would have " + "nothing to watch. Either set guard_config=None or " + "enable at least one of the two surfaces." + ) + # OODGuardConfig validates keys and applies defaults. + cfg = OODGuardConfig(**guard_config) + dim_head = n_hidden // n_head + self.ood_guard = OODGuard( + buffer_size=cfg.buffer_size, + global_dim=global_dim, + geometry_embed_dim=dim_head if geometry_dim is not None else None, + knn_k=cfg.knn_k, + sensitivity=cfg.sensitivity, + ) + def forward( self, local_embedding: ( @@ -562,10 +605,23 @@ def forward( ) # Build context embeddings and extract local features - embedding_states, local_embedding_bq = self.context_builder.build_context( - local_embedding, local_positions, geometry, global_embedding + embedding_states, local_embedding_bq, geo_ctx = ( + self.context_builder.build_context( + local_embedding, local_positions, geometry, global_embedding + ) ) + # --- OOD Guard --- + if self.ood_guard is not None: + # Pool (B, H, S, D) -> (B, D); guard expects pre-pooled latents. + geo_latent = ( + geo_ctx.mean(dim=(1, 2)) if geo_ctx is not None else None + ) + if self.training: + self.ood_guard.collect(global_embedding, geo_latent) + else: + self.ood_guard.check(global_embedding, geo_latent) + # Project inputs to hidden dimension: (B, N, C) -> (B, N, n_hidden) x = [self.preprocess[i](le) for i, le in enumerate(local_embedding)] @@ -591,4 +647,4 @@ def forward( if return_embedding_states: return x, embedding_states - return x \ No newline at end of file + return x diff --git a/test/experimental/guardrails/embedded/__init__.py b/test/experimental/guardrails/embedded/__init__.py new file mode 100644 index 0000000000..af85283aa4 --- /dev/null +++ b/test/experimental/guardrails/embedded/__init__.py @@ -0,0 +1,15 @@ +# SPDX-FileCopyrightText: Copyright (c) 2023 - 2026 NVIDIA CORPORATION & AFFILIATES. +# SPDX-FileCopyrightText: All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/test/experimental/guardrails/embedded/test_ood_guard.py b/test/experimental/guardrails/embedded/test_ood_guard.py new file mode 100644 index 0000000000..e5ec4f0462 --- /dev/null +++ b/test/experimental/guardrails/embedded/test_ood_guard.py @@ -0,0 +1,347 @@ +# SPDX-FileCopyrightText: Copyright (c) 2023 - 2026 NVIDIA CORPORATION & AFFILIATES. +# SPDX-FileCopyrightText: All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Tests for the embedded OOD guardrail (``OODGuard``) and its config.""" + +import logging + +import pytest +import torch + +from physicsnemo.experimental.guardrails.embedded import OODGuard, OODGuardConfig + +_GUARD_LOGGER = "physicsnemo.experimental.guardrails.embedded.ood_guard" + +_DEVICES = [ + pytest.param("cpu", id="cpu"), + pytest.param( + "cuda", + marks=pytest.mark.skipif( + not torch.cuda.is_available(), reason="CUDA not available" + ), + id="cuda", + ), +] + + +def _populate( + guard: OODGuard, + n_samples: int, + device: str, + *, + global_dim: int | None = None, + geo_dim: int | None = None, + batch_size: int = 4, + seed: int = 0, +) -> None: + """Feed ``n_samples`` in-distribution samples into ``guard`` via ``collect``.""" + gen = torch.Generator(device=device).manual_seed(seed) + remaining = n_samples + while remaining > 0: + b = min(batch_size, remaining) + g = ( + torch.randn(b, global_dim, device=device, generator=gen) + if global_dim is not None + else None + ) + z = ( + torch.randn(b, geo_dim, device=device, generator=gen) + if geo_dim is not None + else None + ) + guard.collect(global_embedding=g, geometry_latent=z) + remaining -= b + + +# --------------------------------------------------------------------------- +# Config +# --------------------------------------------------------------------------- + + +def test_config_requires_buffer_size_and_applies_defaults(): + """``OODGuardConfig`` requires ``buffer_size``; other fields have defaults.""" + with pytest.raises(TypeError): + OODGuardConfig() # buffer_size is required + + cfg = OODGuardConfig(buffer_size=32) + assert cfg.buffer_size == 32 + assert cfg.knn_k == 10 + assert cfg.sensitivity == 1.5 + + with pytest.raises(TypeError): + OODGuardConfig(buffer_size=32, unknown_field=1) + + +# --------------------------------------------------------------------------- +# Construction: surfaces can be independently enabled/disabled +# --------------------------------------------------------------------------- + + +@pytest.mark.parametrize( + "global_dim,geo_dim", + [ + (3, 8), # both enabled + (3, None), # global only + (None, 8), # geometry only + (None, None), # fully disabled (still valid) + ], +) +def test_construction_buffers(global_dim, geo_dim): + """Each surface's buffers are allocated iff its dim is set.""" + guard = OODGuard(buffer_size=16, global_dim=global_dim, geometry_embed_dim=geo_dim) + + if global_dim is not None: + assert guard.global_min.shape == (global_dim,) + assert guard.global_max.shape == (global_dim,) + assert torch.isinf(guard.global_min).all() + else: + assert guard.global_min is None + assert guard.global_max is None + + if geo_dim is not None: + assert guard.geo_embeddings.shape == (16, geo_dim) + assert guard.geo_ptr.item() == 0 + assert not guard.geo_full.item() + assert torch.isinf(guard.knn_threshold) + else: + assert guard.geo_embeddings is None + assert guard.geo_ptr is None + assert guard.geo_full is None + assert guard.knn_threshold is None + + +# --------------------------------------------------------------------------- +# Global-parameter bounding box +# --------------------------------------------------------------------------- + + +@pytest.mark.parametrize("device", _DEVICES) +def test_global_bounds_collect_and_detect(device, caplog): + """Collect shrinks bounds; out-of-box batch at check-time warns.""" + guard = OODGuard(buffer_size=8, global_dim=2).to(device) + + # In-distribution: all values in [-1, 1]. + for _ in range(5): + vals = torch.rand(4, 2, device=device) * 2 - 1 # uniform in [-1, 1) + guard.collect(global_embedding=vals) + + assert (guard.global_min >= -1.0).all() + assert (guard.global_max <= 1.0).all() + + # In-distribution check: no warnings emitted. + with caplog.at_level(logging.WARNING, logger=_GUARD_LOGGER): + guard.check(global_embedding=torch.zeros(2, 2, device=device)) + assert not any( + "OOD Guard: global_embedding" in r.getMessage() for r in caplog.records + ) + + # OOD check: value well above the collected max on dim 1. + caplog.clear() + ood = torch.tensor([[0.0, 10.0]], device=device) + with caplog.at_level(logging.WARNING, logger=_GUARD_LOGGER): + guard.check(global_embedding=ood) + msgs = [r.getMessage() for r in caplog.records] + assert any("dim 1" in m and "above training max" in m for m in msgs), msgs + + +# --------------------------------------------------------------------------- +# Geometry FIFO + threshold +# --------------------------------------------------------------------------- + + +@pytest.mark.parametrize("device", _DEVICES) +def test_geometry_fifo_wraps_and_latches_full(device): + """Pointer wraps modulo buffer_size; ``geo_full`` latches on first wrap.""" + buffer_size = 5 + guard = OODGuard(buffer_size=buffer_size, geometry_embed_dim=4).to(device) + + # 3 samples: ptr advances, not yet full. + _populate(guard, 3, device=device, geo_dim=4, batch_size=1, seed=1) + assert guard.geo_ptr.item() == 3 + assert not guard.geo_full.item() + + # 4 more samples (total 7) → wrap past capacity. + _populate(guard, 4, device=device, geo_dim=4, batch_size=1, seed=2) + assert guard.geo_full.item() + assert guard.geo_ptr.item() == 7 % buffer_size # == 2 + assert 0 <= guard.geo_ptr.item() < buffer_size + + +@pytest.mark.parametrize("device", _DEVICES) +def test_geometry_threshold_computes_and_detects_ood(device, caplog): + """Threshold becomes finite after calibration; far-OOD queries warn. + + The guard L2-normalises both buffer and query, so distances are bounded + in ``[0, 2]`` on the unit sphere. We seed the buffer with a cluster + biased toward ``+e_0`` so coverage is local; a query pointing along + ``-e_0`` is then ~2 away — reliably over any reasonable threshold. + """ + guard = OODGuard(buffer_size=32, geometry_embed_dim=8, knn_k=4, sensitivity=1.5).to( + device + ) + + # Clustered in-distribution buffer: a tight Gaussian offset along +e_0. + gen = torch.Generator(device=device).manual_seed(11) + shift = torch.zeros(8, device=device) + shift[0] = 3.0 + in_dist = torch.randn(32, 8, device=device, generator=gen) * 0.3 + shift + for i in range(0, 32, 4): + guard.collect(geometry_latent=in_dist[i : i + 4]) + + # First check triggers lazy threshold calibration. + guard.check(geometry_latent=in_dist[:1].clone()) + assert torch.isfinite(guard.knn_threshold) + + # Far-OOD query: unit vector pointing along -e_0, antipodal to the cluster. + z_ood = torch.zeros(1, 8, device=device) + z_ood[0, 0] = -1.0 + with caplog.at_level(logging.WARNING, logger=_GUARD_LOGGER): + guard.check(geometry_latent=z_ood) + assert any( + "OOD Guard: geometry sample" in r.getMessage() + and "above threshold" in r.getMessage() + for r in caplog.records + ) + + +# --------------------------------------------------------------------------- +# Shape validation +# --------------------------------------------------------------------------- + + +@pytest.mark.parametrize( + "global_embedding,geometry_latent,match", + [ + # Rank-1 global_embedding: batch+channel confusion hazard. + (torch.zeros(3), None, "at least 2 dims"), + # Channel-dim mismatch on global_embedding. + (torch.zeros(2, 5), None, "last-dim mismatch"), + # Rank-3 geometry_latent (pooling missed at caller). + (None, torch.zeros(2, 3, 8), "must be rank-2"), + # Channel mismatch on geometry_latent. + (None, torch.zeros(2, 5), "channel dim mismatch"), + ], +) +def test_shape_validation(global_embedding, geometry_latent, match): + """Bad-shape inputs raise ``ValueError`` with actionable messages.""" + guard = OODGuard(buffer_size=8, global_dim=3, geometry_embed_dim=8) + with pytest.raises(ValueError, match=match): + guard.collect( + global_embedding=global_embedding, geometry_latent=geometry_latent + ) + with pytest.raises(ValueError, match=match): + guard.check(global_embedding=global_embedding, geometry_latent=geometry_latent) + + +# --------------------------------------------------------------------------- +# AMP dtype robustness +# --------------------------------------------------------------------------- + + +@pytest.mark.parametrize( + "dtype", + [pytest.param(torch.float16, id="fp16"), pytest.param(torch.bfloat16, id="bf16")], +) +def test_amp_inputs_upcast_into_fp32_buffers(dtype): + """AMP (fp16/bf16) inputs are accepted; buffers stay fp32.""" + guard = OODGuard(buffer_size=4, global_dim=2, geometry_embed_dim=4) + g = torch.ones(2, 2, dtype=dtype) + z = torch.ones(2, 4, dtype=dtype) + + guard.collect(global_embedding=g, geometry_latent=z) + + assert guard.global_min.dtype == torch.float32 + assert guard.geo_embeddings.dtype == torch.float32 + # Values propagated despite dtype mismatch at the input boundary. + assert torch.allclose(guard.global_min, torch.ones(2)) + assert torch.allclose(guard.geo_embeddings[:2], torch.ones(2, 4)) + + +# --------------------------------------------------------------------------- +# Threshold staleness across train → eval → train → eval +# --------------------------------------------------------------------------- + + +def test_threshold_restales_after_collect_then_check(): + """Running ``collect`` after ``check`` re-marks the threshold stale.""" + guard = OODGuard(buffer_size=16, geometry_embed_dim=4, knn_k=3) + _populate(guard, n_samples=16, device="cpu", geo_dim=4, seed=7) + + # First check computes the threshold. + guard.check(geometry_latent=torch.randn(1, 4)) + t0 = guard.knn_threshold.clone() + assert torch.isfinite(t0) + assert guard._threshold_stale is False + + # More collection invalidates the stale flag. + _populate(guard, n_samples=8, device="cpu", geo_dim=4, seed=8) + assert guard._threshold_stale is True + + # Next check recomputes. + guard.check(geometry_latent=torch.randn(1, 4)) + assert guard._threshold_stale is False + # Buffer differs from iteration 1, so threshold should change in general. + # (Identity not guaranteed; compare for finiteness and plausibility.) + assert torch.isfinite(guard.knn_threshold) + + +# --------------------------------------------------------------------------- +# Sensitivity multiplier +# --------------------------------------------------------------------------- + + +def test_sensitivity_scales_threshold_linearly(): + """Doubling ``sensitivity`` doubles the computed kNN threshold.""" + # Deterministic collection so both guards see the same buffer. + gen = torch.Generator(device="cpu").manual_seed(42) + samples = torch.randn(32, 8, generator=gen) + + def _calibrate(sensitivity: float) -> torch.Tensor: + g = OODGuard( + buffer_size=32, geometry_embed_dim=8, knn_k=4, sensitivity=sensitivity + ) + g.collect(geometry_latent=samples.clone()) + g.compute_threshold() + return g.knn_threshold.clone() + + t1 = _calibrate(1.0) + t2 = _calibrate(2.0) + assert torch.isfinite(t1) and torch.isfinite(t2) + assert torch.allclose(t2, t1 * 2.0, rtol=1e-5) + + +# --------------------------------------------------------------------------- +# Checkpoint round-trip via state_dict +# --------------------------------------------------------------------------- + + +def test_state_dict_roundtrip_preserves_calibration(): + """``state_dict`` captures all guard state; reload reproduces threshold.""" + src = OODGuard(buffer_size=16, global_dim=2, geometry_embed_dim=4, knn_k=3) + _populate(src, n_samples=16, device="cpu", global_dim=2, geo_dim=4, seed=5) + src.compute_threshold() + threshold_src = src.knn_threshold.clone() + + dst = OODGuard(buffer_size=16, global_dim=2, geometry_embed_dim=4, knn_k=3) + dst.load_state_dict(src.state_dict()) + + # Buffers transferred verbatim; threshold identical. + assert torch.equal(dst.geo_embeddings, src.geo_embeddings) + assert dst.geo_full.item() == src.geo_full.item() + assert dst.geo_ptr.item() == src.geo_ptr.item() + assert torch.equal(dst.global_min, src.global_min) + assert torch.equal(dst.global_max, src.global_max) + assert torch.allclose(dst.knn_threshold, threshold_src) diff --git a/test/experimental/guardrails/geometry/test_ood_detector.py b/test/experimental/guardrails/geometry/test_ood_detector.py index cdd66b5c09..6ac8956c41 100644 --- a/test/experimental/guardrails/geometry/test_ood_detector.py +++ b/test/experimental/guardrails/geometry/test_ood_detector.py @@ -28,8 +28,11 @@ import pyvista as pv -from physicsnemo.experimental.guardrails import GeometryGuardrail -from physicsnemo.experimental.guardrails.geometry import FEATURE_NAMES, FEATURE_VERSION +from physicsnemo.experimental.guardrails.geometry import ( + FEATURE_NAMES, + FEATURE_VERSION, + GeometryGuardrail, +) from physicsnemo.mesh.io.io_pyvista import from_pyvista diff --git a/test/models/geotransolver/test_geotransolver.py b/test/models/geotransolver/test_geotransolver.py index 124827a208..6b897cc191 100644 --- a/test/models/geotransolver/test_geotransolver.py +++ b/test/models/geotransolver/test_geotransolver.py @@ -748,3 +748,107 @@ def test_geotransolver_metadata(): assert model.meta.name == "GeoTransolver" assert model.meta.amp is True assert model.__name__ == "GeoTransolver" + + +# ============================================================================= +# Embedded OOD guard (guard_config) integration +# ============================================================================= + + +def _make_guarded_model(device, guard_config): + """Minimal guard-enabled GeoTransolver used by the tests below.""" + return GeoTransolver( + functional_dim=32, + out_dim=4, + geometry_dim=3, + global_dim=16, + n_layers=2, + n_hidden=64, + dropout=0.0, + n_head=4, + act="gelu", + mlp_ratio=2, + slice_num=8, + use_te=False, + time_input=False, + plus=False, + include_local_features=False, + guard_config=guard_config, + ).to(device) + + +def test_geotransolver_guard_config_none_leaves_guard_unattached(device): + """``guard_config=None`` (the default) produces no OOD guard.""" + model = _make_guarded_model(device, guard_config=None) + assert model.ood_guard is None + + +def test_geotransolver_guard_config_dict_attaches_and_runs(device): + """Dict ``guard_config`` attaches an ``OODGuard`` wired through the forward pass.""" + torch.manual_seed(42) + + model = _make_guarded_model( + device, + guard_config={"buffer_size": 8, "knn_k": 3, "sensitivity": 1.5}, + ) + assert model.ood_guard is not None + + batch_size = 2 + local_emb = torch.randn(batch_size, 50, 32, device=device) + local_positions = local_emb[:, :, :3] + geometry = torch.randn(batch_size, 80, 3, device=device) + global_emb = torch.randn(batch_size, 1, 16, device=device) + + # Training forward should populate the guard's buffers. + model.train() + _ = model( + local_emb, + local_positions=local_positions, + global_embedding=global_emb, + geometry=geometry, + ) + assert model.ood_guard.geo_ptr.item() == batch_size + assert not torch.isinf(model.ood_guard.global_min).any() + + # Eval forward should run the checks (threshold may remain inf until the + # buffer has enough samples, which is acceptable — we just verify no crash). + model.eval() + _ = model( + local_emb, + local_positions=local_positions, + global_embedding=global_emb, + geometry=geometry, + ) + + +@pytest.mark.parametrize( + "bad_config,expected_exc,match", + [ + # Unknown field: OODGuardConfig rejects at construction. + ({"buffer_size": 8, "nope": 1}, TypeError, "unexpected keyword argument"), + # Missing required field. + ({}, TypeError, "buffer_size"), + # Non-dict type. + (42, TypeError, "guard_config must be a dict"), + ], +) +def test_geotransolver_guard_config_invalid_inputs(bad_config, expected_exc, match): + """Invalid ``guard_config`` values raise at construction with clear messages.""" + with pytest.raises(expected_exc, match=match): + _make_guarded_model("cpu", guard_config=bad_config) + + +def test_geotransolver_guard_config_without_any_surface_raises(): + """Enabling the guard without either ``global_dim`` or ``geometry_dim`` raises.""" + with pytest.raises(ValueError, match="nothing to watch"): + GeoTransolver( + functional_dim=32, + out_dim=4, + geometry_dim=None, + global_dim=None, + n_layers=2, + n_hidden=64, + n_head=4, + use_te=False, + guard_config={"buffer_size": 8}, + )