Conversation
Greptile SummaryThis PR adds two lightweight OOD guards to GeoTransolver — a per-dimension bounding box on global parameters and a kNN distance check on pooled geometry latents — that passively calibrate during training and warn at inference without changing model outputs or gradients. The
Important Files Changed
Reviews (2): Last reviewed commit: "Merge branch 'GeoT_Guard' of https://git..." | Re-trigger Greptile |
|
Note: Tests will be added after the initial review is done for the overall design and implementation of this guardrail. |
|
/blossom-ci |
|
/blossom-ci |
coreyjadams
left a comment
There was a problem hiding this comment.
Approving to unblock, overall it looks good, I do think we should use a non-brute force kNN.
I've got a unified training recipe for geotransolver that makes some small tweaks to geotransolver too, so whoever merges second will have minor conflicts to deal with but it will not be difficult I expect.
|
/ok to test 9c59ee4 |
|
/blossom-ci |
|
/ok to test a47961b |
|
/blossom-ci |
PhysicsNeMo Pull Request
Description
Problem
When users apply a pretrained GeoTransolver checkpoint to inputs outside the training distribution (e.g., running a DrivAerML-trained model on motorcycles or aircraft), the model silently produces unreliable predictions. There is no mechanism to detect or warn about out-of-distribution (OOD) inputs at inference time.
Solution
We add two lightweight OOD guards that integrate seamlessly into the existing training and inference workflow. Both guards are controlled by a single knob (
guard_buffer_size) and require no additional scripts, calibration steps, or changes to the training loop.warnings.warn()if OOD is detected.Guard 1: Global Parameters — Bounding Box
What it monitors: The raw
global_embeddinginput tensor (e.g., air density, stream velocity for DrivAerML).How it works:
Why input space: Global parameters are low-dimensional scalars (2-3 dims). A bounding box is simpler, more interpretable, and more reliable than latent-space methods at this dimensionality.
Example warning:
Guard 2: Geometry Context — kNN Distance
The implementation is based on this paper: https://arxiv.org/pdf/2204.06507
What it monitors: The geometry context vector produced by
GlobalContextBuilder.geometry_tokenizer-- a learned 32-dimensional representation of the input geometry, mean-pooled over attention heads and slices.How it works:
Why latent space: Geometry is a variable-size point cloud -- there is no fixed-dimensional input representation to bound. The post-ContextProjector embedding compresses geometry into a fixed 32-dim vector suitable for distance-based methods.
Why kNN:
Why not monitor multi-scale local features:
Example warning:
Usage
Enabling guards
Add to model config (or pass to constructor):
Both guards are enabled when
guard_buffer_sizeis set, and disabled when it isNone.Training
No changes to the training script. Guards collect data automatically during
model.train()forward passes. The kNN threshold is computed automatically when the checkpoint is saved (viastate_dict()override).Inference
No changes to the inference script. Guards run automatically during
model.eval()forward passes and emit Python warnings for OOD inputs.Configuration
guard_buffer_sizeNone(disabled)guard_knn_k10Threshold
The kNN threshold is set at the 99th percentile of training-set leave-one-out kNN distances. This means ~1% false alarm rate on in-distribution data -- near-zero warnings on validation/test sets that are in-distribution.
Multi-GPU / DDP
guard_buffer_size >= dataset_sizeto ensure good coverage per rank.Checkpoint Compatibility
guard_*keys in the state dict. Requiresstrict=Falsewhen loading.Tests
This is currently tested for the Crash recipe on the bumper beam dataset. No OOD warnings for inference on test samples. For new OOD samples which have either OOD global parameters or a scaled geometry by a small factor (1.05), the OOD warning is raised.
Implementation
Registered buffers
All guard state is stored as registered buffers (persistent, non-trainable):
guard_global_min(global_dim,)guard_global_max(global_dim,)guard_geo_embeddings(buffer_size, dim_head)guard_geo_ptr(1,)guard_knn_thresholdguard_knn_kKey methods
_guard_collect()_guard_check()compute_guard_threshold()state_dict()call (before save)Checklist
Dependencies
None