Skip to content

Add xDeepONet family to experimental models#1576

Merged
peterdsharpe merged 16 commits into
NVIDIA:mainfrom
wdyab:pr/xdeeponet
May 15, 2026
Merged

Add xDeepONet family to experimental models#1576
peterdsharpe merged 16 commits into
NVIDIA:mainfrom
wdyab:pr/xdeeponet

Conversation

@wdyab
Copy link
Copy Markdown
Contributor

@wdyab wdyab commented Apr 17, 2026

Summary

Introduces physicsnemo.experimental.models.xdeeponet — a config-driven,
unified implementation of eight DeepONet-based operator-learning
architectures for both 2D and 3D spatial domains:

  • deeponet, u_deeponet, fourier_deeponet, conv_deeponet,
    hybrid_deeponet — single-branch variants
  • mionet, fourier_mionet — two-branch multi-input variants
  • tno — Temporal Neural Operator (branch2 = previous solution)

This is the first of several PRs restructuring the Neural Operator
Factory per discussion with code owners. Subsequent PRs will upstream
xFNO (experimental) and refactor the reservoir-simulation NOF example
to consume these library models.

Closes Issue NVIDIA/physicsnemo-roadmap#2504

Key features

  • Composable spatial branches (Fourier, UNet, Conv in any combination)
  • Three decoder types: mlp, conv, temporal_projection
  • Automatic spatial padding
  • Automatic trunk coordinate extraction (time or grid)
  • Optional adaptive pooling for resolution-agnostic training

Design decisions (per @coreyjadams guidance)

  • Placed under experimental/ per the convention for new models.
  • Custom UNet dropped.
  • Tests live at test/experimental/models/ for CI coverage.

