diff --git a/CHANGELOG.md b/CHANGELOG.md index e7ac7f2659..00321b97ec 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -10,6 +10,14 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ### Added +- Adds xDeepONet — the extended DeepONet family — to experimental models + (`physicsnemo.experimental.models.xdeeponet`). Provides a config-driven + assembly of eight DeepONet-based architectures (DeepONet, U-DeepONet, + Fourier-DeepONet, Conv-DeepONet, Hybrid-DeepONet, MIONet, Fourier-MIONet, + and TNO) for 2D and 3D spatial domains, with composable Fourier / UNet / + Conv branches, multiple decoder types (MLP, Conv, temporal projection), + and automatic spatial padding. Suitable for standalone operator learning + and autoregressive temporal bundling. - Adds GLOBE model (`physicsnemo.experimental.models.globe.model.GLOBE`), including new variant that uses a dual tree traversal algorithm to reduce the complexity of the kernel evaluations from O(N^2) to O(N). diff --git a/physicsnemo/experimental/models/xdeeponet/__init__.py b/physicsnemo/experimental/models/xdeeponet/__init__.py new file mode 100644 index 0000000000..40630fbb08 --- /dev/null +++ b/physicsnemo/experimental/models/xdeeponet/__init__.py @@ -0,0 +1,48 @@ +# SPDX-FileCopyrightText: Copyright (c) 2023 - 2026 NVIDIA CORPORATION & AFFILIATES. +# SPDX-FileCopyrightText: All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""xDeepONet — the extended DeepONet family. + +Config-driven assembly of eight DeepONet-based architectures sharing a +common branch/trunk/decoder pattern: + +- ``deeponet``, ``u_deeponet``, ``fourier_deeponet``, ``conv_deeponet``, + ``hybrid_deeponet`` — single-branch variants. +- ``mionet``, ``fourier_mionet`` — two-branch multi-input variants. +- ``tno`` — Temporal Neural Operator (branch2 = previous solution). + +Both 2D and 3D spatial versions are provided. :class:`DeepONetWrapper` and +:class:`DeepONet3DWrapper` are the recommended entry points; see their class +docstrings for usage examples and the branch/trunk configuration schema. +""" + +from .branches import MLPBranch, SpatialBranch, SpatialBranch3D, TrunkNet +from .deeponet import DeepONet, DeepONet3D +from .wrappers import DeepONet3DWrapper, DeepONetWrapper + +__all__ = [ + # Core architectures + "DeepONet", + "DeepONet3D", + # Convenience wrappers (recommended entry points) + "DeepONetWrapper", + "DeepONet3DWrapper", + # Building blocks + "TrunkNet", + "MLPBranch", + "SpatialBranch", + "SpatialBranch3D", +] diff --git a/physicsnemo/experimental/models/xdeeponet/branches.py b/physicsnemo/experimental/models/xdeeponet/branches.py new file mode 100644 index 0000000000..7f8cc73985 --- /dev/null +++ b/physicsnemo/experimental/models/xdeeponet/branches.py @@ -0,0 +1,665 @@ +# SPDX-FileCopyrightText: Copyright (c) 2023 - 2026 NVIDIA CORPORATION & AFFILIATES. +# SPDX-FileCopyrightText: All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Branch and trunk building blocks used by the xDeepONet family. + +Provides four sub-networks: + +- :class:`TrunkNet` — MLP trunk that encodes query coordinates (time or grid). +- :class:`MLPBranch` — fully-connected branch for scalar/vector inputs + (e.g. the scalar branch in MIONet). +- :class:`SpatialBranch` — 2D spatial encoder composable from Fourier, UNet, + and Conv layers. +- :class:`SpatialBranch3D` — 3D counterpart of ``SpatialBranch``. + +UNet sub-modules inside the spatial branches use +:class:`physicsnemo.models.unet.UNet` (3D). A small adapter +:class:`_UNet2DFromUNet3D` is provided locally for the 2D variant: it wraps +the 3D UNet with a singleton time dimension so the same library model covers +both spatial dimensionalities. +""" + +from __future__ import annotations + +from typing import List, Optional + +import torch +import torch.nn as nn +import torch.nn.functional as F +import torch.nn.init as init +from jaxtyping import Float +from torch import Tensor + +from physicsnemo.models.unet import UNet as _PhysicsNeMoUNet +from physicsnemo.nn import SpectralConv2d, SpectralConv3d, get_activation + +# --------------------------------------------------------------------------- +# UNet adapters (wrap the library's 3D UNet for reuse inside spatial branches) +# --------------------------------------------------------------------------- + + +class _UNet2DFromUNet3D(nn.Module): + r"""Adapter using :class:`physicsnemo.models.unet.UNet` for 2D inputs. + + The library UNet is 3D only. To reuse it for 2D, this adapter adds a + short tiled time axis of length :math:`2^{\text{model\_depth}}` (long + enough to survive the UNet's ``model_depth`` pooling stages), runs the + 3D UNet, and averages the result back to 2D. Channel-first layout + :math:`(B, C, H, W)` is preserved on input and output. + """ + + def __init__( + self, + in_channels: int, + out_channels: int, + kernel_size: int = 3, + model_depth: int = 3, + feature_map_channels: Optional[List[int]] = None, + ): + super().__init__() + if feature_map_channels is None: + feature_map_channels = [in_channels] * model_depth + self._t_tile = 2**model_depth + self.unet = _PhysicsNeMoUNet( + in_channels=in_channels, + out_channels=out_channels, + kernel_size=kernel_size, + model_depth=model_depth, + feature_map_channels=feature_map_channels, + num_conv_blocks=1, + conv_activation="leaky_relu", + conv_transpose_activation="leaky_relu", + padding=kernel_size // 2, + pooling_type="MaxPool3d", + normalization="batchnorm", + gradient_checkpointing=False, + ) + + def forward( + self, + x: Float[Tensor, "batch channels h w"], + ) -> Float[Tensor, "batch out_channels h w"]: + """Forward through the 3D UNet via a tiled time axis.""" + x = x.unsqueeze(-1).repeat(1, 1, 1, 1, self._t_tile) + x = self.unet(x) + return x.mean(dim=-1) + + +class _UNet3DFromUNet3D(nn.Module): + r"""Thin wrapper exposing :class:`physicsnemo.models.unet.UNet`. + + Exposes the library 3D UNet with a fixed default configuration suitable + for skip-connection reuse inside :class:`SpatialBranch3D`. + """ + + def __init__( + self, + in_channels: int, + out_channels: int, + kernel_size: int = 3, + model_depth: int = 3, + feature_map_channels: Optional[List[int]] = None, + ): + super().__init__() + if feature_map_channels is None: + feature_map_channels = [in_channels] * model_depth + self.unet = _PhysicsNeMoUNet( + in_channels=in_channels, + out_channels=out_channels, + kernel_size=kernel_size, + model_depth=model_depth, + feature_map_channels=feature_map_channels, + num_conv_blocks=1, + conv_activation="leaky_relu", + conv_transpose_activation="leaky_relu", + padding=kernel_size // 2, + pooling_type="MaxPool3d", + normalization="batchnorm", + gradient_checkpointing=False, + ) + + def forward( + self, + x: Float[Tensor, "batch channels x y z"], + ) -> Float[Tensor, "batch out_channels x y z"]: + """Forward pass through the library 3D UNet.""" + return self.unet(x) + + +# --------------------------------------------------------------------------- +# Trunk and MLP branch +# --------------------------------------------------------------------------- + + +class TrunkNet(nn.Module): + r"""MLP trunk network encoding query coordinates. + + Parameters + ---------- + in_features : int + Dimensionality of each query point (``1`` for time-only, ``3`` for 2D + grid coordinates, ``4`` for 3D grid coordinates). + out_features : int + Output width (matches the DeepONet latent size). + hidden_width : int + Hidden layer width. + num_layers : int + Number of hidden layers. + activation_fn : str + Activation function name (``"sin"``, ``"tanh"``, ``"relu"``, etc.). + output_activation : bool + When ``True`` (default) the final layer is followed by the activation. + Set ``False`` for linear output (e.g. the TNO configuration). + + Forward + ------- + x : torch.Tensor + Query coordinates of shape :math:`(T, D_{in})` where + :math:`D_{in}` equals ``in_features``. + + Outputs + ------- + torch.Tensor + Encoded coordinates of shape :math:`(T, D_{out})` where + :math:`D_{out}` equals ``out_features``. + + Examples + -------- + >>> import torch + >>> from physicsnemo.experimental.models.xdeeponet import TrunkNet + >>> trunk = TrunkNet(in_features=1, out_features=64, hidden_width=64, num_layers=4) + >>> t = torch.linspace(0, 1, 10).unsqueeze(-1) # (10, 1) + >>> phi = trunk(t) # (10, 64) + """ + + def __init__( + self, + in_features: int = 1, + out_features: int = 64, + hidden_width: int = 128, + num_layers: int = 6, + activation_fn: str = "sin", + output_activation: bool = True, + ): + super().__init__() + + self._output_activation = output_activation + + if activation_fn.lower() == "sin": + self.activation_fn = torch.sin + else: + self.activation_fn = get_activation(activation_fn) + + self.layers = nn.ModuleList() + self.layers.append(self._make_linear(in_features, hidden_width)) + for _ in range(num_layers - 1): + self.layers.append(self._make_linear(hidden_width, hidden_width)) + + self.output_layer = self._make_linear(hidden_width, out_features) + + def _make_linear(self, in_dim: int, out_dim: int) -> nn.Linear: + layer = nn.Linear(in_dim, out_dim) + init.xavier_normal_(layer.weight) + init.zeros_(layer.bias) + return layer + + def forward( + self, + x: Float[Tensor, "time in_features"], + ) -> Float[Tensor, "time out_features"]: + """Forward pass of the trunk network.""" + if not torch.compiler.is_compiling(): + if x.ndim != 2: + raise ValueError( + f"Expected 2D input (T, in_features), got {x.ndim}D " + f"tensor with shape {tuple(x.shape)}" + ) + for layer in self.layers: + x = self.activation_fn(layer(x)) + x = self.output_layer(x) + if self._output_activation: + x = self.activation_fn(x) + return x + + +class MLPBranch(nn.Module): + r"""Fully-connected branch for scalar/vector inputs. + + Used for the scalar branch in MIONet-style architectures. Input features + are auto-discovered via :class:`torch.nn.LazyLinear` on the first forward. + + Parameters + ---------- + out_features : int + Output width (matches the DeepONet latent size). + hidden_width : int + Hidden layer width. + num_layers : int + Number of fully-connected layers (including output). Must be ``>= 2``. + activation_fn : str + Activation function name. + + Forward + ------- + x : torch.Tensor + Scalar input of shape :math:`(B, D_{in})` where :math:`D_{in}` is + auto-discovered on the first forward pass. + + Outputs + ------- + torch.Tensor + Encoded features of shape :math:`(B, D_{out})` where + :math:`D_{out}` equals ``out_features``. + + Examples + -------- + >>> import torch + >>> from physicsnemo.experimental.models.xdeeponet import MLPBranch + >>> branch = MLPBranch(out_features=64, hidden_width=64, num_layers=3) + >>> x = torch.randn(2, 128) + >>> out = branch(x) # (2, 64) + """ + + def __init__( + self, + out_features: int, + hidden_width: int = 64, + num_layers: int = 3, + activation_fn: str = "relu", + ): + super().__init__() + + if num_layers < 2: + raise ValueError( + f"MLPBranch requires num_layers >= 2 (input + output), " + f"got num_layers={num_layers}" + ) + + if activation_fn.lower() == "sin": + self.activation_fn = torch.sin + else: + self.activation_fn = get_activation(activation_fn) + + self.layers = nn.ModuleList() + self.layers.append(nn.LazyLinear(hidden_width)) + for _ in range(num_layers - 2): + self.layers.append(self._make_linear(hidden_width, hidden_width)) + + self.output_layer = self._make_linear(hidden_width, out_features) + + def _make_linear(self, in_dim: int, out_dim: int) -> nn.Linear: + layer = nn.Linear(in_dim, out_dim) + init.xavier_normal_(layer.weight) + init.zeros_(layer.bias) + return layer + + def forward( + self, + x: Float[Tensor, "batch in_features"], + ) -> Float[Tensor, "batch out_features"]: + """Forward pass of the MLP branch.""" + if not torch.compiler.is_compiling(): + if x.ndim != 2: + raise ValueError( + f"Expected 2D input (B, in_features), got {x.ndim}D " + f"tensor with shape {tuple(x.shape)}" + ) + for layer in self.layers: + x = self.activation_fn(layer(x)) + return self.activation_fn(self.output_layer(x)) + + +# --------------------------------------------------------------------------- +# 2D spatial branch +# --------------------------------------------------------------------------- + + +class SpatialBranch(nn.Module): + r"""2D spatial branch composable from Fourier, UNet, and Conv layers. + + The branch can be configured to use any combination of spectral, UNet, + and plain convolutional layers. When Fourier layers are present (the + "base" mode) UNet/Conv layers are added alongside the spectral path + (hybrid residual). When no Fourier layers are present UNet/Conv act + as independent sequential layers. + + Parameters + ---------- + in_channels : int + Number of input channels (used only for documentation; the lift is + :class:`torch.nn.LazyLinear`). + width : int + Latent/output width. + num_fourier_layers : int + Number of spectral layers. + num_unet_layers : int + Number of UNet layers (uses :class:`physicsnemo.models.unet.UNet`). + num_conv_layers : int + Number of Conv+BN layers. + modes1, modes2 : int + Fourier modes along H, W. + kernel_size : int + Kernel size for UNet and Conv layers. + dropout : float + Unused; kept for config compatibility. + activation_fn : str + Activation function name. + internal_resolution : list, optional + If set, inputs are adaptively pooled to this resolution before + processing and upsampled back, decoupling model size from grid size. + + Forward + ------- + x : torch.Tensor + Channels-last input of shape :math:`(B, H, W, C)`. + + Outputs + ------- + torch.Tensor + Channels-last output of shape :math:`(B, H, W, D)` where + :math:`D` equals ``width``. + + Examples + -------- + >>> import torch + >>> from physicsnemo.experimental.models.xdeeponet import SpatialBranch + >>> branch = SpatialBranch( + ... in_channels=5, width=64, num_unet_layers=1, kernel_size=3 + ... ) + >>> x = torch.randn(2, 32, 32, 5) # (B, H, W, C) + >>> out = branch(x) # (2, 32, 32, 64) + """ + + def __init__( + self, + in_channels: int, + width: int, + num_fourier_layers: int = 0, + num_unet_layers: int = 0, + num_conv_layers: int = 0, + modes1: int = 12, + modes2: int = 12, + kernel_size: int = 3, + dropout: float = 0.0, # noqa: ARG002 - kept for config compatibility + activation_fn: str = "gelu", + internal_resolution: Optional[list] = None, + ): + super().__init__() + + self.num_fourier_layers = num_fourier_layers + self.num_unet_layers = num_unet_layers + self.num_conv_layers = num_conv_layers + self.use_fourier_base = num_fourier_layers > 0 + self.internal_resolution = ( + tuple(internal_resolution) if internal_resolution else None + ) + + total_layers = num_fourier_layers + num_unet_layers + num_conv_layers + if total_layers == 0: + raise ValueError("SpatialBranch requires at least one layer type") + + if activation_fn.lower() == "sin": + self.activation_fn = torch.sin + else: + self.activation_fn = get_activation(activation_fn) + + if self.internal_resolution is not None: + self.adaptive_pool = nn.AdaptiveAvgPool2d(self.internal_resolution) + + self.lift = nn.LazyLinear(width) + + num_fourier_components = ( + total_layers if self.use_fourier_base else num_fourier_layers + ) + self.spectral_convs = nn.ModuleList() + self.conv_1x1s = nn.ModuleList() + for _ in range(num_fourier_components): + self.spectral_convs.append(SpectralConv2d(width, width, modes1, modes2)) + self.conv_1x1s.append(nn.Conv2d(width, width, kernel_size=1)) + + self.unet_modules = nn.ModuleList() + for _ in range(num_unet_layers): + self.unet_modules.append( + _UNet2DFromUNet3D(width, width, kernel_size=kernel_size) + ) + + self.conv_modules = nn.ModuleList() + padding = (kernel_size - 1) // 2 + for _ in range(num_conv_layers): + self.conv_modules.append( + nn.Sequential( + nn.Conv2d( + width, + width, + kernel_size=kernel_size, + padding=padding, + bias=False, + ), + nn.BatchNorm2d(width), + ) + ) + + def forward( + self, + x: Float[Tensor, "batch height width channels"], + ) -> Float[Tensor, "batch height width out_channels"]: + """Forward pass of the 2D spatial branch.""" + if not torch.compiler.is_compiling(): + if x.ndim != 4: + raise ValueError( + f"Expected 4D input (B, H, W, C), got {x.ndim}D " + f"tensor with shape {tuple(x.shape)}" + ) + x = self.lift(x) + x = x.permute(0, 3, 1, 2) + + original_size = x.shape[2:] + if self.internal_resolution is not None: + x = self.adaptive_pool(x) + + for i in range(self.num_fourier_layers): + x = self.activation_fn(self.spectral_convs[i](x) + self.conv_1x1s[i](x)) + + if self.use_fourier_base: + for i in range(self.num_unet_layers): + j = self.num_fourier_layers + i + x = self.activation_fn( + self.spectral_convs[j](x) + + self.conv_1x1s[j](x) + + self.unet_modules[i](x) + ) + for i in range(self.num_conv_layers): + j = self.num_fourier_layers + self.num_unet_layers + i + x = self.activation_fn( + self.spectral_convs[j](x) + + self.conv_1x1s[j](x) + + self.conv_modules[i](x) + ) + else: + for unet in self.unet_modules: + x = self.activation_fn(unet(x)) + for conv in self.conv_modules: + x = self.activation_fn(conv(x)) + + if self.internal_resolution is not None and x.shape[2:] != original_size: + x = F.interpolate( + x, size=original_size, mode="bilinear", align_corners=True + ) + + return x.permute(0, 2, 3, 1) + + +# --------------------------------------------------------------------------- +# 3D spatial branch +# --------------------------------------------------------------------------- + + +class SpatialBranch3D(nn.Module): + r"""3D spatial branch composable from Fourier, UNet, and Conv layers. + + See :class:`SpatialBranch` for parameter semantics. The 3D variant adds + ``modes3`` for the third spectral axis. + + Forward + ------- + x : torch.Tensor + Channels-last input of shape :math:`(B, X, Y, Z, C)`. + + Outputs + ------- + torch.Tensor + Channels-last output of shape :math:`(B, X, Y, Z, D)` where + :math:`D` equals ``width``. + + Examples + -------- + >>> import torch + >>> from physicsnemo.experimental.models.xdeeponet import SpatialBranch3D + >>> branch = SpatialBranch3D( + ... in_channels=5, width=64, num_unet_layers=1, kernel_size=3 + ... ) + >>> x = torch.randn(1, 16, 16, 16, 5) # (B, X, Y, Z, C) + >>> out = branch(x) # (1, 16, 16, 16, 64) + """ + + def __init__( + self, + in_channels: int, + width: int, + num_fourier_layers: int = 0, + num_unet_layers: int = 0, + num_conv_layers: int = 0, + modes1: int = 10, + modes2: int = 10, + modes3: int = 8, + kernel_size: int = 3, + dropout: float = 0.0, # noqa: ARG002 - kept for config compatibility + activation_fn: str = "gelu", + internal_resolution: Optional[list] = None, + ): + super().__init__() + + self.num_fourier_layers = num_fourier_layers + self.num_unet_layers = num_unet_layers + self.num_conv_layers = num_conv_layers + self.use_fourier_base = num_fourier_layers > 0 + self.internal_resolution = ( + tuple(internal_resolution) if internal_resolution else None + ) + + total_layers = num_fourier_layers + num_unet_layers + num_conv_layers + if total_layers == 0: + raise ValueError("SpatialBranch3D requires at least one layer type") + + if activation_fn.lower() == "sin": + self.activation_fn = torch.sin + else: + self.activation_fn = get_activation(activation_fn) + + if self.internal_resolution is not None: + self.adaptive_pool = nn.AdaptiveAvgPool3d(self.internal_resolution) + + self.lift = nn.LazyLinear(width) + + num_fourier_components = ( + total_layers if self.use_fourier_base else num_fourier_layers + ) + self.spectral_convs = nn.ModuleList() + self.conv_1x1s = nn.ModuleList() + for _ in range(num_fourier_components): + self.spectral_convs.append( + SpectralConv3d(width, width, modes1, modes2, modes3) + ) + self.conv_1x1s.append(nn.Conv3d(width, width, kernel_size=1)) + + self.unet_modules = nn.ModuleList() + for _ in range(num_unet_layers): + self.unet_modules.append( + _UNet3DFromUNet3D(width, width, kernel_size=kernel_size) + ) + + self.conv_modules = nn.ModuleList() + padding = (kernel_size - 1) // 2 + for _ in range(num_conv_layers): + self.conv_modules.append( + nn.Sequential( + nn.Conv3d( + width, + width, + kernel_size=kernel_size, + padding=padding, + bias=False, + ), + nn.BatchNorm3d(width), + ) + ) + + def forward( + self, + x: Float[Tensor, "batch x y z channels"], + ) -> Float[Tensor, "batch x y z out_channels"]: + """Forward pass of the 3D spatial branch.""" + if not torch.compiler.is_compiling(): + if x.ndim != 5: + raise ValueError( + f"Expected 5D input (B, X, Y, Z, C), got {x.ndim}D " + f"tensor with shape {tuple(x.shape)}" + ) + x = self.lift(x) + x = x.permute(0, 4, 1, 2, 3) + + original_size = x.shape[2:] + if self.internal_resolution is not None: + x = self.adaptive_pool(x) + + for i in range(self.num_fourier_layers): + x = self.activation_fn(self.spectral_convs[i](x) + self.conv_1x1s[i](x)) + + if self.use_fourier_base: + for i in range(self.num_unet_layers): + j = self.num_fourier_layers + i + x = self.activation_fn( + self.spectral_convs[j](x) + + self.conv_1x1s[j](x) + + self.unet_modules[i](x) + ) + for i in range(self.num_conv_layers): + j = self.num_fourier_layers + self.num_unet_layers + i + x = self.activation_fn( + self.spectral_convs[j](x) + + self.conv_1x1s[j](x) + + self.conv_modules[i](x) + ) + else: + for unet in self.unet_modules: + x = self.activation_fn(unet(x)) + for conv in self.conv_modules: + x = self.activation_fn(conv(x)) + + if self.internal_resolution is not None and x.shape[2:] != original_size: + x = F.interpolate( + x, size=original_size, mode="trilinear", align_corners=True + ) + + return x.permute(0, 2, 3, 4, 1) + + +__all__ = [ + "TrunkNet", + "MLPBranch", + "SpatialBranch", + "SpatialBranch3D", +] diff --git a/physicsnemo/experimental/models/xdeeponet/deeponet.py b/physicsnemo/experimental/models/xdeeponet/deeponet.py new file mode 100644 index 0000000000..d2af0b2196 --- /dev/null +++ b/physicsnemo/experimental/models/xdeeponet/deeponet.py @@ -0,0 +1,1016 @@ +# SPDX-FileCopyrightText: Copyright (c) 2023 - 2026 NVIDIA CORPORATION & AFFILIATES. +# SPDX-FileCopyrightText: All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Core xDeepONet architectures for 2D and 3D operator learning. + +The xDeepONet family extends the original DeepONet with eight variants +that cover both single-input and multi-input operator learning, including +the Temporal Neural Operator (TNO) for autoregressive temporal bundling: + +- ``deeponet`` — basic DeepONet (MLP branch). +- ``u_deeponet`` — UNet-enhanced spatial branch. +- ``fourier_deeponet`` — spectral (Fourier) spatial branch. +- ``conv_deeponet`` — plain convolutional spatial branch. +- ``hybrid_deeponet`` — Fourier + UNet + Conv spatial branch. +- ``mionet`` — two-branch multi-input operator network. +- ``fourier_mionet`` — MIONet with a Fourier spatial branch. +- ``tno`` — Temporal Neural Operator (branch2 = previous + solution, autoregressive only). + +The core :class:`DeepONet` (2D) and :class:`DeepONet3D` (3D) classes are +dimension-specific but share the same construction pattern: a primary branch +(``branch1``), an optional secondary branch (``branch2`` for MIONet/TNO), +a coordinate trunk, and a decoder. + +References +---------- +- Lu, L. et al. (2021). "Learning nonlinear operators via DeepONet." + *Nature Machine Intelligence*, 3, 218-229. +- Jin, P., Meng, S. & Lu, L. (2022). "MIONet: Learning multiple-input + operators via tensor product." *SIAM J. Sci. Comp.*, 44(6), A3490-A3514. +- Diab, W. & Al Kobaisi, M. (2024). "U-DeepONet: U-Net enhanced deep + operator network for geologic carbon sequestration." + *Scientific Reports*, 14, 21298. +- Zhu, M. et al. (2023). "Fourier-DeepONet: Fourier-enhanced deep operator + networks for full waveform inversion." arXiv:2305.17289. +- Diab, W. & Al Kobaisi, M. (2025). "Temporal neural operator for modeling + time-dependent physical phenomena." *Scientific Reports*, 15. +""" + +from __future__ import annotations + +from dataclasses import dataclass +from typing import Any, Dict, Literal, Optional, get_args + +import torch +import torch.nn as nn +from jaxtyping import Float +from torch import Tensor + +from physicsnemo.core.meta import ModelMetaData +from physicsnemo.core.module import Module +from physicsnemo.experimental.models.xdeeponet.branches import ( + MLPBranch, + SpatialBranch, + SpatialBranch3D, + TrunkNet, +) +from physicsnemo.models.mlp import FullyConnected +from physicsnemo.nn import Conv2dFCLayer, Conv3dFCLayer, get_activation + +# Type aliases for the enumerated string parameters at the public API +# surface. Annotating ``variant`` and ``decoder_type`` with ``Literal`` +# rather than bare ``str`` lets static type checkers and IDEs flag +# unknown values at the call site; the runtime ``.lower()`` +# normalization and ``ValueError`` guards below remain in place so +# mixed-case strings still flow through (Python does not enforce +# ``Literal`` at runtime). +_VariantStr = Literal[ + "deeponet", + "u_deeponet", + "fourier_deeponet", + "conv_deeponet", + "hybrid_deeponet", + "mionet", + "fourier_mionet", + "tno", +] + +_DecoderTypeStr = Literal["mlp", "conv", "temporal_projection"] + +# Runtime sets, derived from the ``Literal`` aliases via ``typing.get_args`` +# so the two views (static type and runtime validator) cannot drift. Each +# class still exposes ``_VALID_VARIANTS`` as the ``VALID_VARIANTS`` class +# attribute for a stable public API. +_VALID_VARIANTS = get_args(_VariantStr) + +# Variants that require a secondary branch (branch2). Used by the core +# DeepONet / DeepONet3D __init__ to validate branch2_config up-front so +# multi-branch variants cannot silently degrade to single-branch models. +_DUAL_BRANCH_VARIANTS = frozenset({"mionet", "fourier_mionet", "tno"}) + +# Supported decoder types -- runtime view of the ``_DecoderTypeStr`` +# alias. Used by the core DeepONet / DeepONet3D __init__ to reject +# unknown decoder types at the API boundary instead of deferring to +# ``_build_decoder`` and raising cryptically from deep inside construction. +_VALID_DECODER_TYPES = frozenset(get_args(_DecoderTypeStr)) + + +@dataclass +class _DeepONetMetaData(ModelMetaData): + """PhysicsNeMo model metadata for :class:`DeepONet`.""" + + +@dataclass +class _DeepONet3DMetaData(ModelMetaData): + """PhysicsNeMo model metadata for :class:`DeepONet3D`.""" + + +# --------------------------------------------------------------------------- +# Branch config helpers +# --------------------------------------------------------------------------- + + +def _normalize_branch_config(config: dict) -> dict: + """Normalize a branch config to the nested encoder/layers format. + + Supports two input formats: + + **New (nested)** format:: + + { + "encoder": {"type": "linear", "activation_fn": "tanh", ...}, + "layers": {"num_fourier_layers": 1, "num_unet_layers": 1, ...}, + "internal_resolution": [H, W], + } + + **Old (flat)** format (auto-converted for backward compatibility):: + + { + "encoder": "spatial", # or "mlp" + "num_fourier_layers": 1, + "num_unet_layers": 1, + "activation_fn": "tanh", + ... + } + + Returns a dict in the new nested format. + """ + if "encoder" not in config: + return config + + enc = config["encoder"] + if not isinstance(enc, str): + return config + + enc_type_str = str(enc).lower() + cfg = dict(config) + cfg.pop("encoder") + + encoder_keys = {"hidden_width", "num_layers"} + layer_keys = { + "num_fourier_layers", + "num_unet_layers", + "num_conv_layers", + "modes1", + "modes2", + "modes3", + "kernel_size", + "dropout", + } + + activation = cfg.pop("activation_fn", "sin") + internal_res = cfg.pop("internal_resolution", None) + in_channels = cfg.pop("in_channels", None) + # The legacy 'unet_impl' key is silently dropped: only the library UNet + # (physicsnemo.models.unet.UNet) is supported in the experimental package. + cfg.pop("unet_impl", None) + + encoder_dict = { + "type": "mlp" if enc_type_str == "mlp" else "linear", + "activation_fn": activation, + } + for k in encoder_keys: + if k in cfg: + encoder_dict[k] = cfg.pop(k) + + layers_dict = {"activation_fn": activation} + for k in layer_keys: + if k in cfg: + layers_dict[k] = cfg.pop(k) + + result = {"encoder": encoder_dict, "layers": layers_dict} + if internal_res is not None: + result["internal_resolution"] = internal_res + if in_channels is not None: + result["in_channels"] = in_channels + + return result + + +class _SinActivation(nn.Module): + """Module wrapper around :func:`torch.sin` for use inside ``nn.Sequential``. + + ``physicsnemo.nn.get_activation`` does not register ``"sin"`` in its + activation table; branch modules in ``branches.py`` work around this + by storing ``torch.sin`` as a bare callable and invoking it directly + in ``forward``. That pattern does not compose with ``nn.Sequential`` + (which requires ``nn.Module`` instances), so this thin wrapper is used + whenever a sin activation needs to slot into a ``Sequential`` pipeline. + """ + + def forward(self, x: Tensor) -> Tensor: + """Apply elementwise sine.""" + return torch.sin(x) + + +def _build_conv_encoder(width: int, enc_config: dict) -> nn.Module: + """Build a multi-layer pointwise encoder replacing the default LazyLinear lift. + + Operates in channels-last format ``(B, *spatial, C)``. Each layer is a + :class:`torch.nn.Linear` with activation — equivalent to a 1x1 convolution + applied independently at every spatial point. + """ + num_layers = enc_config.get("num_layers", 1) + activation_fn = enc_config.get("activation_fn", "relu") + + # ``get_activation`` does not know about ``"sin"``; use the module + # wrapper defined above when the user explicitly requests it, so + # config parity with the branch encoders is preserved. + if activation_fn.lower() == "sin": + act = _SinActivation() + else: + act = get_activation(activation_fn) + + if num_layers <= 1: + return nn.LazyLinear(width) + + hidden_width = enc_config.get("hidden_width", width // 2) + layers_list = [nn.LazyLinear(hidden_width), act] + for _ in range(num_layers - 2): + layers_list.extend([nn.Linear(hidden_width, hidden_width), act]) + layers_list.append(nn.Linear(hidden_width, width)) + return nn.Sequential(*layers_list) + + +# --------------------------------------------------------------------------- +# 2D DeepONet +# --------------------------------------------------------------------------- + + +class DeepONet(Module): + r"""2D xDeepONet core architecture for operator learning. + + Combines a primary spatial/MLP branch, an optional secondary branch + (for MIONet/TNO variants), a coordinate trunk, and a decoder. The + branch outputs and trunk are combined via Hadamard product and then + projected to the output by the decoder. + + Parameters + ---------- + variant : Literal["deeponet", "u_deeponet", "fourier_deeponet", "conv_deeponet", "hybrid_deeponet", "mionet", "fourier_mionet", "tno"] + One of the eight supported variants (see :data:`VALID_VARIANTS`). + Mixed-case strings are accepted at runtime and lowercased. + width : int + Latent width. + branch1_config : dict, optional + Primary branch configuration. See module docstring for schema. + branch2_config : dict, optional + Secondary branch configuration, required for the ``"mionet"``, + ``"fourier_mionet"``, and ``"tno"`` variants. + trunk_config : dict, optional + Trunk network configuration. + decoder_type : Literal["mlp", "conv", "temporal_projection"], optional + Decoder choice: ``"mlp"`` queries the trunk at each target + timestep and applies an MLP decoder; ``"conv"`` uses a + convolutional decoder; ``"temporal_projection"`` queries the + trunk once and projects the combined latent to K timesteps via a + learned linear head for fast autoregressive bundling. + Mixed-case strings are accepted at runtime and lowercased. + decoder_width : int, optional + Decoder hidden width. + decoder_layers : int, optional + Decoder layer count. + decoder_activation_fn : str, optional + Activation function name for the decoder. + output_window : int, optional + Output window length K for the ``"temporal_projection"`` decoder. + When supplied the temporal head is constructed at ``__init__``, which + produces a deterministic ``state_dict`` and makes checkpoint + round-tripping straightforward. When omitted, + :meth:`set_output_window` must be called before the first forward + pass. + + Forward + ------- + x_branch1 : torch.Tensor + Primary input of shape :math:`(B, H, W, C)` for spatial branches or + :math:`(B, D_{in})` for MLP branches. + x_time : torch.Tensor + Query coordinates of shape :math:`(T,)` or + :math:`(T, D_{\text{trunk}})`. + x_branch2 : torch.Tensor, optional + Secondary branch input for MIONet/TNO variants. + + Outputs + ------- + torch.Tensor + Operator output of shape :math:`(B, H, W, T)` for spatial branches + or :math:`(B, T)` for MLP branches. + + Examples + -------- + >>> import torch + >>> from physicsnemo.experimental.models.xdeeponet import DeepONet + >>> model = DeepONet( + ... variant="u_deeponet", + ... width=64, + ... branch1_config={ + ... "encoder": {"type": "linear"}, + ... "layers": {"num_unet_layers": 1, "kernel_size": 3}, + ... }, + ... trunk_config={"hidden_width": 64, "num_layers": 4}, + ... ) + >>> x_branch = torch.randn(2, 32, 32, 5) # (B, H, W, C) + >>> x_time = torch.linspace(0, 1, 3).unsqueeze(-1) # (T, 1) + >>> out = model(x_branch, x_time) # (2, 32, 32, 3) + """ + + VALID_VARIANTS = _VALID_VARIANTS + + def __init__( + self, + variant: _VariantStr = "u_deeponet", + width: int = 64, + branch1_config: Dict[str, Any] = None, + branch2_config: Dict[str, Any] = None, + trunk_config: Dict[str, Any] = None, + decoder_type: _DecoderTypeStr = "mlp", + decoder_width: int = 128, + decoder_layers: int = 2, + decoder_activation_fn: str = "relu", + output_window: Optional[int] = None, + ): + super().__init__(meta=_DeepONetMetaData()) + + self.variant = variant.lower() + self.width = width + self.decoder_type = decoder_type.lower() + self.decoder_activation_fn = decoder_activation_fn + + if self.variant not in self.VALID_VARIANTS: + raise ValueError( + f"Unknown variant: {variant}. Valid: {self.VALID_VARIANTS}" + ) + + if self.decoder_type not in _VALID_DECODER_TYPES: + raise ValueError( + f"Unknown decoder_type: {decoder_type!r}. Valid: " + f"{sorted(_VALID_DECODER_TYPES)}." + ) + + if self.variant in _DUAL_BRANCH_VARIANTS and branch2_config is None: + raise ValueError( + f"variant='{self.variant}' requires branch2_config to be " + f"provided. Dual-branch variants: " + f"{sorted(_DUAL_BRANCH_VARIANTS)}." + ) + + branch1_config = branch1_config or {} + trunk_config = trunk_config or {} + + self.branch1 = self._build_branch(branch1_config, width) + + # Reject MLP-branch configurations paired with a decoder that + # needs a spatial (4D / 5D) ``combined`` tensor. The MLP-branch + # forward path produces a 3D tensor of shape (B, T, width) and: + # * ``temporal_projection`` silently drops the temporal head + # (wrong shape, no error); + # * ``conv`` crashes inside the decoder's ``Conv2d`` / + # ``Conv3d`` with PyTorch's generic "Expected 3D or 4D + # input" message, with no hint that the real cause is a + # config mismatch. + # Fail fast here instead. + if isinstance(self.branch1, MLPBranch) and self.decoder_type in ( + "temporal_projection", + "conv", + ): + raise ValueError( + f"decoder_type={self.decoder_type!r} is not supported with " + "MLP branches. Use decoder_type='mlp', or configure a " + "SpatialBranch for branch1 (set num_unet_layers, " + "num_fourier_layers, or num_conv_layers > 0 in " + "branch1_config)." + ) + + self.has_branch2 = branch2_config is not None + if self.has_branch2: + self.branch2 = self._build_branch(branch2_config, width) + + # Forward assumes branch2's output has the same rank as + # branch1's. Mixing an MLPBranch (2D output (B, width)) with + # a SpatialBranch (4D / 5D output) would either broadcast + # nonsensically or raise a cryptic dim-mismatch error in the + # Hadamard product. Reject the mixed configuration here. + if isinstance(self.branch1, MLPBranch) and not isinstance( + self.branch2, MLPBranch + ): + raise ValueError( + "When branch1 is an MLPBranch, branch2 must also be " + "an MLPBranch (i.e. produce a 2D (B, width) output). " + "Swap branch1 and branch2, or configure branch1 as " + "a SpatialBranch." + ) + + self.trunk = TrunkNet( + in_features=trunk_config.get("in_features", 1), + out_features=width, + hidden_width=trunk_config.get("hidden_width", 128), + num_layers=trunk_config.get("num_layers", 6), + activation_fn=trunk_config.get("activation_fn", "sin"), + output_activation=trunk_config.get("output_activation", True), + ) + + if self.decoder_type == "temporal_projection": + self._temporal_projection = True + self.decoder = self._build_decoder( + width, + width, + decoder_layers, + decoder_width, + "mlp", + decoder_activation_fn, + ) + # Preferred path: construct the temporal head at __init__ so + # state_dict keys are deterministic and checkpointing just works. + # If ``output_window`` is not provided the user must call + # :meth:`set_output_window` before the first forward pass; this + # path is kept for backwards compatibility but produces a + # state_dict whose structure depends on when the method is called. + if output_window is not None: + if output_window < 1: + raise ValueError( + f"output_window must be a positive integer, got {output_window}" + ) + self.temporal_head = nn.Linear(self.width, output_window) + else: + self.temporal_head = None + else: + self._temporal_projection = False + self.decoder = self._build_decoder( + width, + 1, + decoder_layers, + decoder_width, + self.decoder_type, + decoder_activation_fn, + ) + + def set_output_window(self, K: int): + """Create the temporal-projection head for K output timesteps. + + Only effective when ``decoder_type="temporal_projection"``. + """ + if self._temporal_projection: + device = next(self.parameters()).device + self.temporal_head = nn.Linear(self.width, K).to(device) + + def _build_branch(self, config: dict, width: int) -> nn.Module: + config = _normalize_branch_config(config) + enc = config.get("encoder", {}) + layers = config.get("layers", {}) + + enc_type = enc.get("type", "linear") + enc_activation = enc.get("activation_fn", "sin") + + has_layers = ( + layers.get("num_fourier_layers", 0) + + layers.get("num_unet_layers", 0) + + layers.get("num_conv_layers", 0) + ) > 0 + + if enc_type == "mlp" and not has_layers: + return MLPBranch( + out_features=width, + hidden_width=enc.get("hidden_width", 64), + num_layers=enc.get("num_layers", 3), + activation_fn=enc_activation, + ) + + layer_activation = layers.get("activation_fn", enc_activation) + branch = SpatialBranch( + in_channels=config.get("in_channels", 12), + width=width, + num_fourier_layers=layers.get("num_fourier_layers", 0), + num_unet_layers=layers.get("num_unet_layers", 0), + num_conv_layers=layers.get("num_conv_layers", 0), + modes1=layers.get("modes1", 12), + modes2=layers.get("modes2", 12), + kernel_size=layers.get("kernel_size", 3), + dropout=layers.get("dropout", 0.0), + activation_fn=layer_activation, + internal_resolution=config.get("internal_resolution", None), + ) + if enc_type == "conv": + branch.lift = _build_conv_encoder(width, enc) + return branch + + def _build_decoder( + self, + width: int, + out_channels: int, + num_layers: int, + hidden_width: int, + decoder_type: str, + activation_fn: str, + ) -> nn.Module: + if decoder_type == "mlp": + if num_layers == 0: + return nn.Linear(width, out_channels) + return FullyConnected( + width, hidden_width, out_channels, num_layers, activation_fn + ) + + elif decoder_type == "conv": + if num_layers == 0: + return Conv2dFCLayer(width, out_channels) + + layers = [] + in_ch = width + for _ in range(num_layers): + layers.extend( + [Conv2dFCLayer(in_ch, hidden_width), get_activation(activation_fn)] + ) + in_ch = hidden_width + layers.append(Conv2dFCLayer(hidden_width, out_channels)) + return nn.Sequential(*layers) + + else: + raise ValueError(f"Unknown decoder_type: {decoder_type}") + + def forward( + self, + x_branch1: Float[Tensor, "..."], + x_time: Float[Tensor, "..."], + x_branch2: Optional[Float[Tensor, "..."]] = None, + ) -> Float[Tensor, "..."]: + """Forward pass through the DeepONet. + + See class docstring for input/output shapes. ``x_branch1`` accepts + either 2D ``(B, D_in)`` (MLP branches) or 4D ``(B, H, W, C)`` + (spatial branches); ``x_time`` accepts 1D ``(T,)`` or 2D + ``(T, D_trunk)``, so the jaxtyping annotation is the unconstrained + ``"..."`` shape. Strict shape validation is performed at the top + of this method under a :func:`torch.compiler.is_compiling` guard. + """ + if not torch.compiler.is_compiling(): + if x_branch1.ndim not in (2, 4): + raise ValueError( + f"Expected x_branch1 to be 2D (B, D_in) for MLP branches " + f"or 4D (B, H, W, C) for spatial branches, got " + f"{x_branch1.ndim}D tensor with shape " + f"{tuple(x_branch1.shape)}" + ) + if x_time.ndim not in (1, 2): + raise ValueError( + f"Expected x_time to be 1D (T,) or 2D (T, D), got " + f"{x_time.ndim}D tensor with shape {tuple(x_time.shape)}" + ) + if self.has_branch2 and x_branch2 is None: + raise ValueError( + f"variant='{self.variant}' requires x_branch2 but got None" + ) + + if x_time.dim() == 1: + x_time = x_time.unsqueeze(-1) + + b1_out = self.branch1(x_branch1) + + if self.has_branch2: + if x_branch2 is None: + raise ValueError("x_branch2 required for mionet/tno variants") + b2_out = self.branch2(x_branch2) + + trunk_out = self.trunk(x_time) + + if b1_out.dim() == 4: # Spatial branch + if self._temporal_projection: + trunk_single = trunk_out[0:1] + trunk_exp = trunk_single.unsqueeze(1).unsqueeze(2) + combined = b1_out * trunk_exp + if self.has_branch2: + if b2_out.dim() == 4: + combined = combined * b2_out + else: + combined = combined * b2_out.unsqueeze(1).unsqueeze(2) + combined = self.decoder(combined) + if self.temporal_head is None: + raise RuntimeError( + "decoder_type='temporal_projection' requires either " + "output_window to be provided at construction time, " + "or set_output_window(K) to be called before forward." + ) + combined = self.temporal_head(combined) + return combined + + b1_out = b1_out.unsqueeze(1) + trunk_out = trunk_out.unsqueeze(0).unsqueeze(2).unsqueeze(3) + + if self.has_branch2: + if b2_out.dim() == 4: + b2_out = b2_out.unsqueeze(1) + else: + b2_out = b2_out.unsqueeze(1).unsqueeze(2).unsqueeze(3) + combined = b1_out * b2_out * trunk_out + else: + combined = b1_out * trunk_out + + if self.decoder_type == "mlp": + return self.decoder(combined).squeeze(-1).permute(0, 2, 3, 1) + + B, T, H, W, C = combined.shape + combined = combined.permute(0, 1, 4, 2, 3).reshape(B * T, C, H, W) + return self.decoder(combined).reshape(B, T, H, W).permute(0, 2, 3, 1) + + else: # MLP branch + b1_out = b1_out.unsqueeze(1) + trunk_out = trunk_out.unsqueeze(0) + + if self.has_branch2: + combined = b1_out * b2_out.unsqueeze(1) * trunk_out + else: + combined = b1_out * trunk_out + + return self.decoder(combined).squeeze(-1) + + def count_params(self) -> int: + """Return the number of trainable parameters.""" + return sum(p.numel() for p in self.parameters() if p.requires_grad) + + +# --------------------------------------------------------------------------- +# 3D DeepONet +# --------------------------------------------------------------------------- + + +class DeepONet3D(Module): + r"""3D xDeepONet core architecture for volumetric operator learning. + + See :class:`DeepONet` for parameter semantics. The 3D variant operates + on volumetric inputs and uses :class:`SpatialBranch3D` for spatial + branches. + + Parameters + ---------- + variant : Literal["deeponet", "u_deeponet", "fourier_deeponet", "conv_deeponet", "hybrid_deeponet", "mionet", "fourier_mionet", "tno"] + One of the eight supported variants (see :data:`VALID_VARIANTS`). + Mixed-case strings are accepted at runtime and lowercased. + decoder_type : Literal["mlp", "conv", "temporal_projection"], optional + Decoder choice; see :class:`DeepONet` for semantics. + + Forward + ------- + x_branch1 : torch.Tensor + Primary input of shape :math:`(B, X, Y, Z, C)` for spatial branches + or :math:`(B, D_{in})` for MLP branches. + x_time : torch.Tensor + Query coordinates of shape :math:`(T,)` or + :math:`(T, D_{\text{trunk}})`. + x_branch2 : torch.Tensor, optional + Secondary branch input for MIONet/TNO variants. + + Outputs + ------- + torch.Tensor + Operator output of shape :math:`(B, X, Y, Z, T)` for spatial + branches or :math:`(B, T)` for MLP branches. + + Examples + -------- + >>> import torch + >>> from physicsnemo.experimental.models.xdeeponet import DeepONet3D + >>> model = DeepONet3D( + ... variant="u_deeponet", + ... width=64, + ... branch1_config={ + ... "encoder": {"type": "linear"}, + ... "layers": {"num_unet_layers": 1, "kernel_size": 3}, + ... }, + ... trunk_config={"hidden_width": 64, "num_layers": 4}, + ... ) + >>> x_branch = torch.randn(1, 16, 16, 16, 5) # (B, X, Y, Z, C) + >>> x_time = torch.linspace(0, 1, 2).unsqueeze(-1) + >>> out = model(x_branch, x_time) # (1, 16, 16, 16, 2) + """ + + VALID_VARIANTS = _VALID_VARIANTS + + def __init__( + self, + variant: _VariantStr = "u_deeponet", + width: int = 64, + branch1_config: Dict[str, Any] = None, + branch2_config: Dict[str, Any] = None, + trunk_config: Dict[str, Any] = None, + decoder_type: _DecoderTypeStr = "mlp", + decoder_width: int = 128, + decoder_layers: int = 2, + decoder_activation_fn: str = "relu", + output_window: Optional[int] = None, + ): + super().__init__(meta=_DeepONet3DMetaData()) + + self.variant = variant.lower() + self.width = width + self.decoder_type = decoder_type.lower() + self.decoder_activation_fn = decoder_activation_fn + + if self.variant not in self.VALID_VARIANTS: + raise ValueError( + f"Unknown variant: {variant}. Valid: {self.VALID_VARIANTS}" + ) + + if self.decoder_type not in _VALID_DECODER_TYPES: + raise ValueError( + f"Unknown decoder_type: {decoder_type!r}. Valid: " + f"{sorted(_VALID_DECODER_TYPES)}." + ) + + if self.variant in _DUAL_BRANCH_VARIANTS and branch2_config is None: + raise ValueError( + f"variant='{self.variant}' requires branch2_config to be " + f"provided. Dual-branch variants: " + f"{sorted(_DUAL_BRANCH_VARIANTS)}." + ) + + branch1_config = branch1_config or {} + trunk_config = trunk_config or {} + + self.branch1 = self._build_branch(branch1_config, width) + + # Reject MLP-branch configurations paired with a decoder that + # needs a spatial (4D / 5D) ``combined`` tensor. The MLP-branch + # forward path produces a 3D tensor of shape (B, T, width) and: + # * ``temporal_projection`` silently drops the temporal head + # (wrong shape, no error); + # * ``conv`` crashes inside the decoder's ``Conv2d`` / + # ``Conv3d`` with PyTorch's generic "Expected 3D or 4D + # input" message, with no hint that the real cause is a + # config mismatch. + # Fail fast here instead. + if isinstance(self.branch1, MLPBranch) and self.decoder_type in ( + "temporal_projection", + "conv", + ): + raise ValueError( + f"decoder_type={self.decoder_type!r} is not supported with " + "MLP branches. Use decoder_type='mlp', or configure a " + "SpatialBranch for branch1 (set num_unet_layers, " + "num_fourier_layers, or num_conv_layers > 0 in " + "branch1_config)." + ) + + self.has_branch2 = branch2_config is not None + if self.has_branch2: + self.branch2 = self._build_branch(branch2_config, width) + + # Forward assumes branch2's output has the same rank as + # branch1's. Mixing an MLPBranch (2D output (B, width)) with + # a SpatialBranch (4D / 5D output) would either broadcast + # nonsensically or raise a cryptic dim-mismatch error in the + # Hadamard product. Reject the mixed configuration here. + if isinstance(self.branch1, MLPBranch) and not isinstance( + self.branch2, MLPBranch + ): + raise ValueError( + "When branch1 is an MLPBranch, branch2 must also be " + "an MLPBranch (i.e. produce a 2D (B, width) output). " + "Swap branch1 and branch2, or configure branch1 as " + "a SpatialBranch." + ) + + self.trunk = TrunkNet( + in_features=trunk_config.get("in_features", 1), + out_features=width, + hidden_width=trunk_config.get("hidden_width", 128), + num_layers=trunk_config.get("num_layers", 6), + activation_fn=trunk_config.get("activation_fn", "sin"), + output_activation=trunk_config.get("output_activation", True), + ) + + if self.decoder_type == "temporal_projection": + self._temporal_projection = True + self.decoder = self._build_decoder( + width, + width, + decoder_layers, + decoder_width, + "mlp", + decoder_activation_fn, + ) + # Preferred path: construct the temporal head at __init__ so + # state_dict keys are deterministic and checkpointing just works. + # If ``output_window`` is not provided the user must call + # :meth:`set_output_window` before the first forward pass; this + # path is kept for backwards compatibility but produces a + # state_dict whose structure depends on when the method is called. + if output_window is not None: + if output_window < 1: + raise ValueError( + f"output_window must be a positive integer, got {output_window}" + ) + self.temporal_head = nn.Linear(self.width, output_window) + else: + self.temporal_head = None + else: + self._temporal_projection = False + self.decoder = self._build_decoder( + width, + 1, + decoder_layers, + decoder_width, + self.decoder_type, + decoder_activation_fn, + ) + + def set_output_window(self, K: int): + """Create the temporal-projection head for K output timesteps. + + Only effective when ``decoder_type="temporal_projection"``. + """ + if self._temporal_projection: + device = next(self.parameters()).device + self.temporal_head = nn.Linear(self.width, K).to(device) + + def _build_branch(self, config: dict, width: int) -> nn.Module: + config = _normalize_branch_config(config) + enc = config.get("encoder", {}) + layers = config.get("layers", {}) + + enc_type = enc.get("type", "linear") + enc_activation = enc.get("activation_fn", "sin") + + has_layers = ( + layers.get("num_fourier_layers", 0) + + layers.get("num_unet_layers", 0) + + layers.get("num_conv_layers", 0) + ) > 0 + + if enc_type == "mlp" and not has_layers: + return MLPBranch( + out_features=width, + hidden_width=enc.get("hidden_width", 64), + num_layers=enc.get("num_layers", 3), + activation_fn=enc_activation, + ) + + layer_activation = layers.get("activation_fn", enc_activation) + branch = SpatialBranch3D( + in_channels=config.get("in_channels", 11), + width=width, + num_fourier_layers=layers.get("num_fourier_layers", 0), + num_unet_layers=layers.get("num_unet_layers", 0), + num_conv_layers=layers.get("num_conv_layers", 0), + modes1=layers.get("modes1", 10), + modes2=layers.get("modes2", 10), + modes3=layers.get("modes3", 8), + kernel_size=layers.get("kernel_size", 3), + dropout=layers.get("dropout", 0.0), + activation_fn=layer_activation, + internal_resolution=config.get("internal_resolution", None), + ) + if enc_type == "conv": + branch.lift = _build_conv_encoder(width, enc) + return branch + + def _build_decoder( + self, + width: int, + out_channels: int, + num_layers: int, + hidden_width: int, + decoder_type: str, + activation_fn: str, + ) -> nn.Module: + if decoder_type == "mlp": + if num_layers == 0: + return nn.Linear(width, out_channels) + return FullyConnected( + width, hidden_width, out_channels, num_layers, activation_fn + ) + + elif decoder_type == "conv": + if num_layers == 0: + return Conv3dFCLayer(width, out_channels) + + layers = [] + in_ch = width + for _ in range(num_layers): + layers.extend( + [Conv3dFCLayer(in_ch, hidden_width), get_activation(activation_fn)] + ) + in_ch = hidden_width + layers.append(Conv3dFCLayer(hidden_width, out_channels)) + return nn.Sequential(*layers) + + else: + raise ValueError(f"Unknown decoder_type: {decoder_type}") + + def forward( + self, + x_branch1: Float[Tensor, "..."], + x_time: Float[Tensor, "..."], + x_branch2: Optional[Float[Tensor, "..."]] = None, + ) -> Float[Tensor, "..."]: + """Forward pass through the 3D DeepONet. + + See class docstring for input/output shapes. ``x_branch1`` accepts + either 2D ``(B, D_in)`` (MLP branches) or 5D ``(B, X, Y, Z, C)`` + (spatial branches); ``x_time`` accepts 1D ``(T,)`` or 2D + ``(T, D_trunk)``. Strict shape validation is performed at the top + of this method under a :func:`torch.compiler.is_compiling` guard. + """ + if not torch.compiler.is_compiling(): + if x_branch1.ndim not in (2, 5): + raise ValueError( + f"Expected x_branch1 to be 2D (B, D_in) for MLP branches " + f"or 5D (B, X, Y, Z, C) for spatial branches, got " + f"{x_branch1.ndim}D tensor with shape " + f"{tuple(x_branch1.shape)}" + ) + if x_time.ndim not in (1, 2): + raise ValueError( + f"Expected x_time to be 1D (T,) or 2D (T, D), got " + f"{x_time.ndim}D tensor with shape {tuple(x_time.shape)}" + ) + if self.has_branch2 and x_branch2 is None: + raise ValueError( + f"variant='{self.variant}' requires x_branch2 but got None" + ) + + if x_time.dim() == 1: + x_time = x_time.unsqueeze(-1) + + b1_out = self.branch1(x_branch1) + + if self.has_branch2: + if x_branch2 is None: + raise ValueError("x_branch2 required for mionet/tno variants") + b2_out = self.branch2(x_branch2) + + trunk_out = self.trunk(x_time) + + if b1_out.dim() == 5: # Spatial branch + if self._temporal_projection: + trunk_single = trunk_out[0:1] + trunk_exp = trunk_single.unsqueeze(1).unsqueeze(2).unsqueeze(3) + combined = b1_out * trunk_exp + if self.has_branch2: + if b2_out.dim() == 5: + combined = combined * b2_out + else: + combined = combined * b2_out.unsqueeze(1).unsqueeze( + 2 + ).unsqueeze(3) + combined = self.decoder(combined) + if self.temporal_head is None: + raise RuntimeError( + "decoder_type='temporal_projection' requires either " + "output_window to be provided at construction time, " + "or set_output_window(K) to be called before forward." + ) + combined = self.temporal_head(combined) + return combined + + b1_out = b1_out.unsqueeze(1) + trunk_out = trunk_out.unsqueeze(0).unsqueeze(2).unsqueeze(3).unsqueeze(4) + + if self.has_branch2: + if b2_out.dim() == 5: + b2_out = b2_out.unsqueeze(1) + else: + b2_out = b2_out.unsqueeze(1).unsqueeze(2).unsqueeze(3).unsqueeze(4) + combined = b1_out * b2_out * trunk_out + else: + combined = b1_out * trunk_out + + if self.decoder_type == "mlp": + return self.decoder(combined).squeeze(-1).permute(0, 2, 3, 4, 1) + + B, T, X, Y, Z, C = combined.shape + combined = combined.permute(0, 1, 5, 2, 3, 4).reshape(B * T, C, X, Y, Z) + return self.decoder(combined).reshape(B, T, X, Y, Z).permute(0, 2, 3, 4, 1) + + else: # MLP branch + b1_out = b1_out.unsqueeze(1) + trunk_out = trunk_out.unsqueeze(0) + + if self.has_branch2: + combined = b1_out * b2_out.unsqueeze(1) * trunk_out + else: + combined = b1_out * trunk_out + + return self.decoder(combined).squeeze(-1) + + def count_params(self) -> int: + """Return the number of trainable parameters.""" + return sum(p.numel() for p in self.parameters() if p.requires_grad) + + +__all__ = [ + "DeepONet", + "DeepONet3D", +] diff --git a/physicsnemo/experimental/models/xdeeponet/padding.py b/physicsnemo/experimental/models/xdeeponet/padding.py new file mode 100644 index 0000000000..b7f38e89af --- /dev/null +++ b/physicsnemo/experimental/models/xdeeponet/padding.py @@ -0,0 +1,203 @@ +# SPDX-FileCopyrightText: Copyright (c) 2023 - 2026 NVIDIA CORPORATION & AFFILIATES. +# SPDX-FileCopyrightText: All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Right-side spatial padding helpers used by the xDeepONet wrappers. + +The wrappers align spatial dimensions to a multiple (typically 8) so that +spectral and convolutional sub-branches operate on compatible shapes. These +helpers are dimension-agnostic and support 2D, 3D, or 4D spatial layouts. + +Tensor layouts used here: +- 2D spatial samples: ``(B, H, W, T, C)`` +- 3D spatial samples: ``(B, X, Y, Z, T, C)`` +""" + +from __future__ import annotations + +from typing import Sequence, Tuple + +import torch +import torch.nn.functional as F +from torch import Tensor + + +def compute_right_pad_to_multiple( + spatial_shape: Sequence[int], + *, + multiple: int = 8, + min_right_pad: int = 0, +) -> Tuple[int, ...]: + """Compute right-side padding to reach a multiple of *multiple*. + + Parameters + ---------- + spatial_shape : Sequence[int] + Current spatial dimension sizes. + multiple : int, optional + Target alignment (default ``8``). + min_right_pad : int, optional + Minimum right-side padding applied per dimension (default ``0``). + + Returns + ------- + Tuple[int, ...] + Right-side padding per dimension such that ``(d + pad)`` is a multiple + of *multiple* and ``pad >= min_right_pad``. + """ + if multiple <= 0: + raise ValueError(f"multiple must be > 0, got {multiple}") + if min_right_pad < 0: + raise ValueError(f"min_right_pad must be >= 0, got {min_right_pad}") + + pads = [] + for d in spatial_shape: + if d <= 0: + raise ValueError( + f"spatial dimensions must be positive, got {spatial_shape}" + ) + to_mult = (multiple - (d % multiple)) % multiple + if to_mult >= min_right_pad: + pad = to_mult + else: + deficit = min_right_pad - to_mult + k = (deficit + multiple - 1) // multiple + pad = to_mult + k * multiple + pads.append(int(pad)) + return tuple(pads) + + +def pad_right_nd( + x: Tensor, + *, + dims: Sequence[int], + right_pad: Sequence[int], + mode: str = "replicate", + constant_value: float = 0.0, +) -> Tensor: + """Right-pad arbitrary dimensions of an N-D tensor. + + Implemented manually so it works for ``mode="replicate"`` even when + :func:`torch.nn.functional.pad` does not support the tensor rank + (e.g. 6D tensors in the 3D-spatial case). + """ + if len(dims) != len(right_pad): + raise ValueError("dims and right_pad must have the same length") + if not dims: + return x + + for dim, pad in zip(dims, right_pad): + pad = int(pad) + if pad <= 0: + continue + if dim < 0: + dim = x.dim() + dim + if dim < 0 or dim >= x.dim(): + raise ValueError(f"invalid dim {dim} for x.dim()={x.dim()}") + + if mode == "constant": + pad_shape = list(x.shape) + pad_shape[dim] = pad + pad_tensor = torch.full( + pad_shape, float(constant_value), dtype=x.dtype, device=x.device + ) + x = torch.cat([x, pad_tensor], dim=dim) + continue + + if mode != "replicate": + raise ValueError( + f"pad_right_nd supports mode='replicate' or 'constant', got {mode}" + ) + + last = x.select(dim, x.size(dim) - 1).unsqueeze(dim) + expand_shape = list(x.shape) + expand_shape[dim] = pad + pad_tensor = last.expand(*expand_shape) + x = torch.cat([x, pad_tensor], dim=dim) + + return x + + +def pad_spatial_right( + x: Tensor, + *, + spatial_ndim: int, + right_pad: Sequence[int], + mode: str = "replicate", + constant_value: float = 0.0, +) -> Tensor: + """Right-pad the first *spatial_ndim* dimensions after the batch dim. + + Assumes ``x`` is shaped ``(B, *spatial, *rest)``. + """ + if spatial_ndim not in (2, 3, 4): + raise ValueError(f"spatial_ndim must be 2, 3, or 4, got {spatial_ndim}") + if len(right_pad) != spatial_ndim: + raise ValueError( + f"right_pad must have length {spatial_ndim}, got {len(right_pad)}" + ) + if x.dim() < 1 + spatial_ndim: + raise ValueError( + f"expected x.dim() >= {1 + spatial_ndim}, got x.dim()={x.dim()}" + ) + if all(int(p) == 0 for p in right_pad): + return x + + # For 4 spatial dims fall back to the generic implementation (works for 6D+). + if spatial_ndim == 4: + dims = [1, 2, 3, 4] + return pad_right_nd( + x, + dims=dims, + right_pad=right_pad, + mode=mode, + constant_value=constant_value, + ) + + # For 2D/3D spatial, use a reshape trick so F.pad(replicate) applies. + b = x.shape[0] + spatial_shape = x.shape[1 : 1 + spatial_ndim] + rest_shape = x.shape[1 + spatial_ndim :] + rest_prod = ( + 1 if len(rest_shape) == 0 else int(torch.tensor(rest_shape).prod().item()) + ) + + x_reshaped = x.reshape(b, *spatial_shape, rest_prod).permute( + 0, spatial_ndim + 1, *range(1, 1 + spatial_ndim) + ) + + if spatial_ndim == 2: + pad_h, pad_w = (int(p) for p in right_pad) + pad = (0, pad_w, 0, pad_h) + else: + pad_x, pad_y, pad_z = (int(p) for p in right_pad) + pad = (0, pad_z, 0, pad_y, 0, pad_x) + + if mode == "constant": + x_padded = F.pad(x_reshaped, pad, mode="constant", value=float(constant_value)) + else: + x_padded = F.pad(x_reshaped, pad, mode=mode) + + padded_spatial = x_padded.shape[2 : 2 + spatial_ndim] + return x_padded.permute(0, *range(2, 2 + spatial_ndim), 1).reshape( + b, *padded_spatial, *rest_shape + ) + + +__all__ = [ + "compute_right_pad_to_multiple", + "pad_right_nd", + "pad_spatial_right", +] diff --git a/physicsnemo/experimental/models/xdeeponet/wrappers.py b/physicsnemo/experimental/models/xdeeponet/wrappers.py new file mode 100644 index 0000000000..a1a71698ca --- /dev/null +++ b/physicsnemo/experimental/models/xdeeponet/wrappers.py @@ -0,0 +1,434 @@ +# SPDX-FileCopyrightText: Copyright (c) 2023 - 2026 NVIDIA CORPORATION & AFFILIATES. +# SPDX-FileCopyrightText: All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Convenience wrappers for xDeepONet. + +These wrappers add two ergonomic features on top of the core +:class:`~physicsnemo.experimental.models.xdeeponet.deeponet.DeepONet` and +:class:`~physicsnemo.experimental.models.xdeeponet.deeponet.DeepONet3D`: + +1. **Automatic spatial padding** to align the input to a multiple (default + 8), which makes the Fourier, UNet, and Conv sub-branches compatible + across arbitrary grid sizes. Outputs are cropped back to the original + spatial shape before return. +2. **Automatic trunk input extraction** from the full spatiotemporal input + tensor. Given ``(B, H, W, T, C)`` (2D) or ``(B, X, Y, Z, T, C)`` (3D) + and a ``target_times`` kwarg (optional), the wrapper assembles the + trunk query coordinates according to the ``trunk.input_type`` setting + (``"time"`` or ``"grid"``). + +These wrappers are the recommended public entry points for xDeepONet. +""" + +from __future__ import annotations + +from dataclasses import dataclass +from typing import Any, Dict, Optional + +import torch +from jaxtyping import Float +from torch import Tensor + +from physicsnemo.core.meta import ModelMetaData +from physicsnemo.core.module import Module +from physicsnemo.experimental.models.xdeeponet.deeponet import ( + DeepONet, + DeepONet3D, + _DecoderTypeStr, + _VariantStr, +) +from physicsnemo.experimental.models.xdeeponet.padding import ( + compute_right_pad_to_multiple, + pad_spatial_right, +) + + +@dataclass +class _DeepONetWrapperMetaData(ModelMetaData): + """PhysicsNeMo model metadata for :class:`DeepONetWrapper`.""" + + +@dataclass +class _DeepONet3DWrapperMetaData(ModelMetaData): + """PhysicsNeMo model metadata for :class:`DeepONet3DWrapper`.""" + + +class DeepONetWrapper(Module): + r"""2D xDeepONet wrapper with automatic padding and input extraction. + + Extracts the spatial channels and trunk coordinates from a packed 5D + input tensor, pads spatial dimensions to a multiple of 8, runs the + core :class:`~physicsnemo.experimental.models.xdeeponet.deeponet.DeepONet`, + and unpads. + + Parameters + ---------- + padding : int, optional + Minimum right-side padding; the wrapper rounds up to the next + multiple of 8. + variant : Literal["deeponet", "u_deeponet", "fourier_deeponet", "conv_deeponet", "hybrid_deeponet", "mionet", "fourier_mionet", "tno"], optional + xDeepONet variant (see + :attr:`~physicsnemo.experimental.models.xdeeponet.deeponet.DeepONet.VALID_VARIANTS`). + Mixed-case strings are accepted at runtime and lowercased. + width : int, optional + Latent width. + branch1_config : dict, optional + Primary branch configuration (see core class docstring). + branch2_config : dict, optional + Secondary branch configuration for MIONet/TNO variants. + trunk_config : dict, optional + Trunk configuration. May specify ``input_type`` as ``"time"`` + (uses the last input channel as the time coordinate) or + ``"grid"`` (uses the last three channels + ``(grid_x, grid_y, grid_t)``). + decoder_type : Literal["mlp", "conv", "temporal_projection"], optional + One of ``"mlp"``, ``"conv"``, or ``"temporal_projection"``; + mixed-case strings are accepted and lowercased. + decoder_width : int, optional + Decoder hidden width. + decoder_layers : int, optional + Decoder layer count. + decoder_activation_fn : str, optional + Activation function name for the decoder. + output_window : int, optional + Output window for the ``"temporal_projection"`` decoder (forwarded + to :class:`DeepONet`). + + Forward + ------- + x : torch.Tensor + Packed input of shape :math:`(B, H, W, T, C)` where the last + channel axis holds features plus time/grid coordinates. + x_branch2 : torch.Tensor, optional + Secondary branch input for MIONet/TNO variants. + target_times : torch.Tensor, optional + Explicit trunk query coordinates of shape :math:`(K,)` or + :math:`(K, 1)`. When provided the trunk evaluates at these + :math:`K` points instead of the times extracted from ``x``, + enabling autoregressive temporal bundling where :math:`K \neq T`. + + Outputs + ------- + torch.Tensor + Operator output of shape :math:`(B, H, W, T_{out})` where + :math:`T_{out} = K` when ``target_times`` is given and + :math:`T_{out} = T` otherwise. + + Examples + -------- + >>> import torch + >>> from physicsnemo.experimental.models.xdeeponet import DeepONetWrapper + >>> model = DeepONetWrapper( + ... variant="u_deeponet", + ... width=64, + ... branch1_config={ + ... "encoder": {"type": "linear", "activation_fn": "tanh"}, + ... "layers": {"num_unet_layers": 1, "kernel_size": 3}, + ... }, + ... trunk_config={"input_type": "time", "hidden_width": 64, "num_layers": 4}, + ... ) + >>> x = torch.randn(2, 32, 32, 3, 5) # (B, H, W, T, C) + >>> out = model(x) # (2, 32, 32, 3) + """ + + def __init__( + self, + padding: int = 8, + variant: _VariantStr = "u_deeponet", + width: int = 64, + branch1_config: Dict[str, Any] = None, + branch2_config: Dict[str, Any] = None, + trunk_config: Dict[str, Any] = None, + decoder_type: _DecoderTypeStr = "mlp", + decoder_width: int = 128, + decoder_layers: int = 2, + decoder_activation_fn: str = "relu", + output_window: Optional[int] = None, + ): + super().__init__(meta=_DeepONetWrapperMetaData()) + + self.padding = ((padding + 7) // 8) * 8 if padding % 8 != 0 else padding + self.variant = variant + + trunk_config = dict(trunk_config or {}) + self.trunk_input = trunk_config.get("input_type", "time").lower() + + if self.trunk_input not in ["time", "grid"]: + raise ValueError("trunk input_type must be 'time' or 'grid'") + + if self.trunk_input == "grid": + trunk_config["in_features"] = 3 # (x, y, t) + else: + trunk_config["in_features"] = trunk_config.get("in_features", 1) + + self.model = DeepONet( + variant=variant, + width=width, + branch1_config=branch1_config, + branch2_config=branch2_config, + trunk_config=trunk_config, + decoder_type=decoder_type, + decoder_width=decoder_width, + decoder_layers=decoder_layers, + decoder_activation_fn=decoder_activation_fn, + output_window=output_window, + ) + self._temporal_projection = self.model._temporal_projection + + def set_output_window(self, K: int): + """Delegate to the inner :class:`DeepONet` model.""" + self.model.set_output_window(K) + + def forward( + self, + x: Float[Tensor, "batch height width time channels"], + x_branch2: Optional[Float[Tensor, "..."]] = None, + target_times: Optional[Float[Tensor, "..."]] = None, + ) -> Float[Tensor, "batch height width time_out"]: + """Forward pass through the 2D wrapper. + + See class docstring for input/output shapes. ``x_branch2`` and + ``target_times`` accept multiple ranks (see Forward section); their + strict shapes are validated at the top of this method under the + :func:`torch.compiler.is_compiling` guard, so the jaxtyping + annotation uses the unconstrained ``"..."`` shape for those. + """ + if not torch.compiler.is_compiling(): + if x.ndim != 5: + raise ValueError( + f"Expected 5D input (B, H, W, T, C), got {x.ndim}D " + f"tensor with shape {tuple(x.shape)}" + ) + if target_times is not None and target_times.ndim not in (1, 2): + raise ValueError( + f"Expected target_times to be 1D (K,) or 2D (K, 1), " + f"got {target_times.ndim}D tensor with shape " + f"{tuple(target_times.shape)}" + ) + + H, W = x.shape[1], x.shape[2] + + pad_h, pad_w = compute_right_pad_to_multiple( + (H, W), multiple=8, min_right_pad=self.padding + ) + x = pad_spatial_right( + x, spatial_ndim=2, right_pad=(pad_h, pad_w), mode="replicate" + ) + + if x_branch2 is not None and x_branch2.dim() > 2: + x_branch2 = pad_spatial_right( + x_branch2, + spatial_ndim=2, + right_pad=(pad_h, pad_w), + mode="replicate", + ) + + x_spatial = x.permute(0, 4, 1, 2, 3)[..., 0].permute(0, 2, 3, 1) + + if target_times is not None: + if self.trunk_input == "grid": + t_vals = ( + target_times + if target_times.dim() == 1 + else target_times.squeeze(-1) + ) + spatial = x[0, 0, 0, 0, -3:-1] + spatial_exp = spatial.unsqueeze(0).expand(t_vals.shape[0], -1) + x_trunk = torch.cat([spatial_exp, t_vals.unsqueeze(-1)], dim=-1) + else: + x_trunk = ( + target_times + if target_times.dim() == 2 + else target_times.unsqueeze(-1) + ) + elif self.trunk_input == "grid": + x_trunk = x[0, 0, 0, :, -3:] + else: + x_trunk = x[0, 0, 0, :, -1].unsqueeze(-1) + + return self.model(x_spatial, x_trunk, x_branch2)[:, :H, :W, :] + + def count_params(self) -> int: + """Return the number of trainable parameters.""" + return self.model.count_params() + + +class DeepONet3DWrapper(Module): + r"""3D xDeepONet wrapper with automatic padding and input extraction. + + See :class:`DeepONetWrapper` for parameter semantics. The 3D trunk + ``input_type="grid"`` uses the last four input channels + ``(grid_x, grid_y, grid_z, grid_t)``. + + Forward + ------- + x : torch.Tensor + Packed input of shape :math:`(B, X, Y, Z, T, C)`. + x_branch2 : torch.Tensor, optional + Secondary branch input for MIONet/TNO variants. + target_times : torch.Tensor, optional + Explicit trunk query coordinates of shape :math:`(K,)` or + :math:`(K, 1)`. + + Outputs + ------- + torch.Tensor + Operator output of shape :math:`(B, X, Y, Z, T_{out})`. + + Examples + -------- + >>> import torch + >>> from physicsnemo.experimental.models.xdeeponet import DeepONet3DWrapper + >>> model = DeepONet3DWrapper( + ... variant="tno", + ... width=64, + ... branch1_config={ + ... "encoder": {"type": "linear", "activation_fn": "tanh"}, + ... "layers": {"num_unet_layers": 1, "kernel_size": 3}, + ... }, + ... branch2_config={ + ... "encoder": {"type": "linear", "activation_fn": "tanh"}, + ... "layers": {"num_unet_layers": 1, "kernel_size": 3}, + ... }, + ... trunk_config={"input_type": "time", "hidden_width": 64, "num_layers": 4}, + ... ) + >>> x = torch.randn(1, 16, 16, 16, 2, 5) # (B, X, Y, Z, T, C) + >>> prev = torch.randn(1, 16, 16, 16, 1) # previous solution (TNO branch2) + >>> out = model(x, x_branch2=prev) # (1, 16, 16, 16, 2) + """ + + def __init__( + self, + padding: int = 8, + variant: _VariantStr = "u_deeponet", + width: int = 64, + branch1_config: Dict[str, Any] = None, + branch2_config: Dict[str, Any] = None, + trunk_config: Dict[str, Any] = None, + decoder_type: _DecoderTypeStr = "mlp", + decoder_width: int = 128, + decoder_layers: int = 2, + decoder_activation_fn: str = "relu", + output_window: Optional[int] = None, + ): + super().__init__(meta=_DeepONet3DWrapperMetaData()) + + self.padding = ((padding + 7) // 8) * 8 if padding % 8 != 0 else padding + self.variant = variant + + trunk_config = dict(trunk_config or {}) + self.trunk_input = trunk_config.get("input_type", "time").lower() + + if self.trunk_input not in ["time", "grid"]: + raise ValueError("trunk input_type must be 'time' or 'grid'") + + if self.trunk_input == "grid": + trunk_config["in_features"] = 4 # (x, y, z, t) + else: + trunk_config["in_features"] = trunk_config.get("in_features", 1) + + self.model = DeepONet3D( + variant=variant, + width=width, + branch1_config=branch1_config, + branch2_config=branch2_config, + trunk_config=trunk_config, + decoder_type=decoder_type, + decoder_width=decoder_width, + decoder_layers=decoder_layers, + decoder_activation_fn=decoder_activation_fn, + output_window=output_window, + ) + self._temporal_projection = self.model._temporal_projection + + def set_output_window(self, K: int): + """Delegate to the inner :class:`DeepONet3D` model.""" + self.model.set_output_window(K) + + def forward( + self, + x: Float[Tensor, "batch X Y Z time channels"], + x_branch2: Optional[Float[Tensor, "..."]] = None, + target_times: Optional[Float[Tensor, "..."]] = None, + ) -> Float[Tensor, "batch X Y Z time_out"]: + """Forward pass through the 3D wrapper. + + See class docstring for input/output shapes. ``x_branch2`` and + ``target_times`` accept multiple ranks; their strict shapes are + validated at the top of this method under the + :func:`torch.compiler.is_compiling` guard. + """ + if not torch.compiler.is_compiling(): + if x.ndim != 6: + raise ValueError( + f"Expected 6D input (B, X, Y, Z, T, C), got {x.ndim}D " + f"tensor with shape {tuple(x.shape)}" + ) + if target_times is not None and target_times.ndim not in (1, 2): + raise ValueError( + f"Expected target_times to be 1D (K,) or 2D (K, 1), " + f"got {target_times.ndim}D tensor with shape " + f"{tuple(target_times.shape)}" + ) + + X, Y, Z = x.shape[1], x.shape[2], x.shape[3] + + pad_x, pad_y, pad_z = compute_right_pad_to_multiple( + (X, Y, Z), multiple=8, min_right_pad=self.padding + ) + x = pad_spatial_right( + x, spatial_ndim=3, right_pad=(pad_x, pad_y, pad_z), mode="replicate" + ) + + if x_branch2 is not None and x_branch2.dim() > 2: + x_branch2 = pad_spatial_right( + x_branch2, + spatial_ndim=3, + right_pad=(pad_x, pad_y, pad_z), + mode="replicate", + ) + + x_spatial = x[:, :, :, :, 0, :] + + if target_times is not None: + if self.trunk_input == "grid": + t_vals = ( + target_times + if target_times.dim() == 1 + else target_times.squeeze(-1) + ) + spatial = x[0, 0, 0, 0, 0, -4:-1] + spatial_exp = spatial.unsqueeze(0).expand(t_vals.shape[0], -1) + x_trunk = torch.cat([spatial_exp, t_vals.unsqueeze(-1)], dim=-1) + else: + x_trunk = ( + target_times + if target_times.dim() == 2 + else target_times.unsqueeze(-1) + ) + elif self.trunk_input == "grid": + x_trunk = x[0, 0, 0, 0, :, -4:] + else: + x_trunk = x[0, 0, 0, 0, :, -1].unsqueeze(-1) + + return self.model(x_spatial, x_trunk, x_branch2)[:, :X, :Y, :Z, :] + + def count_params(self) -> int: + """Return the number of trainable parameters.""" + return self.model.count_params() + + +__all__ = ["DeepONetWrapper", "DeepONet3DWrapper"] diff --git a/test/experimental/models/xdeeponet/data/_generate_xdeeponet_goldens.py b/test/experimental/models/xdeeponet/data/_generate_xdeeponet_goldens.py new file mode 100644 index 0000000000..70a7b4b055 --- /dev/null +++ b/test/experimental/models/xdeeponet/data/_generate_xdeeponet_goldens.py @@ -0,0 +1,69 @@ +# SPDX-FileCopyrightText: Copyright (c) 2023 - 2026 NVIDIA CORPORATION & AFFILIATES. +# SPDX-FileCopyrightText: All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Regenerate the xDeepONet golden ``.pth`` fixtures. + +Run from the repository root:: + + python test/experimental/models/data/_generate_xdeeponet_goldens.py + +Overwrites the committed fixtures with freshly-seeded model outputs. +Invoke this deliberately whenever model numerics intentionally change +(architecture edit, default-argument change, etc.) and commit the +resulting ``.pth`` files. +""" + +from __future__ import annotations + +import sys +from pathlib import Path + +import torch + +_REPO_ROOT = Path(__file__).resolve().parents[5] +# Repo root: so ``import physicsnemo...`` resolves. +# xdeeponet test dir: so ``import test_xdeeponet`` resolves. +sys.path.insert(0, str(_REPO_ROOT)) +sys.path.insert( + 0, str(_REPO_ROOT / "test" / "experimental" / "models" / "xdeeponet") +) + +from test_xdeeponet import ( # noqa: E402 + _GOLDEN_2D, + _GOLDEN_3D, + _init_lazy, + _wrapper_2d, + _wrapper_3d, +) + + +def _write(path: Path, builder) -> None: + path.parent.mkdir(parents=True, exist_ok=True) + model, x = builder() + _init_lazy(model, x) + with torch.no_grad(): + y = model(x) + torch.save({"x": x, "y": y, "state_dict": model.state_dict()}, path) + print( + f"wrote {path.relative_to(_REPO_ROOT)} " + f"x={tuple(x.shape)} y={tuple(y.shape)} " + f"size={path.stat().st_size}B" + ) + + +if __name__ == "__main__": + _write(_GOLDEN_2D, _wrapper_2d) + _write(_GOLDEN_3D, _wrapper_3d) diff --git a/test/experimental/models/xdeeponet/data/xdeeponet_wrapper_2d_v1.pth b/test/experimental/models/xdeeponet/data/xdeeponet_wrapper_2d_v1.pth new file mode 100644 index 0000000000..2ea8332cc9 Binary files /dev/null and b/test/experimental/models/xdeeponet/data/xdeeponet_wrapper_2d_v1.pth differ diff --git a/test/experimental/models/xdeeponet/data/xdeeponet_wrapper_3d_v1.pth b/test/experimental/models/xdeeponet/data/xdeeponet_wrapper_3d_v1.pth new file mode 100644 index 0000000000..b060249ccf Binary files /dev/null and b/test/experimental/models/xdeeponet/data/xdeeponet_wrapper_3d_v1.pth differ diff --git a/test/experimental/models/xdeeponet/test_xdeeponet.py b/test/experimental/models/xdeeponet/test_xdeeponet.py new file mode 100644 index 0000000000..1c1ad4f549 --- /dev/null +++ b/test/experimental/models/xdeeponet/test_xdeeponet.py @@ -0,0 +1,381 @@ +# SPDX-FileCopyrightText: Copyright (c) 2023 - 2026 NVIDIA CORPORATION & AFFILIATES. +# SPDX-FileCopyrightText: All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Minimal, reviewer-preferred test suite for the xDeepONet family. + +Covers, per `MOD-008a/b/c <../../CODING_STANDARDS/MODELS_IMPLEMENTATION.md>`_ +and the PR #1576 human review: + +- **Constructor + public attributes** (MOD-008a) — default and custom configs. +- **Forward non-regression** (MOD-008b) — compare a single forward pass + against committed golden ``.pth`` fixtures. +- **Checkpoint round-trip** (MOD-008c) — ``save`` to ``.mdlus``, reload via + :meth:`physicsnemo.Module.from_checkpoint`, and verify the loaded model + reproduces the same output as the in-memory model. +- **Gradient flow** — backward pass produces non-None gradients on input + and parameters. +- **torch.compile smoke** — wrapping the model in :func:`torch.compile` + succeeds and produces shape-compatible output. + +Broader shape / variant / error-path coverage (all 8 variants, both +decoder types, construction-time guards, Fourier code path, adaptive +pooling, etc.) lives on the ``pr/neural-operator-factory`` branch as +``examples/reservoir_simulation/neural_operator_factory/tests/test_xdeeponet_upstream.py`` +for local regression when preparing future NOF upstream PRs. +""" + +from __future__ import annotations + +from pathlib import Path + +import pytest +import torch + +from physicsnemo import Module +from physicsnemo.experimental.models.xdeeponet import ( + DeepONet, + DeepONet3D, + DeepONet3DWrapper, + DeepONetWrapper, +) + +_DATA_DIR = Path(__file__).parent / "data" +_GOLDEN_2D = _DATA_DIR / "xdeeponet_wrapper_2d_v1.pth" +_GOLDEN_3D = _DATA_DIR / "xdeeponet_wrapper_3d_v1.pth" +_SEED = 0 + +# Minimal branch/trunk configs chosen for (a) small tensor shapes so the +# golden files stay tiny and the tests run in well under a second each, +# and (b) exercising the spatial-branch + UNet + MLP-decoder path, which +# is the most common user-facing configuration. +_BRANCH_SPATIAL = { + "encoder": {"type": "linear", "activation_fn": "relu"}, + "layers": { + "num_fourier_layers": 0, + "num_unet_layers": 1, + "num_conv_layers": 0, + "kernel_size": 3, + "dropout": 0.0, + "activation_fn": "relu", + }, +} +_TRUNK = { + "input_type": "time", + "hidden_width": 16, + "num_layers": 2, + "activation_fn": "tanh", +} + + +def _wrapper_2d() -> tuple[DeepONetWrapper, torch.Tensor]: + """Build a deterministic 2D wrapper + matching input. + + Uses a fixed RNG seed so the materialised :class:`torch.nn.LazyLinear` + weights (created on the first forward pass) are reproducible across + runs. + """ + torch.manual_seed(_SEED) + model = DeepONetWrapper( + padding=8, + variant="u_deeponet", + width=8, + branch1_config=_BRANCH_SPATIAL, + trunk_config=_TRUNK, + decoder_type="mlp", + decoder_width=8, + decoder_layers=1, + ) + x = torch.randn(1, 8, 8, 2, 2) + return model, x + + +def _wrapper_3d() -> tuple[DeepONet3DWrapper, torch.Tensor]: + """Build a deterministic 3D wrapper + matching input.""" + torch.manual_seed(_SEED) + model = DeepONet3DWrapper( + padding=8, + variant="u_deeponet", + width=8, + branch1_config=_BRANCH_SPATIAL, + trunk_config=_TRUNK, + decoder_type="mlp", + decoder_width=8, + decoder_layers=1, + ) + x = torch.randn(1, 8, 8, 8, 2, 2) + return model, x + + +def _init_lazy(model, x) -> None: + """Run one forward pass to materialise ``nn.LazyLinear`` parameters.""" + with torch.no_grad(): + model(x) + + +def _load_golden(path: Path) -> dict[str, torch.Tensor | dict]: + """Load a golden fixture; fail loudly with a regeneration hint if missing. + + Fixtures under ``test/experimental/models/data/`` are committed + alongside this file and updated deliberately when model numerics + change (via the generator script + ``test/experimental/models/data/_generate_xdeeponet_goldens.py``). + """ + if not path.exists(): + raise FileNotFoundError( + f"Golden fixture {path} is missing. Regenerate with " + f"``python test/experimental/models/xdeeponet/data/" + f"_generate_xdeeponet_goldens.py`` and commit the " + f"resulting ``.pth`` files." + ) + # Golden payload is {str -> Tensor | dict[str, Tensor]} so + # ``weights_only=True`` is the safer default and avoids PyTorch 2.6's + # FutureWarning on the legacy permissive path. + return torch.load(path, weights_only=True) + + +# ---------------------------------------------------------------------- +# Constructor + public attributes (MOD-008a) +# ---------------------------------------------------------------------- + + +class TestDeepONetConstructor: + """Constructor instantiates and exposes the documented public attributes.""" + + @pytest.mark.parametrize( + "config", + [ + {"variant": "u_deeponet", "width": 8, "decoder_type": "mlp"}, + {"variant": "deeponet", "width": 16, "decoder_type": "conv"}, + ], + ids=["default-ish", "custom"], + ) + def test_deeponet_2d_core(self, config): + """``DeepONet`` stores the constructor arguments on public attrs.""" + model = DeepONet( + variant=config["variant"], + width=config["width"], + branch1_config=_BRANCH_SPATIAL, + trunk_config=_TRUNK, + decoder_type=config["decoder_type"], + decoder_width=config["width"], + decoder_layers=1, + ) + assert model.variant == config["variant"] + assert model.width == config["width"] + assert model.decoder_type == config["decoder_type"] + assert model.decoder_activation_fn == "relu" + assert model.trunk is not None + + @pytest.mark.parametrize( + "config", + [ + {"variant": "u_deeponet", "width": 8, "decoder_type": "mlp"}, + {"variant": "deeponet", "width": 16, "decoder_type": "conv"}, + ], + ids=["default-ish", "custom"], + ) + def test_deeponet_3d_core(self, config): + """``DeepONet3D`` stores the constructor arguments on public attrs.""" + model = DeepONet3D( + variant=config["variant"], + width=config["width"], + branch1_config=_BRANCH_SPATIAL, + trunk_config=_TRUNK, + decoder_type=config["decoder_type"], + decoder_width=config["width"], + decoder_layers=1, + ) + assert model.variant == config["variant"] + assert model.width == config["width"] + assert model.decoder_type == config["decoder_type"] + assert model.decoder_activation_fn == "relu" + assert model.trunk is not None + + @pytest.mark.parametrize( + "config", + [ + {"padding": 8, "variant": "u_deeponet", "trunk_input": "time"}, + {"padding": 16, "variant": "deeponet", "trunk_input": "grid"}, + ], + ids=["default-ish", "custom"], + ) + def test_wrapper_2d(self, config): + """``DeepONetWrapper`` exposes padding / variant / trunk_input.""" + model = DeepONetWrapper( + padding=config["padding"], + variant=config["variant"], + width=8, + branch1_config=_BRANCH_SPATIAL, + trunk_config={**_TRUNK, "input_type": config["trunk_input"]}, + decoder_type="mlp", + decoder_width=8, + decoder_layers=1, + ) + assert model.padding == config["padding"] + assert model.variant == config["variant"] + assert model.trunk_input == config["trunk_input"] + assert isinstance(model.model, DeepONet) + + @pytest.mark.parametrize( + "config", + [ + {"padding": 8, "variant": "u_deeponet", "trunk_input": "time"}, + {"padding": 16, "variant": "deeponet", "trunk_input": "grid"}, + ], + ids=["default-ish", "custom"], + ) + def test_wrapper_3d(self, config): + """``DeepONet3DWrapper`` exposes padding / variant / trunk_input.""" + model = DeepONet3DWrapper( + padding=config["padding"], + variant=config["variant"], + width=8, + branch1_config=_BRANCH_SPATIAL, + trunk_config={**_TRUNK, "input_type": config["trunk_input"]}, + decoder_type="mlp", + decoder_width=8, + decoder_layers=1, + ) + assert model.padding == config["padding"] + assert model.variant == config["variant"] + assert model.trunk_input == config["trunk_input"] + assert isinstance(model.model, DeepONet3D) + + +# ---------------------------------------------------------------------- +# Forward non-regression against committed golden files (MOD-008b) +# ---------------------------------------------------------------------- + + +class TestDeepONetNonRegression: + """Forward output matches the committed golden fixture.""" + + def test_wrapper_2d_matches_golden(self): + """2D wrapper: loading fixed state_dict reproduces the stored output.""" + golden = _load_golden(_GOLDEN_2D) + model, _ = _wrapper_2d() + _init_lazy(model, golden["x"]) + model.load_state_dict(golden["state_dict"]) + with torch.no_grad(): + y = model(golden["x"]) + torch.testing.assert_close(y, golden["y"], rtol=1e-5, atol=1e-6) + + def test_wrapper_3d_matches_golden(self): + """3D wrapper: loading fixed state_dict reproduces the stored output.""" + golden = _load_golden(_GOLDEN_3D) + model, _ = _wrapper_3d() + _init_lazy(model, golden["x"]) + model.load_state_dict(golden["state_dict"]) + with torch.no_grad(): + y = model(golden["x"]) + torch.testing.assert_close(y, golden["y"], rtol=1e-5, atol=1e-6) + + +# ---------------------------------------------------------------------- +# Checkpoint (.mdlus) round-trip (MOD-008c) +# ---------------------------------------------------------------------- + + +class TestDeepONetCheckpoint: + """``Module.save`` + ``Module.from_checkpoint`` round-trip. + + Verifies that :meth:`physicsnemo.Module.from_checkpoint` reconstructs a + byte-identical model. The loaded model's forward output is compared + **against the committed golden fixture** — not against a second forward + pass on the in-memory model — so the test fails if the serialized + state is incomplete, corrupted, or silently re-initialised. + + PyTorch's :meth:`torch.nn.Module.load_state_dict` natively materialises + :class:`torch.nn.LazyLinear` parameters from the saved tensors, so no + ``_init_lazy`` call is needed on the reloaded model. + """ + + def _roundtrip(self, model, x, tmp_path): + _init_lazy(model, x) + ckpt = tmp_path / "model.mdlus" + model.save(str(ckpt)) + loaded = Module.from_checkpoint(str(ckpt)) + with torch.no_grad(): + y_loaded = loaded(x) + return loaded, y_loaded + + def test_wrapper_2d_roundtrip(self, tmp_path): + """2D wrapper: reloaded output matches the committed golden.""" + golden = _load_golden(_GOLDEN_2D) + model, _ = _wrapper_2d() + loaded, y_loaded = self._roundtrip(model, golden["x"], tmp_path) + assert type(loaded).__name__ == type(model).__name__ + assert loaded.padding == model.padding + assert loaded.variant == model.variant + assert loaded.trunk_input == model.trunk_input + torch.testing.assert_close(y_loaded, golden["y"], rtol=1e-5, atol=1e-6) + + def test_wrapper_3d_roundtrip(self, tmp_path): + """3D wrapper: reloaded output matches the committed golden.""" + golden = _load_golden(_GOLDEN_3D) + model, _ = _wrapper_3d() + loaded, y_loaded = self._roundtrip(model, golden["x"], tmp_path) + assert type(loaded).__name__ == type(model).__name__ + assert loaded.padding == model.padding + assert loaded.variant == model.variant + assert loaded.trunk_input == model.trunk_input + torch.testing.assert_close(y_loaded, golden["y"], rtol=1e-5, atol=1e-6) + + +# ---------------------------------------------------------------------- +# Gradient flow +# ---------------------------------------------------------------------- + + +class TestDeepONetGradientFlow: + """Backward pass produces non-None gradients on input and parameters.""" + + def test_wrapper_2d_gradients(self): + """Gradients flow through the 2D wrapper.""" + model, x = _wrapper_2d() + _init_lazy(model, x) + x = x.detach().requires_grad_(True) + y = model(x) + y.sum().backward() + assert x.grad is not None + trainable = [p for p in model.parameters() if p.requires_grad] + assert trainable, "model has no trainable parameters" + assert any(p.grad is not None for p in trainable) + + +# ---------------------------------------------------------------------- +# torch.compile smoke test +# ---------------------------------------------------------------------- + + +class TestDeepONetCompile: + """``torch.compile`` wraps the model without raising.""" + + def test_wrapper_2d_compile(self): + """Compiled model produces shape-compatible output vs eager.""" + model, x = _wrapper_2d() + _init_lazy(model, x) + with torch.no_grad(): + y_eager = model(x) + compiled = torch.compile(model, fullgraph=False) + with torch.no_grad(): + y_compiled = compiled(x) + assert y_compiled.shape == y_eager.shape + torch.testing.assert_close(y_compiled, y_eager, rtol=1e-4, atol=1e-5) + + +if __name__ == "__main__": + pytest.main([__file__, "-v"])