-
Notifications
You must be signed in to change notification settings - Fork 664
Add xDeepONet family to experimental models #1576
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
Changes from 14 commits
Commits
Show all changes
16 commits
Select commit
Hold shift + click to select a range
94aad70
Add xDeepONet family to experimental models
wdyab 6fb34d8
xdeeponet: address PR #1576 review feedback (Greptile)
wdyab 2795fa2
xdeeponet: address second Greptile review (PR #1576)
wdyab fb4c4e1
xdeeponet: address third Greptile review (PR #1576)
wdyab 41737cb
xdeeponet: close out silent-degradation combinations (PR #1576)
wdyab 5c822c1
xdeeponet: fix _build_conv_encoder for "sin" activation (PR #1576)
wdyab 7bb0d10
xdeeponet: address review feedback on PR #1576
wdyab 1d53c5a
xdeeponet: use Literal type aliases for enumerated string parameters
wdyab 4404aaa
xdeeponet: housekeeping pass (Theme 5 of PR #1576 review)
wdyab 5901bf3
physicsnemo.nn: register Sin activation in ACT2FN
wdyab a87a584
xdeeponet: dimensional unification, DI-first API, xFNO fold-in
wdyab b56c6ea
Merge branch 'main' into pr/xdeeponet
wdyab 3bc46e9
missing golden files should fail
peterdsharpe b01d844
xdeeponet: address review feedback and regenerate goldens (PR #1576)
wdyab bcc2bc0
xdeeponet: SpatialBranch -> physicsnemo.Module (MOD-001 fix)
wdyab 8f66832
fixing pre-commit violation
peterdsharpe File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,49 @@ | ||
| # SPDX-FileCopyrightText: Copyright (c) 2023 - 2026 NVIDIA CORPORATION & AFFILIATES. | ||
| # SPDX-FileCopyrightText: All rights reserved. | ||
| # SPDX-License-Identifier: Apache-2.0 | ||
| # | ||
| # Licensed under the Apache License, Version 2.0 (the "License"); | ||
| # you may not use this file except in compliance with the License. | ||
| # You may obtain a copy of the License at | ||
| # | ||
| # http://www.apache.org/licenses/LICENSE-2.0 | ||
| # | ||
| # Unless required by applicable law or agreed to in writing, software | ||
| # distributed under the License is distributed on an "AS IS" BASIS, | ||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
| # See the License for the specific language governing permissions and | ||
| # limitations under the License. | ||
|
|
||
| """xDeepONet — the extended DeepONet family. | ||
|
|
||
| A single :class:`DeepONet` class assembles operator-learning | ||
| architectures spanning the DeepONet and FNO families: | ||
|
|
||
| - ``deeponet``, ``u_deeponet``, ``fourier_deeponet``, ``conv_deeponet``, | ||
| ``hybrid_deeponet`` — single-branch + trunk variants. | ||
| - ``mionet``, ``fourier_mionet`` — two-branch multi-input + trunk variants. | ||
| - ``tno`` — Temporal Neural Operator (branch2 = previous solution) + trunk. | ||
| - ``ufno`` / xFNO-style trunkless operators — trunkless spatial branch | ||
| with composable Fourier / UNet / Conv layers; the last spatial axis | ||
| can be interpreted as time for autoregressive bundling via the | ||
| :attr:`DeepONet.time_modes` parameter. | ||
|
|
||
| The :class:`DeepONet` class is dimension-generic (``dimension=2|3`` | ||
| constructor argument; per-dimension primitives are dispatched | ||
| internally) and dispatches forward by two flags | ||
| (:attr:`auto_pad`, :attr:`trunk`-is-None) over six valid call | ||
| conventions: packed-input vs core-input × trunked vs trunkless, | ||
| plus the ``temporal_projection`` decoder variant. See the | ||
| :class:`DeepONet` class docstring for the full matrix and worked | ||
| examples; see :class:`SpatialBranch` for the spatial-encoder | ||
| composition options (Fourier / UNet / Conv layers, multi-layer | ||
| pointwise lift, optional coordinate-feature channels). | ||
| """ | ||
|
|
||
| from .branches import SpatialBranch | ||
| from .deeponet import DeepONet | ||
|
|
||
| __all__ = [ | ||
| "DeepONet", | ||
| "SpatialBranch", | ||
| ] | ||
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,253 @@ | ||
| # SPDX-FileCopyrightText: Copyright (c) 2023 - 2026 NVIDIA CORPORATION & AFFILIATES. | ||
| # SPDX-FileCopyrightText: All rights reserved. | ||
| # SPDX-License-Identifier: Apache-2.0 | ||
| # | ||
| # Licensed under the Apache License, Version 2.0 (the "License"); | ||
| # you may not use this file except in compliance with the License. | ||
| # You may obtain a copy of the License at | ||
| # | ||
| # http://www.apache.org/licenses/LICENSE-2.0 | ||
| # | ||
| # Unless required by applicable law or agreed to in writing, software | ||
| # distributed under the License is distributed on an "AS IS" BASIS, | ||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
| # See the License for the specific language governing permissions and | ||
| # limitations under the License. | ||
|
|
||
| """Right-side spatial padding helpers used by the xDeepONet packed-input | ||
| forward path. | ||
|
|
||
| When :class:`~physicsnemo.experimental.models.xdeeponet.DeepONet` is | ||
| constructed with ``auto_pad=True`` it aligns spatial dimensions to a | ||
| multiple (typically 8) so that spectral and convolutional sub-branches | ||
| operate on compatible shapes. These helpers are dimension-agnostic and | ||
| support 2D, 3D, or 4D spatial layouts. | ||
|
|
||
| Tensor layouts used here: | ||
| - 2D spatial samples: ``(B, H, W, T, C)`` | ||
| - 3D spatial samples: ``(B, X, Y, Z, T, C)`` | ||
|
|
||
| This module is private (leading underscore): the helpers are part of the | ||
| xdeeponet package's internal API surface only and may be renamed or | ||
| restructured without notice. | ||
| """ | ||
|
|
||
| from __future__ import annotations | ||
|
|
||
| import math | ||
| from typing import Literal, Sequence | ||
|
|
||
| import torch | ||
| import torch.nn.functional as F | ||
| from jaxtyping import Shaped | ||
| from torch import Tensor | ||
|
|
||
|
|
||
| def compute_right_pad_to_multiple( | ||
| spatial_shape: Sequence[int], | ||
| *, | ||
| multiple: int = 8, | ||
| min_right_pad: int = 0, | ||
| ) -> tuple[int, ...]: | ||
| """Compute right-side padding to reach a multiple of *multiple*. | ||
|
|
||
| Parameters | ||
| ---------- | ||
| spatial_shape : Sequence[int] | ||
| Current spatial dimension sizes. | ||
| multiple : int, optional | ||
| Target alignment (default ``8``). | ||
| min_right_pad : int, optional | ||
| Minimum right-side padding applied per dimension (default ``0``). | ||
|
|
||
| Returns | ||
| ------- | ||
| tuple[int, ...] | ||
| Right-side padding per dimension such that ``(d + pad)`` is a multiple | ||
| of *multiple* and ``pad >= min_right_pad``. | ||
| """ | ||
| if multiple <= 0: | ||
| raise ValueError(f"multiple must be > 0, got {multiple}") | ||
| if min_right_pad < 0: | ||
| raise ValueError(f"min_right_pad must be >= 0, got {min_right_pad}") | ||
|
|
||
| pads = [] | ||
| for d in spatial_shape: | ||
| if d <= 0: | ||
| raise ValueError( | ||
| f"spatial dimensions must be positive, got {spatial_shape}" | ||
| ) | ||
| to_mult = (multiple - (d % multiple)) % multiple | ||
| if to_mult >= min_right_pad: | ||
| pad = to_mult | ||
| else: | ||
| deficit = min_right_pad - to_mult | ||
| k = (deficit + multiple - 1) // multiple | ||
| pad = to_mult + k * multiple | ||
| pads.append(int(pad)) | ||
| return tuple(pads) | ||
|
|
||
|
|
||
| def pad_right_nd( | ||
| x: Shaped[Tensor, "..."], | ||
| *, | ||
| dims: Sequence[int], | ||
| right_pad: Sequence[int], | ||
| mode: Literal["replicate", "constant"] = "replicate", | ||
| constant_value: float = 0.0, | ||
| ) -> Shaped[Tensor, "..."]: | ||
| """Right-pad arbitrary dimensions of an N-D tensor. | ||
|
|
||
| Implemented manually so it works for ``mode="replicate"`` even when | ||
| :func:`torch.nn.functional.pad` does not support the tensor rank | ||
| (e.g. 6D tensors in the 3D-spatial case). | ||
|
|
||
| Parameters | ||
| ---------- | ||
| x : torch.Tensor | ||
| Input tensor of any rank and dtype. | ||
| dims : Sequence[int] | ||
| Dimensions to right-pad. Negative indices are supported. | ||
| right_pad : Sequence[int] | ||
| Right-side padding amounts per ``dims`` entry. Non-positive | ||
| entries are no-ops. | ||
| mode : str, optional | ||
| ``"replicate"`` (default) repeats the last slice along each | ||
| padded dim; ``"constant"`` uses ``constant_value``. | ||
| constant_value : float, optional | ||
| Fill value when ``mode="constant"`` (default ``0.0``). | ||
|
|
||
| Returns | ||
| ------- | ||
| torch.Tensor | ||
| Tensor of the same rank and dtype as ``x`` with the specified | ||
| dimensions right-padded. | ||
| """ | ||
| if len(dims) != len(right_pad): | ||
| raise ValueError("dims and right_pad must have the same length") | ||
| if not dims: | ||
| return x | ||
|
|
||
| for dim, pad in zip(dims, right_pad): | ||
| pad = int(pad) | ||
| if pad <= 0: | ||
| continue | ||
| if dim < 0: | ||
| dim = x.dim() + dim | ||
| if dim < 0 or dim >= x.dim(): | ||
| raise ValueError(f"invalid dim {dim} for x.dim()={x.dim()}") | ||
|
|
||
| if mode == "constant": | ||
| pad_shape = list(x.shape) | ||
| pad_shape[dim] = pad | ||
| pad_tensor = torch.full( | ||
| pad_shape, float(constant_value), dtype=x.dtype, device=x.device | ||
| ) | ||
| x = torch.cat([x, pad_tensor], dim=dim) | ||
| continue | ||
|
|
||
| if mode != "replicate": | ||
| raise ValueError( | ||
| f"pad_right_nd supports mode='replicate' or 'constant', got {mode}" | ||
| ) | ||
|
|
||
| last = x.select(dim, x.size(dim) - 1).unsqueeze(dim) | ||
| expand_shape = list(x.shape) | ||
| expand_shape[dim] = pad | ||
| pad_tensor = last.expand(*expand_shape) | ||
| x = torch.cat([x, pad_tensor], dim=dim) | ||
|
|
||
| return x | ||
|
|
||
|
|
||
| def pad_spatial_right( | ||
| x: Shaped[Tensor, "..."], | ||
| *, | ||
| spatial_ndim: int, | ||
| right_pad: Sequence[int], | ||
| mode: Literal["replicate", "constant"] = "replicate", | ||
| constant_value: float = 0.0, | ||
| ) -> Shaped[Tensor, "..."]: | ||
| """Right-pad the first *spatial_ndim* dimensions after the batch dim. | ||
|
|
||
| Assumes ``x`` is shaped ``(B, *spatial, *rest)``. | ||
|
|
||
| Parameters | ||
| ---------- | ||
| x : torch.Tensor | ||
| Input tensor shaped ``(B, *spatial, *rest)``; any dtype is | ||
| accepted. Must satisfy ``x.dim() >= 1 + spatial_ndim``. | ||
| spatial_ndim : int | ||
| Number of spatial dimensions immediately following the batch | ||
| dim. Must be ``2``, ``3``, or ``4``. | ||
| right_pad : Sequence[int] | ||
| Right-side padding amounts per spatial dimension; must have | ||
| length ``spatial_ndim``. Non-positive entries are no-ops. | ||
| mode : str, optional | ||
| ``"replicate"`` (default) or ``"constant"``. | ||
| constant_value : float, optional | ||
| Fill value when ``mode="constant"`` (default ``0.0``). | ||
|
|
||
| Returns | ||
| ------- | ||
| torch.Tensor | ||
| Tensor of the same rank and dtype as ``x`` with the spatial | ||
| dimensions right-padded. | ||
| """ | ||
| if spatial_ndim not in (2, 3, 4): | ||
| raise ValueError(f"spatial_ndim must be 2, 3, or 4, got {spatial_ndim}") | ||
| if len(right_pad) != spatial_ndim: | ||
| raise ValueError( | ||
| f"right_pad must have length {spatial_ndim}, got {len(right_pad)}" | ||
| ) | ||
| if x.dim() < 1 + spatial_ndim: | ||
| raise ValueError( | ||
| f"expected x.dim() >= {1 + spatial_ndim}, got x.dim()={x.dim()}" | ||
| ) | ||
| if all(int(p) == 0 for p in right_pad): | ||
| return x | ||
|
|
||
| # For 4 spatial dims fall back to the generic implementation (works for 6D+). | ||
| if spatial_ndim == 4: | ||
| dims = [1, 2, 3, 4] | ||
| return pad_right_nd( | ||
| x, | ||
| dims=dims, | ||
| right_pad=right_pad, | ||
| mode=mode, | ||
| constant_value=constant_value, | ||
| ) | ||
|
|
||
| # For 2D/3D spatial, use a reshape trick so F.pad(replicate) applies. | ||
| b = x.shape[0] | ||
| spatial_shape = x.shape[1 : 1 + spatial_ndim] | ||
| rest_shape = x.shape[1 + spatial_ndim :] | ||
| rest_prod = math.prod(rest_shape) | ||
|
|
||
| x_reshaped = x.reshape(b, *spatial_shape, rest_prod).permute( | ||
| 0, spatial_ndim + 1, *range(1, 1 + spatial_ndim) | ||
| ) | ||
|
|
||
| if spatial_ndim == 2: | ||
| pad_h, pad_w = (int(p) for p in right_pad) | ||
| pad = (0, pad_w, 0, pad_h) | ||
| else: | ||
| pad_x, pad_y, pad_z = (int(p) for p in right_pad) | ||
| pad = (0, pad_z, 0, pad_y, 0, pad_x) | ||
|
|
||
| if mode == "constant": | ||
| x_padded = F.pad(x_reshaped, pad, mode="constant", value=float(constant_value)) | ||
| else: | ||
| x_padded = F.pad(x_reshaped, pad, mode=mode) | ||
|
|
||
| padded_spatial = x_padded.shape[2 : 2 + spatial_ndim] | ||
| return x_padded.permute(0, *range(2, 2 + spatial_ndim), 1).reshape( | ||
| b, *padded_spatial, *rest_shape | ||
| ) | ||
|
|
||
|
|
||
| __all__ = [ | ||
| "compute_right_pad_to_multiple", | ||
| "pad_right_nd", | ||
| "pad_spatial_right", | ||
| ] |
Oops, something went wrong.
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.