From a0cd8b0a865c9d53547d405269c0e8a5b3f4dcba Mon Sep 17 00:00:00 2001 From: Corey Adams <6619961+coreyjadams@users.noreply.github.com> Date: Fri, 13 Mar 2026 22:36:05 -0500 Subject: [PATCH 1/8] Update geotransolver for 2d and 3d use cases --- .../models/geotransolver/__init__.py | 11 +- .../models/geotransolver/context_projector.py | 402 ++++++++++++++---- .../experimental/models/geotransolver/gale.py | 389 ++++++++++++----- .../models/geotransolver/geotransolver.py | 127 +++++- .../geotransolver/test_geotransolver.py | 92 ++++ 5 files changed, 824 insertions(+), 197 deletions(-) diff --git a/physicsnemo/experimental/models/geotransolver/__init__.py b/physicsnemo/experimental/models/geotransolver/__init__.py index 109496b618..596f5da0a9 100644 --- a/physicsnemo/experimental/models/geotransolver/__init__.py +++ b/physicsnemo/experimental/models/geotransolver/__init__.py @@ -51,8 +51,12 @@ torch.Size([2, 1000, 3]) """ -from .context_projector import ContextProjector, GlobalContextBuilder -from .gale import GALE, GALE_block +from .context_projector import ( + ContextProjector, + GlobalContextBuilder, + StructuredContextProjector, +) +from .gale import GALE, GALE_block, GALEStructuredMesh2D, GALEStructuredMesh3D from .geotransolver import GeoTransolver, GeoTransolverMetaData __all__ = [ @@ -60,6 +64,9 @@ "GeoTransolverMetaData", "GALE", "GALE_block", + "GALEStructuredMesh2D", + "GALEStructuredMesh3D", "ContextProjector", "GlobalContextBuilder", + "StructuredContextProjector", ] \ No newline at end of file diff --git a/physicsnemo/experimental/models/geotransolver/context_projector.py b/physicsnemo/experimental/models/geotransolver/context_projector.py index 36031ec36c..547fa95642 100644 --- a/physicsnemo/experimental/models/geotransolver/context_projector.py +++ b/physicsnemo/experimental/models/geotransolver/context_projector.py @@ -49,7 +49,200 @@ import transformer_engine.pytorch as te -class ContextProjector(nn.Module): +def _compute_slices_from_projections_impl( + slice_projections: Float[torch.Tensor, "batch heads tokens slices"], + fx: Float[torch.Tensor, "batch heads tokens dim"], + temperature: torch.Tensor, + plus: bool, + proj_temperature: nn.Module | None = None, +) -> tuple[ + Float[torch.Tensor, "batch heads tokens slices"], + Float[torch.Tensor, "batch heads slices dim"], +]: + r"""Shared slice aggregation: temperature-weighted softmax then weighted sum over tokens. + + Used by both :class:`ContextProjector` and :class:`StructuredContextProjector` + to avoid duplicating the slice-weight and slice-token computation. + + Parameters + ---------- + slice_projections : torch.Tensor + Projection of each token onto each slice, shape :math:`(B, H, N, S)`. + fx : torch.Tensor + Latent features to aggregate per slice, shape :math:`(B, H, N, D)`. + temperature : torch.Tensor + Scalar temperature for softmax/gumbel, shape broadcastable to projections. + plus : bool + If ``True``, use Gumbel softmax with optional adaptive temperature. + proj_temperature : nn.Module or None, optional + If ``plus`` is ``True``, module mapping :math:`(B, H, N, D)` to adaptive + temperature; ignored otherwise. Default is ``None``. + + Returns + ------- + slice_weights : torch.Tensor + Normalized weights per token and slice, shape :math:`(B, H, N, S)`. + slice_token : torch.Tensor + Aggregated features per slice, shape :math:`(B, H, S, D)`. + """ + if plus and proj_temperature is not None: + temp = temperature + proj_temperature(fx) + clamped_temp = torch.clamp(temp, min=0.01).to(slice_projections.dtype) + slice_weights = gumbel_softmax(slice_projections, clamped_temp) + else: + clamped_temp = torch.clamp(temperature, min=0.5, max=5).to( + slice_projections.dtype + ) + slice_weights = nn.functional.softmax( + slice_projections / clamped_temp, dim=-1 + ) + slice_weights = slice_weights.to(slice_projections.dtype) + slice_norm = slice_weights.sum(2) + normed_weights = slice_weights / (slice_norm[:, :, None, :] + 1e-2) + slice_token = torch.matmul(normed_weights.transpose(2, 3), fx) + return slice_weights, slice_token + + +def _structured_grid_to_conv_input( + x: Float[torch.Tensor, "batch tokens channels"], + batch: int, + tokens: int, + channels: int, + ndim: int, + spatial_shape: tuple[int, ...], +) -> Float[torch.Tensor, "batch channels ..."]: + r"""Reshape flat token tensor to spatial layout for Conv2d/Conv3d. + + Converts :math:`(B, N, C)` to :math:`(B, C, H, W)` for 2D or + :math:`(B, C, H, W, D)` for 3D so that structured context projectors + can apply spatial convolutions. Validates that :math:`N` matches the + grid size. + + Parameters + ---------- + x : torch.Tensor + Input tensor of shape :math:`(B, N, C)` (batch, tokens, channels). + batch : int + Batch size :math:`B`. + tokens : int + Number of tokens :math:`N` (must equal :math:`H \\times W` or + :math:`H \\times W \\times D`). + channels : int + Channel dimension :math:`C`. + ndim : int + Number of spatial dimensions; must be 2 or 3. + spatial_shape : tuple[int, ...] + :math:`(H, W)` for 2D or :math:`(H, W, D)` for 3D. + + Returns + ------- + torch.Tensor + Reshaped tensor of shape :math:`(B, C, H, W)` or + :math:`(B, C, H, W, D)` for use as conv input. + + Raises + ------ + ValueError + If ``tokens`` does not match the product of ``spatial_shape``. + """ + if ndim == 2: + H, W = spatial_shape + if tokens != H * W: + raise ValueError( + f"Expected N={H * W} tokens for 2D grid, got N={tokens}" + ) + return x.view(batch, H, W, channels).permute(0, 3, 1, 2) + H, W, D = spatial_shape + if tokens != H * W * D: + raise ValueError( + f"Expected N={H * W * D} tokens for 3D grid, got N={tokens}" + ) + return x.view(batch, H, W, D, channels).permute(0, 4, 1, 2, 3) + + +class _SliceToContextMixin: + r"""Internal mixin providing shared slice-to-context init and slice aggregation. + + Used by :class:`ContextProjector` and :class:`StructuredContextProjector` to + avoid duplicating in_project_slice, temperature, proj_temperature, and + compute_slices_from_projections. + """ + + def _init_slice_components( + self, + dim_head: int, + slice_num: int, + heads: int, + use_te: bool, + plus: bool, + ) -> None: + r"""Initialize slice projection, temperature, and optional adaptive temperature. + + Sets ``in_project_slice``, ``temperature``, and (when ``plus`` is True) + ``proj_temperature`` on this instance. Uses Transformer Engine linear + when ``use_te`` is True and TE is available. + + Parameters + ---------- + dim_head : int + Head dimension for the slice projection input. + slice_num : int + Number of slices (output dimension of ``in_project_slice``). + heads : int + Number of heads (used for temperature shape). + use_te : bool + Whether to prefer Transformer Engine for linear layers. + plus : bool + If True, add ``proj_temperature`` for Transolver++. + """ + linear_layer = te.Linear if (use_te and TE_AVAILABLE) else nn.Linear + self.in_project_slice = linear_layer(dim_head, slice_num) + self.temperature = nn.Parameter(torch.ones([1, heads, 1, 1]) * 0.5) + if plus: + self.proj_temperature = nn.Sequential( + linear_layer(dim_head, slice_num), + nn.GELU(), + linear_layer(slice_num, 1), + nn.GELU(), + ) + + def compute_slices_from_projections( + self, + slice_projections: Float[torch.Tensor, "batch heads tokens slices"], + fx: Float[torch.Tensor, "batch heads tokens dim"], + ) -> tuple[ + Float[torch.Tensor, "batch heads tokens slices"], + Float[torch.Tensor, "batch heads slices dim"], + ]: + r"""Compute slice weights and slice tokens from projections and latent features. + + Delegates to :func:`_compute_slices_from_projections_impl` using this + instance's ``temperature``, ``plus``, and (when plus) ``proj_temperature``. + + Parameters + ---------- + slice_projections : torch.Tensor + Shape :math:`(B, H, N, S)`. + fx : torch.Tensor + Shape :math:`(B, H, N, D)`. + + Returns + ------- + tuple[torch.Tensor, torch.Tensor] + ``(slice_weights, slice_token)`` with shapes :math:`(B, H, N, S)` + and :math:`(B, H, S, D)`. + """ + proj_temp = getattr(self, "proj_temperature", None) if self.plus else None + return _compute_slices_from_projections_impl( + slice_projections, + fx, + self.temperature, + self.plus, + proj_temperature=proj_temp, + ) + + +class ContextProjector(_SliceToContextMixin, nn.Module): r"""Projects context features onto physical state space. This context projector is conceptually similar to half of a GALE attention layer. @@ -136,19 +329,8 @@ def __init__( # Attention components self.softmax = nn.Softmax(dim=-1) self.dropout = nn.Dropout(dropout) - self.temperature = nn.Parameter(torch.ones([1, heads, 1, 1]) * 0.5) - - # Transolver++ adaptive temperature projection - if plus: - self.proj_temperature = nn.Sequential( - linear_layer(self.dim_head, slice_num), - nn.GELU(), - linear_layer(slice_num, 1), - nn.GELU(), - ) - # Slice projection layer maps from head dimension to slice space - self.in_project_slice = linear_layer(dim_head, slice_num) + self._init_slice_components(dim_head, slice_num, heads, use_te, plus) def project_input_onto_slices( self, x: Float[torch.Tensor, "batch tokens channels"] @@ -193,72 +375,6 @@ def project_input_onto_slices( ) return projected_x, feature_projection - def compute_slices_from_projections( - self, - slice_projections: Float[torch.Tensor, "batch heads tokens slices"], - fx: Float[torch.Tensor, "batch heads tokens dim"], - ) -> tuple[ - Float[torch.Tensor, "batch heads tokens slices"], - Float[torch.Tensor, "batch heads slices dim"], - ]: - r"""Compute slice weights and slice tokens from input projections and latent features. - - Parameters - ---------- - slice_projections : torch.Tensor - Projected input tensor of shape :math:`(B, H, N, S)` where :math:`B` is batch size, - :math:`H` is number of heads, :math:`N` is number of tokens, and :math:`S` is number of - slices, representing the projection of each token onto each slice for each - attention head. - fx : torch.Tensor - Latent feature tensor of shape :math:`(B, H, N, D)` where :math:`D` is head dimension, - representing the learned states to be aggregated by the slice weights. - - Returns - ------- - tuple[torch.Tensor, torch.Tensor] - - ``slice_weights``: Tensor of shape :math:`(B, H, N, S)`, normalized weights for - each slice per token and head. - - ``slice_token``: Tensor of shape :math:`(B, H, S, D)`, aggregated latent features - for each slice, head, and batch. - - Notes - ----- - The function computes a temperature-scaled softmax over the slice projections to - obtain slice weights, then aggregates the latent features for each slice using - these weights. The aggregated features are normalized by the sum of weights for - numerical stability. - """ - # Compute temperature-adjusted softmax weights - if self.plus: - # Transolver++ uses adaptive temperature with Gumbel softmax - temperature = self.temperature + self.proj_temperature(fx) - clamped_temp = torch.clamp(temperature, min=0.01).to( - slice_projections.dtype - ) - slice_weights = gumbel_softmax(slice_projections, clamped_temp) - else: - # Standard Transolver uses fixed temperature with regular softmax - clamped_temp = torch.clamp(self.temperature, min=0.5, max=5).to( - slice_projections.dtype - ) - slice_weights = nn.functional.softmax( - slice_projections / clamped_temp, dim=-1 - ) - - # Ensure weights match the computation dtype - slice_weights = slice_weights.to(slice_projections.dtype) - - # Aggregate features by slice weights with normalization - # Normalize first to prevent overflow in reduced precision - slice_norm = slice_weights.sum(2) # Sum over tokens: (B, H, S) - normed_weights = slice_weights / (slice_norm[:, :, None, :] + 1e-2) - - # Weighted aggregation: (B, H, S, N) @ (B, H, N, D) -> (B, H, S, D) - slice_token = torch.matmul(normed_weights.transpose(2, 3), fx) - - return slice_weights, slice_token - def forward( self, x: Float[torch.Tensor, "batch tokens channels"] ) -> Float[torch.Tensor, "batch heads slices dim"]: @@ -313,6 +429,101 @@ def forward( return slice_tokens +class StructuredContextProjector(_SliceToContextMixin, nn.Module): + r"""Context projector with Conv2d/Conv3d geometry encoding on structured grids. + + Same output interface as :class:`ContextProjector`—slice tokens + :math:`(B, H, S, D)`—but projects per-cell geometry via spatial convolutions + aligned with structured GALE attention. + """ + + def __init__( + self, + dim: int, + spatial_shape: tuple[int, ...], + heads: int = 8, + dim_head: int = 64, + dropout: float = 0.0, + slice_num: int = 64, + kernel: int = 3, + use_te: bool = True, + plus: bool = False, + ) -> None: + super().__init__() + if len(spatial_shape) not in (2, 3): + raise ValueError( + f"StructuredContextProjector expects spatial_shape of length 2 or 3, got {spatial_shape!r}" + ) + inner_dim = dim_head * heads + self.dim_head = dim_head + self.heads = heads + self.plus = plus + self.use_te = use_te + self.spatial_shape = tuple(int(s) for s in spatial_shape) + self._nd = len(self.spatial_shape) + pad = kernel // 2 + if self._nd == 2: + H, W = self.spatial_shape + self.H, self.W = H, W + self.in_project_x = nn.Conv2d(dim, inner_dim, kernel, 1, pad) + if not plus: + self.in_project_fx = nn.Conv2d(dim, inner_dim, kernel, 1, pad) + else: + H, W, D_ = self.spatial_shape + self.H, self.W, self.D = H, W, D_ + self.in_project_x = nn.Conv3d(dim, inner_dim, kernel, 1, pad) + if not plus: + self.in_project_fx = nn.Conv3d(dim, inner_dim, kernel, 1, pad) + + self.softmax = nn.Softmax(dim=-1) + self.dropout = nn.Dropout(dropout) + self._init_slice_components(dim_head, slice_num, heads, use_te, plus) + + def _grid_project( + self, x: Float[torch.Tensor, "batch tokens channels"] + ) -> tuple[ + Float[torch.Tensor, "batch heads tokens dim"], + Float[torch.Tensor, "batch heads tokens dim"], + ]: + B, N, C = x.shape + grid = _structured_grid_to_conv_input( + x, B, N, C, self._nd, self.spatial_shape + ) + pattern = ( + "B (H D) h w -> B H (h w) D" + if self._nd == 2 + else "B (H D) h w d -> B H (h w d) D" + ) + px = rearrange( + self.in_project_x(grid), pattern, H=self.heads, D=self.dim_head + ) + if self.plus: + return px, px + pfx = rearrange( + self.in_project_fx(grid), pattern, H=self.heads, D=self.dim_head + ) + return px, pfx + + def forward( + self, x: Float[torch.Tensor, "batch tokens channels"] + ) -> Float[torch.Tensor, "batch heads slices dim"]: + if not torch.compiler.is_compiling(): + if x.ndim != 3: + raise ValueError( + f"Expected 3D input (B, N, C), got {x.ndim}D shape {tuple(x.shape)}" + ) + if self.plus: + projected_x = self._grid_project(x)[0] + feature_projection = projected_x + else: + projected_x, feature_projection = self._grid_project(x) + slice_projections = self.in_project_slice(projected_x) + _, slice_tokens = self.compute_slices_from_projections( + slice_projections, feature_projection + ) + return slice_tokens + + class GeometricFeatureProcessor(nn.Module): r"""Processes geometric features at a single spatial scale using BQWarp. @@ -616,6 +827,10 @@ class GlobalContextBuilder(nn.Module): Whether to use Transolver++ features. Default is ``False``. include_local_features : bool, optional Enable local feature extraction. Default is ``False``. + structured_shape : tuple[int, ...] | None, optional + If set, disables ball-query extractors and uses + :class:`StructuredContextProjector` for geometry when ``geometry_dim`` + is set. Default is ``None``. Forward ------- @@ -663,6 +878,7 @@ def __init__( use_te: bool = True, plus: bool = False, include_local_features: bool = False, + structured_shape: tuple[int, ...] | None = None, ) -> None: super().__init__() @@ -674,9 +890,17 @@ def __init__( dim_head = n_hidden // n_head context_dim = 0 + self.structured_shape = structured_shape + + # Ball-query local features are not used on structured grids + use_local_bq = ( + geometry_dim is not None + and include_local_features + and structured_shape is None + ) # Multi-scale extractors for local features (one per functional dim) - if geometry_dim is not None and include_local_features: + if use_local_bq: self.local_extractors = nn.ModuleList( [ MultiScaleFeatureExtractor( @@ -700,9 +924,21 @@ def __init__( # Geometry tokenizer for global geometry context if geometry_dim is not None: - self.geometry_tokenizer = ContextProjector( - geometry_dim, n_head, dim_head, dropout, slice_num, use_te, plus - ) + if structured_shape is not None: + self.geometry_tokenizer = StructuredContextProjector( + geometry_dim, + structured_shape, + n_head, + dim_head, + dropout, + slice_num, + use_te=use_te, + plus=plus, + ) + else: + self.geometry_tokenizer = ContextProjector( + geometry_dim, n_head, dim_head, dropout, slice_num, use_te, plus + ) context_dim += dim_head else: self.geometry_tokenizer = None diff --git a/physicsnemo/experimental/models/geotransolver/gale.py b/physicsnemo/experimental/models/geotransolver/gale.py index b64f8e9df5..d8b4870be5 100644 --- a/physicsnemo/experimental/models/geotransolver/gale.py +++ b/physicsnemo/experimental/models/geotransolver/gale.py @@ -33,6 +33,8 @@ from physicsnemo.nn import Mlp from physicsnemo.nn.module.physics_attention import ( PhysicsAttentionIrregularMesh, + PhysicsAttentionStructuredMesh2D, + PhysicsAttentionStructuredMesh3D, ) # Check optional dependency availability @@ -41,6 +43,148 @@ import transformer_engine.pytorch as te +def _gale_compute_slice_attention_cross( + module: nn.Module, + slice_tokens: list[Float[torch.Tensor, "batch heads slices dim"]], + context: Float[torch.Tensor, "batch heads context_slices context_dim"], +) -> list[Float[torch.Tensor, "batch heads slices dim"]]: + r"""Shared cross-attention between slice tokens and context. + + Used by :class:`GALE` and :class:`_GALEStructuredForwardMixin` so the + cross-attention implementation lives in one place. Projects queries from + concatenated slice tokens, keys and values from context; runs Transformer + Engine or SDPA attention; splits the result back to one tensor per input. + + Parameters + ---------- + module : nn.Module + Module with ``cross_q``, ``cross_k``, ``cross_v``, ``use_te``, + ``heads``, ``dim_head``, and (if ``use_te``) ``attn_fn``. + slice_tokens : list[torch.Tensor] + One tensor per input, each of shape :math:`(B, H, S, D)`. + context : torch.Tensor + Context tensor of shape :math:`(B, H, S_c, D_c)`. + + Returns + ------- + list[torch.Tensor] + One cross-attention output per element of ``slice_tokens``, each + of shape :math:`(B, H, S, D)`. + """ + q_input = torch.cat(slice_tokens, dim=-2) + q = module.cross_q(q_input) + k = module.cross_k(context) + v = module.cross_v(context) + if module.use_te: + q = rearrange(q, "b h s d -> b s h d") + k = rearrange(k, "b h s d -> b s h d") + v = rearrange(v, "b h s d -> b s h d") + cross_attention = module.attn_fn(q, k, v) + cross_attention = rearrange( + cross_attention, + "b s (h d) -> b h s d", + h=module.heads, + d=module.dim_head, + ) + else: + cross_attention = torch.nn.functional.scaled_dot_product_attention( + q, k, v, is_causal=False + ) + cross_attention = torch.split( + cross_attention, slice_tokens[0].shape[-2], dim=-2 + ) + return list(cross_attention) + + +def _gale_forward_impl( + module: nn.Module, + x: tuple[Float[torch.Tensor, "batch tokens channels"], ...], + context: Float[torch.Tensor, "batch heads context_slices context_dim"] + | None, +) -> list[Float[torch.Tensor, "batch tokens channels"]]: + r"""Single implementation of the GALE forward pipeline. + + Shared by :class:`GALE` and :class:`_GALEStructuredForwardMixin`. Steps: + validate inputs; project onto slices; compute slice weights and tokens; + apply self-attention on slices; optionally cross-attend to context and + mix with ``state_mixing``; project attention outputs back to token space. + + Parameters + ---------- + module : nn.Module + GALE-like module with ``project_input_onto_slices``, + ``in_project_slice``, ``_compute_slices_from_projections``, + ``_compute_slice_attention_te``, ``_compute_slice_attention_sdpa``, + ``compute_slice_attention_cross``, ``_project_attention_outputs``, + plus attributes ``use_te``, ``plus``, ``state_mixing``. + x : tuple[torch.Tensor, ...] + Input tensors, each of shape :math:`(B, N, C)`; must be non-empty. + context : torch.Tensor or None + Optional context of shape :math:`(B, H, S_c, D_c)` for cross-attention. + If ``None``, only self-attention is applied. + + Returns + ------- + list[torch.Tensor] + One output tensor per input, each of shape :math:`(B, N, C)`. + + Raises + ------ + ValueError + If ``x`` is empty or any element is not 3D. + """ + if not torch.compiler.is_compiling(): + if len(x) == 0: + raise ValueError("Expected non-empty tuple of input tensors") + for i, tensor in enumerate(x): + if tensor.ndim != 3: + raise ValueError( + f"Expected 3D input tensor (B, N, C) at index {i}, " + f"got {tensor.ndim}D tensor with shape {tuple(tensor.shape)}" + ) + if module.plus: + x_mid = [module.project_input_onto_slices(_x) for _x in x] + fx_mid = [_x_mid for _x_mid in x_mid] + else: + x_mid, fx_mid = zip( + *[module.project_input_onto_slices(_x) for _x in x] + ) + slice_projections = [module.in_project_slice(_x_mid) for _x_mid in x_mid] + slice_weights, slice_tokens = zip( + *[ + module._compute_slices_from_projections(proj, _fx_mid) + for proj, _fx_mid in zip(slice_projections, fx_mid) + ] + ) + if module.use_te: + self_slice_token = [ + module._compute_slice_attention_te(_slice_token) + for _slice_token in slice_tokens + ] + else: + self_slice_token = [ + module._compute_slice_attention_sdpa(_slice_token) + for _slice_token in slice_tokens + ] + if context is not None: + cross_slice_token = [ + module.compute_slice_attention_cross([_slice_token], context)[0] + for _slice_token in slice_tokens + ] + mixing_weight = torch.sigmoid(module.state_mixing) + out_slice_token = [ + mixing_weight * sst + (1 - mixing_weight) * cst + for sst, cst in zip(self_slice_token, cross_slice_token) + ] + else: + out_slice_token = self_slice_token + outputs = [ + module._project_attention_outputs(ost, sw) + for ost, sw in zip(out_slice_token, slice_weights) + ] + return outputs + + class GALE(PhysicsAttentionIrregularMesh): r"""Geometry-Aware Latent Embeddings (GALE) attention layer. @@ -121,7 +265,7 @@ def __init__( ) -> None: super().__init__(dim, heads, dim_head, dropout, slice_num, use_te, plus) - linear_layer = te.Linear if self.use_te else nn.Linear + linear_layer = te.Linear if (self.use_te and TE_AVAILABLE) else nn.Linear # Cross-attention projection layers for context integration self.cross_q = linear_layer(dim_head, dim_head) @@ -154,39 +298,10 @@ def compute_slice_attention_cross( list[torch.Tensor] List of cross-attention outputs, each of shape :math:`(B, H, S, D)`. """ - # Concatenate all slice tokens for batched projection - q_input = torch.cat(slice_tokens, dim=-2) # (B, H, total_slices, D) - - # Project queries from slice tokens - q = self.cross_q(q_input) # (B, H, total_slices, D) - - # Project keys and values from context - k = self.cross_k(context) # (B, H, S_c, D) - v = self.cross_v(context) # (B, H, S_c, D) - - # Compute cross-attention using appropriate backend - if self.use_te: - # Transformer Engine expects (B, S, H, D) format - q = rearrange(q, "b h s d -> b s h d") - k = rearrange(k, "b h s d -> b s h d") - v = rearrange(v, "b h s d -> b s h d") - cross_attention = self.attn_fn(q, k, v) - cross_attention = rearrange( - cross_attention, "b s (h d) -> b h s d", h=self.heads, d=self.dim_head - ) - else: - # Use PyTorch's scaled dot-product attention - cross_attention = torch.nn.functional.scaled_dot_product_attention( - q, k, v, is_causal=False - ) - - # Split back into individual slice token outputs - cross_attention = torch.split( - cross_attention, slice_tokens[0].shape[-2], dim=-2 + return _gale_compute_slice_attention_cross( + self, slice_tokens, context ) - return list(cross_attention) - def forward( self, x: tuple[Float[torch.Tensor, "batch tokens channels"], ...], @@ -216,74 +331,102 @@ def forward( List of output tensors, each of shape :math:`(B, N, C)``, same shape as inputs. """ - ### Input validation - if not torch.compiler.is_compiling(): - if len(x) == 0: - raise ValueError("Expected non-empty tuple of input tensors") - for i, tensor in enumerate(x): - if tensor.ndim != 3: - raise ValueError( - f"Expected 3D input tensor (B, N, C) at index {i}, " - f"got {tensor.ndim}D tensor with shape {tuple(tensor.shape)}" - ) + return _gale_forward_impl(self, x, context) - # Project inputs onto learned latent spaces - if self.plus: - x_mid = [self.project_input_onto_slices(_x) for _x in x] - # In Transolver++, x_mid is reused for both projections - fx_mid = [_x_mid for _x_mid in x_mid] - else: - x_mid, fx_mid = zip( - *[self.project_input_onto_slices(_x) for _x in x] - ) - # Project latent representations onto physical state slices - slice_projections = [self.in_project_slice(_x_mid) for _x_mid in x_mid] +def _gale_cross_init( + self: nn.Module, + dim_head: int, + context_dim: int, + use_te: bool, +) -> None: + # Match GALE: TE linear only when TE is installed (GALE_block already errors if use_te without TE) + linear_layer = te.Linear if (use_te and TE_AVAILABLE) else nn.Linear + self.cross_q = linear_layer(dim_head, dim_head) + self.cross_k = linear_layer(context_dim, dim_head) + self.cross_v = linear_layer(context_dim, dim_head) + self.state_mixing = nn.Parameter(torch.tensor(0.0)) + - # Compute slice weights and aggregated slice tokens - slice_weights, slice_tokens = zip( - *[ - self._compute_slices_from_projections(proj, _fx_mid) - for proj, _fx_mid in zip(slice_projections, fx_mid) - ] +class _GALEStructuredForwardMixin: + """Shared cross-attention and forward for structured GALE (2D/3D conv projection).""" + + def compute_slice_attention_cross( + self, + slice_tokens: list[Float[torch.Tensor, "batch heads slices dim"]], + context: Float[torch.Tensor, "batch heads context_slices context_dim"], + ) -> list[Float[torch.Tensor, "batch heads slices dim"]]: + return _gale_compute_slice_attention_cross( + self, slice_tokens, context ) - # Apply self-attention to slice tokens - if self.use_te: - self_slice_token = [ - self._compute_slice_attention_te(_slice_token) - for _slice_token in slice_tokens - ] - else: - self_slice_token = [ - self._compute_slice_attention_sdpa(_slice_token) - for _slice_token in slice_tokens - ] - - # Apply cross-attention with context if provided - if context is not None: - cross_slice_token = [ - self.compute_slice_attention_cross([_slice_token], context)[0] - for _slice_token in slice_tokens - ] - - # Blend self-attention and cross-attention with learnable mixing weight - mixing_weight = torch.sigmoid(self.state_mixing) - out_slice_token = [ - mixing_weight * sst + (1 - mixing_weight) * cst - for sst, cst in zip(self_slice_token, cross_slice_token) - ] - else: - # Use only self-attention when no context is provided - out_slice_token = self_slice_token + def forward( + self, + x: tuple[Float[torch.Tensor, "batch tokens channels"], ...], + context: Float[torch.Tensor, "batch heads context_slices context_dim"] + | None = None, + ) -> list[Float[torch.Tensor, "batch tokens channels"]]: + return _gale_forward_impl(self, x, context) - # Project attention outputs back to original space using slice weights - outputs = [ - self._project_attention_outputs(ost, sw) - for ost, sw in zip(out_slice_token, slice_weights) - ] - return outputs +class GALEStructuredMesh2D(_GALEStructuredForwardMixin, PhysicsAttentionStructuredMesh2D): + r"""GALE with Conv2d slice projection for 2D structured grids (see :class:`GALE`).""" + + def __init__( + self, + dim: int, + spatial_shape: tuple[int, int], + heads: int = 8, + dim_head: int = 64, + dropout: float = 0.0, + slice_num: int = 64, + kernel: int = 3, + use_te: bool = True, + plus: bool = False, + context_dim: int = 0, + ) -> None: + super().__init__( + dim, + spatial_shape, + heads, + dim_head, + dropout, + slice_num, + kernel, + use_te, + plus, + ) + _gale_cross_init(self, dim_head, context_dim, use_te) + + +class GALEStructuredMesh3D(_GALEStructuredForwardMixin, PhysicsAttentionStructuredMesh3D): + r"""GALE with Conv3d slice projection for 3D structured grids (see :class:`GALE`).""" + + def __init__( + self, + dim: int, + spatial_shape: tuple[int, int, int], + heads: int = 8, + dim_head: int = 64, + dropout: float = 0.0, + slice_num: int = 64, + kernel: int = 3, + use_te: bool = True, + plus: bool = False, + context_dim: int = 0, + ) -> None: + super().__init__( + dim, + spatial_shape, + heads, + dim_head, + dropout, + slice_num, + kernel, + use_te, + plus, + ) + _gale_cross_init(self, dim_head, context_dim, use_te) class GALE_block(nn.Module): @@ -317,6 +460,10 @@ class GALE_block(nn.Module): Whether to use Transolver++ features. Default is ``False``. context_dim : int, optional Dimension of the context vector for cross-attention. Default is 0. + spatial_shape : tuple[int, ...] | None, optional + If ``None``, uses irregular-mesh GALE. Length-2 tuple enables 2D Conv2d + projection; length-3 tuple enables 3D Conv3d projection (flattened + :math:`N = H \times W` or :math:`H \times W \times D`). Default is ``None``. Forward ------- @@ -369,6 +516,7 @@ def __init__( use_te: bool = True, plus: bool = False, context_dim: int = 0, + spatial_shape: tuple[int, ...] | None = None, ) -> None: super().__init__() @@ -386,17 +534,50 @@ def __init__( else: self.ln_1 = nn.LayerNorm(hidden_dim) - # GALE attention layer - self.Attn = GALE( - hidden_dim, - heads=num_heads, - dim_head=hidden_dim // num_heads, - dropout=dropout, - slice_num=slice_num, - use_te=use_te, - plus=plus, - context_dim=context_dim, - ) + dim_head = hidden_dim // num_heads + if spatial_shape is None: + self.Attn = GALE( + hidden_dim, + heads=num_heads, + dim_head=dim_head, + dropout=dropout, + slice_num=slice_num, + use_te=use_te, + plus=plus, + context_dim=context_dim, + ) + elif len(spatial_shape) == 2: + self.Attn = GALEStructuredMesh2D( + hidden_dim, + spatial_shape=(int(spatial_shape[0]), int(spatial_shape[1])), + heads=num_heads, + dim_head=dim_head, + dropout=dropout, + slice_num=slice_num, + use_te=use_te, + plus=plus, + context_dim=context_dim, + ) + elif len(spatial_shape) == 3: + self.Attn = GALEStructuredMesh3D( + hidden_dim, + spatial_shape=( + int(spatial_shape[0]), + int(spatial_shape[1]), + int(spatial_shape[2]), + ), + heads=num_heads, + dim_head=dim_head, + dropout=dropout, + slice_num=slice_num, + use_te=use_te, + plus=plus, + context_dim=context_dim, + ) + else: + raise ValueError( + f"spatial_shape must be None, length-2, or length-3; got {spatial_shape!r}" + ) # Feed-forward network with layer normalization if use_te: diff --git a/physicsnemo/experimental/models/geotransolver/geotransolver.py b/physicsnemo/experimental/models/geotransolver/geotransolver.py index 5b106810df..b743b6c966 100644 --- a/physicsnemo/experimental/models/geotransolver/geotransolver.py +++ b/physicsnemo/experimental/models/geotransolver/geotransolver.py @@ -23,6 +23,7 @@ from __future__ import annotations +import math from collections.abc import Sequence from dataclasses import dataclass @@ -143,6 +144,47 @@ def _normalize_tensor( raise TypeError(f"Invalid tensor structure") +def _structured_num_tokens(spatial_shape: tuple[int, ...]) -> int: + return int(math.prod(spatial_shape)) + + +def _flatten_for_structured( + t: torch.Tensor, + spatial_shape: tuple[int, ...], + name: str, +) -> torch.Tensor: + """Flatten (B,H,W,C) or (B,H,W,D,C) to (B,N,C); pass through (B,N,C) if N matches. + + Mirrors Transolver's structured flatten/unflatten behavior so the rest of + GeoTransolver can assume a single token layout (B, N, C). + """ + n = _structured_num_tokens(spatial_shape) + if t.ndim == 3: + if not torch.compiler.is_compiling() and t.shape[1] != n: + raise ValueError( + f"{name} token count {t.shape[1]} != structured grid size {n}" + ) + return t + if len(spatial_shape) == 2 and t.ndim == 4: + B, H, W, C = t.shape + if (H, W) != spatial_shape: + raise ValueError( + f"{name} spatial dims {(H, W)} != structured_shape {spatial_shape}" + ) + return t.reshape(B, n, C) + if len(spatial_shape) == 3 and t.ndim == 5: + B, H, W, D, C = t.shape + if (H, W, D) != spatial_shape: + raise ValueError( + f"{name} spatial dims {(H, W, D)} != structured_shape {spatial_shape}" + ) + return t.reshape(B, n, C) + raise ValueError( + f"{name}: expected (B,N,C) with N={n}, or spatial layout matching " + f"structured_shape {spatial_shape}; got shape {tuple(t.shape)}" + ) + + class GeoTransolver(Module): r"""GeoTransolver: Geometry-Aware Physics Attention Transformer. @@ -204,13 +246,18 @@ class GeoTransolver(Module): Neighbors in radius for the local features. Default is ``[8, 32]``. n_hidden_local : int, optional Hidden dimension for the local features. Default is 32. + structured_shape : tuple[int, ...] | None, optional + If set to ``(H, W)`` or ``(H, W, D)``, enables structured 2D/3D paths + (Conv2d/Conv3d GALE; no ball-query local features). Inputs may be + flattened :math:`(B, N, C)` with :math:`N = H W` or :math:`H W D`, or + spatial :math:`(B, H, W, C)` / :math:`(B, H, W, D, C)`. Default is ``None``. Forward ------- local_embedding : torch.Tensor | tuple[torch.Tensor, ...] - Local embedding of the input data of shape :math:`(B, N, C)` where :math:`B` - is batch size, :math:`N` is number of nodes/tokens, and :math:`C` is - ``functional_dim``. Can be a single tensor or tuple for multiple input types. + Local embedding: unstructured :math:`(B, N, C)`; structured 2D + :math:`(B, H, W, C)` or flattened :math:`(B, H W, C)`; structured 3D + :math:`(B, H, W, D, C)` or flattened. Can be a tuple for multiple input types. local_positions : torch.Tensor | tuple[torch.Tensor, ...] | None, optional Local positions for each input, each of shape :math:`(B, N, 3)`. Required if ``include_local_features=True``. Default is ``None``. @@ -228,9 +275,9 @@ class GeoTransolver(Module): Outputs ------- torch.Tensor | tuple[torch.Tensor, ...] - Output tensor of shape :math:`(B, N, C_{out})` where :math:`C_{out}` is - ``out_dim``. Returns a single tensor if input was a single tensor, or a - tuple if input was a tuple. + Unstructured: :math:`(B, N, C_{out})`. Structured: same as input layout— + flattened :math:`(B, N, C_{out})` or spatial :math:`(B, H, W, C_{out})` / + :math:`(B, H, W, D, C_{out})` when inputs were 4D/5D. Tuple if tuple in. Raises ------ @@ -244,8 +291,9 @@ class GeoTransolver(Module): Notes ----- - GeoTransolver currently supports unstructured mesh input only. Enhancements for - image-based and voxel-based inputs may be available in the future. + Unstructured mesh uses linear GALE projection; structured ``structured_shape`` + uses the same Conv2d/Conv3d slice projection as :class:`~physicsnemo.models.transolver.Transolver`. + Ball-query local features are disabled when ``structured_shape`` is set. For more details on Transolver, see: @@ -293,6 +341,21 @@ class GeoTransolver(Module): >>> output = model(local_emb, global_embedding=global_emb, geometry=geometry) >>> output.shape torch.Size([2, 1000, 3]) + + Structured 2D grid: + + >>> model = GeoTransolver( + ... functional_dim=3, + ... out_dim=1, + ... structured_shape=(8, 8), + ... n_hidden=64, + ... n_head=4, + ... n_layers=2, + ... use_te=False, + ... ) + >>> y = model(torch.randn(2, 8, 8, 3)) + >>> y.shape + torch.Size([2, 8, 8, 1]) """ def __init__( @@ -315,6 +378,7 @@ def __init__( radii: list[float] | None = None, neighbors_in_radius: list[int] | None = None, n_hidden_local: int = 32, + structured_shape: tuple[int, ...] | None = None, ) -> None: super().__init__(meta=GeoTransolverMetaData()) self.__name__ = "GeoTransolver" @@ -325,8 +389,22 @@ def __init__( if neighbors_in_radius is None: neighbors_in_radius = [8, 32] + if structured_shape is not None: + if include_local_features: + raise ValueError( + "include_local_features=True is not supported with structured_shape " + "(ball-query path is mesh-only)." + ) + if len(structured_shape) not in (2, 3): + raise ValueError( + f"structured_shape must have length 2 or 3, got {structured_shape!r}" + ) + if not all(int(s) > 0 for s in structured_shape): + raise ValueError(f"structured_shape must be positive ints, got {structured_shape!r}") + self.include_local_features = include_local_features self.use_te = use_te + self.structured_shape = structured_shape # Validate head dimension compatibility if not n_hidden % n_head == 0: @@ -357,6 +435,7 @@ def __init__( use_te=use_te, plus=plus, include_local_features=self.include_local_features, + structured_shape=structured_shape, ) context_dim = self.context_builder.get_context_dim() @@ -404,6 +483,7 @@ def __init__( use_te=use_te, plus=plus, context_dim=context_dim, + spatial_shape=structured_shape, ) for layer_idx in range(n_layers) ] @@ -507,6 +587,27 @@ def forward( if local_positions is not None: local_positions = _normalize_tensor(local_positions) + unflatten_output = False + if self.structured_shape is not None: + unflatten_output = any(le.ndim in (4, 5) for le in local_embedding) + local_embedding = tuple( + _flatten_for_structured( + le, self.structured_shape, f"local_embedding[{i}]" + ) + for i, le in enumerate(local_embedding) + ) + if geometry is not None: + geometry = _flatten_for_structured( + geometry, self.structured_shape, "geometry" + ) + n_tok = _structured_num_tokens(self.structured_shape) + for i, le in enumerate(local_embedding): + if le.shape[1] != n_tok: + raise ValueError( + f"structured GeoTransolver: all streams must have N={n_tok} tokens; " + f"local_embedding[{i}] has N={le.shape[1]}" + ) + ### Input validation if not torch.compiler.is_compiling(): if len(local_embedding) == 0: @@ -550,6 +651,16 @@ def forward( # Project to output dimensions: (B, N, n_hidden) -> (B, N, out_dim) x = [self.ln_mlp_out[i](x[i]) for i in range(len(x))] + if self.structured_shape is not None and unflatten_output: + B = x[0].shape[0] + for i in range(len(x)): + if len(self.structured_shape) == 2: + H, W = self.structured_shape + x[i] = x[i].reshape(B, H, W, -1) + else: + H, W, D_ = self.structured_shape + x[i] = x[i].reshape(B, H, W, D_, -1) + # Return same format as input (single tensor or tuple) if single_input: x = x[0] diff --git a/test/models/geotransolver/test_geotransolver.py b/test/models/geotransolver/test_geotransolver.py index c0a8d33968..82691fc527 100644 --- a/test/models/geotransolver/test_geotransolver.py +++ b/test/models/geotransolver/test_geotransolver.py @@ -556,6 +556,98 @@ def test_geotransolver_mismatched_functional_out_dims(): ) +def test_geotransolver_structured_rejects_local_features(): + """Ball-query local features are incompatible with structured_shape.""" + with pytest.raises(ValueError, match="include_local_features=True"): + GeoTransolver( + functional_dim=8, + out_dim=1, + structured_shape=(4, 4), + include_local_features=True, + geometry_dim=2, + use_te=False, + ) + + +def test_geotransolver_structured_2d_forward(device): + """Structured 2D: spatial input (B,H,W,C) and flattened (B,N,C); optional geometry.""" + torch.manual_seed(0) + H, W = 4, 4 + model = GeoTransolver( + functional_dim=3, + out_dim=2, + structured_shape=(H, W), + geometry_dim=2, + global_dim=None, + n_layers=2, + n_hidden=32, + n_head=4, + slice_num=8, + mlp_ratio=2, + use_te=False, + ).to(device) + B = 2 + x4 = torch.randn(B, H, W, 3, device=device) + g = torch.randn(B, H, W, 2, device=device) + y4 = model(x4, geometry=g) + assert y4.shape == (B, H, W, 2) + assert not torch.isnan(y4).any() + + x3 = x4.reshape(B, H * W, 3) + g3 = g.reshape(B, H * W, 2) + y3 = model(x3, geometry=g3) + assert y3.shape == (B, H * W, 2) + + y_none = model(x4) + assert y_none.shape == (B, H, W, 2) + + +def test_geotransolver_structured_3d_forward(device): + """Structured 3D voxel input (B,H,W,D,C).""" + torch.manual_seed(1) + H, W, Dg = 2, 2, 2 + model = GeoTransolver( + functional_dim=4, + out_dim=1, + structured_shape=(H, W, Dg), + n_layers=1, + n_hidden=32, + n_head=4, + slice_num=4, + mlp_ratio=2, + use_te=False, + ).to(device) + B = 1 + x = torch.randn(B, H, W, Dg, 4, device=device) + y = model(x) + assert y.shape == (B, H, W, Dg, 1) + + +def test_geotransolver_structured_global_context(device): + """Structured grid with global embedding context.""" + torch.manual_seed(2) + H, W = 4, 4 + model = GeoTransolver( + functional_dim=2, + out_dim=1, + structured_shape=(H, W), + geometry_dim=2, + global_dim=8, + n_layers=2, + n_hidden=32, + n_head=4, + slice_num=8, + mlp_ratio=2, + use_te=False, + ).to(device) + B = 2 + x = torch.randn(B, H, W, 2, device=device) + geo = torch.randn(B, H, W, 2, device=device) + glob = torch.randn(B, 3, 8, device=device) + y = model(x, geometry=geo, global_embedding=glob) + assert y.shape == (B, H, W, 1) + + # ============================================================================= # Activation Function Tests # ============================================================================= From f13862b5511ea03f18303cbe9b844314d2d9847b Mon Sep 17 00:00:00 2001 From: Corey Adams <6619961+coreyjadams@users.noreply.github.com> Date: Thu, 30 Apr 2026 14:19:12 -0500 Subject: [PATCH 2/8] Fix issues with the merge of context projector. --- .../models/geotransolver/context_projector.py | 25 +++++++++++-------- 1 file changed, 15 insertions(+), 10 deletions(-) diff --git a/physicsnemo/experimental/models/geotransolver/context_projector.py b/physicsnemo/experimental/models/geotransolver/context_projector.py index b43b1e9b27..aa4454e362 100644 --- a/physicsnemo/experimental/models/geotransolver/context_projector.py +++ b/physicsnemo/experimental/models/geotransolver/context_projector.py @@ -463,6 +463,7 @@ def __init__( kernel: int = 3, use_te: bool = True, plus: bool = False, + concrete_dropout: bool = False, ) -> None: super().__init__() if len(spatial_shape) not in (2, 3): @@ -494,6 +495,15 @@ def __init__( self.dropout = nn.Dropout(dropout) self._init_slice_components(dim_head, slice_num, heads, use_te, plus) + # Concrete dropout on the output slice tokens + if concrete_dropout: + self.output_dropout = ConcreteDropout( + in_features=dim_head, + init_p=max(dropout, 0.05), + ) + else: + self.output_dropout = None + def _grid_project( self, x: Float[torch.Tensor, "batch tokens channels"] ) -> tuple[ @@ -536,6 +546,11 @@ def forward( _, slice_tokens = self.compute_slices_from_projections( slice_projections, feature_projection ) + + # Apply concrete dropout to output slice tokens + if self.output_dropout is not None: + slice_tokens = self.output_dropout(slice_tokens) + return slice_tokens @@ -967,16 +982,6 @@ def __init__( geometry_dim, n_head, dim_head, dropout, slice_num, use_te, plus=plus, concrete_dropout=concrete_dropout, ) - self.geometry_tokenizer = ContextProjector( - geometry_dim, - n_head, - dim_head, - dropout, - slice_num, - use_te, - plus, - concrete_dropout=concrete_dropout, - ) context_dim += dim_head else: self.geometry_tokenizer = None From 57981f5d66a1aa7ecdd66d2de3226848126a5768 Mon Sep 17 00:00:00 2001 From: Corey Adams <6619961+coreyjadams@users.noreply.github.com> Date: Thu, 30 Apr 2026 16:48:10 -0500 Subject: [PATCH 3/8] Refactor physics attention and gale to consolidate implementations. Refactor flare to reduce from three duplicate implementations to one. --- .../models/geotransolver/context_projector.py | 152 +++------- .../experimental/models/geotransolver/gale.py | 268 ++++++++++++++-- .../models/geotransolver/gale_fa.py | 121 +++----- .../experimental/nn/flare_attention.py | 72 +++-- physicsnemo/nn/module/physics_attention.py | 287 +++++++++++------- 5 files changed, 561 insertions(+), 339 deletions(-) diff --git a/physicsnemo/experimental/models/geotransolver/context_projector.py b/physicsnemo/experimental/models/geotransolver/context_projector.py index aa4454e362..fb25b9a330 100644 --- a/physicsnemo/experimental/models/geotransolver/context_projector.py +++ b/physicsnemo/experimental/models/geotransolver/context_projector.py @@ -39,9 +39,12 @@ from jaxtyping import Float from physicsnemo.core.version_check import check_version_spec -from physicsnemo.nn import gumbel_softmax from physicsnemo.nn import BQWarp from physicsnemo.nn import Mlp +from physicsnemo.nn.module.physics_attention import ( + _compute_slices_from_projections, + _project_input, +) from physicsnemo.nn import ConcreteDropout @@ -51,60 +54,6 @@ import transformer_engine.pytorch as te -def _compute_slices_from_projections_impl( - slice_projections: Float[torch.Tensor, "batch heads tokens slices"], - fx: Float[torch.Tensor, "batch heads tokens dim"], - temperature: torch.Tensor, - plus: bool, - proj_temperature: nn.Module | None = None, -) -> tuple[ - Float[torch.Tensor, "batch heads tokens slices"], - Float[torch.Tensor, "batch heads slices dim"], -]: - r"""Shared slice aggregation: temperature-weighted softmax then weighted sum over tokens. - - Used by both :class:`ContextProjector` and :class:`StructuredContextProjector` - to avoid duplicating the slice-weight and slice-token computation. - - Parameters - ---------- - slice_projections : torch.Tensor - Projection of each token onto each slice, shape :math:`(B, H, N, S)`. - fx : torch.Tensor - Latent features to aggregate per slice, shape :math:`(B, H, N, D)`. - temperature : torch.Tensor - Scalar temperature for softmax/gumbel, shape broadcastable to projections. - plus : bool - If ``True``, use Gumbel softmax with optional adaptive temperature. - proj_temperature : nn.Module or None, optional - If ``plus`` is ``True``, module mapping :math:`(B, H, N, D)` to adaptive - temperature; ignored otherwise. Default is ``None``. - - Returns - ------- - slice_weights : torch.Tensor - Normalized weights per token and slice, shape :math:`(B, H, N, S)`. - slice_token : torch.Tensor - Aggregated features per slice, shape :math:`(B, H, S, D)`. - """ - if plus and proj_temperature is not None: - temp = temperature + proj_temperature(fx) - clamped_temp = torch.clamp(temp, min=0.01).to(slice_projections.dtype) - slice_weights = gumbel_softmax(slice_projections, clamped_temp) - else: - clamped_temp = torch.clamp(temperature, min=0.5, max=5).to( - slice_projections.dtype - ) - slice_weights = nn.functional.softmax( - slice_projections / clamped_temp, dim=-1 - ) - slice_weights = slice_weights.to(slice_projections.dtype) - slice_norm = slice_weights.sum(2) - normed_weights = slice_weights / (slice_norm[:, :, None, :] + 1e-2) - slice_token = torch.matmul(normed_weights.transpose(2, 3), fx) - return slice_weights, slice_token - - def _structured_grid_to_conv_input( x: Float[torch.Tensor, "batch tokens channels"], batch: int, @@ -166,8 +115,9 @@ class _SliceToContextMixin: r"""Internal mixin providing shared slice-to-context init and slice aggregation. Used by :class:`ContextProjector` and :class:`StructuredContextProjector` to - avoid duplicating in_project_slice, temperature, proj_temperature, and - compute_slices_from_projections. + avoid duplicating ``in_project_slice``, ``temperature``, ``proj_temperature``, + and the call to + :func:`~physicsnemo.nn.module.physics_attention._compute_slices_from_projections`. """ def _init_slice_components( @@ -199,7 +149,7 @@ def _init_slice_components( """ linear_layer = te.Linear if (use_te and TE_AVAILABLE) else nn.Linear self.in_project_slice = linear_layer(dim_head, slice_num) - self.temperature = nn.Parameter(torch.ones([1, heads, 1, 1]) * 0.5) + self.temperature = nn.Parameter(torch.ones([1, 1, heads, 1]) * 0.5) if plus: self.proj_temperature = nn.Sequential( linear_layer(dim_head, slice_num), @@ -208,34 +158,35 @@ def _init_slice_components( nn.GELU(), ) - def compute_slices_from_projections( + def _compute_slices( self, - slice_projections: Float[torch.Tensor, "batch heads tokens slices"], - fx: Float[torch.Tensor, "batch heads tokens dim"], + slice_projections: Float[torch.Tensor, "batch tokens heads slices"], + fx: Float[torch.Tensor, "batch tokens heads dim"], ) -> tuple[ - Float[torch.Tensor, "batch heads tokens slices"], + Float[torch.Tensor, "batch tokens heads slices"], Float[torch.Tensor, "batch heads slices dim"], ]: r"""Compute slice weights and slice tokens from projections and latent features. - Delegates to :func:`_compute_slices_from_projections_impl` using this - instance's ``temperature``, ``plus``, and (when plus) ``proj_temperature``. + Delegates to :func:`~physicsnemo.nn.module.physics_attention._compute_slices_from_projections`, + the shared free function that also backs + :meth:`~physicsnemo.nn.module.physics_attention.PhysicsAttentionBase._compute_slices_from_projections`. Parameters ---------- slice_projections : torch.Tensor - Shape :math:`(B, H, N, S)`. + Shape :math:`(B, N, H, S)`. fx : torch.Tensor - Shape :math:`(B, H, N, D)`. + Shape :math:`(B, N, H, D)`. Returns ------- tuple[torch.Tensor, torch.Tensor] - ``(slice_weights, slice_token)`` with shapes :math:`(B, H, N, S)` + ``(slice_weights, slice_token)`` with shapes :math:`(B, N, H, S)` and :math:`(B, H, S, D)`. """ proj_temp = getattr(self, "proj_temperature", None) if self.plus else None - return _compute_slices_from_projections_impl( + return _compute_slices_from_projections( slice_projections, fx, self.temperature, @@ -346,10 +297,10 @@ def __init__( def project_input_onto_slices( self, x: Float[torch.Tensor, "batch tokens channels"] ) -> ( - Float[torch.Tensor, "batch heads tokens dim"] + Float[torch.Tensor, "batch tokens heads dim"] | tuple[ - Float[torch.Tensor, "batch heads tokens dim"], - Float[torch.Tensor, "batch heads tokens dim"], + Float[torch.Tensor, "batch tokens heads dim"], + Float[torch.Tensor, "batch tokens heads dim"], ] ): r"""Project the input onto the slice space. @@ -363,29 +314,17 @@ def project_input_onto_slices( Returns ------- torch.Tensor or tuple[torch.Tensor, torch.Tensor] - If ``plus=True``, returns single tensor of shape :math:`(B, H, N, D)` where + If ``plus=True``, returns single tensor of shape :math:`(B, N, H, D)` where :math:`H` is number of heads and :math:`D` is head dimension. If ``plus=False``, - returns tuple of two tensors both of shape :math:`(B, H, N, D)`, representing + returns tuple of two tensors both of shape :math:`(B, N, H, D)`, representing the query and key projections respectively. """ - # Project input to multi-head representation: (B, N, C) -> (B, H, N, D) - projected_x = rearrange( - self.in_project_x(x), "B N (h d) -> B h N d", h=self.heads, d=self.dim_head + fx = None if self.plus else self.in_project_fx + return _project_input( + x, self.in_project_x, self.heads, self.dim_head, + "B N (H D) -> B N H D", project_fx=fx, ) - if self.plus: - # Transolver++ uses single projection for both paths - return projected_x - else: - # Standard Transolver uses separate query and key projections - feature_projection = rearrange( - self.in_project_fx(x), - "B N (h d) -> B h N d", - h=self.heads, - d=self.dim_head, - ) - return projected_x, feature_projection - def forward( self, x: Float[torch.Tensor, "batch tokens channels"] ) -> Float[torch.Tensor, "batch heads slices dim"]: @@ -429,11 +368,11 @@ def forward( else: projected_x, feature_projection = self.project_input_onto_slices(x) - # Project latent representations onto physical state slices: (B, H, N, D) -> (B, H, N, S) + # Project latent representations onto physical state slices: (B, N, H, D) -> (B, N, H, S) slice_projections = self.in_project_slice(projected_x) # Compute weighted aggregation of features into slice tokens - _, slice_tokens = self.compute_slices_from_projections( + _, slice_tokens = self._compute_slices( slice_projections, feature_projection ) @@ -506,28 +445,27 @@ def __init__( def _grid_project( self, x: Float[torch.Tensor, "batch tokens channels"] - ) -> tuple[ - Float[torch.Tensor, "batch heads tokens dim"], - Float[torch.Tensor, "batch heads tokens dim"], - ]: + ) -> ( + Float[torch.Tensor, "batch tokens heads dim"] + | tuple[ + Float[torch.Tensor, "batch tokens heads dim"], + Float[torch.Tensor, "batch tokens heads dim"], + ] + ): B, N, C = x.shape grid = _structured_grid_to_conv_input( x, B, N, C, self._nd, self.spatial_shape ) pattern = ( - "B (H D) h w -> B H (h w) D" + "B (H D) h w -> B (h w) H D" if self._nd == 2 - else "B (H D) h w d -> B H (h w d) D" - ) - px = rearrange( - self.in_project_x(grid), pattern, H=self.heads, D=self.dim_head + else "B (H D) h w d -> B (h w d) H D" ) - if self.plus: - return px, px - pfx = rearrange( - self.in_project_fx(grid), pattern, H=self.heads, D=self.dim_head + fx = None if self.plus else self.in_project_fx + return _project_input( + grid, self.in_project_x, self.heads, self.dim_head, + pattern, project_fx=fx, ) - return px, pfx def forward( self, x: Float[torch.Tensor, "batch tokens channels"] @@ -538,12 +476,12 @@ def forward( f"Expected 3D input (B, N, C), got {x.ndim}D shape {tuple(x.shape)}" ) if self.plus: - projected_x = self._grid_project(x)[0] + projected_x = self._grid_project(x) feature_projection = projected_x else: projected_x, feature_projection = self._grid_project(x) slice_projections = self.in_project_slice(projected_x) - _, slice_tokens = self.compute_slices_from_projections( + _, slice_tokens = self._compute_slices( slice_projections, feature_projection ) diff --git a/physicsnemo/experimental/models/geotransolver/gale.py b/physicsnemo/experimental/models/geotransolver/gale.py index e1d1473276..8074c958eb 100644 --- a/physicsnemo/experimental/models/geotransolver/gale.py +++ b/physicsnemo/experimental/models/geotransolver/gale.py @@ -25,6 +25,7 @@ import torch import torch.nn as nn +import torch.nn.functional as F from einops import rearrange from jaxtyping import Float @@ -35,9 +36,10 @@ PhysicsAttentionIrregularMesh, PhysicsAttentionStructuredMesh2D, PhysicsAttentionStructuredMesh3D, + _project_input, ) +from physicsnemo.experimental.nn.flare_attention import _flare_self_attention -from physicsnemo.experimental.models.geotransolver.gale_fa import GALE_FA from physicsnemo.nn import ConcreteDropout # Check optional dependency availability @@ -45,6 +47,44 @@ te = OptionalImport("transformer_engine.pytorch", "0.1.0") +def _mix_self_and_cross( + self_attn: torch.Tensor, + cross_attn: torch.Tensor, + mode: str, + state_mixing: nn.Parameter | None = None, + concat_project: nn.Module | None = None, +) -> torch.Tensor: + r"""Blend self-attention and cross-attention outputs. + + Parameters + ---------- + self_attn : torch.Tensor + Self-attention output. + cross_attn : torch.Tensor + Cross-attention output (same shape as ``self_attn``). + mode : str + ``"weighted"`` for sigmoid-gated sum, ``"concat_project"`` for + concatenation followed by a learned projection. + state_mixing : nn.Parameter or None + Learnable scalar for ``"weighted"`` mode. + concat_project : nn.Module or None + Projection module for ``"concat_project"`` mode. + + Returns + ------- + torch.Tensor + Blended output, same shape as inputs. + """ + match mode: + case "weighted": + w = torch.sigmoid(state_mixing) + return w * self_attn + (1 - w) * cross_attn + case "concat_project": + return concat_project(torch.cat([self_attn, cross_attn], dim=-1)) + case _: + raise ValueError(f"Invalid state_mixing_mode: {mode!r}") + + def _gale_compute_slice_attention_cross( module: nn.Module, slice_tokens: list[Float[torch.Tensor, "batch heads slices dim"]], @@ -174,23 +214,14 @@ def _gale_forward_impl( module.compute_slice_attention_cross([_slice_token], context)[0] for _slice_token in slice_tokens ] - # Blend self-attention and cross-attention - match module.state_mixing_mode: - case "weighted": - mixing_weight = torch.sigmoid(module.state_mixing) - out_slice_token = [ - mixing_weight * sst + (1 - mixing_weight) * cst - for sst, cst in zip(self_slice_token, cross_slice_token) - ] - case "concat_project": - out_slice_token = [ - module.concat_project(torch.cat([sst, cst], dim=-1)) - for sst, cst in zip(self_slice_token, cross_slice_token) - ] - case _: - raise ValueError( - f"Invalid state_mixing_mode: {module.state_mixing_mode!r}" - ) + out_slice_token = [ + _mix_self_and_cross( + sst, cst, module.state_mixing_mode, + state_mixing=getattr(module, "state_mixing", None), + concat_project=getattr(module, "concat_project", None), + ) + for sst, cst in zip(self_slice_token, cross_slice_token) + ] else: # Use only self-attention when no context is provided out_slice_token = self_slice_token @@ -472,6 +503,207 @@ def __init__( _gale_cross_init(self, dim_head, context_dim, use_te, state_mixing_mode) +class GALE_FA(nn.Module): + r"""GALE_FA: Geometry-Aware Latent Embeddings with FLARE self-Attention attention layer. + + Adopted: + + - FLARE attention: Fast Low-rank Attention Routing Engine + paper: https://arxiv.org/abs/2508.12594 + - GeoTransolver context: + paper: https://arxiv.org/abs/2512.20399 + + GALE_FA is an alternative to the GALE attention mechanism of the GeoTransolver. + It supports cross-attention with a context vector, built from geometry and global embeddings. + GALE_FA combines FLARE self-attention on learned physical state slices with cross-attention + to geometry-aware context, using a learnable mixing weight to blend the two. + + Parameters + ---------- + dim : int + Input dimension of the features. + heads : int, optional + Number of attention heads. Default is 8. + dim_head : int, optional + Dimension of each attention head. Default is 64. + dropout : float, optional + Dropout rate. Default is 0.0. + n_global_queries : int, optional + Number of learned global queries. Default is 64. + use_te : bool, optional + Whether to use Transformer Engine backend when available. Default is False. + context_dim : int, optional + Dimension of the context vector for cross-attention. Default is 0. + concrete_dropout : bool, optional + Whether to use learned concrete dropout instead of standard dropout. + Default is ``False``. + state_mixing_mode : str, optional + How to blend self-attention and cross-attention outputs. ``"weighted"`` uses + a learnable sigmoid-gated weighted sum. ``"concat_project"`` + concatenates the two along the head dimension and projects back with a + linear layer. Default is ``"weighted"``. + + Forward + ------- + x : tuple[torch.Tensor, ...] + Tuple of input tensors, each of shape :math:`(B, N, C)` where :math:`B` is + batch size, :math:`N` is number of tokens, and :math:`C` is number of channels. + context : tuple[torch.Tensor, ...] | None, optional + Context tensor for cross-attention of shape :math:`(B, H, S_c, D_c)` where + :math:`H` is number of heads, :math:`S_c` is number of context slices, and + :math:`D_c` is context dimension. If ``None``, only self-attention is applied. + Default is ``None``. + + Outputs + ------- + list[torch.Tensor] + List of output tensors, each of shape :math:`(B, N, C)`, same shape as inputs. + + Notes + ----- + The mixing between self-attention and cross-attention is controlled by a learnable + parameter ``state_mixing`` which is passed through a sigmoid function to ensure + the mixing weight stays in :math:`[0, 1]`. + + See Also + -------- + :class:`GALE` : Original GeoTransolver GALE attention class. + :class:`GALE_block` : Transformer block that calls GALE or GALE_FA attention. + + Examples + -------- + >>> import torch + >>> gale_fa = GALE_FA(dim=256, heads=8, dim_head=32, context_dim=32) + >>> x = (torch.randn(2, 100, 256),) # Single input tensor in tuple + >>> context = torch.randn(2, 8, 64, 32) # Context for cross-attention + >>> outputs = gale_fa(x, context) + >>> len(outputs) + 1 + >>> outputs[0].shape + torch.Size([2, 100, 256]) + """ + + def __init__( + self, + dim, + heads: int = 8, + dim_head: int = 64, + dropout: float = 0.0, + n_global_queries: int = 64, + use_te: bool = True, + context_dim: int = 0, + concrete_dropout: bool = False, + state_mixing_mode: str = "weighted", + ): + if use_te: + raise ValueError( + "GALE_FA does not support Transformer Engine backend. " + "Use use_te=False; TE disables FlashAttention for differing q/k sizes in FLARE attention." + ) + super().__init__() + self.use_te = use_te + self.heads = heads + self.dim_head = dim_head + self.scale = 1.0 + # It is recommended by the FLARE authors to use self.scale = 1 if self.dim_head <= 8 else (self.dim_head ** -0.5) + # but we use self.scale = 1.0 because the recommended scaling is not tested yet. + inner_dim = dim_head * heads + + linear_layer = te.Linear if self.use_te else nn.Linear + + # Global queries for FLARE self-attention + self.q_global = nn.Parameter(torch.randn(1, heads, n_global_queries, dim_head)) + + # Linear projections for self-attention + self.in_project_x = linear_layer(dim, inner_dim) + self.self_k = linear_layer(dim_head, dim_head) + self.self_v = linear_layer(dim_head, dim_head) + + if context_dim > 0: + _gale_cross_init(self, dim_head, context_dim, use_te, state_mixing_mode) + + # Linear projection for output + self.out_linear = linear_layer(inner_dim, dim) + if concrete_dropout: + self.out_dropout = ConcreteDropout( + in_features=dim, + init_p=max(dropout, 0.05), + ) + else: + self.out_dropout = nn.Dropout(dropout) + + def forward( + self, + x: tuple[Float[torch.Tensor, "batch tokens channels"], ...], + context: Float[torch.Tensor, "batch heads context_slices context_dim"] + | None = None, + ) -> list[Float[torch.Tensor, "batch tokens channels"]]: + r"""Forward pass of the GALE_FA module. + + Applies GALE_FA attention to the input features. + + Parameters + ---------- + x : tuple[torch.Tensor, ...] + Tuple of input tensors, each of shape :math:`(B, N, C)` where :math:`B` + is batch size, :math:`N` is number of tokens, and :math:`C` is number + of channels. + context : torch.Tensor | None, optional + Context tensor for cross-attention of shape :math:`(B, H, S_c, D_c)` + where :math:`H` is number of heads, :math:`S_c` is number of context + slices, and :math:`D_c` is context dimension. If ``None``, only + self-attention is applied. Default is ``None``. + + Returns + ------- + list[torch.Tensor] + List of output tensors, each of shape :math:`(B, N, C)``, same shape + as inputs. + """ + # Input projection: (B, N, C) -> (B, N, H, D) -> (B, H, N, D) + x_mid = [ + _project_input( + _x, self.in_project_x, self.heads, self.dim_head, + "B N (H D) -> B N H D", + ).permute(0, 2, 1, 3) + for _x in x + ] + + # FLARE self-attention per input + self_attention = [ + _flare_self_attention( + _x_mid, self.q_global, self.self_k, self.self_v, self.scale, + ) + for _x_mid in x_mid + ] + + # Cross-attention with context and state mixing + if context is not None: + q = [self.cross_q(_x_mid) for _x_mid in x_mid] + k = self.cross_k(context) + v = self.cross_v(context) + cross_attention = [ + F.scaled_dot_product_attention(_q, k, v, scale=self.scale) + for _q in q + ] + outputs = [ + _mix_self_and_cross( + sa, ca, self.state_mixing_mode, + state_mixing=getattr(self, "state_mixing", None), + concat_project=getattr(self, "concat_project", None), + ) + for sa, ca in zip(self_attention, cross_attention) + ] + else: + outputs = self_attention + + # Back to token layout: (B, H, N, D) -> (B, N, H, D) + outputs = [_y.permute(0, 2, 1, 3) for _y in outputs] + outputs = [rearrange(_out, "b n h d -> b n (h d)") for _out in outputs] + outputs = [self.out_linear(_out) for _out in outputs] + return [self.out_dropout(_out) for _out in outputs] + + class GALE_block(nn.Module): r"""Transformer encoder block using GALE attention. diff --git a/physicsnemo/experimental/models/geotransolver/gale_fa.py b/physicsnemo/experimental/models/geotransolver/gale_fa.py index eb9f8a3478..145f8c80e8 100644 --- a/physicsnemo/experimental/models/geotransolver/gale_fa.py +++ b/physicsnemo/experimental/models/geotransolver/gale_fa.py @@ -28,11 +28,15 @@ from einops import rearrange from jaxtyping import Float -from physicsnemo.core.version_check import check_version_spec, OptionalImport +from physicsnemo.core.version_check import OptionalImport from physicsnemo.nn import ConcreteDropout +from physicsnemo.nn.module.physics_attention import _project_input +from physicsnemo.experimental.nn.flare_attention import _flare_self_attention +from physicsnemo.experimental.models.geotransolver.gale import ( + _gale_cross_init, + _mix_self_and_cross, +) -# Check optional dependency availability -TE_AVAILABLE = check_version_spec("transformer_engine", "0.1.0", hard_fail=False) te = OptionalImport("transformer_engine.pytorch", "0.1.0") @@ -132,12 +136,6 @@ def __init__( "Use use_te=False; TE disables FlashAttention for differing q/k sizes in FLARE attention." ) super().__init__() - if state_mixing_mode not in ("weighted", "concat_project"): - raise ValueError( - f"Invalid state_mixing_mode: {state_mixing_mode!r}. " - f"Expected 'weighted' or 'concat_project'." - ) - self.state_mixing_mode = state_mixing_mode self.use_te = use_te self.heads = heads self.dim_head = dim_head @@ -157,31 +155,7 @@ def __init__( self.self_v = linear_layer(dim_head, dim_head) if context_dim > 0: - # Linear projections for cross-attention - self.cross_q = linear_layer(dim_head, dim_head) - self.cross_k = linear_layer(context_dim, dim_head) - self.cross_v = linear_layer(context_dim, dim_head) - - # Mixing layers for blending self-attention and cross-attention - if state_mixing_mode == "weighted": - # Learnable mixing weight between self and cross attention - self.state_mixing = nn.Parameter(torch.tensor(0.0)) - else: - # Concatenate self and cross attention and project back to dim_head - self.concat_project = nn.Sequential( - linear_layer(2 * dim_head, dim_head), - nn.GELU(), - ) - - # te attention - if self.use_te: - self.attn_fn = te.DotProductAttention( - num_attention_heads=self.heads, - kv_channels=self.dim_head, - attention_dropout=dropout, - qkv_format="bshd", - softmax_scale=self.scale - ) + _gale_cross_init(self, dim_head, context_dim, use_te, state_mixing_mode) # Linear projection for output self.out_linear = linear_layer(inner_dim, dim) @@ -223,62 +197,45 @@ def forward( as inputs. """ - # with record_function("forward"): - x_mid = [self.in_project_x(_x) for _x in x] - x_mid = [rearrange( - _x_mid, "B N (h d) -> B N h d", h=self.heads, d=self.dim_head - ) for _x_mid in x_mid] - x_mid = [_x_mid.permute(0, 2, 1, 3) for _x_mid in x_mid] # [B, H, N, D] - G = [self.q_global.to(dtype=x_mid[0].dtype).expand(x_mid[0].shape[0], -1, -1, -1)] * len(x) - k = [self.self_k(_x_mid) for _x_mid in x_mid] - v = [self.self_v(_x_mid) for _x_mid in x_mid] - - # FLARE: Self Attention - if self.use_te: - # Transformer Engine expects (B, S, H, D) format - G = [rearrange(_G, "b h s d -> b s h d") for _G in G] - k = [rearrange(_k, "b h s d -> b s h d") for _k in k] - v = [rearrange(_v, "b h s d -> b s h d") for _v in v] - z = [self.attn_fn(_G, _k, _v) for _G, _k, _v in zip(G, k, v)] - z = [rearrange( - _z, "b s (h d) -> b s h d", h=self.heads, d=self.dim_head - ) for _z in z] - self_attention = [self.attn_fn(_k, _G, _z) for _k, _G, _z in zip(k, G, z)] - self_attention = [rearrange( - _self_attention, "b s (h d) -> b h s d", h=self.heads, d=self.dim_head - ) for _self_attention in self_attention] - else: - # Use PyTorch's scaled dot-product attention - z = [F.scaled_dot_product_attention(_G, _k, _v, scale=self.scale) for _G, _k, _v in zip(G, k, v)] - self_attention = [F.scaled_dot_product_attention(_k, _G, _z, scale=self.scale) for _k, _G, _z in zip(k, G, z)] + # Input projection: (B, N, C) -> (B, N, H, D) -> (B, H, N, D) + x_mid = [ + _project_input( + _x, self.in_project_x, self.heads, self.dim_head, + "B N (H D) -> B N H D", + ).permute(0, 2, 1, 3) + for _x in x + ] + + # FLARE self-attention per input + self_attention = [ + _flare_self_attention( + _x_mid, self.q_global, self.self_k, self.self_v, self.scale, + ) + for _x_mid in x_mid + ] - # apply cross-attention with physical states: + # Cross-attention with context and state mixing if context is not None: q = [self.cross_q(_x_mid) for _x_mid in x_mid] k = self.cross_k(context) v = self.cross_v(context) - - if self.use_te: - q = [rearrange(_q, "b h s d -> b s h d") for _q in q] - k = rearrange(k, "b h s d -> b s h d") - v = rearrange(v, "b h s d -> b s h d") - cross_attention = [self.attn_fn(_q, k, v) for _q in q] - cross_attention = [rearrange( - _cross_attention, "b s (h d) -> b h s d", h=self.heads, d=self.dim_head - ) for _cross_attention in cross_attention] - else: - cross_attention = [F.scaled_dot_product_attention(_q, k, v, scale=self.scale) for _q in q] - - # Blend self-attention and cross-attention - if self.state_mixing_mode == "weighted": - mixing_weight = torch.sigmoid(self.state_mixing) - outputs = [mixing_weight * _ys + (1 - mixing_weight) * _yc for _ys, _yc in zip(self_attention, cross_attention)] - else: - outputs = [self.concat_project(torch.cat([_ys, _yc], dim=-1)) for _ys, _yc in zip(self_attention, cross_attention)] + cross_attention = [ + F.scaled_dot_product_attention(_q, k, v, scale=self.scale) + for _q in q + ] + outputs = [ + _mix_self_and_cross( + sa, ca, self.state_mixing_mode, + state_mixing=getattr(self, "state_mixing", None), + concat_project=getattr(self, "concat_project", None), + ) + for sa, ca in zip(self_attention, cross_attention) + ] else: outputs = self_attention - outputs = [_y.permute(0, 2, 1, 3) for _y in outputs] # [B, N, H, D] + # Back to token layout: (B, H, N, D) -> (B, N, H, D) + outputs = [_y.permute(0, 2, 1, 3) for _y in outputs] outputs = [rearrange(_out, "b n h d -> b n (h d)") for _out in outputs] outputs = [self.out_linear(_out) for _out in outputs] return [self.out_dropout(_out) for _out in outputs] diff --git a/physicsnemo/experimental/nn/flare_attention.py b/physicsnemo/experimental/nn/flare_attention.py index 70a6c14f57..4ecc68a927 100644 --- a/physicsnemo/experimental/nn/flare_attention.py +++ b/physicsnemo/experimental/nn/flare_attention.py @@ -29,12 +29,50 @@ from jaxtyping import Float from physicsnemo.core.version_check import check_version_spec, OptionalImport +from physicsnemo.nn.module.physics_attention import _project_input # Check optional dependency availability TE_AVAILABLE = check_version_spec("transformer_engine", "0.1.0", hard_fail=False) te = OptionalImport("transformer_engine.pytorch", "0.1.0") +def _flare_self_attention( + x_mid: Float[torch.Tensor, "B H N D"], + q_global: nn.Parameter, + self_k: nn.Module, + self_v: nn.Module, + scale: float, +) -> Float[torch.Tensor, "B H N D"]: + r"""FLARE two-pass self-attention kernel. + + Computes low-rank attention via learned global queries: first aggregate + token values into global slots, then distribute back to tokens. + + Parameters + ---------- + x_mid : torch.Tensor + Projected input of shape :math:`(B, H, N, D)`. + q_global : nn.Parameter + Learned global queries of shape :math:`(1, H, S, D)`. + self_k : nn.Module + Key projection applied to ``x_mid``. + self_v : nn.Module + Value projection applied to ``x_mid``. + scale : float + Attention scale factor. + + Returns + ------- + torch.Tensor + Self-attended output of shape :math:`(B, H, N, D)`. + """ + G = q_global.to(dtype=x_mid.dtype).expand(x_mid.shape[0], -1, -1, -1) + k = self_k(x_mid) + v = self_v(x_mid) + z = F.scaled_dot_product_attention(G, k, v, scale=scale) + return F.scaled_dot_product_attention(k, G, z, scale=scale) + + class FLARE(nn.Module): r"""FLARE: Fast Low-rank Attention Routing Engine attention layer. Adopted: @@ -137,35 +175,17 @@ def forward(self, x: Float[torch.Tensor, "B N C"]) -> Float[torch.Tensor, "B N C Output tensor of shape :math:`(B, N, C)`, same shape as inputs. """ - x_mid = self.in_project_x(x) - x_mid = rearrange( - x_mid, "B N (h d) -> B N h d", h=self.heads, d=self.dim_head + x_mid = _project_input( + x, self.in_project_x, self.heads, self.dim_head, + "B N (H D) -> B N H D", ) - x_mid = x_mid.permute(0, 2, 1, 3) # [B, H, N, D] - G = self.q_global.to(dtype=x_mid.dtype).expand(x_mid.shape[0], -1, -1, -1) - k = self.self_k(x_mid) - v = self.self_v(x_mid) + x_mid = x_mid.permute(0, 2, 1, 3) # (B, N, H, D) -> (B, H, N, D) - # FLARE: Fast Low-rank Attention Routing Engine - if self.use_te: - # Transformer Engine expects (B, S, H, D) format - G = rearrange(G, "b h s d -> b s h d") - k = rearrange(k, "b h s d -> b s h d") - v = rearrange(v, "b h s d -> b s h d") - z = self.attn_fn(G, k, v) - z = rearrange( - z, "b s (h d) -> b s h d", h=self.heads, d=self.dim_head - ) - self_attention = self.attn_fn(k, G, z) - y = rearrange( - self_attention, "b s (h d) -> b h s d", h=self.heads, d=self.dim_head - ) - else: - # Use PyTorch's scaled dot-product attention - z = F.scaled_dot_product_attention(G, k, v, scale=self.scale) - y = F.scaled_dot_product_attention(k, G, z, scale=self.scale) + y = _flare_self_attention( + x_mid, self.q_global, self.self_k, self.self_v, self.scale, + ) - out_x = y.permute(0, 2, 1, 3) # [B, N, H, D] + out_x = y.permute(0, 2, 1, 3) # (B, H, N, D) -> (B, N, H, D) out_x = rearrange(out_x, "b n h d -> b n (h d)") out_x = self.out_linear(out_x) return self.out_dropout(out_x) diff --git a/physicsnemo/nn/module/physics_attention.py b/physicsnemo/nn/module/physics_attention.py index 2110a91331..22e0ff3ba8 100644 --- a/physicsnemo/nn/module/physics_attention.py +++ b/physicsnemo/nn/module/physics_attention.py @@ -35,7 +35,6 @@ SOFTWARE. """ -import importlib from abc import ABC, abstractmethod import torch @@ -45,19 +44,165 @@ from torch.autograd.profiler import record_function from torch.distributed.tensor.placement_types import Replicate -from physicsnemo.core.version_check import check_version_spec +from physicsnemo.core.version_check import OptionalImport from physicsnemo.nn import gumbel_softmax # Note: We use duck typing to check for ShardTensor instead of importing it # directly to avoid circular imports (domain_parallel imports from nn). # ShardTensor has a `redistribute` method that we check for. -TE_AVAILABLE = check_version_spec("transformer_engine", hard_fail=False) +te = OptionalImport("transformer_engine.pytorch") + + +def _project_input( + x: torch.Tensor, + project_x: nn.Module, + heads: int, + dim_head: int, + pattern: str, + project_fx: nn.Module | None = None, +) -> ( + Float[torch.Tensor, "B N H D"] + | tuple[ + Float[torch.Tensor, "B N H D"], + Float[torch.Tensor, "B N H D"], + ] +): + r"""Project input through one or two learned layers and rearrange to multi-head format. + + Universal building block for the ``(B, N, C) -> (B, N, H, D)`` projection + used by low attention mechanism in the repo: slice-based (Transolver / + GALE), FLARE, and future linear-attention variants. + + For structured grids the caller reshapes to spatial layout *before* calling + this function and passes the appropriate einops ``pattern``; the output is + always ``(B, N, H, D)`` with tokens in position 1, preserving + domain-parallel sharding over N. -if TE_AVAILABLE: - te = importlib.import_module("transformer_engine.pytorch") -else: - te = None + Parameters + ---------- + x : torch.Tensor + Input tensor. For irregular meshes this is ``(B, N, C)``; for + structured grids it has already been reshaped to conv-friendly layout + (e.g. ``(B, C, H_s, W_s)``). + project_x : nn.Module + Primary projection layer (``nn.Linear``, ``nn.Conv2d``, etc.). + heads : int + Number of attention heads ``H``. + dim_head : int + Dimension per head ``D``. + pattern : str + Einops rearrange pattern that maps the projection output to + ``(B, N, H, D)``. Examples: + + - ``"B N (H D) -> B N H D"`` for linear projection + - ``"B (H D) h w -> B (h w) H D"`` for 2-D convolution + - ``"B (H D) h w d -> B (h w d) H D"`` for 3-D convolution + project_fx : nn.Module or None, optional + If provided, a second projection is applied to ``x`` and the function + returns a ``(px, pfx)`` tuple. Omit for single-projection callers + (Transolver++ ``plus=True``, FLARE, etc.). Default is ``None``. + + Returns + ------- + torch.Tensor or tuple[torch.Tensor, torch.Tensor] + Single tensor of shape :math:`(B, N, H, D)` when ``project_fx`` is + ``None``; otherwise a tuple ``(px, pfx)`` both of shape + :math:`(B, N, H, D)`. + """ + px = rearrange(project_x(x), pattern, H=heads, D=dim_head) + if project_fx is None: + return px + return px, rearrange(project_fx(x), pattern, H=heads, D=dim_head) + + +def _compute_slices_from_projections( + slice_projections: Float[torch.Tensor, "B N H S"], + fx: Float[torch.Tensor, "B N H D"], + temperature: torch.Tensor, + plus: bool, + proj_temperature: nn.Module | None = None, +) -> tuple[ + Float[torch.Tensor, "B N H S"], + Float[torch.Tensor, "B H S D"], +]: + r"""Compute slice weights and slice tokens from input projections. + + Standalone implementation of the temperature-scaled softmax slice + aggregation used by :class:`PhysicsAttentionBase` and reusable by any + module that needs the same project-to-slices-then-aggregate pattern + (e.g. :class:`~physicsnemo.experimental.models.geotransolver.context_projector.ContextProjector`). + + In domain-parallel settings, this performs an implicit allreduce when + summing over the sharded token dimension. + + Parameters + ---------- + slice_projections : torch.Tensor + Projected input of shape :math:`(B, N, H, S)` where :math:`H` is + number of attention heads and :math:`S` is number of physics slices. + fx : torch.Tensor + Latent features of shape :math:`(B, N, H, D)` where :math:`D` is + dimension per head. + temperature : torch.Tensor + Scalar temperature for softmax/gumbel, shape broadcastable to + ``slice_projections`` (typically :math:`(1, 1, H, 1)`). + plus : bool + If ``True``, use Gumbel softmax with optional adaptive temperature. + proj_temperature : nn.Module or None, optional + If ``plus`` is ``True``, module mapping :math:`(B, N, H, D)` to + adaptive temperature; ignored otherwise. Default is ``None``. + + Returns + ------- + tuple[torch.Tensor, torch.Tensor] + - ``slice_weights``: Shape :math:`(B, N, H, S)`, normalized weights + for each slice per token. + - ``slice_token``: Shape :math:`(B, H, S, D)`, aggregated features + per slice. + """ + # Compute temperature-scaled softmax over slices + if plus and proj_temperature is not None: + # Transolver++ uses learned per-token temperature + temp = temperature + proj_temperature(fx) + clamped_temp = torch.clamp(temp, min=0.01).to(slice_projections.dtype) + slice_weights = gumbel_softmax(slice_projections, clamped_temp) # (B, N, H, S) + else: + # Standard Transolver uses global temperature + clamped_temp = torch.clamp(temperature, min=0.5, max=5).to( + slice_projections.dtype + ) + slice_weights = nn.functional.softmax( + slice_projections / clamped_temp, dim=-1 + ) # (B, N, H, S) + + # Cast to the computation type (since the parameter is probably fp32) + slice_weights = slice_weights.to(slice_projections.dtype) + + # Computing the slice tokens is a matmul followed by a normalization. + # It can, unfortunately, overflow in reduced precision, so normalize first: + slice_norm = slice_weights.sum(1) + 1e-2 # (B, H, S) + # Sharded note: slice_norm will be a partial sum at this point. + # That's because the we're summing over the tokens, which are distributed + normed_weights = slice_weights / (slice_norm[:, None, :, :]) + # Normed weights has shape (B, N, H, S) + + # Sharded note: normed_weights will resolve the partial slice_norm + # and the output normed_weights will be sharded. + # fx has shape (B, N, H, D) + # This matmul needs to contract over the tokens + # This should produce an output with shape (B, H, S, D) + + # Like the weight norm, this sum is a **partial** sum since we are summing + # over the tokens + + # Aggregate features: (B, N, H, S)^T @ (B, N, H, D) -> (B, H, S, D) + slice_token = torch.matmul( + normed_weights.permute(0, 2, 3, 1), fx.permute(0, 2, 1, 3) + ) + + # Return the original weights, not the normed weights: + return slice_weights, slice_token class PhysicsAttentionBase(nn.Module, ABC): @@ -225,6 +370,9 @@ def _compute_slices_from_projections( In domain-parallel settings, this performs an implicit allreduce when summing over the sharded token dimension. + Delegates to the module-level + :func:`_compute_slices_from_projections` free function. + Parameters ---------- slice_projections : torch.Tensor @@ -242,56 +390,11 @@ def _compute_slices_from_projections( - ``slice_token``: Shape :math:`(B, H, S, D)`, aggregated features per slice. """ - # Compute temperature-scaled softmax over slices - if self.plus: - # Transolver++ uses learned per-token temperature - temperature = self.temperature + self.proj_temperature(fx) - clamped_temp = torch.clamp(temperature, min=0.01).to( - slice_projections.dtype - ) - slice_weights = gumbel_softmax( - slice_projections, clamped_temp - ) # (B, N, H, S) - else: - # Standard Transolver uses global temperature - clamped_temp = torch.clamp(self.temperature, min=0.5, max=5).to( - slice_projections.dtype - ) - slice_weights = nn.functional.softmax( - slice_projections / clamped_temp, dim=-1 - ) # (B, N, H, S) - - # Cast to the computation type (since the parameter is probably fp32) - slice_weights = slice_weights.to(slice_projections.dtype) - - # This does the projection of the latent space fx by the weights: - - # Computing the slice tokens is a matmul followed by a normalization. - # It can, unfortunately, overflow in reduced precision, so normalize first: - slice_norm = slice_weights.sum(1) + 1e-2 # (B, H, S) - # Sharded note: slice_norm will be a partial sum at this point. - # That's because the we're summing over the tokens, which are distributed - normed_weights = slice_weights / (slice_norm[:, None, :, :]) - # Normed weights has shape (B, N, H, S) - - # Sharded note: normed_weights will resolve the partial slice_norm - # and the output normed_weights will be sharded. - # fx has shape (B, N, H, D) - # This matmul needs to contract over the tokens - # This should produce an output with shape (B, H, S, D) - - # Like the weight norm, this sum is a **partial** sum since we are summing - # over the tokens - - # Aggregate features: (B, N, H, S)^T @ (B, N, H, D) -> (B, H, S, D) - slice_token = torch.matmul( - normed_weights.permute(0, 2, 3, 1), fx.permute(0, 2, 1, 3) + proj_temp = getattr(self, "proj_temperature", None) if self.plus else None + return _compute_slices_from_projections( + slice_projections, fx, self.temperature, self.plus, proj_temp ) - # Return the original weights, not the normed weights: - - return slice_weights, slice_token - def _compute_slice_attention_te( self, slice_tokens: Float[torch.Tensor, "B H S D"] ) -> Float[torch.Tensor, "B H S D"]: @@ -544,22 +647,16 @@ def project_input_onto_slices( Projected tensors of shape :math:`(B, N, H, D)` where :math:`H` is number of attention heads and :math:`D` is dimension per head. """ - # Project and reshape to multi-head format - x_mid = rearrange( - self.in_project_x(x), "B N (H D) -> B N H D", H=self.heads, D=self.dim_head + fx = None if self.plus else self.in_project_fx + return _project_input( + x, + self.in_project_x, + self.heads, + self.dim_head, + "B N (H D) -> B N H D", + project_fx=fx, ) - if self.plus: - return x_mid - else: - fx_mid = rearrange( - self.in_project_fx(x), - "B N (H D) -> B N H D", - H=self.heads, - D=self.dim_head, - ) - return x_mid, fx_mid - class PhysicsAttentionStructuredMesh2D(PhysicsAttentionBase): r""" @@ -673,27 +770,16 @@ def project_input_onto_slices( x = x.view(B, self.H, self.W, C) x = x.permute(0, 3, 1, 2) - # Apply 2D convolution and reshape to multi-head format - input_projected_x = self.in_project_x(x) - input_projected_x = rearrange( - input_projected_x, + fx = None if self.plus else self.in_project_fx + return _project_input( + x, + self.in_project_x, + self.heads, + self.dim_head, "B (H D) h w -> B (h w) H D", - D=self.dim_head, - H=self.heads, + project_fx=fx, ) - if self.plus: - return input_projected_x - else: - input_projected_fx = self.in_project_fx(x) - input_projected_fx = rearrange( - input_projected_fx, - "B (H D) h w -> B (h w) H D", - D=self.dim_head, - H=self.heads, - ) - return input_projected_x, input_projected_fx - class PhysicsAttentionStructuredMesh3D(PhysicsAttentionBase): r""" @@ -808,23 +894,12 @@ def project_input_onto_slices( x = x.view(B, self.H, self.W, self.D, C) x = x.permute(0, 4, 1, 2, 3) - # Apply 3D convolution and reshape to multi-head format - input_projected_x = self.in_project_x(x) - input_projected_x = rearrange( - input_projected_x, - "B (H D) height width depth -> B (height width depth) H D", - D=self.dim_head, - H=self.heads, + fx = None if self.plus else self.in_project_fx + return _project_input( + x, + self.in_project_x, + self.heads, + self.dim_head, + "B (H D) h w d -> B (h w d) H D", + project_fx=fx, ) - - if self.plus: - return input_projected_x - else: - input_projected_fx = self.in_project_fx(x) - input_projected_fx = rearrange( - input_projected_fx, - "B (H D) height width depth -> B (height width depth) H D", - D=self.dim_head, - H=self.heads, - ) - return input_projected_x, input_projected_fx From 93084ebca4b9e9426ef71606958d51dfbcfabd48 Mon Sep 17 00:00:00 2001 From: Corey Adams <6619961+coreyjadams@users.noreply.github.com> Date: Thu, 30 Apr 2026 17:05:49 -0500 Subject: [PATCH 4/8] Finish wrapping up some code movements. --- .../models/geotransolver/__init__.py | 19 +- .../models/geotransolver/gale_fa.py | 242 ------------------ test/models/geotransolver/test_gale.py | 2 +- 3 files changed, 17 insertions(+), 246 deletions(-) delete mode 100644 physicsnemo/experimental/models/geotransolver/gale_fa.py diff --git a/physicsnemo/experimental/models/geotransolver/__init__.py b/physicsnemo/experimental/models/geotransolver/__init__.py index 80cbf6a5ae..dacc89c628 100644 --- a/physicsnemo/experimental/models/geotransolver/__init__.py +++ b/physicsnemo/experimental/models/geotransolver/__init__.py @@ -25,8 +25,10 @@ Main model class combining GALE attention with geometry and global context. GALE Geometry-Aware Latent Embeddings attention layer. +GALE_FA + GALE with FLARE self-attention backend. GALE_block - Transformer block using GALE attention. + Transformer block using GALE or GALE_FA attention. ContextProjector Projects context features onto physical state slices. GlobalContextBuilder @@ -56,14 +58,25 @@ collect_concrete_dropout_losses, get_concrete_dropout_rates, ) -from .context_projector import ContextProjector, GlobalContextBuilder -from .gale import GALE, GALE_block +from .context_projector import ( + ContextProjector, + GlobalContextBuilder, + StructuredContextProjector, +) +from .gale import ( + GALE, + GALE_FA, + GALE_block, + GALEStructuredMesh2D, + GALEStructuredMesh3D, +) from .geotransolver import GeoTransolver, GeoTransolverMetaData __all__ = [ "GeoTransolver", "GeoTransolverMetaData", "GALE", + "GALE_FA", "GALE_block", "GALEStructuredMesh2D", "GALEStructuredMesh3D", diff --git a/physicsnemo/experimental/models/geotransolver/gale_fa.py b/physicsnemo/experimental/models/geotransolver/gale_fa.py deleted file mode 100644 index 145f8c80e8..0000000000 --- a/physicsnemo/experimental/models/geotransolver/gale_fa.py +++ /dev/null @@ -1,242 +0,0 @@ -# SPDX-FileCopyrightText: Copyright (c) 2023 - 2026 NVIDIA CORPORATION & AFFILIATES. -# SPDX-FileCopyrightText: All rights reserved. -# SPDX-License-Identifier: Apache-2.0 -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -"""GALE_FA (Geometry-Aware Latent Embeddings with FLARE self-Attention) attention layer. - -This module provides the GALE_FA attention mechanism, -an alternative to the GALE attention mechanism of the GeoTransolver. -""" - -from __future__ import annotations - -import torch -import torch.nn as nn -import torch.nn.functional as F -from einops import rearrange -from jaxtyping import Float - -from physicsnemo.core.version_check import OptionalImport -from physicsnemo.nn import ConcreteDropout -from physicsnemo.nn.module.physics_attention import _project_input -from physicsnemo.experimental.nn.flare_attention import _flare_self_attention -from physicsnemo.experimental.models.geotransolver.gale import ( - _gale_cross_init, - _mix_self_and_cross, -) - -te = OptionalImport("transformer_engine.pytorch", "0.1.0") - - -class GALE_FA(nn.Module): - r"""GALE_FA: Geometry-Aware Latent Embeddings with FLARE self-Attention attention layer. - Adopted: - - FLARE attention: Fast Low-rank Attention Routing Engine - paper: https://arxiv.org/abs/2508.12594 - - GeoTransolver context: - paper: https://arxiv.org/abs/2512.20399 - - GALE_FA is an alternative to the GALE attention mechanism of the GeoTransolver - It supports cross-attention with a context vector, built from geometry and global embeddings. - GALE_FA combines FLARE self-attention on learned physical state slices with cross-attention - to geometry-aware context, using a learnable mixing weight to blend the two. - - Parameters - ---------- - dim : int - Input dimension of the features. - heads : int, optional - Number of attention heads. Default is 8. - dim_head : int, optional - Dimension of each attention head. Default is 64. - dropout : float, optional - Dropout rate. Default is 0.0. - n_global_queries : int, optional - Number of learned global queries. Default is 64. - use_te : bool, optional - Whether to use Transformer Engine backend when available. Default is False. - context_dim : int, optional - Dimension of the context vector for cross-attention. Default is 0. - concrete_dropout : bool, optional - Whether to use learned concrete dropout instead of standard dropout. - Default is ``False``. - state_mixing_mode : str, optional - How to blend self-attention and cross-attention outputs. ``"weighted"`` uses - a learnable sigmoid-gated weighted sum. ``"concat_project"`` - concatenates the two along the head dimension and projects back with a - linear layer. Default is ``"weighted"``. - - Forward - ------- - x : tuple[torch.Tensor, ...] - Tuple of input tensors, each of shape :math:`(B, N, C)` where :math:`B` is - batch size, :math:`N` is number of tokens, and :math:`C` is number of channels. - context : tuple[torch.Tensor, ...] | None, optional - Context tensor for cross-attention of shape :math:`(B, H, S_c, D_c)` where - :math:`H` is number of heads, :math:`S_c` is number of context slices, and - :math:`D_c` is context dimension. If ``None``, only self-attention is applied. - Default is ``None``. - - Outputs - ------- - list[torch.Tensor] - List of output tensors, each of shape :math:`(B, N, C)`, same shape as inputs. - - Notes - ----- - The mixing between self-attention and cross-attention is controlled by a learnable - parameter ``state_mixing`` which is passed through a sigmoid function to ensure - the mixing weight stays in :math:`[0, 1]`. - - See Also - -------- - :class:`GALE` : Original GeoTransolver GALE attention class. - :class:`GALE_block` : Transformer block that calls GALE or GALE_FA attention. - - Examples - -------- - >>> import torch - >>> gale_fa = GALE_FA(dim=256, heads=8, dim_head=32, context_dim=32) - >>> x = (torch.randn(2, 100, 256),) # Single input tensor in tuple - >>> context = torch.randn(2, 8, 64, 32) # Context for cross-attention - >>> outputs = gale_fa(x, context) - >>> len(outputs) - 1 - >>> outputs[0].shape - torch.Size([2, 100, 256]) - """ - - def __init__( - self, - dim, - heads: int = 8, - dim_head: int = 64, - dropout: float = 0.0, - n_global_queries: int = 64, - use_te: bool = True, - context_dim: int = 0, - concrete_dropout: bool = False, - state_mixing_mode: str = "weighted", - ): - if use_te: - raise ValueError( - "GALE_FA does not support Transformer Engine backend. " - "Use use_te=False; TE disables FlashAttention for differing q/k sizes in FLARE attention." - ) - super().__init__() - self.use_te = use_te - self.heads = heads - self.dim_head = dim_head - self.scale = 1.0 - # It is recommended by the FLARE authors to use self.scale = 1 if self.dim_head <= 8 else (self.dim_head ** -0.5) - # but we use self.scale = 1.0 because the recommended scaling is not tested yet. - inner_dim = dim_head * heads - - linear_layer = te.Linear if self.use_te else nn.Linear - - # Global queries for FLARE self-attention - self.q_global = nn.Parameter(torch.randn(1, heads, n_global_queries, dim_head)) - - # Linear projections for self-attention - self.in_project_x = linear_layer(dim, inner_dim) - self.self_k = linear_layer(dim_head, dim_head) - self.self_v = linear_layer(dim_head, dim_head) - - if context_dim > 0: - _gale_cross_init(self, dim_head, context_dim, use_te, state_mixing_mode) - - # Linear projection for output - self.out_linear = linear_layer(inner_dim, dim) - if concrete_dropout: - self.out_dropout = ConcreteDropout( - in_features=dim, - init_p=max(dropout, 0.05), - ) - else: - self.out_dropout = nn.Dropout(dropout) - - - def forward( - self, - x: tuple[Float[torch.Tensor, "batch tokens channels"], ...], - context: Float[torch.Tensor, "batch heads context_slices context_dim"] - | None = None, - ) -> list[Float[torch.Tensor, "batch tokens channels"]]: - r"""Forward pass of the GALE_FA module. - - Applies GALE_FA attention to the input features. - - Parameters - ---------- - x : tuple[torch.Tensor, ...] - Tuple of input tensors, each of shape :math:`(B, N, C)` where :math:`B` - is batch size, :math:`N` is number of tokens, and :math:`C` is number - of channels. - context : torch.Tensor | None, optional - Context tensor for cross-attention of shape :math:`(B, H, S_c, D_c)` - where :math:`H` is number of heads, :math:`S_c` is number of context - slices, and :math:`D_c` is context dimension. If ``None``, only - self-attention is applied. Default is ``None``. - - Returns - ------- - list[torch.Tensor] - List of output tensors, each of shape :math:`(B, N, C)``, same shape - as inputs. - """ - - # Input projection: (B, N, C) -> (B, N, H, D) -> (B, H, N, D) - x_mid = [ - _project_input( - _x, self.in_project_x, self.heads, self.dim_head, - "B N (H D) -> B N H D", - ).permute(0, 2, 1, 3) - for _x in x - ] - - # FLARE self-attention per input - self_attention = [ - _flare_self_attention( - _x_mid, self.q_global, self.self_k, self.self_v, self.scale, - ) - for _x_mid in x_mid - ] - - # Cross-attention with context and state mixing - if context is not None: - q = [self.cross_q(_x_mid) for _x_mid in x_mid] - k = self.cross_k(context) - v = self.cross_v(context) - cross_attention = [ - F.scaled_dot_product_attention(_q, k, v, scale=self.scale) - for _q in q - ] - outputs = [ - _mix_self_and_cross( - sa, ca, self.state_mixing_mode, - state_mixing=getattr(self, "state_mixing", None), - concat_project=getattr(self, "concat_project", None), - ) - for sa, ca in zip(self_attention, cross_attention) - ] - else: - outputs = self_attention - - # Back to token layout: (B, H, N, D) -> (B, N, H, D) - outputs = [_y.permute(0, 2, 1, 3) for _y in outputs] - outputs = [rearrange(_out, "b n h d -> b n (h d)") for _out in outputs] - outputs = [self.out_linear(_out) for _out in outputs] - return [self.out_dropout(_out) for _out in outputs] - diff --git a/test/models/geotransolver/test_gale.py b/test/models/geotransolver/test_gale.py index 79ef54d8c5..63db8c4a93 100644 --- a/test/models/geotransolver/test_gale.py +++ b/test/models/geotransolver/test_gale.py @@ -19,9 +19,9 @@ from physicsnemo.experimental.models.geotransolver.gale import ( GALE, + GALE_FA, GALE_block, ) -from physicsnemo.experimental.models.geotransolver.gale_fa import GALE_FA # ============================================================================= # GALE (Geometry-Aware Latent Embeddings) Attention Tests From ea2aa99bbe671ef57140ae3e0dbe8d9be5990d02 Mon Sep 17 00:00:00 2001 From: Corey Adams <6619961+coreyjadams@users.noreply.github.com> Date: Fri, 1 May 2026 19:38:08 +0000 Subject: [PATCH 5/8] Update transolver darcy example to use the 2Dgeotransolver or flare modesl, optionally --- examples/cfd/darcy_transolver/config_fix.yaml | 43 ++- .../cfd/darcy_transolver/model/flare.yaml | 39 +++ .../darcy_transolver/model/geotransolver.yaml | 40 +++ .../model/geotransolver_fa.yaml | 41 +++ .../darcy_transolver/model/transolver.yaml | 36 +++ .../train_transolver_darcy_fix.py | 262 ++++++++++++++---- 6 files changed, 388 insertions(+), 73 deletions(-) create mode 100644 examples/cfd/darcy_transolver/model/flare.yaml create mode 100644 examples/cfd/darcy_transolver/model/geotransolver.yaml create mode 100644 examples/cfd/darcy_transolver/model/geotransolver_fa.yaml create mode 100644 examples/cfd/darcy_transolver/model/transolver.yaml diff --git a/examples/cfd/darcy_transolver/config_fix.yaml b/examples/cfd/darcy_transolver/config_fix.yaml index ee90d84855..a9c4f64f93 100644 --- a/examples/cfd/darcy_transolver/config_fix.yaml +++ b/examples/cfd/darcy_transolver/config_fix.yaml @@ -20,33 +20,31 @@ # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE # SOFTWARE. -output_dir: ./output/darcy_transolver_fix -run_id: bf16_dev_r85_b8_s64 +defaults: + - model: transolver # transolver | flare | geotransolver | geotransolver_fa + - _self_ + +hydra: + output_subdir: null + run: + dir: . + +output_dir: ./output/ +run_id: ${hydra:runtime.choices.model}-muon_${precision}_r${resolution}_b${data.batch_size}_s${model.slice_num} + +resolution: 85 #421, 211, 141, 106, 85 all viable data: - train_path: /user_data/datasets/darcy_fix/example_data/piececonst_r421_N1024_smooth1.npz - test_path: /user_data/datasets/darcy_fix/example_data/piececonst_r421_N1024_smooth2.npz - resolution: 85 #421, 211, 141, 106, 85 all viable - batch_size: 8 # This is the GLOBAL batch size + train_path: //lustre/fsw/portfolios/coreai/users/coreya/datasets/darcy_fix/example_data/piececonst_r421_N1024_smooth1.npz + test_path: //lustre/fsw/portfolios/coreai/users/coreya/datasets/darcy_fix/example_data/piececonst_r421_N1024_smooth2.npz + resolution: ${resolution} + batch_size: 64 # This is the GLOBAL batch size -model: - functional_dim: 1 - out_dim: 1 - embedding_dim: 2 - n_layers: 4 - n_hidden: 128 - dropout: 0.0 - n_head: 4 - act: gelu - mlp_ratio: 4 - unified_pos: False - ref: 8 - slice_num: 64 - use_te: False - time_input: False - precision: bf16 +optimizer: + type: muon # adamw | muon (muon requires PyTorch >= 2.9) + normaliser: permeability: mean: 1.25 @@ -69,3 +67,4 @@ training: validation: sample_size: 200 validation_pseudo_epochs: 1 + save_plots: false diff --git a/examples/cfd/darcy_transolver/model/flare.yaml b/examples/cfd/darcy_transolver/model/flare.yaml new file mode 100644 index 0000000000..31e9c78dfd --- /dev/null +++ b/examples/cfd/darcy_transolver/model/flare.yaml @@ -0,0 +1,39 @@ +# SPDX-FileCopyrightText: Copyright (c) 2023 - 2026 NVIDIA CORPORATION & AFFILIATES. +# SPDX-FileCopyrightText: All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# FLARE inherits from Transolver and replaces physics attention blocks with +# FLARE (Fast Low-rank Attention Routing Engine) blocks. +# use_te and plus are forced off internally by FLARE. + +_target_: physicsnemo.experimental.models.flare.FLARE +_convert_: all + +functional_dim: 1 +out_dim: 1 +embedding_dim: 2 +n_layers: 4 +n_hidden: 128 +dropout: 0.0 +n_head: 4 +act: gelu +mlp_ratio: 4 +slice_num: 64 +unified_pos: False +ref: 8 +structured_shape: + - ${resolution} + - ${resolution} +time_input: False diff --git a/examples/cfd/darcy_transolver/model/geotransolver.yaml b/examples/cfd/darcy_transolver/model/geotransolver.yaml new file mode 100644 index 0000000000..db585b8829 --- /dev/null +++ b/examples/cfd/darcy_transolver/model/geotransolver.yaml @@ -0,0 +1,40 @@ +# SPDX-FileCopyrightText: Copyright (c) 2023 - 2026 NVIDIA CORPORATION & AFFILIATES. +# SPDX-FileCopyrightText: All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# GeoTransolver with GALE attention (default). +# Uses geometry-aware cross-attention instead of Transolver's structured grid. + +_target_: physicsnemo.experimental.models.geotransolver.GeoTransolver +_convert_: all + +functional_dim: 3 +out_dim: 1 +geometry_dim: 3 +n_layers: 4 +n_hidden: 128 +dropout: 0.0 +n_head: 4 +act: gelu +mlp_ratio: 4 +slice_num: 64 +use_te: False +time_input: False +attention_type: GALE +include_local_features: False +state_mixing_mode: weighted +structured_shape: + - ${resolution} + - ${resolution} diff --git a/examples/cfd/darcy_transolver/model/geotransolver_fa.yaml b/examples/cfd/darcy_transolver/model/geotransolver_fa.yaml new file mode 100644 index 0000000000..305a411489 --- /dev/null +++ b/examples/cfd/darcy_transolver/model/geotransolver_fa.yaml @@ -0,0 +1,41 @@ +# SPDX-FileCopyrightText: Copyright (c) 2023 - 2026 NVIDIA CORPORATION & AFFILIATES. +# SPDX-FileCopyrightText: All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# GeoTransolver with GALE_FA (FLARE attention backend). +# Combines FLARE-style global-query self-attention with GALE geometry/context +# cross-attention. use_te must be False for GALE_FA. + +_target_: physicsnemo.experimental.models.geotransolver.GeoTransolver +_convert_: all + +functional_dim: 3 +out_dim: 1 +geometry_dim: 3 +n_layers: 4 +n_hidden: 128 +dropout: 0.0 +n_head: 4 +act: gelu +mlp_ratio: 4 +slice_num: 64 +use_te: False +time_input: False +attention_type: GALE_FA +include_local_features: False +state_mixing_mode: concat_project +structured_shape: + - ${resolution} + - ${resolution} diff --git a/examples/cfd/darcy_transolver/model/transolver.yaml b/examples/cfd/darcy_transolver/model/transolver.yaml new file mode 100644 index 0000000000..6e01da1b3f --- /dev/null +++ b/examples/cfd/darcy_transolver/model/transolver.yaml @@ -0,0 +1,36 @@ +# SPDX-FileCopyrightText: Copyright (c) 2023 - 2026 NVIDIA CORPORATION & AFFILIATES. +# SPDX-FileCopyrightText: All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +_target_: physicsnemo.models.transolver.Transolver +_convert_: all + +functional_dim: 1 +out_dim: 1 +embedding_dim: 2 +n_layers: 4 +n_hidden: 128 +dropout: 0.0 +n_head: 4 +act: gelu +mlp_ratio: 4 +slice_num: 64 +unified_pos: False +ref: 8 +structured_shape: + - ${resolution} + - ${resolution} +use_te: False +time_input: False diff --git a/examples/cfd/darcy_transolver/train_transolver_darcy_fix.py b/examples/cfd/darcy_transolver/train_transolver_darcy_fix.py index 0f528f6205..888a99360c 100644 --- a/examples/cfd/darcy_transolver/train_transolver_darcy_fix.py +++ b/examples/cfd/darcy_transolver/train_transolver_darcy_fix.py @@ -16,9 +16,12 @@ # Configuration imports: import hydra +from hydra.utils import instantiate from omegaconf import DictConfig, OmegaConf import json +import os import time +from datetime import datetime, timezone from math import ceil # Base PyTorch imports: @@ -38,8 +41,8 @@ from utils.testloss import TestLoss # Model imports from PhysicsNeMo -from physicsnemo.models.transolver import Transolver from physicsnemo.distributed import DistributedManager +from physicsnemo.optim import CombinedOptimizer from physicsnemo.utils import load_checkpoint, save_checkpoint from physicsnemo.utils.logging import PythonLogger, RankZeroLoggingWrapper @@ -54,8 +57,110 @@ prof = Profiler() +_GEOTRANSOLVER_TARGETS = { + "physicsnemo.experimental.models.geotransolver.GeoTransolver", +} + + +def make_model_forward(cfg: DictConfig) -> callable: + """ + Return a forward callable that uses the right keyword arguments for the + configured model. + + GeoTransolver uses (local_embedding, geometry) while Transolver/FLARE + use (fx, embedding). The decision is made once at startup from the Hydra + config, avoiding fragile isinstance checks through DDP/compile wrappers. + + Args: + cfg (DictConfig): Full Hydra config (reads model._target_). + + Returns: + callable: ``fn(model, pos, x) -> Tensor`` + """ + if cfg.model._target_ in _GEOTRANSOLVER_TARGETS: + + def _forward(model, pos, x): + combined_inputs = torch.cat([pos, x.unsqueeze(-1)], dim=-1) + return model(local_embedding=combined_inputs, geometry=combined_inputs).squeeze(-1) + + else: + + def _forward(model, pos, x): + return model(embedding=pos, fx=x.unsqueeze(-1)).squeeze(-1) + + return _forward + + +def build_optimizer( + model: torch.nn.Module, + cfg: DictConfig, +) -> torch.optim.Optimizer: + """ + Build optimizer based on config. Supports AdamW and Muon. + + Muon is applied to 2D weight matrices; remaining parameters (biases, norms, + embeddings) are handled by AdamW. When both groups exist they are wrapped in + ``CombinedOptimizer``. + + Args: + model (torch.nn.Module): The model (possibly DDP-wrapped). + cfg (DictConfig): Full Hydra config (reads optimizer.type, scheduler.initial_lr, + scheduler.weight_decay). + + Returns: + torch.optim.Optimizer: The configured optimizer. + """ + opt_type = cfg.optimizer.type + lr = cfg.scheduler.initial_lr + weight_decay = cfg.scheduler.weight_decay + + if opt_type == "adamw": + return AdamW(model.parameters(), lr=lr, weight_decay=weight_decay) + + if opt_type == "muon": + if not hasattr(torch.optim, "Muon"): + raise ImportError( + "Muon optimizer requires PyTorch >= 2.9. " + "Install a newer PyTorch or use optimizer.type=adamw." + ) + base_model = model.module if hasattr(model, "module") else model + muon_params = [p for p in base_model.parameters() if p.ndim == 2] + other_params = [p for p in base_model.parameters() if p.ndim != 2] + + if muon_params and other_params: + return CombinedOptimizer( + [ + torch.optim.Muon( + muon_params, + lr=lr, + weight_decay=weight_decay, + adjust_lr_fn="match_rms_adamw", + ), + AdamW( + other_params, + lr=lr, + weight_decay=weight_decay, + betas=(0.9, 0.999), + eps=1.0e-8, + ), + ] + ) + elif muon_params: + return torch.optim.Muon( + muon_params, + lr=lr, + weight_decay=weight_decay, + adjust_lr_fn="match_rms_adamw", + ) + else: + return AdamW(other_params, lr=lr, weight_decay=weight_decay) + + raise ValueError(f"Unsupported optimizer type: {opt_type!r}. Use 'adamw' or 'muon'.") + + def forward_train_full_loop( model: torch.nn.Module, + model_forward: callable, loss_fun: callable, optimizer: torch.optim.Optimizer, pos: torch.Tensor, @@ -70,6 +175,7 @@ def forward_train_full_loop( Args: model (torch.nn.Module): The model to train. + model_forward (callable): Forward callable from ``make_model_forward``. loss_fun (callable): Loss function. optimizer (torch.optim.Optimizer): Optimizer. pos (torch.Tensor): Position tensor (embedding). @@ -84,7 +190,7 @@ def forward_train_full_loop( """ dm = DistributedManager() with precision_context: - pred = model(embedding=pos, fx=x.unsqueeze(-1)).squeeze(-1) + pred = model_forward(model, pos, x) pred = y_normalizer.decode(pred) loss = loss_fun(pred, y) if scaler is not None: @@ -100,6 +206,7 @@ def forward_train_full_loop( def train_epoch( model: torch.nn.Module, + model_forward: callable, optimizer: torch.optim.Optimizer, scheduler: torch.optim.lr_scheduler._LRScheduler, train_dataloader: DataLoader, @@ -113,6 +220,7 @@ def train_epoch( Args: model (torch.nn.Module): The model to train. + model_forward (callable): Forward callable from ``make_model_forward``. optimizer (torch.optim.Optimizer): Optimizer. scheduler (torch.optim.lr_scheduler._LRScheduler): Learning rate scheduler. train_dataloader (DataLoader): Training data loader. @@ -128,6 +236,7 @@ def train_epoch( pos, x, y = batch loss = forward_train_full_loop( model, + model_forward, loss_fun, optimizer, pos, @@ -150,6 +259,7 @@ def train_epoch( def val_epoch( model: torch.nn.Module, + model_forward: callable, test_dataloader: DataLoader, loss_fun: callable, y_normalizer, @@ -159,6 +269,7 @@ def val_epoch( Args: model (torch.nn.Module): The model to validate. + model_forward (callable): Forward callable from ``make_model_forward``. test_dataloader (DataLoader): Validation data loader. loss_fun (callable): Loss function. y_normalizer: Normalizer for the target tensor. @@ -175,7 +286,7 @@ def val_epoch( for i, batch in enumerate(test_dataloader): pos, x, y = batch with torch.no_grad(): - pred = model(embedding=pos, fx=x.unsqueeze(-1)).squeeze(-1) + pred = model_forward(model, pos, x) pred = y_normalizer.decode(pred) loss = loss_fun(pred, y) @@ -227,12 +338,20 @@ def darcy_trainer(cfg: DictConfig) -> None: logger = RankZeroLoggingWrapper(PythonLogger(name="darcy_transolver"), dm) logger.file_logging() - # === TensorBoard SummaryWriter === - # Only rank 0 writes logs to avoid duplication in DDP - writer = None + # === TensorBoard SummaryWriters === + # Separate train/val writers so TensorBoard can overlay matching scalars + train_writer = None + val_writer = None + metrics_file = None if dm.rank == 0: log_dir = f"{cfg.output_dir}/runs/{cfg.run_id}" - writer = SummaryWriter(log_dir=log_dir) + train_writer = SummaryWriter(log_dir=f"{log_dir}/train") + val_writer = SummaryWriter(log_dir=f"{log_dir}/val") + + # === JSONL metrics log (append-safe for resumed runs) === + metrics_path = os.path.join(log_dir, "metrics.jsonl") + os.makedirs(log_dir, exist_ok=True) + metrics_file = open(metrics_path, "a") ######################################################################## # Print the configuration to log @@ -242,23 +361,8 @@ def darcy_trainer(cfg: DictConfig) -> None: ######################################################################## # define model ######################################################################## - model = Transolver( - functional_dim=cfg.model.functional_dim, - out_dim=cfg.model.out_dim, - embedding_dim=cfg.model.embedding_dim, - n_layers=cfg.model.n_layers, - n_hidden=cfg.model.n_hidden, - dropout=cfg.model.dropout, - n_head=cfg.model.n_head, - act=cfg.model.act, - mlp_ratio=cfg.model.mlp_ratio, - slice_num=cfg.model.slice_num, - unified_pos=cfg.model.unified_pos, - ref=cfg.model.ref, - structured_shape=[cfg.data.resolution, cfg.data.resolution], - use_te=cfg.model.use_te, - time_input=cfg.model.time_input, - ).to(dm.device) + model = instantiate(cfg.model).to(dm.device) + model_forward = make_model_forward(cfg) logger.info(f"\n{torchinfo.summary(model, verbose=0)}") @@ -269,11 +373,7 @@ def darcy_trainer(cfg: DictConfig) -> None: # define loss and optimizer ######################################################################## loss_fun = TestLoss(size_average=True) - optimizer = AdamW( - model.parameters(), - lr=cfg.scheduler.initial_lr, - weight_decay=cfg.scheduler.weight_decay, - ) + optimizer = build_optimizer(model, cfg) ######################################################################## # Create the data pipes and samplers @@ -323,12 +423,30 @@ def darcy_trainer(cfg: DictConfig) -> None: cfg.training.pseudo_epoch_sample_size / cfg.data.batch_size ) - scheduler = lr_scheduler.OneCycleLR( - optimizer, - max_lr=cfg.scheduler.initial_lr, - steps_per_epoch=steps_per_pseudo_epoch, - epochs=cfg.training.max_pseudo_epochs, - ) + total_steps = steps_per_pseudo_epoch * cfg.training.max_pseudo_epochs + if cfg.optimizer.type == "muon": + warmup_steps = steps_per_pseudo_epoch * 2 + scheduler = lr_scheduler.SequentialLR( + optimizer, + schedulers=[ + lr_scheduler.LinearLR( + optimizer, start_factor=1e-2, total_iters=warmup_steps + ), + lr_scheduler.CosineAnnealingLR( + optimizer, + T_max=total_steps - warmup_steps, + eta_min=cfg.scheduler.initial_lr * 0.1, + ), + ], + milestones=[warmup_steps], + ) + else: + scheduler = lr_scheduler.OneCycleLR( + optimizer, + max_lr=cfg.scheduler.initial_lr, + steps_per_epoch=steps_per_pseudo_epoch, + epochs=cfg.training.max_pseudo_epochs, + ) validator = GridValidator(output_dir=f"{cfg.output_dir}/runs/{cfg.run_id}/plots") @@ -340,6 +458,9 @@ def darcy_trainer(cfg: DictConfig) -> None: } loaded_pseudo_epoch = load_checkpoint(device=dm.device, **ckpt_args) + # Compile after checkpoint loading to avoid triggering recompilation + model = torch.compile(model) + validation_iters = ceil(cfg.validation.sample_size / cfg.data.batch_size) if cfg.training.pseudo_epoch_sample_size % cfg.data.batch_size != 0: @@ -381,6 +502,7 @@ def darcy_trainer(cfg: DictConfig) -> None: train_start = time.time() loss = train_epoch( model, + model_forward, optimizer, scheduler, train_dataloader, @@ -406,7 +528,7 @@ def darcy_trainer(cfg: DictConfig) -> None: logger.info(log_string) # --- TensorBoard logging (only on rank 0) --- - if dm.rank == 0 and writer is not None: + if dm.rank == 0 and train_writer is not None: # Images/sec/GPU: (num images processed in train_epoch) / train_time / num_gpus # Each batch processes batch_size // world_size images, for steps_per_pseudo_epoch steps images_per_epoch = len(train_dataloader) * ( @@ -414,12 +536,28 @@ def darcy_trainer(cfg: DictConfig) -> None: ) images_per_sec_per_gpu = images_per_epoch / train_time - writer.add_scalar("loss/train", loss.item(), pseudo_epoch) - writer.add_scalar("time_per_epoch/train", train_time, pseudo_epoch) - writer.add_scalar( - "images_per_sec_per_gpu/train", images_per_sec_per_gpu, pseudo_epoch + train_writer.add_scalar("loss", loss.item(), pseudo_epoch) + train_writer.add_scalar("time_per_epoch", train_time, pseudo_epoch) + train_writer.add_scalar( + "images_per_sec_per_gpu", images_per_sec_per_gpu, pseudo_epoch + ) + train_writer.add_scalar("learning_rate", lr, pseudo_epoch) + + # --- JSONL metrics record (training fields) --- + metrics_record = None + if dm.rank == 0 and metrics_file is not None: + images_per_epoch = len(train_dataloader) * ( + cfg.data.batch_size // dm.world_size ) - writer.add_scalar("learning_rate/train", lr, pseudo_epoch) + metrics_record = { + "pseudo_epoch": pseudo_epoch, + "timestamp": datetime.now(timezone.utc).isoformat(), + "train_loss": loss.item(), + "train_time_s": train_time, + "learning_rate": lr, + "images_per_sec_per_gpu": images_per_epoch / train_time, + "gpu_mem_reserved_gb": gpu_mem_reserved, + } # save checkpoint if pseudo_epoch % cfg.training.rec_results_freq == 0 and dm.rank == 0: @@ -429,7 +567,7 @@ def darcy_trainer(cfg: DictConfig) -> None: if pseudo_epoch % cfg.validation.validation_pseudo_epochs == 0: val_start = time.time() val_loss, pred, y, RL2 = val_epoch( - model, test_dataloader, loss_fun, y_normalizer + model, model_forward, test_dataloader, loss_fun, y_normalizer ) val_time = time.time() - val_start @@ -440,29 +578,51 @@ def darcy_trainer(cfg: DictConfig) -> None: logger.info(log_string) # --- TensorBoard logging (only on rank 0) --- - if dm.rank == 0 and writer is not None: + if dm.rank == 0 and val_writer is not None: # Validation images/sec/GPU val_images = validation_iters * ( cfg.data.batch_size // dm.world_size ) val_images_per_sec_per_gpu = val_images / val_time - writer.add_scalar("loss/val", val_loss.item(), pseudo_epoch) - writer.add_scalar("RL2/val", RL2.item(), pseudo_epoch) - writer.add_scalar("time_per_epoch/val", val_time, pseudo_epoch) - writer.add_scalar( - "images_per_sec_per_gpu/val", + val_writer.add_scalar("loss", val_loss.item(), pseudo_epoch) + val_writer.add_scalar("RL2", RL2.item(), pseudo_epoch) + val_writer.add_scalar("time_per_epoch", val_time, pseudo_epoch) + val_writer.add_scalar( + "images_per_sec_per_gpu", val_images_per_sec_per_gpu, pseudo_epoch, ) - if dm.rank == 0: + # --- JSONL metrics record (validation fields) --- + if metrics_record is not None: + val_images = validation_iters * ( + cfg.data.batch_size // dm.world_size + ) + metrics_record["val_loss"] = val_loss.item() + metrics_record["val_rl2"] = RL2.item() + metrics_record["val_time_s"] = val_time + metrics_record["val_images_per_sec_per_gpu"] = ( + val_images / val_time + ) + + if dm.rank == 0 and cfg.validation.save_plots: validator.make_plot(pred, y, pseudo_epoch, test_datapipe.s) + # --- Flush JSONL record for this pseudo-epoch --- + if metrics_record is not None: + metrics_file.write(json.dumps(metrics_record) + "\n") + metrics_file.flush() + # update learning rate # if pseudo_epoch % cfg.scheduler.decay_pseudo_epochs == 0: - if dm.rank == 0 and writer is not None: - writer.close() + if dm.rank == 0: + if train_writer is not None: + train_writer.close() + if val_writer is not None: + val_writer.close() + if metrics_file is not None: + metrics_file.close() logger.success("Training completed *yay*") From 81ddd6acdb0a1d845273b8c6d6b6d66534053894 Mon Sep 17 00:00:00 2001 From: Corey Adams <6619961+coreyjadams@users.noreply.github.com> Date: Fri, 1 May 2026 19:40:27 +0000 Subject: [PATCH 6/8] Update README --- examples/cfd/darcy_transolver/README.md | 9 +++++++++ 1 file changed, 9 insertions(+) diff --git a/examples/cfd/darcy_transolver/README.md b/examples/cfd/darcy_transolver/README.md index fcfa96da19..f2b38f72db 100644 --- a/examples/cfd/darcy_transolver/README.md +++ b/examples/cfd/darcy_transolver/README.md @@ -89,6 +89,15 @@ To train Transolver on your own physics problem, modify the `dataloader` in `train_transolver_darcy.py` to use your own pre-computed data or on-the-fly solver. +#### Other models + +You can also use GeoTransolver, Flare, or GeoTransolver with flare attention +for the darcy example with the "fixed" usage. Change the `model` parameter +to update it! Note that the inputs to geotransolver here uses the positions +and the `x` diffusivity to both the model trunk and encoder paths. For these +problems, GeoTransolver is not yet extensively tested, and your results +may vary. + ## Additional Information More components are added for convenience. `Validators` calculate the loss From b2579d8e96361150994f8f43880240bdccb18fd3 Mon Sep 17 00:00:00 2001 From: Corey Adams <6619961+coreyjadams@users.noreply.github.com> Date: Fri, 1 May 2026 19:57:51 +0000 Subject: [PATCH 7/8] Add missing classes in comments --- .../experimental/models/geotransolver/__init__.py | 15 +++++++++++++++ .../models/geotransolver/context_projector.py | 2 ++ 2 files changed, 17 insertions(+) diff --git a/physicsnemo/experimental/models/geotransolver/__init__.py b/physicsnemo/experimental/models/geotransolver/__init__.py index dacc89c628..c27d9484bc 100644 --- a/physicsnemo/experimental/models/geotransolver/__init__.py +++ b/physicsnemo/experimental/models/geotransolver/__init__.py @@ -23,17 +23,32 @@ ------- GeoTransolver Main model class combining GALE attention with geometry and global context. +GeoTransolverMetaData + Data class for storing essential meta data needed for the GeoTransolver model. GALE Geometry-Aware Latent Embeddings attention layer. GALE_FA GALE with FLARE self-attention backend. GALE_block Transformer block using GALE or GALE_FA attention. +GALEStructuredMesh2D + GALE with Conv2d slice projection for 2D structured grids. +GALEStructuredMesh3D + GALE with Conv3d slice projection for 3D structured grids. ContextProjector Projects context features onto physical state slices. +StructuredContextProjector + Context projector with Conv2d/Conv3d geometry encoding on structured grids. GlobalContextBuilder Orchestrates context construction for the model. +Functions +--------- +collect_concrete_dropout_losses + Collect concrete dropout regularization losses from a model. +get_concrete_dropout_rates + Get concrete dropout rates from a model. + Examples -------- Basic usage: diff --git a/physicsnemo/experimental/models/geotransolver/context_projector.py b/physicsnemo/experimental/models/geotransolver/context_projector.py index fb25b9a330..15f7d67ffc 100644 --- a/physicsnemo/experimental/models/geotransolver/context_projector.py +++ b/physicsnemo/experimental/models/geotransolver/context_projector.py @@ -23,6 +23,8 @@ ------- ContextProjector Projects context features onto physical state slices. +StructuredContextProjector + Context projector with Conv2d/Conv3d geometry encoding on structured grids. GeometricFeatureProcessor Processes geometric features at a single spatial scale using BQWarp. MultiScaleFeatureExtractor From 1350ff89204142a0ecaf0ae0ef8cf5c098ca6052 Mon Sep 17 00:00:00 2001 From: Corey Adams <6619961+coreyjadams@users.noreply.github.com> Date: Mon, 4 May 2026 13:55:02 -0500 Subject: [PATCH 8/8] Address review comments for geotransolver 2d/3d unification with FLARE as well --- examples/cfd/darcy_transolver/README.md | 2 +- .../darcy_transolver/train_transolver_darcy_fix.py | 12 +++++++----- .../experimental/models/geotransolver/__init__.py | 8 ++++++++ 3 files changed, 16 insertions(+), 6 deletions(-) diff --git a/examples/cfd/darcy_transolver/README.md b/examples/cfd/darcy_transolver/README.md index f2b38f72db..941eac3c96 100644 --- a/examples/cfd/darcy_transolver/README.md +++ b/examples/cfd/darcy_transolver/README.md @@ -95,7 +95,7 @@ You can also use GeoTransolver, Flare, or GeoTransolver with flare attention for the darcy example with the "fixed" usage. Change the `model` parameter to update it! Note that the inputs to geotransolver here uses the positions and the `x` diffusivity to both the model trunk and encoder paths. For these -problems, GeoTransolver is not yet extensively tested, and your results +problems, GeoTransolver is not yet extensively tested, and your results may vary. ## Additional Information diff --git a/examples/cfd/darcy_transolver/train_transolver_darcy_fix.py b/examples/cfd/darcy_transolver/train_transolver_darcy_fix.py index 888a99360c..8c8b38a4e7 100644 --- a/examples/cfd/darcy_transolver/train_transolver_darcy_fix.py +++ b/examples/cfd/darcy_transolver/train_transolver_darcy_fix.py @@ -81,7 +81,9 @@ def make_model_forward(cfg: DictConfig) -> callable: def _forward(model, pos, x): combined_inputs = torch.cat([pos, x.unsqueeze(-1)], dim=-1) - return model(local_embedding=combined_inputs, geometry=combined_inputs).squeeze(-1) + return model( + local_embedding=combined_inputs, geometry=combined_inputs + ).squeeze(-1) else: @@ -155,7 +157,9 @@ def build_optimizer( else: return AdamW(other_params, lr=lr, weight_decay=weight_decay) - raise ValueError(f"Unsupported optimizer type: {opt_type!r}. Use 'adamw' or 'muon'.") + raise ValueError( + f"Unsupported optimizer type: {opt_type!r}. Use 'adamw' or 'muon'." + ) def forward_train_full_loop( @@ -601,9 +605,7 @@ def darcy_trainer(cfg: DictConfig) -> None: metrics_record["val_loss"] = val_loss.item() metrics_record["val_rl2"] = RL2.item() metrics_record["val_time_s"] = val_time - metrics_record["val_images_per_sec_per_gpu"] = ( - val_images / val_time - ) + metrics_record["val_images_per_sec_per_gpu"] = val_images / val_time if dm.rank == 0 and cfg.validation.save_plots: validator.make_plot(pred, y, pseudo_epoch, test_datapipe.s) diff --git a/physicsnemo/experimental/models/geotransolver/__init__.py b/physicsnemo/experimental/models/geotransolver/__init__.py index c27d9484bc..3fed9686b8 100644 --- a/physicsnemo/experimental/models/geotransolver/__init__.py +++ b/physicsnemo/experimental/models/geotransolver/__init__.py @@ -39,6 +39,10 @@ Projects context features onto physical state slices. StructuredContextProjector Context projector with Conv2d/Conv3d geometry encoding on structured grids. +GeometricFeatureProcessor + Processes geometric features at a single spatial scale using BQWarp. +MultiScaleFeatureExtractor + Multi-scale geometric feature extraction over multiple radii. GlobalContextBuilder Orchestrates context construction for the model. @@ -75,7 +79,9 @@ ) from .context_projector import ( ContextProjector, + GeometricFeatureProcessor, GlobalContextBuilder, + MultiScaleFeatureExtractor, StructuredContextProjector, ) from .gale import ( @@ -96,7 +102,9 @@ "GALEStructuredMesh2D", "GALEStructuredMesh3D", "ContextProjector", + "GeometricFeatureProcessor", "GlobalContextBuilder", + "MultiScaleFeatureExtractor", "StructuredContextProjector", "ConcreteDropout", "collect_concrete_dropout_losses",