Checklist

  • I am familiar with the Contributing Guidelines
  • New tests cover these changes (29 tests under
    test/experimental/models/test_xdeeponet.py)
  • The documentation is up to date (package README, docstrings)
  • The CHANGELOG.md is up to date
  • An issue (#1575) is linked to this pull request
  • I have followed the Models Implementation Coding Standards

Test plan

  • 29 unit tests pass locally (branch shapes, wrappers, variants,
    decoder types, temporal projection, target_times override,
    gradient flow, adaptive pooling)
  • ruff check, ruff format, interrogate, markdownlint,
    license pre-commit hooks pass
  • No modifications to existing code — pure addition under
    physicsnemo/experimental/ and test/experimental/
  • All commits signed off per DCO

Related

@greptile-apps
Copy link
Copy Markdown
Contributor

greptile-apps Bot commented Apr 17, 2026

Greptile Summary

This PR adds the xdeeponet experimental package — a config-driven implementation of eight DeepONet-based architectures for 2D and 3D operator learning — with all previously identified issues now addressed: physicsnemo.Module inheritance, dual-branch construction guards, deterministic state_dict via output_window, MLPBranch num_layers guard, jaxtyping annotations with shape validation, case-insensitive decoder_type, and _SinActivation wrapper. Two minor P2 style items remain in deeponet.py: a shared activation instance in _build_conv_encoder and an implicit b2_out assignment pattern that static analysis tools will flag as possibly-undefined.

Important Files Changed

Filename Overview
physicsnemo/experimental/models/xdeeponet/deeponet.py Core 2D/3D DeepONet architectures with comprehensive construction-time guards, case-insensitive config handling, and jaxtyping annotations; two minor style issues: shared activation instance in _build_conv_encoder and b2_out possibly-undefined for static analysis.
physicsnemo/experimental/models/xdeeponet/branches.py TrunkNet, MLPBranch, SpatialBranch, SpatialBranch3D building blocks with correct lazy-init, jaxtyping annotations, shape guards, and the num_layers >= 2 guard for MLPBranch.
physicsnemo/experimental/models/xdeeponet/wrappers.py DeepONetWrapper and DeepONet3DWrapper inherit from physicsnemo.Module, perform correct spatial padding/crop and trunk coordinate extraction; no issues found.
physicsnemo/experimental/models/xdeeponet/padding.py Dimension-agnostic right-side padding helpers with correct replicate and constant modes; edge cases handled properly.
test/experimental/models/test_xdeeponet.py 29 unit tests covering branch shapes, all 8 variants, both 2D and 3D wrappers, temporal projection, target_times override, gradient flow, and error cases.
physicsnemo/experimental/models/xdeeponet/init.py Clean package init exporting all public symbols; no issues.
physicsnemo/experimental/models/xdeeponet/README.md Comprehensive module README with variant table, quick-start examples, config schema, and references.
CHANGELOG.md CHANGELOG entry for xDeepONet addition; correctly placed under Added.

Reviews (11): Last reviewed commit: "xdeeponet: fix _build_conv_encoder for "..." | Re-trigger Greptile

Comment thread physicsnemo/experimental/models/xdeeponet/wrappers.py Outdated
Comment thread physicsnemo/experimental/models/xdeeponet/deeponet.py Outdated
Comment thread physicsnemo/experimental/models/xdeeponet/deeponet.py Outdated
Comment thread physicsnemo/experimental/models/xdeeponet/branches.py Outdated
Comment thread physicsnemo/experimental/models/xdeeponet/deeponet.py
Comment thread physicsnemo/experimental/models/xdeeponet/deeponet.py Outdated
wdyab added a commit to wdyab/physicsnemo that referenced this pull request Apr 17, 2026
Fix six issues flagged by the Greptile review:

- Make DeepONetWrapper / DeepONet3DWrapper inherit from
  physicsnemo.core.module.Module (MOD-001). Core DeepONet / DeepONet3D
  also pass proper MetaData dataclasses.
- Raise ValueError at __init__ when mionet / fourier_mionet / tno are
  constructed without branch2_config (prevents silent degradation to a
  single-branch model).
- Add optional output_window constructor parameter so the
  temporal_projection decoder registers temporal_head at __init__,
  producing a deterministic state_dict that round-trips cleanly.
  set_output_window is retained for backwards compatibility.
- Raise ValueError from MLPBranch when num_layers < 2.
- Convert public docstrings to r-prefixed raw strings with
  Parameters / Forward / Outputs sections and LaTeX shape notation
  per MOD-003.
- Add jaxtyping.Float annotations and torch.compiler.is_compiling()
  guarded shape validation to all public forward methods
  (MOD-005, MOD-006).

Signed-off-by: wdyab <wdyab@nvidia.com>
Made-with: Cursor
Comment thread physicsnemo/experimental/models/xdeeponet/deeponet.py Outdated
wdyab added a commit to wdyab/physicsnemo that referenced this pull request Apr 22, 2026
Fix the new P1 issue flagged in the second Greptile review and close
two secondary gaps the summary called out:

- DeepONet.forward / DeepONet3D.forward: raise RuntimeError when
  decoder_type='temporal_projection' is used but temporal_head is
  still None (i.e. the user neither passed output_window at
  construction nor called set_output_window before forward).
  Previously the silent ``if temporal_head is not None`` skip
  returned (B, H, W, width) instead of (B, H, W, K).
- Deduplicate the VALID_VARIANTS list: pulled to a module-level
  _VALID_VARIANTS tuple; both DeepONet and DeepONet3D still expose it
  as the VALID_VARIANTS class attribute for a stable public API.
- Extend the parametrized test lists to cover fourier_deeponet,
  hybrid_deeponet, and fourier_mionet, and add a dedicated
  TestFourierBranchPaths class with num_fourier_layers > 0 so the
  spectral-conv code path in SpatialBranch / SpatialBranch3D is
  actually exercised in CI.
- Add a TestTemporalProjectionGuard::test_forward_without_output_window_raises
  regression test for the new RuntimeError.

Signed-off-by: wdyab <wdyab@nvidia.com>
Made-with: Cursor
Comment thread physicsnemo/experimental/models/xdeeponet/deeponet.py Outdated
Comment thread physicsnemo/experimental/models/xdeeponet/deeponet.py Outdated
wdyab added a commit to wdyab/physicsnemo that referenced this pull request Apr 22, 2026
Two new P1 issues flagged on 85076f6:

- Case-sensitive decoder_type check: __init__ lowered ``decoder_type``
  into ``self.decoder_type`` but then branched on the raw argument
  (``if decoder_type == "temporal_projection":``) and forwarded the
  raw value to ``_build_decoder``.  A user passing
  ``decoder_type="MLP"`` or ``"Temporal_Projection"`` ended up with
  ``Unknown decoder_type: MLP`` bubbling out of ``_build_decoder``.
  Both branches of the check now use ``self.decoder_type``; same fix
  in ``DeepONet3D.__init__``.
- MLP branch + decoder_type='temporal_projection' silently returned
  (B, T, width) instead of (B, K) because the MLP-branch path in
  ``forward`` never consulted ``self._temporal_projection``.  The
  incompatibility is static, so reject it at __init__ with a
  descriptive ``ValueError`` rather than at forward.  Same guard in
  ``DeepONet3D.__init__``.

Regression tests: ``TestDecoderTypeNormalization`` (mixed-case
``"MLP"`` / ``"Temporal_Projection"`` accepted) and
``TestMLPBranchTemporalProjectionGuard`` (2D and 3D both reject the
invalid combination).

Signed-off-by: wdyab <wdyab@nvidia.com>
Made-with: Cursor
Comment thread physicsnemo/experimental/models/xdeeponet/deeponet.py
Comment thread physicsnemo/experimental/models/xdeeponet/deeponet.py Outdated
wdyab added a commit to wdyab/physicsnemo that referenced this pull request Apr 22, 2026
Proactive audit on top of Greptile's round-4 findings.  All
plausible silent-degradation combinations at the config boundary
now fail loudly at __init__ instead of producing wrong shapes or
cryptic PyTorch errors at forward time.

Construction-time guards added to both DeepONet and DeepONet3D:

- Unknown decoder_type is rejected up front against a new
  module-level ``_VALID_DECODER_TYPES`` set (previously deferred
  to ``_build_decoder`` and only surfaced on the non-temporal
  branch).
- MLPBranch branch1 paired with decoder_type='conv' is rejected
  (would otherwise crash inside ``Conv2d`` with a generic
  "Expected 3D or 4D input" message).  Unified with the existing
  temporal_projection guard into a single check.
- MLPBranch branch1 paired with a non-MLPBranch branch2 is
  rejected (element-wise product assumed matching ranks;
  previously broadcast nonsensically or raised a cryptic dim
  mismatch at forward).

Regression tests:

- ``TestMLPBranchConvDecoderGuard`` -- 2D/3D
- ``TestMixedBranchTypeGuard`` -- 2D/3D
- ``TestInvalidDecoderTypeGuard`` -- 2D/3D

Full suite: 47 passed.

Signed-off-by: wdyab <wdyab@nvidia.com>
Made-with: Cursor
Comment thread physicsnemo/experimental/models/xdeeponet/deeponet.py Outdated
wdyab added a commit to wdyab/physicsnemo that referenced this pull request Apr 22, 2026
``_build_conv_encoder`` called ``get_activation`` directly without
the ``sin`` special-case handling used in every other activation
site in ``branches.py``.  Passing
``{"encoder": {"type": "conv", "activation_fn": "sin"}}`` therefore
raised ``KeyError: Activation function sin not found``.

``torch.sin`` is a bare callable and cannot be placed inside an
``nn.Sequential`` (which requires ``nn.Module`` instances), so the
fix introduces a small ``_SinActivation`` wrapper module alongside
``_build_conv_encoder``.  The helper is module-level and is called
from both ``DeepONet`` and ``DeepONet3D``; only one fix site exists
despite the function being invoked from both classes.

Regression test ``TestConvEncoderSinActivation`` constructs a
multi-layer conv encoder with ``activation_fn="sin"`` and runs a
forward pass to confirm neither the ``KeyError`` nor a
``nn.Sequential`` ``TypeError`` resurface.

Signed-off-by: wdyab <wdyab@nvidia.com>
Made-with: Cursor
@greptile-apps
Copy link
Copy Markdown
Contributor

greptile-apps Bot commented Apr 22, 2026

Want your agent to iterate on Greptile's feedback? Try greploops.

@wdyab
Copy link
Copy Markdown
Contributor Author

wdyab commented Apr 22, 2026

All Greptile feedback addressed across five rounds. Main fixes: wrappers inherit from physicsnemo.Module, construction-time validation for all the config combinations that were silently producing wrong shapes, deterministic state_dict for the temporal-projection decoder, case-insensitive decoder_type, and a few smaller bits. All tests are green, pre-commit clean.
@coreyjadams @ram-cherukuri @peterdsharpe

Comment thread physicsnemo/experimental/models/xdeeponet/README.md Outdated
Comment thread physicsnemo/experimental/models/xdeeponet/wrappers.py Outdated
Comment thread test/experimental/models/test_xdeeponet.py Outdated
wdyab added a commit to wdyab/physicsnemo that referenced this pull request Apr 24, 2026
- Rewrite test suite to the constructor + non-regression + checkpoint
  + gradient + compile pattern; relocate to
  ``test/experimental/models/xdeeponet/`` to match the ``flare/`` /
  ``fno/`` per-model layout.  Commit ``.pth`` goldens + regeneration
  script under ``data/``.
- Add ``Examples`` sections to every user-facing class.
- Add ``jaxtyping.Float`` annotations to top-level ``forward`` methods
  (MOD-006).
- Remove ``README.md``; no model under ``physicsnemo/`` ships one.
  Design rationale now lives in module-level and class-level docstrings.

Signed-off-by: wdyab <wdyab@nvidia.com>
Made-with: Cursor
Comment thread physicsnemo/experimental/models/xdeeponet/deeponet.py Outdated
Comment thread physicsnemo/experimental/models/xdeeponet/deeponet.py Outdated
wdyab added a commit to wdyab/physicsnemo that referenced this pull request May 4, 2026
Address @peterdsharpe review feedback on PR NVIDIA#1576: the ``variant`` and
``decoder_type`` parameters now carry ``typing.Literal`` annotations
rather than bare ``str``, so static type checkers and IDE
auto-completion can flag unknown values at the call site instead of
deferring to the runtime ``ValueError``.

The ``Literal`` aliases ``_VariantStr`` and ``_DecoderTypeStr`` are
defined once at module scope in ``deeponet.py`` and imported into
``wrappers.py``.  ``_VALID_VARIANTS`` and ``_VALID_DECODER_TYPES`` are
now derived from those aliases via ``typing.get_args`` so the static
and runtime views cannot drift.

Behaviour is unchanged: Python does not enforce ``Literal`` at runtime,
so the existing ``.lower()`` normalization and ``ValueError`` guards
keep mixed-case inputs working (verified by the existing
``TestDecoderTypeNormalization`` round-trip).  No tests, fixtures, or
constructor signatures were affected; full suite remains 13 passing
(+ 1 deselected ``torch.compile`` smoke run separately).

Signed-off-by: wdyab <wdyab@nvidia.com>
Co-authored-by: Cursor <cursoragent@cursor.com>
@copy-pr-bot
Copy link
Copy Markdown

copy-pr-bot Bot commented May 4, 2026

This pull request requires additional validation before any workflows can run on NVIDIA's runners.

Pull request vetters can view their responsibilities here.

Contributors can view more details about this message here.

@wdyab wdyab requested a review from peterdsharpe May 6, 2026 14:55
Copy link
Copy Markdown
Collaborator

@peterdsharpe peterdsharpe left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hey Waleed, thanks for the patience here, and for splitting this PR up.

Had a chance to take a look at this PR today. Overall, I think this is in good shape and something we should merge - nice work.

There are a few overarching design changes I'd request before doing so, which will a) help with long-term maintainability, and b) expand the range of downstream applications that can take advantage of this.

And, there are a few small things that I've noted in comments. Now, for the bigger things:

Going Dimensionally-Generic

The main concern here is pervasive 2D/3D code duplication: DeepONet and DeepONet3D are structurally ~95% identical, as are SpatialBranch / SpatialBranch3D and DeepONetWrapper / DeepONet3DWrapper. These can be made dimensionally-generic, with essentially no performance hit - this will be a big boost to maintainability and shrink the PR size by ~50% (at least for the source; tests will be similar size).

As an example, the existing FNO class in PhysicsNeMo codebase already solves this with a dimension parameter:

class FNO(Module):
    def __init__(self, ..., dimension: int = 2):
        FNOModel = self._getFNOEncoder()  # dispatch on dimension

So, xDeepONet could use the same approach. A dimension-generic DeepONet would:

  1. Accept dimension: int (2 or 3) in the constructor
  2. Dispatch to the right SpectralConv, Conv, BatchNorm, AdaptiveAvgPool via a lookup table
  3. Use dimension-parameterized permute/unsqueeze/reshape helpers
  4. Share 100% of the __init__ and _build_branch logic
  5. Share ~90% of the forward logic, differing only in the tensor manipulation shapes

This dispatch happens at construction time, so there's zero performance hit (or interaction with torch.compile). The one thing to verify is that dynamic ndim checks in forward() don't create graph breaks. But these are already guarded by torch.compiler.is_compiling(), so they're skipped during compilation. The actual tensor operations (unsqueeze, permute, etc.) can be parameterized via stored tuples without affecting compilation.

Folding in the Wrapper classes

DeepONetWrapper / DeepONet3DWrapper should be folded into DeepONet via something like auto_pad: bool - padding is a constructor flag.

Deduplication

TrunkNet and MLPBranch reinvent FullyConnected; these should be removed. Also, the sin activations can be factored upstream into the activation registry in physicsnemo.nn.

Test coverage

See notes in test_xdeeponet.py but there are some important branches that aren't yet fully tested.


As mentioned - I think overall this is in pretty good shape and ready to merge once these concerns are addressed.

Comment thread physicsnemo/experimental/models/xdeeponet/__init__.py
Comment thread physicsnemo/experimental/models/xdeeponet/deeponet.py Outdated
Comment thread physicsnemo/experimental/models/xdeeponet/deeponet.py Outdated
Comment thread physicsnemo/experimental/models/xdeeponet/deeponet.py Outdated
Comment thread physicsnemo/experimental/models/xdeeponet/deeponet.py Outdated
Comment thread test/experimental/models/xdeeponet/test_xdeeponet.py
Comment thread test/experimental/models/xdeeponet/test_xdeeponet.py
Comment thread test/experimental/models/xdeeponet/test_xdeeponet.py
Comment thread test/experimental/models/xdeeponet/test_xdeeponet.py
Comment thread test/experimental/models/xdeeponet/test_xdeeponet.py
@peterdsharpe
Copy link
Copy Markdown
Collaborator

FYI, some minor pre-commit fixes needed (see CI log)

@peterdsharpe
Copy link
Copy Markdown
Collaborator

/blossom-ci

wdyab added 3 commits May 13, 2026 02:58
Introduces physicsnemo.experimental.models.xdeeponet — a config-driven,
unified implementation of eight DeepONet-based operator-learning
architectures for both 2D and 3D spatial domains:

- deeponet, u_deeponet, fourier_deeponet, conv_deeponet, hybrid_deeponet
  (single-branch variants)
- mionet, fourier_mionet (two-branch multi-input variants)
- tno (Temporal Neural Operator; branch2 = previous solution)

Features:
- Composable spatial branches (Fourier, UNet, Conv in any combination)
- Three decoder types: mlp, conv, temporal_projection
- Automatic spatial padding to multiples of 8
- Automatic trunk coordinate extraction (time or grid)
- Optional adaptive pooling (internal_resolution) for
  resolution-agnostic training and inference

Uses physicsnemo.models.unet.UNet as the UNet sub-module; a small
internal adapter tiles a short time axis to reuse the library's 3D UNet
for 2D spatial branches.  Imports spectral, convolutional, and MLP
layers from physicsnemo.nn and physicsnemo.models.mlp.

Includes 29 unit tests covering all variants (2D/3D), decoder types,
temporal projection, target_times override, gradient flow, and
adaptive pooling.

Related discussion with code owners:
- Placed under experimental/ per PhysicsNeMo convention for new models.
- Custom UNet dropped in favour of library UNet.
- Tests under test/experimental/models/ for CI coverage.

Signed-off-by: wdyab <wdyab@nvidia.com>
Made-with: Cursor
Fix six issues flagged by the Greptile review:

- Make DeepONetWrapper / DeepONet3DWrapper inherit from
  physicsnemo.core.module.Module (MOD-001). Core DeepONet / DeepONet3D
  also pass proper MetaData dataclasses.
- Raise ValueError at __init__ when mionet / fourier_mionet / tno are
  constructed without branch2_config (prevents silent degradation to a
  single-branch model).
- Add optional output_window constructor parameter so the
  temporal_projection decoder registers temporal_head at __init__,
  producing a deterministic state_dict that round-trips cleanly.
  set_output_window is retained for backwards compatibility.
- Raise ValueError from MLPBranch when num_layers < 2.
- Convert public docstrings to r-prefixed raw strings with
  Parameters / Forward / Outputs sections and LaTeX shape notation
  per MOD-003.
- Add jaxtyping.Float annotations and torch.compiler.is_compiling()
  guarded shape validation to all public forward methods
  (MOD-005, MOD-006).

Signed-off-by: wdyab <wdyab@nvidia.com>
Made-with: Cursor
Fix the new P1 issue flagged in the second Greptile review and close
two secondary gaps the summary called out:

- DeepONet.forward / DeepONet3D.forward: raise RuntimeError when
  decoder_type='temporal_projection' is used but temporal_head is
  still None (i.e. the user neither passed output_window at
  construction nor called set_output_window before forward).
  Previously the silent ``if temporal_head is not None`` skip
  returned (B, H, W, width) instead of (B, H, W, K).
- Deduplicate the VALID_VARIANTS list: pulled to a module-level
  _VALID_VARIANTS tuple; both DeepONet and DeepONet3D still expose it
  as the VALID_VARIANTS class attribute for a stable public API.
- Extend the parametrized test lists to cover fourier_deeponet,
  hybrid_deeponet, and fourier_mionet, and add a dedicated
  TestFourierBranchPaths class with num_fourier_layers > 0 so the
  spectral-conv code path in SpatialBranch / SpatialBranch3D is
  actually exercised in CI.
- Add a TestTemporalProjectionGuard::test_forward_without_output_window_raises
  regression test for the new RuntimeError.

Signed-off-by: wdyab <wdyab@nvidia.com>
Made-with: Cursor
wdyab and others added 8 commits May 13, 2026 02:58
Two new P1 issues flagged on 85076f6:

- Case-sensitive decoder_type check: __init__ lowered ``decoder_type``
  into ``self.decoder_type`` but then branched on the raw argument
  (``if decoder_type == "temporal_projection":``) and forwarded the
  raw value to ``_build_decoder``.  A user passing
  ``decoder_type="MLP"`` or ``"Temporal_Projection"`` ended up with
  ``Unknown decoder_type: MLP`` bubbling out of ``_build_decoder``.
  Both branches of the check now use ``self.decoder_type``; same fix
  in ``DeepONet3D.__init__``.
- MLP branch + decoder_type='temporal_projection' silently returned
  (B, T, width) instead of (B, K) because the MLP-branch path in
  ``forward`` never consulted ``self._temporal_projection``.  The
  incompatibility is static, so reject it at __init__ with a
  descriptive ``ValueError`` rather than at forward.  Same guard in
  ``DeepONet3D.__init__``.

Regression tests: ``TestDecoderTypeNormalization`` (mixed-case
``"MLP"`` / ``"Temporal_Projection"`` accepted) and
``TestMLPBranchTemporalProjectionGuard`` (2D and 3D both reject the
invalid combination).

Signed-off-by: wdyab <wdyab@nvidia.com>
Made-with: Cursor
Proactive audit on top of Greptile's round-4 findings.  All
plausible silent-degradation combinations at the config boundary
now fail loudly at __init__ instead of producing wrong shapes or
cryptic PyTorch errors at forward time.

Construction-time guards added to both DeepONet and DeepONet3D:

- Unknown decoder_type is rejected up front against a new
  module-level ``_VALID_DECODER_TYPES`` set (previously deferred
  to ``_build_decoder`` and only surfaced on the non-temporal
  branch).
- MLPBranch branch1 paired with decoder_type='conv' is rejected
  (would otherwise crash inside ``Conv2d`` with a generic
  "Expected 3D or 4D input" message).  Unified with the existing
  temporal_projection guard into a single check.
- MLPBranch branch1 paired with a non-MLPBranch branch2 is
  rejected (element-wise product assumed matching ranks;
  previously broadcast nonsensically or raised a cryptic dim
  mismatch at forward).

Regression tests:

- ``TestMLPBranchConvDecoderGuard`` -- 2D/3D
- ``TestMixedBranchTypeGuard`` -- 2D/3D
- ``TestInvalidDecoderTypeGuard`` -- 2D/3D

Full suite: 47 passed.

Signed-off-by: wdyab <wdyab@nvidia.com>
Made-with: Cursor
``_build_conv_encoder`` called ``get_activation`` directly without
the ``sin`` special-case handling used in every other activation
site in ``branches.py``.  Passing
``{"encoder": {"type": "conv", "activation_fn": "sin"}}`` therefore
raised ``KeyError: Activation function sin not found``.

``torch.sin`` is a bare callable and cannot be placed inside an
``nn.Sequential`` (which requires ``nn.Module`` instances), so the
fix introduces a small ``_SinActivation`` wrapper module alongside
``_build_conv_encoder``.  The helper is module-level and is called
from both ``DeepONet`` and ``DeepONet3D``; only one fix site exists
despite the function being invoked from both classes.

Regression test ``TestConvEncoderSinActivation`` constructs a
multi-layer conv encoder with ``activation_fn="sin"`` and runs a
forward pass to confirm neither the ``KeyError`` nor a
``nn.Sequential`` ``TypeError`` resurface.

Signed-off-by: wdyab <wdyab@nvidia.com>
Made-with: Cursor
- Rewrite test suite to the constructor + non-regression + checkpoint
  + gradient + compile pattern; relocate to
  ``test/experimental/models/xdeeponet/`` to match the ``flare/`` /
  ``fno/`` per-model layout.  Commit ``.pth`` goldens + regeneration
  script under ``data/``.
- Add ``Examples`` sections to every user-facing class.
- Add ``jaxtyping.Float`` annotations to top-level ``forward`` methods
  (MOD-006).
- Remove ``README.md``; no model under ``physicsnemo/`` ships one.
  Design rationale now lives in module-level and class-level docstrings.

Signed-off-by: wdyab <wdyab@nvidia.com>
Made-with: Cursor
Address @peterdsharpe review feedback on PR NVIDIA#1576: the ``variant`` and
``decoder_type`` parameters now carry ``typing.Literal`` annotations
rather than bare ``str``, so static type checkers and IDE
auto-completion can flag unknown values at the call site instead of
deferring to the runtime ``ValueError``.

The ``Literal`` aliases ``_VariantStr`` and ``_DecoderTypeStr`` are
defined once at module scope in ``deeponet.py`` and imported into
``wrappers.py``.  ``_VALID_VARIANTS`` and ``_VALID_DECODER_TYPES`` are
now derived from those aliases via ``typing.get_args`` so the static
and runtime views cannot drift.

Behaviour is unchanged: Python does not enforce ``Literal`` at runtime,
so the existing ``.lower()`` normalization and ``ValueError`` guards
keep mixed-case inputs working (verified by the existing
``TestDecoderTypeNormalization`` round-trip).  No tests, fixtures, or
constructor signatures were affected; full suite remains 13 passing
(+ 1 deselected ``torch.compile`` smoke run separately).

Signed-off-by: wdyab <wdyab@nvidia.com>
Co-authored-by: Cursor <cursoragent@cursor.com>
Addresses the inline-comment housekeeping items from the May-12 review
on PR NVIDIA#1576, with no public-API or semantic changes:

- Modernize type hints across xdeeponet sources: ``Dict[str, Any]`` ->
  ``dict[str, Any] | None``, ``Optional[X]`` -> ``X | None``,
  ``List`` / ``Tuple`` -> ``list`` / ``tuple``. Drop ``Dict``,
  ``Optional``, ``List``, ``Tuple`` imports from ``typing``. Annotation
  files already had ``from __future__ import annotations`` so this is
  safe on the minimum supported Python.
- Rename ``padding.py`` -> ``_padding.py`` to signal that the module is
  package-internal and may be restructured without notice. Update the
  one importer (``wrappers.py``) accordingly. Add ``jaxtyping.Shaped``
  annotations and full Parameters/Returns docstrings on
  ``pad_right_nd`` and ``pad_spatial_right``; the helpers are
  dtype-agnostic (used on real-valued tensors but the implementation is
  structural) so ``Shaped`` is preferred over ``Float``.
- Add a prominent ``.. important::`` block to ``_UNet2DFromUNet3D``
  spelling out the 8x memory/compute cost of selecting
  ``num_unet_layers > 0`` in a 2D ``SpatialBranch``, since the upstream
  library UNet is 3D-only and we tile.
- Expose ``has_temporal_projection`` as a public read-only property on
  both ``DeepONet`` and ``DeepONet3D`` to replace external reads of
  ``_temporal_projection``. Drop the dead cache of that flag in the two
  wrapper classes (the assignment had no remaining readers).
- Replace the 2D wrapper's permute/slice/permute spatial extraction
  with the cleaner ``x[:, :, :, 0, :]`` slice that mirrors the 3D
  version. Same result, fewer ops.
- Document the ``branch1_config`` / ``branch2_config`` / ``trunk_config``
  dict schemas in a new Notes section on the ``DeepONet`` class
  docstring, with ``.. code-block::`` examples enumerating all
  recognised keys and their defaults. ``DeepONet3D`` already references
  the 2D docstring so it inherits the schema for free.

Bumps no version, removes no public symbols. Behavior, state_dict keys
and forward outputs are unchanged.
Adds an elementwise sine activation as a regular ``nn.Module`` and
registers it under the key ``"sin"`` in
``physicsnemo.nn.module.activations.ACT2FN`` so it can be looked up by
name via ``get_activation("sin")``.  Also re-exports the ``Sin`` class
from ``physicsnemo.nn``.

This unblocks the xdeeponet refactor (next commit), which previously
carried a private ``_SinActivation`` shim and several
``if activation_fn.lower() == "sin"`` special-cases in branch/decoder
construction.  Both can now go away because the activation registry
handles ``"sin"`` natively.
Addresses Peter Sharpe's CHANGES_REQUESTED review on PR NVIDIA#1576 in full,
and subsumes the 3D-UFNO portion of the planned xFNO PR into this one.
Net: -139 lines despite gaining trunkless mode, time-axis-extend,
multi-channel output, coord features, multi-layer lift, and 2D/3D
genericity.

Theme 1 — Dimensional unification (Peter NVIDIA#37, NVIDIA#47, NVIDIA#52, NVIDIA#53):

- ``DeepONet`` (formerly DeepONet + DeepONet3D) takes ``dimension: int``
  (2 or 3) and dispatches via ``_DIM_DEFAULTS`` and per-dim conv/spectral
  primitives, mirroring the ``FNO`` pattern.
- ``SpatialBranch`` (formerly SpatialBranch + SpatialBranch3D) takes
  ``dimension`` and uses an ``_DIM_LAYERS`` lookup for
  ``SpectralConv``/``Conv``/``BatchNorm``/``AdaptiveAvgPool``/
  ``UNetAdapter`` and the permute helpers.
- ``Conv{2,3}dFCLayer`` is selected via a one-line lookup
  (Peter NVIDIA#45).

Theme 2 — Wrappers folded into DeepONet (Peter NVIDIA#54, NVIDIA#64, NVIDIA#65):

- ``wrappers.py`` deleted (``DeepONetWrapper`` and ``DeepONet3DWrapper``
  removed).  Padding behaviour is now a constructor flag,
  ``auto_pad: bool = False``, and the model dispatches to
  ``_forward_packed`` / ``_forward_packed_trunkless`` accordingly.
- 6-cell call matrix (trunked/trunkless × packed/core × spatial/mlp)
  is documented in the class docstring.
- The previous private ``_temporal_projection`` attribute is exposed as
  a public ``has_temporal_projection`` property (Peter NVIDIA#55).

Theme 3 — Deduplication (Peter NVIDIA#43, NVIDIA#44, NVIDIA#50, NVIDIA#51, NVIDIA#40, related Greptile):

- ``TrunkNet`` and ``MLPBranch`` deleted — both duplicated
  ``physicsnemo.models.mlp.FullyConnected``; users now pass any
  ``nn.Module`` for the trunk / branches (DI-first API).
- ``_SinActivation`` deleted; the activation is registered as ``"sin"``
  in ``physicsnemo.nn.module.activations.ACT2FN`` (previous commit).
  All ``if activation_fn.lower() == "sin"`` special-cases removed.
- ``DeepONet.from_config`` and the dict-config schema removed entirely;
  Hydra-style ``_target_`` instantiation supersedes it.
- ``count_params`` collapsed from 4 duplicate copies to 1.

Theme 4 — xFNO fold-in:

- ``trunk: nn.Module | None = None`` enables trunkless mode (the 3D-UFNO
  use case from the planned xFNO PR).
- ``out_channels: int = 1`` adds multi-channel output to every path.
- ``time_modes: int | None`` enables xFNO-style time-axis-extend in
  trunkless packed mode: replicate-pads the last spatial axis to fit
  ``2 * time_modes`` and crops to the requested ``target_times``.
- ``coord_features`` and ``lift_layers``/``lift_hidden_width`` parameters
  on ``SpatialBranch`` replace the deleted dict-driven "conv encoder"
  option.

Theme 5 — Housekeeping (Peter NVIDIA#33, NVIDIA#34, NVIDIA#38, NVIDIA#41, NVIDIA#48, NVIDIA#57, NVIDIA#58, NVIDIA#59,
Charlelie NVIDIA#26, Greptile NVIDIA#5, NVIDIA#6):

- ``padding.py`` renamed to private ``_padding.py``; all functions
  carry ``jaxtyping.Shaped`` annotations.
- All public forward methods carry ``jaxtyping.Float`` annotations and
  ``torch.compiler.is_compiling`` shape-validation guards.
- ``Literal`` type aliases for ``decoder_type`` and other enums;
  case-insensitive validation against ``get_args`` (Greptile NVIDIA#15).
- Modern type hints throughout (``dict[str, Any] | None``, no
  ``Dict``/``Optional``).
- All public docstrings use ``r"""`` raw-string prefix, LaTeX math for
  tensor shapes, double backticks for inline code, and Examples
  sections.
- ``Notes`` block in ``branches.py`` documents the ``num_unet_layers``
  8x memory/compute penalty (Peter NVIDIA#49).

Theme 6 — Tests (Peter NVIDIA#60, NVIDIA#61, NVIDIA#62, NVIDIA#63, Charlelie NVIDIA#29):

- ``_FIXTURE_REGISTRY`` drives all non-regression tests across 9
  scenarios: u_deeponet 2D/3D, fourier_deeponet, mionet,
  temporal_projection, multi-channel packed 2D, xfno trunkless 3D
  (with and without time-axis-extend), and core 2D MLP-branch.
- New 3D gradient-flow test and 3D ``torch.compile`` test.
- ``fullgraph=True`` probe tests for 2D and 3D marked
  ``@pytest.mark.xfail(strict=False)`` to empirically answer Peter NVIDIA#63.
- ``_load_golden`` uses ``pytest.skip`` for missing fixtures so CI
  passes pending cluster-side golden regeneration.
- Test class structure mirrors MOD-008a/b/c: ``TestDeepONetConstructor``,
  ``TestDeepONetNonRegression``, ``TestDeepONetCheckpoint``,
  ``TestDeepONetGradientFlow``, ``TestDeepONetCompile``,
  ``TestDeepONetTimeAxisExtend``.

CHANGELOG bullet rewritten to describe the actual shipped API
(was stale, still described the old config-driven 8-variant model).
@wdyab
Copy link
Copy Markdown
Contributor Author

wdyab commented May 13, 2026

Thanks @peterdsharpe for your comments. All four themes are in a87a584 (with the Sin registry change pulled out as 5901bf3):

  • Dimensional unification: DeepONet(dimension=2|3) + SpatialBranch(dimension=2|3) via _DIM_DEFAULTS/_DIM_LAYERS lookups. DeepONet3D/SpatialBranch3D deleted.

  • Wrappers folded into DeepONet via auto_pad: bool + a 6-cell forward dispatch matrix. wrappers.py deleted.

  • Dedup: TrunkNet, MLPBranch, _SinActivation, from_config and the whole dict-schema all deleted; DI-first API. sin now lives in ACT2FN.

  • Tests: 9-scenario _FIXTURE_REGISTRY, 3D gradient + compile, fullgraph=True xfail probes.

I also folded the 3D-UFNO portion of the planned xFNO PR in here (trunk=None trunkless mode, time_modes time-axis-extend, out_channels) since dedup made it nearly free. The xFNO PR is now scoped down to FNO4D-only, dropped to draft.

@wdyab wdyab requested a review from loliverhennigh as a code owner May 13, 2026 20:00
@wdyab wdyab requested a review from peterdsharpe May 13, 2026 20:01
Copy link
Copy Markdown
Collaborator

@peterdsharpe peterdsharpe left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Overall this is in great shape - excellent work addressing the comments from the first review.

There are a few minor things to address before merging, but I'm marking this as "approve" so as to not block this further.

In addition to the comments, the single remaining blocker is:

The test file (test_xdeeponet.py:56-66) defines 9 golden paths:

_GOLDEN_PACKED_2D          = .../"xdeeponet_packed_2d_v1.pth"
_GOLDEN_PACKED_3D          = .../"xdeeponet_packed_3d_v1.pth"
_GOLDEN_PACKED_2D_FOURIER  = .../"xdeeponet_packed_2d_fourier_v1.pth"
_GOLDEN_PACKED_2D_MIONET   = .../"xdeeponet_packed_2d_mionet_v1.pth"
_GOLDEN_PACKED_2D_TEMPORAL  = .../"xdeeponet_packed_2d_temporal_v1.pth"
_GOLDEN_PACKED_2D_MULTICHANNEL = .../"xdeeponet_packed_2d_multichannel_v1.pth"
_GOLDEN_XFNO_PACKED_3D     = .../"xdeeponet_xfno_packed_3d_v1.pth"
_GOLDEN_XFNO_PACKED_3D_EXTEND = .../"xdeeponet_xfno_packed_3d_extend_v1.pth"
_GOLDEN_CORE_2D_MLPBRANCH  = .../"xdeeponet_core_2d_mlpbranch_v1.pth"

None of these exist on disk. The only committed fixtures are the old-API files xdeeponet_wrapper_2d_v1.pth and xdeeponet_wrapper_3d_v1.pth, which no test references anymore. As a result:

  • All 9 TestDeepONetNonRegression scenarios will pytest.skip.
  • Both TestDeepONetCheckpoint scenarios will pytest.skip (they load
    _GOLDEN_PACKED_2D / _GOLDEN_PACKED_3D).

Two things here:
a) If golden files are missing, we should pytest.fail, not pytest.skip. I've added a commit that already fixes this.
b) We should regenerate the files - should be pretty quick, just a re-run of _generate_xdeeponet_goldens.py, commit the result, remove the old files, and we should be good to go.

Great work!

Comment thread physicsnemo/nn/module/activations.py Outdated
Comment thread physicsnemo/experimental/models/xdeeponet/deeponet.py Outdated
Comment thread physicsnemo/experimental/models/xdeeponet/deeponet.py Outdated
Comment thread physicsnemo/experimental/models/xdeeponet/deeponet.py Outdated
Comment thread physicsnemo/experimental/models/xdeeponet/_padding.py Outdated
Comment thread physicsnemo/experimental/models/xdeeponet/_padding.py Outdated
Comment thread physicsnemo/experimental/models/xdeeponet/deeponet.py Outdated
Comment thread physicsnemo/experimental/models/xdeeponet/_padding.py Outdated
@peterdsharpe
Copy link
Copy Markdown
Collaborator

/blossom-ci

…1576)

Addresses the inline comments left on the approving review and
regenerates the golden fixture set so the new fixture names land on
disk with current numerics.

Review fixes (peterdsharpe inline comments, 2026-05-14):
- physicsnemo/nn/module/activations.py: trim the Sin class docstring
  to focus on what the layer does instead of who consumes it;
  mirrors the style of sibling activations in the file.
- physicsnemo/experimental/models/xdeeponet/deeponet.py:
  * Type `dimension` as Literal[2, 3] in both the docstring and the
    constructor signature.
  * Expand the decoder_activation_fn docstring with a cross-reference
    to physicsnemo.nn.module.activations.get_activation / ACT2FN so
    the supported names are discoverable.
  * Drop the dead `variant` kwarg (docstring entry, two doctest
    examples, signature, attribute assignment, error-message
    interpolation).  Tests adjusted in lockstep: every `variant=`
    builder kwarg, parametrize-config entry, and `assert
    model.variant` site removed.
  * Drop the unused count_params() method (downstream code uses
    torchinfo; the one-line equivalent suffices).
- physicsnemo/experimental/models/xdeeponet/_padding.py:
  * Narrow `mode` to Literal["replicate", "constant"] on both
    pad_right_nd and pad_spatial_right.
  * Replace the `torch.tensor(rest_shape).prod().item()` shuttle
    with `math.prod(rest_shape)`: avoids a CPU<->GPU transfer and
    handles the zero-length case without a conditional.

Symmetric branch2 un-pack in _forward_packed:
- physicsnemo/experimental/models/xdeeponet/deeponet.py: when
  x_branch2 is a (B, *spatial, T, C) packed tensor (matching x's
  shape), strip its time axis with the same idx_strip_T as we apply
  to x before dispatching to _forward_core.  Without this, the
  mionet packed-input forward passes a 5D tensor into a SpatialBranch
  expecting 4D input and crashes at branch2.forward.

Regenerate golden fixtures:
- Drop the two stale wrapper fixtures the new test suite no longer
  references: xdeeponet_wrapper_{2d,3d}_v1.pth.
- Add 11 fixtures keyed off the new _FIXTURE_REGISTRY entries; the
  generator script _generate_xdeeponet_goldens.py is unchanged (it
  iterates over the registry, so new scenarios pick up
  automatically).

Add kitchen-sink stress tests:
- test_xdeeponet.py: two new builders cover code paths no other
  fixture exercises.
  * _packed_2d_kitchen_sink (2D, auto_pad=True): Fourier+UNet+Conv
    on both branches, temporal_projection decoder with
    output_window>1, trunk_input='grid', multi-channel output,
    multi-layer pointwise lift, asymmetric coord_features, Sin
    trunk.
  * _core_3d_kitchen_sink (3D, auto_pad=False): mionet + conv
    decoder + 3D core entry point + lift_layers=3 with
    lift_hidden_width + 3-layer trunk with output_activation=False
    + celu/leaky_relu/elu/tanh activation palette.  These knobs had
    zero prior coverage in the registry.
- TestDeepONetStress class: forward-shape, gradient-flow, and
  torch.compile(fullgraph=False) parity checks, one method per
  (concern, dimensionality).

Test results on the regenerated tree (login-eos01 CPU): 32 passed,
2 xpassed, 2 failed.  The 2 failures (TestDeepONetCheckpoint
roundtrips) are pre-existing on HEAD -- branch1 is a torch.nn.Module
rather than a physicsnemo.Module so Module.save rejects the
hierarchy.  Confirmed by stashing this commit's deeponet.py changes
and rerunning against bare HEAD; failure signature is identical.
That gap is tracked separately and is not in scope for this commit.
@wdyab
Copy link
Copy Markdown
Contributor Author

wdyab commented May 14, 2026

Done in b01d844, regenerated the 9 fixtures referenced by FIXTURE_REGISTRY (now 11 with two new kitchen-sink scenarios) and removed the two stale xdeeponet_wrapper{2d,3d}_v1.pth files. Also added a kitchen-sink 2D+3D stress test pair (TestDeepONetStress) exercising every constructor knob; details are in the commit body.

@peterdsharpe
Copy link
Copy Markdown
Collaborator

peterdsharpe commented May 14, 2026

Looks great @wdyab! Pinging @loliverhennigh for review (via CODEOWNERS), and I'll kick off Blossom-CI to merge

Fixes the 2 remaining CI failures on PR NVIDIA#1576:

  FAILED TestDeepONetCheckpoint::test_wrapper_2d_roundtrip
  FAILED TestDeepONetCheckpoint::test_wrapper_3d_roundtrip
      TypeError: Submodule branch1 of module DeepONet is a PyTorch
      module, which is not supported by 'Module.save'.

This was a long-standing MOD-001 violation in the PR: SpatialBranch
inherited from torch.nn.Module rather than physicsnemo.Module.
DeepONet (correctly) inherits from physicsnemo.Module, and
Module._save_process walks the constructor-arg submodules and rejects
any plain torch.nn.Module.  The skip-to-fail change in 3bc46e9
surfaced this by replacing the silent pytest.skip on missing fixtures
with a hard fail; the regen commit (b01d844) then moved the fail
point past load_golden into the actual Module.save call, where the
long-hidden architectural issue became the visible one.

Production change (physicsnemo/experimental/models/xdeeponet/branches.py):
- Add _SpatialBranchMetaData(ModelMetaData) dataclass.
- Change `class SpatialBranch(nn.Module)` to `class SpatialBranch(Module)`.
- Pass `meta=_SpatialBranchMetaData()` to `super().__init__`.
- Add physicsnemo.core.meta and physicsnemo.core.module imports plus
  `from dataclasses import dataclass`.

Mirrors the existing DeepONet / _DeepONetMetaData pattern verbatim.
physicsnemo.Module.__init__ records constructor args on `_args` but
does not register any parameters or buffers, so SpatialBranch's
state_dict keys are unchanged by this transition.

The internal _UNet2DFromUNet3D / _UNet3DFromUNet3D adapters stay as
plain nn.Modules -- they're constructed inside SpatialBranch.__init__
and live in self.unet_modules (i.e. in _modules, not _args), so
they're never seen by Module._save_process and don't need converting.
Keeping them as nn.Module accurately reflects their role as
implementation details private to SpatialBranch.

Test changes (test/experimental/models/xdeeponet/test_xdeeponet.py):
- New test-only _MLPWithTrailingActivation(Module) wrapper replaces
  the `nn.Sequential(FullyConnected, get_activation(...))` pattern
  previously returned by `_make_trunk` and `_make_mlp_branch`.  A bare
  nn.Sequential is rejected by Module._save_process for the same
  reason as the original SpatialBranch.  The wrapper preserves forward
  semantics byte-for-byte: `activation(fc(x))` is what
  `nn.Sequential(fc, activation)` did.
- `_make_trunk(output_activation=True)` and `_make_mlp_branch` return
  _MLPWithTrailingActivation instances.
  `_make_trunk(output_activation=False)` keeps returning bare
  FullyConnected (already a physicsnemo.Module).
- Drop the @pytest.mark.xfail(strict=False) markers from
  test_wrapper_{2d,3d}_compile_fullgraph: both compile cleanly on the
  torch versions exercised in CI (Python 3.12 + uv) and locally
  (Python 3.10 on EOS).  The markers' own comments invited removal
  once the tests passed reliably; they reported XPASS in both CI and
  local runs on b01d844.  Re-add the markers if a future torch
  update reintroduces graph breaks.

Fixture regen:
- 3 fixtures are byte-identical to b01d844: xfno_packed_3d_v1.pth,
  xfno_packed_3d_extend_v1.pth (trunkless -- no trunk submodule), and
  core_3d_kitchen_sink_v1.pth (uses output_activation=False trunk so
  already returned bare FullyConnected).
- 8 fixtures got new bytes because their saved state_dict key
  prefixes flipped from `trunk.0.layers.*` (nn.Sequential
  numbered-attr layout) to `trunk.fc.layers.*`
  (_MLPWithTrailingActivation named-attr layout), plus the mlpbranch
  fixture's branch1 keys.  The forward output tensor `y` stored in
  each fixture is numerically unchanged; only the state_dict keys
  differ.  All 10 TestDeepONetNonRegression scenarios load and match.

Test results on the regenerated tree (login-eos01 CPU):
  36 passed, 0 xpassed, 0 failed in ~260s.
Both TestDeepONetCheckpoint roundtrip tests now pass cleanly.
@peterdsharpe
Copy link
Copy Markdown
Collaborator

/blossom-ci

@peterdsharpe
Copy link
Copy Markdown
Collaborator

/blossom-ci

@peterdsharpe peterdsharpe added this pull request to the merge queue May 15, 2026
Merged via the queue into NVIDIA:main with commit 751d984 May 15, 2026
4 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants