diff --git a/examples/cfd/darcy_transolver/README.md b/examples/cfd/darcy_transolver/README.md index fcfa96da19..941eac3c96 100644 --- a/examples/cfd/darcy_transolver/README.md +++ b/examples/cfd/darcy_transolver/README.md @@ -89,6 +89,15 @@ To train Transolver on your own physics problem, modify the `dataloader` in `train_transolver_darcy.py` to use your own pre-computed data or on-the-fly solver. +#### Other models + +You can also use GeoTransolver, Flare, or GeoTransolver with flare attention +for the darcy example with the "fixed" usage. Change the `model` parameter +to update it! Note that the inputs to geotransolver here uses the positions +and the `x` diffusivity to both the model trunk and encoder paths. For these +problems, GeoTransolver is not yet extensively tested, and your results +may vary. + ## Additional Information More components are added for convenience. `Validators` calculate the loss diff --git a/examples/cfd/darcy_transolver/config_fix.yaml b/examples/cfd/darcy_transolver/config_fix.yaml index ee90d84855..a9c4f64f93 100644 --- a/examples/cfd/darcy_transolver/config_fix.yaml +++ b/examples/cfd/darcy_transolver/config_fix.yaml @@ -20,33 +20,31 @@ # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE # SOFTWARE. -output_dir: ./output/darcy_transolver_fix -run_id: bf16_dev_r85_b8_s64 +defaults: + - model: transolver # transolver | flare | geotransolver | geotransolver_fa + - _self_ + +hydra: + output_subdir: null + run: + dir: . + +output_dir: ./output/ +run_id: ${hydra:runtime.choices.model}-muon_${precision}_r${resolution}_b${data.batch_size}_s${model.slice_num} + +resolution: 85 #421, 211, 141, 106, 85 all viable data: - train_path: /user_data/datasets/darcy_fix/example_data/piececonst_r421_N1024_smooth1.npz - test_path: /user_data/datasets/darcy_fix/example_data/piececonst_r421_N1024_smooth2.npz - resolution: 85 #421, 211, 141, 106, 85 all viable - batch_size: 8 # This is the GLOBAL batch size + train_path: //lustre/fsw/portfolios/coreai/users/coreya/datasets/darcy_fix/example_data/piececonst_r421_N1024_smooth1.npz + test_path: //lustre/fsw/portfolios/coreai/users/coreya/datasets/darcy_fix/example_data/piececonst_r421_N1024_smooth2.npz + resolution: ${resolution} + batch_size: 64 # This is the GLOBAL batch size -model: - functional_dim: 1 - out_dim: 1 - embedding_dim: 2 - n_layers: 4 - n_hidden: 128 - dropout: 0.0 - n_head: 4 - act: gelu - mlp_ratio: 4 - unified_pos: False - ref: 8 - slice_num: 64 - use_te: False - time_input: False - precision: bf16 +optimizer: + type: muon # adamw | muon (muon requires PyTorch >= 2.9) + normaliser: permeability: mean: 1.25 @@ -69,3 +67,4 @@ training: validation: sample_size: 200 validation_pseudo_epochs: 1 + save_plots: false diff --git a/examples/cfd/darcy_transolver/model/flare.yaml b/examples/cfd/darcy_transolver/model/flare.yaml new file mode 100644 index 0000000000..31e9c78dfd --- /dev/null +++ b/examples/cfd/darcy_transolver/model/flare.yaml @@ -0,0 +1,39 @@ +# SPDX-FileCopyrightText: Copyright (c) 2023 - 2026 NVIDIA CORPORATION & AFFILIATES. +# SPDX-FileCopyrightText: All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# FLARE inherits from Transolver and replaces physics attention blocks with +# FLARE (Fast Low-rank Attention Routing Engine) blocks. +# use_te and plus are forced off internally by FLARE. + +_target_: physicsnemo.experimental.models.flare.FLARE +_convert_: all + +functional_dim: 1 +out_dim: 1 +embedding_dim: 2 +n_layers: 4 +n_hidden: 128 +dropout: 0.0 +n_head: 4 +act: gelu +mlp_ratio: 4 +slice_num: 64 +unified_pos: False +ref: 8 +structured_shape: + - ${resolution} + - ${resolution} +time_input: False diff --git a/examples/cfd/darcy_transolver/model/geotransolver.yaml b/examples/cfd/darcy_transolver/model/geotransolver.yaml new file mode 100644 index 0000000000..db585b8829 --- /dev/null +++ b/examples/cfd/darcy_transolver/model/geotransolver.yaml @@ -0,0 +1,40 @@ +# SPDX-FileCopyrightText: Copyright (c) 2023 - 2026 NVIDIA CORPORATION & AFFILIATES. +# SPDX-FileCopyrightText: All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# GeoTransolver with GALE attention (default). +# Uses geometry-aware cross-attention instead of Transolver's structured grid. + +_target_: physicsnemo.experimental.models.geotransolver.GeoTransolver +_convert_: all + +functional_dim: 3 +out_dim: 1 +geometry_dim: 3 +n_layers: 4 +n_hidden: 128 +dropout: 0.0 +n_head: 4 +act: gelu +mlp_ratio: 4 +slice_num: 64 +use_te: False +time_input: False +attention_type: GALE +include_local_features: False +state_mixing_mode: weighted +structured_shape: + - ${resolution} + - ${resolution} diff --git a/examples/cfd/darcy_transolver/model/geotransolver_fa.yaml b/examples/cfd/darcy_transolver/model/geotransolver_fa.yaml new file mode 100644 index 0000000000..305a411489 --- /dev/null +++ b/examples/cfd/darcy_transolver/model/geotransolver_fa.yaml @@ -0,0 +1,41 @@ +# SPDX-FileCopyrightText: Copyright (c) 2023 - 2026 NVIDIA CORPORATION & AFFILIATES. +# SPDX-FileCopyrightText: All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# GeoTransolver with GALE_FA (FLARE attention backend). +# Combines FLARE-style global-query self-attention with GALE geometry/context +# cross-attention. use_te must be False for GALE_FA. + +_target_: physicsnemo.experimental.models.geotransolver.GeoTransolver +_convert_: all + +functional_dim: 3 +out_dim: 1 +geometry_dim: 3 +n_layers: 4 +n_hidden: 128 +dropout: 0.0 +n_head: 4 +act: gelu +mlp_ratio: 4 +slice_num: 64 +use_te: False +time_input: False +attention_type: GALE_FA +include_local_features: False +state_mixing_mode: concat_project +structured_shape: + - ${resolution} + - ${resolution} diff --git a/examples/cfd/darcy_transolver/model/transolver.yaml b/examples/cfd/darcy_transolver/model/transolver.yaml new file mode 100644 index 0000000000..6e01da1b3f --- /dev/null +++ b/examples/cfd/darcy_transolver/model/transolver.yaml @@ -0,0 +1,36 @@ +# SPDX-FileCopyrightText: Copyright (c) 2023 - 2026 NVIDIA CORPORATION & AFFILIATES. +# SPDX-FileCopyrightText: All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +_target_: physicsnemo.models.transolver.Transolver +_convert_: all + +functional_dim: 1 +out_dim: 1 +embedding_dim: 2 +n_layers: 4 +n_hidden: 128 +dropout: 0.0 +n_head: 4 +act: gelu +mlp_ratio: 4 +slice_num: 64 +unified_pos: False +ref: 8 +structured_shape: + - ${resolution} + - ${resolution} +use_te: False +time_input: False diff --git a/examples/cfd/darcy_transolver/train_transolver_darcy_fix.py b/examples/cfd/darcy_transolver/train_transolver_darcy_fix.py index 0f528f6205..8c8b38a4e7 100644 --- a/examples/cfd/darcy_transolver/train_transolver_darcy_fix.py +++ b/examples/cfd/darcy_transolver/train_transolver_darcy_fix.py @@ -16,9 +16,12 @@ # Configuration imports: import hydra +from hydra.utils import instantiate from omegaconf import DictConfig, OmegaConf import json +import os import time +from datetime import datetime, timezone from math import ceil # Base PyTorch imports: @@ -38,8 +41,8 @@ from utils.testloss import TestLoss # Model imports from PhysicsNeMo -from physicsnemo.models.transolver import Transolver from physicsnemo.distributed import DistributedManager +from physicsnemo.optim import CombinedOptimizer from physicsnemo.utils import load_checkpoint, save_checkpoint from physicsnemo.utils.logging import PythonLogger, RankZeroLoggingWrapper @@ -54,8 +57,114 @@ prof = Profiler() +_GEOTRANSOLVER_TARGETS = { + "physicsnemo.experimental.models.geotransolver.GeoTransolver", +} + + +def make_model_forward(cfg: DictConfig) -> callable: + """ + Return a forward callable that uses the right keyword arguments for the + configured model. + + GeoTransolver uses (local_embedding, geometry) while Transolver/FLARE + use (fx, embedding). The decision is made once at startup from the Hydra + config, avoiding fragile isinstance checks through DDP/compile wrappers. + + Args: + cfg (DictConfig): Full Hydra config (reads model._target_). + + Returns: + callable: ``fn(model, pos, x) -> Tensor`` + """ + if cfg.model._target_ in _GEOTRANSOLVER_TARGETS: + + def _forward(model, pos, x): + combined_inputs = torch.cat([pos, x.unsqueeze(-1)], dim=-1) + return model( + local_embedding=combined_inputs, geometry=combined_inputs + ).squeeze(-1) + + else: + + def _forward(model, pos, x): + return model(embedding=pos, fx=x.unsqueeze(-1)).squeeze(-1) + + return _forward + + +def build_optimizer( + model: torch.nn.Module, + cfg: DictConfig, +) -> torch.optim.Optimizer: + """ + Build optimizer based on config. Supports AdamW and Muon. + + Muon is applied to 2D weight matrices; remaining parameters (biases, norms, + embeddings) are handled by AdamW. When both groups exist they are wrapped in + ``CombinedOptimizer``. + + Args: + model (torch.nn.Module): The model (possibly DDP-wrapped). + cfg (DictConfig): Full Hydra config (reads optimizer.type, scheduler.initial_lr, + scheduler.weight_decay). + + Returns: + torch.optim.Optimizer: The configured optimizer. + """ + opt_type = cfg.optimizer.type + lr = cfg.scheduler.initial_lr + weight_decay = cfg.scheduler.weight_decay + + if opt_type == "adamw": + return AdamW(model.parameters(), lr=lr, weight_decay=weight_decay) + + if opt_type == "muon": + if not hasattr(torch.optim, "Muon"): + raise ImportError( + "Muon optimizer requires PyTorch >= 2.9. " + "Install a newer PyTorch or use optimizer.type=adamw." + ) + base_model = model.module if hasattr(model, "module") else model + muon_params = [p for p in base_model.parameters() if p.ndim == 2] + other_params = [p for p in base_model.parameters() if p.ndim != 2] + + if muon_params and other_params: + return CombinedOptimizer( + [ + torch.optim.Muon( + muon_params, + lr=lr, + weight_decay=weight_decay, + adjust_lr_fn="match_rms_adamw", + ), + AdamW( + other_params, + lr=lr, + weight_decay=weight_decay, + betas=(0.9, 0.999), + eps=1.0e-8, + ), + ] + ) + elif muon_params: + return torch.optim.Muon( + muon_params, + lr=lr, + weight_decay=weight_decay, + adjust_lr_fn="match_rms_adamw", + ) + else: + return AdamW(other_params, lr=lr, weight_decay=weight_decay) + + raise ValueError( + f"Unsupported optimizer type: {opt_type!r}. Use 'adamw' or 'muon'." + ) + + def forward_train_full_loop( model: torch.nn.Module, + model_forward: callable, loss_fun: callable, optimizer: torch.optim.Optimizer, pos: torch.Tensor, @@ -70,6 +179,7 @@ def forward_train_full_loop( Args: model (torch.nn.Module): The model to train. + model_forward (callable): Forward callable from ``make_model_forward``. loss_fun (callable): Loss function. optimizer (torch.optim.Optimizer): Optimizer. pos (torch.Tensor): Position tensor (embedding). @@ -84,7 +194,7 @@ def forward_train_full_loop( """ dm = DistributedManager() with precision_context: - pred = model(embedding=pos, fx=x.unsqueeze(-1)).squeeze(-1) + pred = model_forward(model, pos, x) pred = y_normalizer.decode(pred) loss = loss_fun(pred, y) if scaler is not None: @@ -100,6 +210,7 @@ def forward_train_full_loop( def train_epoch( model: torch.nn.Module, + model_forward: callable, optimizer: torch.optim.Optimizer, scheduler: torch.optim.lr_scheduler._LRScheduler, train_dataloader: DataLoader, @@ -113,6 +224,7 @@ def train_epoch( Args: model (torch.nn.Module): The model to train. + model_forward (callable): Forward callable from ``make_model_forward``. optimizer (torch.optim.Optimizer): Optimizer. scheduler (torch.optim.lr_scheduler._LRScheduler): Learning rate scheduler. train_dataloader (DataLoader): Training data loader. @@ -128,6 +240,7 @@ def train_epoch( pos, x, y = batch loss = forward_train_full_loop( model, + model_forward, loss_fun, optimizer, pos, @@ -150,6 +263,7 @@ def train_epoch( def val_epoch( model: torch.nn.Module, + model_forward: callable, test_dataloader: DataLoader, loss_fun: callable, y_normalizer, @@ -159,6 +273,7 @@ def val_epoch( Args: model (torch.nn.Module): The model to validate. + model_forward (callable): Forward callable from ``make_model_forward``. test_dataloader (DataLoader): Validation data loader. loss_fun (callable): Loss function. y_normalizer: Normalizer for the target tensor. @@ -175,7 +290,7 @@ def val_epoch( for i, batch in enumerate(test_dataloader): pos, x, y = batch with torch.no_grad(): - pred = model(embedding=pos, fx=x.unsqueeze(-1)).squeeze(-1) + pred = model_forward(model, pos, x) pred = y_normalizer.decode(pred) loss = loss_fun(pred, y) @@ -227,12 +342,20 @@ def darcy_trainer(cfg: DictConfig) -> None: logger = RankZeroLoggingWrapper(PythonLogger(name="darcy_transolver"), dm) logger.file_logging() - # === TensorBoard SummaryWriter === - # Only rank 0 writes logs to avoid duplication in DDP - writer = None + # === TensorBoard SummaryWriters === + # Separate train/val writers so TensorBoard can overlay matching scalars + train_writer = None + val_writer = None + metrics_file = None if dm.rank == 0: log_dir = f"{cfg.output_dir}/runs/{cfg.run_id}" - writer = SummaryWriter(log_dir=log_dir) + train_writer = SummaryWriter(log_dir=f"{log_dir}/train") + val_writer = SummaryWriter(log_dir=f"{log_dir}/val") + + # === JSONL metrics log (append-safe for resumed runs) === + metrics_path = os.path.join(log_dir, "metrics.jsonl") + os.makedirs(log_dir, exist_ok=True) + metrics_file = open(metrics_path, "a") ######################################################################## # Print the configuration to log @@ -242,23 +365,8 @@ def darcy_trainer(cfg: DictConfig) -> None: ######################################################################## # define model ######################################################################## - model = Transolver( - functional_dim=cfg.model.functional_dim, - out_dim=cfg.model.out_dim, - embedding_dim=cfg.model.embedding_dim, - n_layers=cfg.model.n_layers, - n_hidden=cfg.model.n_hidden, - dropout=cfg.model.dropout, - n_head=cfg.model.n_head, - act=cfg.model.act, - mlp_ratio=cfg.model.mlp_ratio, - slice_num=cfg.model.slice_num, - unified_pos=cfg.model.unified_pos, - ref=cfg.model.ref, - structured_shape=[cfg.data.resolution, cfg.data.resolution], - use_te=cfg.model.use_te, - time_input=cfg.model.time_input, - ).to(dm.device) + model = instantiate(cfg.model).to(dm.device) + model_forward = make_model_forward(cfg) logger.info(f"\n{torchinfo.summary(model, verbose=0)}") @@ -269,11 +377,7 @@ def darcy_trainer(cfg: DictConfig) -> None: # define loss and optimizer ######################################################################## loss_fun = TestLoss(size_average=True) - optimizer = AdamW( - model.parameters(), - lr=cfg.scheduler.initial_lr, - weight_decay=cfg.scheduler.weight_decay, - ) + optimizer = build_optimizer(model, cfg) ######################################################################## # Create the data pipes and samplers @@ -323,12 +427,30 @@ def darcy_trainer(cfg: DictConfig) -> None: cfg.training.pseudo_epoch_sample_size / cfg.data.batch_size ) - scheduler = lr_scheduler.OneCycleLR( - optimizer, - max_lr=cfg.scheduler.initial_lr, - steps_per_epoch=steps_per_pseudo_epoch, - epochs=cfg.training.max_pseudo_epochs, - ) + total_steps = steps_per_pseudo_epoch * cfg.training.max_pseudo_epochs + if cfg.optimizer.type == "muon": + warmup_steps = steps_per_pseudo_epoch * 2 + scheduler = lr_scheduler.SequentialLR( + optimizer, + schedulers=[ + lr_scheduler.LinearLR( + optimizer, start_factor=1e-2, total_iters=warmup_steps + ), + lr_scheduler.CosineAnnealingLR( + optimizer, + T_max=total_steps - warmup_steps, + eta_min=cfg.scheduler.initial_lr * 0.1, + ), + ], + milestones=[warmup_steps], + ) + else: + scheduler = lr_scheduler.OneCycleLR( + optimizer, + max_lr=cfg.scheduler.initial_lr, + steps_per_epoch=steps_per_pseudo_epoch, + epochs=cfg.training.max_pseudo_epochs, + ) validator = GridValidator(output_dir=f"{cfg.output_dir}/runs/{cfg.run_id}/plots") @@ -340,6 +462,9 @@ def darcy_trainer(cfg: DictConfig) -> None: } loaded_pseudo_epoch = load_checkpoint(device=dm.device, **ckpt_args) + # Compile after checkpoint loading to avoid triggering recompilation + model = torch.compile(model) + validation_iters = ceil(cfg.validation.sample_size / cfg.data.batch_size) if cfg.training.pseudo_epoch_sample_size % cfg.data.batch_size != 0: @@ -381,6 +506,7 @@ def darcy_trainer(cfg: DictConfig) -> None: train_start = time.time() loss = train_epoch( model, + model_forward, optimizer, scheduler, train_dataloader, @@ -406,7 +532,7 @@ def darcy_trainer(cfg: DictConfig) -> None: logger.info(log_string) # --- TensorBoard logging (only on rank 0) --- - if dm.rank == 0 and writer is not None: + if dm.rank == 0 and train_writer is not None: # Images/sec/GPU: (num images processed in train_epoch) / train_time / num_gpus # Each batch processes batch_size // world_size images, for steps_per_pseudo_epoch steps images_per_epoch = len(train_dataloader) * ( @@ -414,12 +540,28 @@ def darcy_trainer(cfg: DictConfig) -> None: ) images_per_sec_per_gpu = images_per_epoch / train_time - writer.add_scalar("loss/train", loss.item(), pseudo_epoch) - writer.add_scalar("time_per_epoch/train", train_time, pseudo_epoch) - writer.add_scalar( - "images_per_sec_per_gpu/train", images_per_sec_per_gpu, pseudo_epoch + train_writer.add_scalar("loss", loss.item(), pseudo_epoch) + train_writer.add_scalar("time_per_epoch", train_time, pseudo_epoch) + train_writer.add_scalar( + "images_per_sec_per_gpu", images_per_sec_per_gpu, pseudo_epoch + ) + train_writer.add_scalar("learning_rate", lr, pseudo_epoch) + + # --- JSONL metrics record (training fields) --- + metrics_record = None + if dm.rank == 0 and metrics_file is not None: + images_per_epoch = len(train_dataloader) * ( + cfg.data.batch_size // dm.world_size ) - writer.add_scalar("learning_rate/train", lr, pseudo_epoch) + metrics_record = { + "pseudo_epoch": pseudo_epoch, + "timestamp": datetime.now(timezone.utc).isoformat(), + "train_loss": loss.item(), + "train_time_s": train_time, + "learning_rate": lr, + "images_per_sec_per_gpu": images_per_epoch / train_time, + "gpu_mem_reserved_gb": gpu_mem_reserved, + } # save checkpoint if pseudo_epoch % cfg.training.rec_results_freq == 0 and dm.rank == 0: @@ -429,7 +571,7 @@ def darcy_trainer(cfg: DictConfig) -> None: if pseudo_epoch % cfg.validation.validation_pseudo_epochs == 0: val_start = time.time() val_loss, pred, y, RL2 = val_epoch( - model, test_dataloader, loss_fun, y_normalizer + model, model_forward, test_dataloader, loss_fun, y_normalizer ) val_time = time.time() - val_start @@ -440,29 +582,49 @@ def darcy_trainer(cfg: DictConfig) -> None: logger.info(log_string) # --- TensorBoard logging (only on rank 0) --- - if dm.rank == 0 and writer is not None: + if dm.rank == 0 and val_writer is not None: # Validation images/sec/GPU val_images = validation_iters * ( cfg.data.batch_size // dm.world_size ) val_images_per_sec_per_gpu = val_images / val_time - writer.add_scalar("loss/val", val_loss.item(), pseudo_epoch) - writer.add_scalar("RL2/val", RL2.item(), pseudo_epoch) - writer.add_scalar("time_per_epoch/val", val_time, pseudo_epoch) - writer.add_scalar( - "images_per_sec_per_gpu/val", + val_writer.add_scalar("loss", val_loss.item(), pseudo_epoch) + val_writer.add_scalar("RL2", RL2.item(), pseudo_epoch) + val_writer.add_scalar("time_per_epoch", val_time, pseudo_epoch) + val_writer.add_scalar( + "images_per_sec_per_gpu", val_images_per_sec_per_gpu, pseudo_epoch, ) - if dm.rank == 0: + # --- JSONL metrics record (validation fields) --- + if metrics_record is not None: + val_images = validation_iters * ( + cfg.data.batch_size // dm.world_size + ) + metrics_record["val_loss"] = val_loss.item() + metrics_record["val_rl2"] = RL2.item() + metrics_record["val_time_s"] = val_time + metrics_record["val_images_per_sec_per_gpu"] = val_images / val_time + + if dm.rank == 0 and cfg.validation.save_plots: validator.make_plot(pred, y, pseudo_epoch, test_datapipe.s) + # --- Flush JSONL record for this pseudo-epoch --- + if metrics_record is not None: + metrics_file.write(json.dumps(metrics_record) + "\n") + metrics_file.flush() + # update learning rate # if pseudo_epoch % cfg.scheduler.decay_pseudo_epochs == 0: - if dm.rank == 0 and writer is not None: - writer.close() + if dm.rank == 0: + if train_writer is not None: + train_writer.close() + if val_writer is not None: + val_writer.close() + if metrics_file is not None: + metrics_file.close() logger.success("Training completed *yay*") diff --git a/physicsnemo/experimental/models/geotransolver/__init__.py b/physicsnemo/experimental/models/geotransolver/__init__.py index 21638bdc6e..3fed9686b8 100644 --- a/physicsnemo/experimental/models/geotransolver/__init__.py +++ b/physicsnemo/experimental/models/geotransolver/__init__.py @@ -23,15 +23,36 @@ ------- GeoTransolver Main model class combining GALE attention with geometry and global context. +GeoTransolverMetaData + Data class for storing essential meta data needed for the GeoTransolver model. GALE Geometry-Aware Latent Embeddings attention layer. +GALE_FA + GALE with FLARE self-attention backend. GALE_block - Transformer block using GALE attention. + Transformer block using GALE or GALE_FA attention. +GALEStructuredMesh2D + GALE with Conv2d slice projection for 2D structured grids. +GALEStructuredMesh3D + GALE with Conv3d slice projection for 3D structured grids. ContextProjector Projects context features onto physical state slices. +StructuredContextProjector + Context projector with Conv2d/Conv3d geometry encoding on structured grids. +GeometricFeatureProcessor + Processes geometric features at a single spatial scale using BQWarp. +MultiScaleFeatureExtractor + Multi-scale geometric feature extraction over multiple radii. GlobalContextBuilder Orchestrates context construction for the model. +Functions +--------- +collect_concrete_dropout_losses + Collect concrete dropout regularization losses from a model. +get_concrete_dropout_rates + Get concrete dropout rates from a model. + Examples -------- Basic usage: @@ -56,17 +77,35 @@ collect_concrete_dropout_losses, get_concrete_dropout_rates, ) -from .context_projector import ContextProjector, GlobalContextBuilder -from .gale import GALE, GALE_block +from .context_projector import ( + ContextProjector, + GeometricFeatureProcessor, + GlobalContextBuilder, + MultiScaleFeatureExtractor, + StructuredContextProjector, +) +from .gale import ( + GALE, + GALE_FA, + GALE_block, + GALEStructuredMesh2D, + GALEStructuredMesh3D, +) from .geotransolver import GeoTransolver, GeoTransolverMetaData __all__ = [ "GeoTransolver", "GeoTransolverMetaData", "GALE", + "GALE_FA", "GALE_block", + "GALEStructuredMesh2D", + "GALEStructuredMesh3D", "ContextProjector", + "GeometricFeatureProcessor", "GlobalContextBuilder", + "MultiScaleFeatureExtractor", + "StructuredContextProjector", "ConcreteDropout", "collect_concrete_dropout_losses", "get_concrete_dropout_rates", diff --git a/physicsnemo/experimental/models/geotransolver/context_projector.py b/physicsnemo/experimental/models/geotransolver/context_projector.py index a57b7ecfc1..15f7d67ffc 100644 --- a/physicsnemo/experimental/models/geotransolver/context_projector.py +++ b/physicsnemo/experimental/models/geotransolver/context_projector.py @@ -23,6 +23,8 @@ ------- ContextProjector Projects context features onto physical state slices. +StructuredContextProjector + Context projector with Conv2d/Conv3d geometry encoding on structured grids. GeometricFeatureProcessor Processes geometric features at a single spatial scale using BQWarp. MultiScaleFeatureExtractor @@ -39,9 +41,12 @@ from jaxtyping import Float from physicsnemo.core.version_check import check_version_spec -from physicsnemo.nn import gumbel_softmax from physicsnemo.nn import BQWarp from physicsnemo.nn import Mlp +from physicsnemo.nn.module.physics_attention import ( + _compute_slices_from_projections, + _project_input, +) from physicsnemo.nn import ConcreteDropout @@ -51,7 +56,148 @@ import transformer_engine.pytorch as te -class ContextProjector(nn.Module): +def _structured_grid_to_conv_input( + x: Float[torch.Tensor, "batch tokens channels"], + batch: int, + tokens: int, + channels: int, + ndim: int, + spatial_shape: tuple[int, ...], +) -> Float[torch.Tensor, "batch channels ..."]: + r"""Reshape flat token tensor to spatial layout for Conv2d/Conv3d. + + Converts :math:`(B, N, C)` to :math:`(B, C, H, W)` for 2D or + :math:`(B, C, H, W, D)` for 3D so that structured context projectors + can apply spatial convolutions. Validates that :math:`N` matches the + grid size. + + Parameters + ---------- + x : torch.Tensor + Input tensor of shape :math:`(B, N, C)` (batch, tokens, channels). + batch : int + Batch size :math:`B`. + tokens : int + Number of tokens :math:`N` (must equal :math:`H \\times W` or + :math:`H \\times W \\times D`). + channels : int + Channel dimension :math:`C`. + ndim : int + Number of spatial dimensions; must be 2 or 3. + spatial_shape : tuple[int, ...] + :math:`(H, W)` for 2D or :math:`(H, W, D)` for 3D. + + Returns + ------- + torch.Tensor + Reshaped tensor of shape :math:`(B, C, H, W)` or + :math:`(B, C, H, W, D)` for use as conv input. + + Raises + ------ + ValueError + If ``tokens`` does not match the product of ``spatial_shape``. + """ + if ndim == 2: + H, W = spatial_shape + if tokens != H * W: + raise ValueError( + f"Expected N={H * W} tokens for 2D grid, got N={tokens}" + ) + return x.view(batch, H, W, channels).permute(0, 3, 1, 2) + H, W, D = spatial_shape + if tokens != H * W * D: + raise ValueError( + f"Expected N={H * W * D} tokens for 3D grid, got N={tokens}" + ) + return x.view(batch, H, W, D, channels).permute(0, 4, 1, 2, 3) + + +class _SliceToContextMixin: + r"""Internal mixin providing shared slice-to-context init and slice aggregation. + + Used by :class:`ContextProjector` and :class:`StructuredContextProjector` to + avoid duplicating ``in_project_slice``, ``temperature``, ``proj_temperature``, + and the call to + :func:`~physicsnemo.nn.module.physics_attention._compute_slices_from_projections`. + """ + + def _init_slice_components( + self, + dim_head: int, + slice_num: int, + heads: int, + use_te: bool, + plus: bool, + ) -> None: + r"""Initialize slice projection, temperature, and optional adaptive temperature. + + Sets ``in_project_slice``, ``temperature``, and (when ``plus`` is True) + ``proj_temperature`` on this instance. Uses Transformer Engine linear + when ``use_te`` is True and TE is available. + + Parameters + ---------- + dim_head : int + Head dimension for the slice projection input. + slice_num : int + Number of slices (output dimension of ``in_project_slice``). + heads : int + Number of heads (used for temperature shape). + use_te : bool + Whether to prefer Transformer Engine for linear layers. + plus : bool + If True, add ``proj_temperature`` for Transolver++. + """ + linear_layer = te.Linear if (use_te and TE_AVAILABLE) else nn.Linear + self.in_project_slice = linear_layer(dim_head, slice_num) + self.temperature = nn.Parameter(torch.ones([1, 1, heads, 1]) * 0.5) + if plus: + self.proj_temperature = nn.Sequential( + linear_layer(dim_head, slice_num), + nn.GELU(), + linear_layer(slice_num, 1), + nn.GELU(), + ) + + def _compute_slices( + self, + slice_projections: Float[torch.Tensor, "batch tokens heads slices"], + fx: Float[torch.Tensor, "batch tokens heads dim"], + ) -> tuple[ + Float[torch.Tensor, "batch tokens heads slices"], + Float[torch.Tensor, "batch heads slices dim"], + ]: + r"""Compute slice weights and slice tokens from projections and latent features. + + Delegates to :func:`~physicsnemo.nn.module.physics_attention._compute_slices_from_projections`, + the shared free function that also backs + :meth:`~physicsnemo.nn.module.physics_attention.PhysicsAttentionBase._compute_slices_from_projections`. + + Parameters + ---------- + slice_projections : torch.Tensor + Shape :math:`(B, N, H, S)`. + fx : torch.Tensor + Shape :math:`(B, N, H, D)`. + + Returns + ------- + tuple[torch.Tensor, torch.Tensor] + ``(slice_weights, slice_token)`` with shapes :math:`(B, N, H, S)` + and :math:`(B, H, S, D)`. + """ + proj_temp = getattr(self, "proj_temperature", None) if self.plus else None + return _compute_slices_from_projections( + slice_projections, + fx, + self.temperature, + self.plus, + proj_temperature=proj_temp, + ) + + +class ContextProjector(_SliceToContextMixin, nn.Module): r"""Projects context features onto physical state space. This context projector is conceptually similar to half of a GALE attention layer. @@ -138,19 +284,8 @@ def __init__( # Attention components self.softmax = nn.Softmax(dim=-1) - self.temperature = nn.Parameter(torch.ones([1, heads, 1, 1]) * 0.5) - - # Transolver++ adaptive temperature projection - if plus: - self.proj_temperature = nn.Sequential( - linear_layer(self.dim_head, slice_num), - nn.GELU(), - linear_layer(slice_num, 1), - nn.GELU(), - ) - # Slice projection layer maps from head dimension to slice space - self.in_project_slice = linear_layer(dim_head, slice_num) + self._init_slice_components(dim_head, slice_num, heads, use_te, plus) # Concrete dropout on the output slice tokens if concrete_dropout: @@ -164,10 +299,10 @@ def __init__( def project_input_onto_slices( self, x: Float[torch.Tensor, "batch tokens channels"] ) -> ( - Float[torch.Tensor, "batch heads tokens dim"] + Float[torch.Tensor, "batch tokens heads dim"] | tuple[ - Float[torch.Tensor, "batch heads tokens dim"], - Float[torch.Tensor, "batch heads tokens dim"], + Float[torch.Tensor, "batch tokens heads dim"], + Float[torch.Tensor, "batch tokens heads dim"], ] ): r"""Project the input onto the slice space. @@ -181,95 +316,17 @@ def project_input_onto_slices( Returns ------- torch.Tensor or tuple[torch.Tensor, torch.Tensor] - If ``plus=True``, returns single tensor of shape :math:`(B, H, N, D)` where + If ``plus=True``, returns single tensor of shape :math:`(B, N, H, D)` where :math:`H` is number of heads and :math:`D` is head dimension. If ``plus=False``, - returns tuple of two tensors both of shape :math:`(B, H, N, D)`, representing + returns tuple of two tensors both of shape :math:`(B, N, H, D)`, representing the query and key projections respectively. """ - # Project input to multi-head representation: (B, N, C) -> (B, H, N, D) - projected_x = rearrange( - self.in_project_x(x), "B N (h d) -> B h N d", h=self.heads, d=self.dim_head + fx = None if self.plus else self.in_project_fx + return _project_input( + x, self.in_project_x, self.heads, self.dim_head, + "B N (H D) -> B N H D", project_fx=fx, ) - if self.plus: - # Transolver++ uses single projection for both paths - return projected_x - else: - # Standard Transolver uses separate query and key projections - feature_projection = rearrange( - self.in_project_fx(x), - "B N (h d) -> B h N d", - h=self.heads, - d=self.dim_head, - ) - return projected_x, feature_projection - - def compute_slices_from_projections( - self, - slice_projections: Float[torch.Tensor, "batch heads tokens slices"], - fx: Float[torch.Tensor, "batch heads tokens dim"], - ) -> tuple[ - Float[torch.Tensor, "batch heads tokens slices"], - Float[torch.Tensor, "batch heads slices dim"], - ]: - r"""Compute slice weights and slice tokens from input projections and latent features. - - Parameters - ---------- - slice_projections : torch.Tensor - Projected input tensor of shape :math:`(B, H, N, S)` where :math:`B` is batch size, - :math:`H` is number of heads, :math:`N` is number of tokens, and :math:`S` is number of - slices, representing the projection of each token onto each slice for each - attention head. - fx : torch.Tensor - Latent feature tensor of shape :math:`(B, H, N, D)` where :math:`D` is head dimension, - representing the learned states to be aggregated by the slice weights. - - Returns - ------- - tuple[torch.Tensor, torch.Tensor] - - ``slice_weights``: Tensor of shape :math:`(B, H, N, S)`, normalized weights for - each slice per token and head. - - ``slice_token``: Tensor of shape :math:`(B, H, S, D)`, aggregated latent features - for each slice, head, and batch. - - Notes - ----- - The function computes a temperature-scaled softmax over the slice projections to - obtain slice weights, then aggregates the latent features for each slice using - these weights. The aggregated features are normalized by the sum of weights for - numerical stability. - """ - # Compute temperature-adjusted softmax weights - if self.plus: - # Transolver++ uses adaptive temperature with Gumbel softmax - temperature = self.temperature + self.proj_temperature(fx) - clamped_temp = torch.clamp(temperature, min=0.01).to( - slice_projections.dtype - ) - slice_weights = gumbel_softmax(slice_projections, clamped_temp) - else: - # Standard Transolver uses fixed temperature with regular softmax - clamped_temp = torch.clamp(self.temperature, min=0.5, max=5).to( - slice_projections.dtype - ) - slice_weights = nn.functional.softmax( - slice_projections / clamped_temp, dim=-1 - ) - - # Ensure weights match the computation dtype - slice_weights = slice_weights.to(slice_projections.dtype) - - # Aggregate features by slice weights with normalization - # Normalize first to prevent overflow in reduced precision - slice_norm = slice_weights.sum(2) # Sum over tokens: (B, H, S) - normed_weights = slice_weights / (slice_norm[:, :, None, :] + 1e-2) - - # Weighted aggregation: (B, H, S, N) @ (B, H, N, D) -> (B, H, S, D) - slice_token = torch.matmul(normed_weights.transpose(2, 3), fx) - - return slice_weights, slice_token - def forward( self, x: Float[torch.Tensor, "batch tokens channels"] ) -> Float[torch.Tensor, "batch heads slices dim"]: @@ -313,11 +370,120 @@ def forward( else: projected_x, feature_projection = self.project_input_onto_slices(x) - # Project latent representations onto physical state slices: (B, H, N, D) -> (B, H, N, S) + # Project latent representations onto physical state slices: (B, N, H, D) -> (B, N, H, S) slice_projections = self.in_project_slice(projected_x) # Compute weighted aggregation of features into slice tokens - _, slice_tokens = self.compute_slices_from_projections( + _, slice_tokens = self._compute_slices( + slice_projections, feature_projection + ) + + # Apply concrete dropout to output slice tokens + if self.output_dropout is not None: + slice_tokens = self.output_dropout(slice_tokens) + + return slice_tokens + + +class StructuredContextProjector(_SliceToContextMixin, nn.Module): + r"""Context projector with Conv2d/Conv3d geometry encoding on structured grids. + + Same output interface as :class:`ContextProjector`—slice tokens + :math:`(B, H, S, D)`—but projects per-cell geometry via spatial convolutions + aligned with structured GALE attention. + """ + + def __init__( + self, + dim: int, + spatial_shape: tuple[int, ...], + heads: int = 8, + dim_head: int = 64, + dropout: float = 0.0, + slice_num: int = 64, + kernel: int = 3, + use_te: bool = True, + plus: bool = False, + concrete_dropout: bool = False, + ) -> None: + super().__init__() + if len(spatial_shape) not in (2, 3): + raise ValueError( + f"StructuredContextProjector expects spatial_shape of length 2 or 3, got {spatial_shape!r}" + ) + inner_dim = dim_head * heads + self.dim_head = dim_head + self.heads = heads + self.plus = plus + self.use_te = use_te + self.spatial_shape = tuple(int(s) for s in spatial_shape) + self._nd = len(self.spatial_shape) + pad = kernel // 2 + if self._nd == 2: + H, W = self.spatial_shape + self.H, self.W = H, W + self.in_project_x = nn.Conv2d(dim, inner_dim, kernel, 1, pad) + if not plus: + self.in_project_fx = nn.Conv2d(dim, inner_dim, kernel, 1, pad) + else: + H, W, D_ = self.spatial_shape + self.H, self.W, self.D = H, W, D_ + self.in_project_x = nn.Conv3d(dim, inner_dim, kernel, 1, pad) + if not plus: + self.in_project_fx = nn.Conv3d(dim, inner_dim, kernel, 1, pad) + + self.softmax = nn.Softmax(dim=-1) + self.dropout = nn.Dropout(dropout) + self._init_slice_components(dim_head, slice_num, heads, use_te, plus) + + # Concrete dropout on the output slice tokens + if concrete_dropout: + self.output_dropout = ConcreteDropout( + in_features=dim_head, + init_p=max(dropout, 0.05), + ) + else: + self.output_dropout = None + + def _grid_project( + self, x: Float[torch.Tensor, "batch tokens channels"] + ) -> ( + Float[torch.Tensor, "batch tokens heads dim"] + | tuple[ + Float[torch.Tensor, "batch tokens heads dim"], + Float[torch.Tensor, "batch tokens heads dim"], + ] + ): + B, N, C = x.shape + grid = _structured_grid_to_conv_input( + x, B, N, C, self._nd, self.spatial_shape + ) + pattern = ( + "B (H D) h w -> B (h w) H D" + if self._nd == 2 + else "B (H D) h w d -> B (h w d) H D" + ) + fx = None if self.plus else self.in_project_fx + return _project_input( + grid, self.in_project_x, self.heads, self.dim_head, + pattern, project_fx=fx, + ) + + def forward( + self, x: Float[torch.Tensor, "batch tokens channels"] + ) -> Float[torch.Tensor, "batch heads slices dim"]: + if not torch.compiler.is_compiling(): + if x.ndim != 3: + raise ValueError( + f"Expected 3D input (B, N, C), got {x.ndim}D shape {tuple(x.shape)}" + ) + if self.plus: + projected_x = self._grid_project(x) + feature_projection = projected_x + else: + projected_x, feature_projection = self._grid_project(x) + slice_projections = self.in_project_slice(projected_x) + _, slice_tokens = self._compute_slices( slice_projections, feature_projection ) @@ -639,6 +805,10 @@ class GlobalContextBuilder(nn.Module): Whether to use Transolver++ features. Default is ``False``. include_local_features : bool, optional Enable local feature extraction. Default is ``False``. + structured_shape : tuple[int, ...] | None, optional + If set, disables ball-query extractors and uses + :class:`StructuredContextProjector` for geometry when ``geometry_dim`` + is set. Default is ``None``. Forward ------- @@ -687,6 +857,7 @@ def __init__( use_te: bool = True, plus: bool = False, include_local_features: bool = False, + structured_shape: tuple[int, ...] | None = None, concrete_dropout: bool = False, ) -> None: super().__init__() @@ -699,9 +870,17 @@ def __init__( dim_head = n_hidden // n_head context_dim = 0 + self.structured_shape = structured_shape + + # Ball-query local features are not used on structured grids + use_local_bq = ( + geometry_dim is not None + and include_local_features + and structured_shape is None + ) # Multi-scale extractors for local features (one per functional dim) - if geometry_dim is not None and include_local_features: + if use_local_bq: self.local_extractors = nn.ModuleList( [ MultiScaleFeatureExtractor( @@ -726,16 +905,23 @@ def __init__( # Geometry tokenizer for global geometry context if geometry_dim is not None: - self.geometry_tokenizer = ContextProjector( - geometry_dim, - n_head, - dim_head, - dropout, - slice_num, - use_te, - plus, - concrete_dropout=concrete_dropout, - ) + if structured_shape is not None: + self.geometry_tokenizer = StructuredContextProjector( + geometry_dim, + structured_shape, + n_head, + dim_head, + dropout, + slice_num, + use_te=use_te, + plus=plus, + concrete_dropout=concrete_dropout, + ) + else: + self.geometry_tokenizer = ContextProjector( + geometry_dim, n_head, dim_head, dropout, slice_num, use_te, plus=plus, + concrete_dropout=concrete_dropout, + ) context_dim += dim_head else: self.geometry_tokenizer = None diff --git a/physicsnemo/experimental/models/geotransolver/gale.py b/physicsnemo/experimental/models/geotransolver/gale.py index e87eed41b2..8074c958eb 100644 --- a/physicsnemo/experimental/models/geotransolver/gale.py +++ b/physicsnemo/experimental/models/geotransolver/gale.py @@ -25,6 +25,7 @@ import torch import torch.nn as nn +import torch.nn.functional as F from einops import rearrange from jaxtyping import Float @@ -33,9 +34,12 @@ from physicsnemo.nn import Mlp from physicsnemo.nn.module.physics_attention import ( PhysicsAttentionIrregularMesh, + PhysicsAttentionStructuredMesh2D, + PhysicsAttentionStructuredMesh3D, + _project_input, ) +from physicsnemo.experimental.nn.flare_attention import _flare_self_attention -from physicsnemo.experimental.models.geotransolver.gale_fa import GALE_FA from physicsnemo.nn import ConcreteDropout # Check optional dependency availability @@ -43,6 +47,191 @@ te = OptionalImport("transformer_engine.pytorch", "0.1.0") +def _mix_self_and_cross( + self_attn: torch.Tensor, + cross_attn: torch.Tensor, + mode: str, + state_mixing: nn.Parameter | None = None, + concat_project: nn.Module | None = None, +) -> torch.Tensor: + r"""Blend self-attention and cross-attention outputs. + + Parameters + ---------- + self_attn : torch.Tensor + Self-attention output. + cross_attn : torch.Tensor + Cross-attention output (same shape as ``self_attn``). + mode : str + ``"weighted"`` for sigmoid-gated sum, ``"concat_project"`` for + concatenation followed by a learned projection. + state_mixing : nn.Parameter or None + Learnable scalar for ``"weighted"`` mode. + concat_project : nn.Module or None + Projection module for ``"concat_project"`` mode. + + Returns + ------- + torch.Tensor + Blended output, same shape as inputs. + """ + match mode: + case "weighted": + w = torch.sigmoid(state_mixing) + return w * self_attn + (1 - w) * cross_attn + case "concat_project": + return concat_project(torch.cat([self_attn, cross_attn], dim=-1)) + case _: + raise ValueError(f"Invalid state_mixing_mode: {mode!r}") + + +def _gale_compute_slice_attention_cross( + module: nn.Module, + slice_tokens: list[Float[torch.Tensor, "batch heads slices dim"]], + context: Float[torch.Tensor, "batch heads context_slices context_dim"], +) -> list[Float[torch.Tensor, "batch heads slices dim"]]: + r"""Shared cross-attention between slice tokens and context. + + Used by :class:`GALE` and :class:`_GALEStructuredForwardMixin` so the + cross-attention implementation lives in one place. Projects queries from + concatenated slice tokens, keys and values from context; runs Transformer + Engine or SDPA attention; splits the result back to one tensor per input. + + Parameters + ---------- + module : nn.Module + Module with ``cross_q``, ``cross_k``, ``cross_v``, ``use_te``, + ``heads``, ``dim_head``, and (if ``use_te``) ``attn_fn``. + slice_tokens : list[torch.Tensor] + One tensor per input, each of shape :math:`(B, H, S, D)`. + context : torch.Tensor + Context tensor of shape :math:`(B, H, S_c, D_c)`. + + Returns + ------- + list[torch.Tensor] + One cross-attention output per element of ``slice_tokens``, each + of shape :math:`(B, H, S, D)`. + """ + q_input = torch.cat(slice_tokens, dim=-2) + q = module.cross_q(q_input) + k = module.cross_k(context) + v = module.cross_v(context) + if module.use_te: + q = rearrange(q, "b h s d -> b s h d") + k = rearrange(k, "b h s d -> b s h d") + v = rearrange(v, "b h s d -> b s h d") + cross_attention = module.attn_fn(q, k, v) + cross_attention = rearrange( + cross_attention, + "b s (h d) -> b h s d", + h=module.heads, + d=module.dim_head, + ) + else: + cross_attention = torch.nn.functional.scaled_dot_product_attention( + q, k, v, is_causal=False + ) + cross_attention = torch.split( + cross_attention, slice_tokens[0].shape[-2], dim=-2 + ) + return list(cross_attention) + + +def _gale_forward_impl( + module: nn.Module, + x: tuple[Float[torch.Tensor, "batch tokens channels"], ...], + context: Float[torch.Tensor, "batch heads context_slices context_dim"] + | None, +) -> list[Float[torch.Tensor, "batch tokens channels"]]: + r"""Single implementation of the GALE forward pipeline. + + Shared by :class:`GALE` and :class:`_GALEStructuredForwardMixin`. Steps: + validate inputs; project onto slices; compute slice weights and tokens; + apply self-attention on slices; optionally cross-attend to context and + mix with ``state_mixing``; project attention outputs back to token space. + + Parameters + ---------- + module : nn.Module + GALE-like module with ``project_input_onto_slices``, + ``in_project_slice``, ``_compute_slices_from_projections``, + ``_compute_slice_attention_te``, ``_compute_slice_attention_sdpa``, + ``compute_slice_attention_cross``, ``_project_attention_outputs``, + plus attributes ``use_te``, ``plus``, ``state_mixing_mode``, and + ``state_mixing`` (if weighted) or ``concat_project`` (if concat). + x : tuple[torch.Tensor, ...] + Input tensors, each of shape :math:`(B, N, C)`; must be non-empty. + context : torch.Tensor or None + Optional context of shape :math:`(B, H, S_c, D_c)` for cross-attention. + If ``None``, only self-attention is applied. + + Returns + ------- + list[torch.Tensor] + One output tensor per input, each of shape :math:`(B, N, C)`. + + Raises + ------ + ValueError + If ``x`` is empty or any element is not 3D. + """ + if not torch.compiler.is_compiling(): + if len(x) == 0: + raise ValueError("Expected non-empty tuple of input tensors") + for i, tensor in enumerate(x): + if tensor.ndim != 3: + raise ValueError( + f"Expected 3D input tensor (B, N, C) at index {i}, " + f"got {tensor.ndim}D tensor with shape {tuple(tensor.shape)}" + ) + if module.plus: + x_mid = [module.project_input_onto_slices(_x) for _x in x] + fx_mid = [_x_mid for _x_mid in x_mid] + else: + x_mid, fx_mid = zip( + *[module.project_input_onto_slices(_x) for _x in x] + ) + slice_projections = [module.in_project_slice(_x_mid) for _x_mid in x_mid] + slice_weights, slice_tokens = zip( + *[ + module._compute_slices_from_projections(proj, _fx_mid) + for proj, _fx_mid in zip(slice_projections, fx_mid) + ] + ) + if module.use_te: + self_slice_token = [ + module._compute_slice_attention_te(_slice_token) + for _slice_token in slice_tokens + ] + else: + self_slice_token = [ + module._compute_slice_attention_sdpa(_slice_token) + for _slice_token in slice_tokens + ] + if context is not None: + cross_slice_token = [ + module.compute_slice_attention_cross([_slice_token], context)[0] + for _slice_token in slice_tokens + ] + out_slice_token = [ + _mix_self_and_cross( + sst, cst, module.state_mixing_mode, + state_mixing=getattr(module, "state_mixing", None), + concat_project=getattr(module, "concat_project", None), + ) + for sst, cst in zip(self_slice_token, cross_slice_token) + ] + else: + # Use only self-attention when no context is provided + out_slice_token = self_slice_token + outputs = [ + module._project_attention_outputs(ost, sw) + for ost, sw in zip(out_slice_token, slice_weights) + ] + return outputs + + class GALE(PhysicsAttentionIrregularMesh): r"""Geometry-Aware Latent Embeddings (GALE) attention layer. @@ -69,8 +258,10 @@ class GALE(PhysicsAttentionIrregularMesh): Whether to use Transolver++ features. Default is False. context_dim : int, optional Dimension of the context vector for cross-attention. Default is 0. + concrete_dropout : bool, optional + Whether to use ConcreteDropout instead of standard dropout. Default is False. state_mixing_mode : str, optional - How to blend self-attention and cross-attention outputs. ``"weighted"`` uses + How to blend self-attention and cross-attention outputs. ``"weighted"`` uses a learnable sigmoid-gated weighted sum. ``"concat_project"`` concatenates the two along the head dimension and projects back with a linear layer. Default is ``"weighted"``. @@ -129,32 +320,7 @@ def __init__( state_mixing_mode: str = "weighted", ) -> None: super().__init__(dim, heads, dim_head, dropout, slice_num, use_te, plus) - - if state_mixing_mode not in ("weighted", "concat_project"): - raise ValueError( - f"Invalid state_mixing_mode: {state_mixing_mode!r}. " - f"Expected 'weighted' or 'concat_project'." - ) - self.state_mixing_mode = state_mixing_mode - - linear_layer = te.Linear if self.use_te else nn.Linear - - # Cross-attention projection layers for context integration - self.cross_q = linear_layer(dim_head, dim_head) - self.cross_k = linear_layer(context_dim, dim_head) - self.cross_v = linear_layer(context_dim, dim_head) - - # Mixing layers for blending self-attention and cross-attention - if state_mixing_mode == "weighted": - # Learnable mixing weight between self and cross attention - # Initialize near 0.0 since sigmoid(0) = 0.5, giving balanced initial mixing - self.state_mixing = nn.Parameter(torch.tensor(0.0)) - else: - # Concatenate self and cross attention and project back to dim_head - self.concat_project = nn.Sequential( - linear_layer(2 * dim_head, dim_head), - nn.GELU(), - ) + _gale_cross_init(self, dim_head, context_dim, use_te, state_mixing_mode) # Replace inherited out_dropout with ConcreteDropout when enabled if concrete_dropout: @@ -185,39 +351,10 @@ def compute_slice_attention_cross( list[torch.Tensor] List of cross-attention outputs, each of shape :math:`(B, H, S, D)`. """ - # Concatenate all slice tokens for batched projection - q_input = torch.cat(slice_tokens, dim=-2) # (B, H, total_slices, D) - - # Project queries from slice tokens - q = self.cross_q(q_input) # (B, H, total_slices, D) - - # Project keys and values from context - k = self.cross_k(context) # (B, H, S_c, D) - v = self.cross_v(context) # (B, H, S_c, D) - - # Compute cross-attention using appropriate backend - if self.use_te: - # Transformer Engine expects (B, S, H, D) format - q = rearrange(q, "b h s d -> b s h d") - k = rearrange(k, "b h s d -> b s h d") - v = rearrange(v, "b h s d -> b s h d") - cross_attention = self.attn_fn(q, k, v) - cross_attention = rearrange( - cross_attention, "b s (h d) -> b h s d", h=self.heads, d=self.dim_head - ) - else: - # Use PyTorch's scaled dot-product attention - cross_attention = torch.nn.functional.scaled_dot_product_attention( - q, k, v, is_causal=False - ) - - # Split back into individual slice token outputs - cross_attention = torch.split( - cross_attention, slice_tokens[0].shape[-2], dim=-2 + return _gale_compute_slice_attention_cross( + self, slice_tokens, context ) - return list(cross_attention) - def forward( self, x: tuple[Float[torch.Tensor, "batch tokens channels"], ...], @@ -247,80 +384,324 @@ def forward( List of output tensors, each of shape :math:`(B, N, C)``, same shape as inputs. """ - ### Input validation - if not torch.compiler.is_compiling(): - if len(x) == 0: - raise ValueError("Expected non-empty tuple of input tensors") - for i, tensor in enumerate(x): - if tensor.ndim != 3: - raise ValueError( - f"Expected 3D input tensor (B, N, C) at index {i}, " - f"got {tensor.ndim}D tensor with shape {tuple(tensor.shape)}" - ) - - # Project inputs onto learned latent spaces - if self.plus: - x_mid = [self.project_input_onto_slices(_x) for _x in x] - # In Transolver++, x_mid is reused for both projections - fx_mid = [_x_mid for _x_mid in x_mid] - else: - x_mid, fx_mid = zip( - *[self.project_input_onto_slices(_x) for _x in x] + return _gale_forward_impl(self, x, context) + + +def _gale_cross_init( + self: nn.Module, + dim_head: int, + context_dim: int, + use_te: bool, + state_mixing_mode: str = "weighted", +) -> None: + # Match GALE: TE linear only when TE is installed (GALE_block already errors if use_te without TE) + linear_layer = te.Linear if (use_te and TE_AVAILABLE) else nn.Linear + self.cross_q = linear_layer(dim_head, dim_head) + self.cross_k = linear_layer(context_dim, dim_head) + self.cross_v = linear_layer(context_dim, dim_head) + + self.state_mixing_mode = state_mixing_mode + + match state_mixing_mode: + case "weighted": + # Learnable mixing weight between self and cross attention + # Initialize near 0.0 since sigmoid(0) = 0.5, giving balanced initial mixing + self.state_mixing = nn.Parameter(torch.tensor(0.0)) + case "concat_project": + # Concatenate self and cross attention and project back to dim_head + self.concat_project = nn.Sequential( + linear_layer(2 * dim_head, dim_head), + nn.GELU(), + ) + case _: + raise ValueError( + f"Invalid state_mixing_mode: {state_mixing_mode!r}. " + f"Expected 'weighted' or 'concat_project'." ) - # Project latent representations onto physical state slices - slice_projections = [self.in_project_slice(_x_mid) for _x_mid in x_mid] - # Compute slice weights and aggregated slice tokens - slice_weights, slice_tokens = zip( - *[ - self._compute_slices_from_projections(proj, _fx_mid) - for proj, _fx_mid in zip(slice_projections, fx_mid) - ] +class _GALEStructuredForwardMixin: + """Shared cross-attention and forward for structured GALE (2D/3D conv projection).""" + + def compute_slice_attention_cross( + self, + slice_tokens: list[Float[torch.Tensor, "batch heads slices dim"]], + context: Float[torch.Tensor, "batch heads context_slices context_dim"], + ) -> list[Float[torch.Tensor, "batch heads slices dim"]]: + return _gale_compute_slice_attention_cross( + self, slice_tokens, context ) - # Apply self-attention to slice tokens - if self.use_te: - self_slice_token = [ - self._compute_slice_attention_te(_slice_token) - for _slice_token in slice_tokens - ] - else: - self_slice_token = [ - self._compute_slice_attention_sdpa(_slice_token) - for _slice_token in slice_tokens - ] + def forward( + self, + x: tuple[Float[torch.Tensor, "batch tokens channels"], ...], + context: Float[torch.Tensor, "batch heads context_slices context_dim"] + | None = None, + ) -> list[Float[torch.Tensor, "batch tokens channels"]]: + return _gale_forward_impl(self, x, context) - # Apply cross-attention with context if provided - if context is not None: - cross_slice_token = [ - self.compute_slice_attention_cross([_slice_token], context)[0] - for _slice_token in slice_tokens - ] - # Blend self-attention and cross-attention - if self.state_mixing_mode == "weighted": - mixing_weight = torch.sigmoid(self.state_mixing) - out_slice_token = [ - mixing_weight * sst + (1 - mixing_weight) * cst - for sst, cst in zip(self_slice_token, cross_slice_token) - ] - else: - out_slice_token = [ - self.concat_project(torch.cat([sst, cst], dim=-1)) - for sst, cst in zip(self_slice_token, cross_slice_token) - ] +class GALEStructuredMesh2D(_GALEStructuredForwardMixin, PhysicsAttentionStructuredMesh2D): + r"""GALE with Conv2d slice projection for 2D structured grids (see :class:`GALE`).""" + + def __init__( + self, + dim: int, + spatial_shape: tuple[int, int], + heads: int = 8, + dim_head: int = 64, + dropout: float = 0.0, + slice_num: int = 64, + kernel: int = 3, + use_te: bool = True, + plus: bool = False, + context_dim: int = 0, + state_mixing_mode: str = "weighted", + ) -> None: + super().__init__( + dim, + spatial_shape, + heads, + dim_head, + dropout, + slice_num, + kernel, + use_te, + plus, + ) + _gale_cross_init(self, dim_head, context_dim, use_te, state_mixing_mode) + + +class GALEStructuredMesh3D(_GALEStructuredForwardMixin, PhysicsAttentionStructuredMesh3D): + r"""GALE with Conv3d slice projection for 3D structured grids (see :class:`GALE`).""" + + def __init__( + self, + dim: int, + spatial_shape: tuple[int, int, int], + heads: int = 8, + dim_head: int = 64, + dropout: float = 0.0, + slice_num: int = 64, + kernel: int = 3, + use_te: bool = True, + plus: bool = False, + context_dim: int = 0, + state_mixing_mode: str = "weighted", + ) -> None: + super().__init__( + dim, + spatial_shape, + heads, + dim_head, + dropout, + slice_num, + kernel, + use_te, + plus, + ) + _gale_cross_init(self, dim_head, context_dim, use_te, state_mixing_mode) + + +class GALE_FA(nn.Module): + r"""GALE_FA: Geometry-Aware Latent Embeddings with FLARE self-Attention attention layer. + + Adopted: + + - FLARE attention: Fast Low-rank Attention Routing Engine + paper: https://arxiv.org/abs/2508.12594 + - GeoTransolver context: + paper: https://arxiv.org/abs/2512.20399 + + GALE_FA is an alternative to the GALE attention mechanism of the GeoTransolver. + It supports cross-attention with a context vector, built from geometry and global embeddings. + GALE_FA combines FLARE self-attention on learned physical state slices with cross-attention + to geometry-aware context, using a learnable mixing weight to blend the two. + + Parameters + ---------- + dim : int + Input dimension of the features. + heads : int, optional + Number of attention heads. Default is 8. + dim_head : int, optional + Dimension of each attention head. Default is 64. + dropout : float, optional + Dropout rate. Default is 0.0. + n_global_queries : int, optional + Number of learned global queries. Default is 64. + use_te : bool, optional + Whether to use Transformer Engine backend when available. Default is False. + context_dim : int, optional + Dimension of the context vector for cross-attention. Default is 0. + concrete_dropout : bool, optional + Whether to use learned concrete dropout instead of standard dropout. + Default is ``False``. + state_mixing_mode : str, optional + How to blend self-attention and cross-attention outputs. ``"weighted"`` uses + a learnable sigmoid-gated weighted sum. ``"concat_project"`` + concatenates the two along the head dimension and projects back with a + linear layer. Default is ``"weighted"``. + + Forward + ------- + x : tuple[torch.Tensor, ...] + Tuple of input tensors, each of shape :math:`(B, N, C)` where :math:`B` is + batch size, :math:`N` is number of tokens, and :math:`C` is number of channels. + context : tuple[torch.Tensor, ...] | None, optional + Context tensor for cross-attention of shape :math:`(B, H, S_c, D_c)` where + :math:`H` is number of heads, :math:`S_c` is number of context slices, and + :math:`D_c` is context dimension. If ``None``, only self-attention is applied. + Default is ``None``. + + Outputs + ------- + list[torch.Tensor] + List of output tensors, each of shape :math:`(B, N, C)`, same shape as inputs. + + Notes + ----- + The mixing between self-attention and cross-attention is controlled by a learnable + parameter ``state_mixing`` which is passed through a sigmoid function to ensure + the mixing weight stays in :math:`[0, 1]`. + + See Also + -------- + :class:`GALE` : Original GeoTransolver GALE attention class. + :class:`GALE_block` : Transformer block that calls GALE or GALE_FA attention. + + Examples + -------- + >>> import torch + >>> gale_fa = GALE_FA(dim=256, heads=8, dim_head=32, context_dim=32) + >>> x = (torch.randn(2, 100, 256),) # Single input tensor in tuple + >>> context = torch.randn(2, 8, 64, 32) # Context for cross-attention + >>> outputs = gale_fa(x, context) + >>> len(outputs) + 1 + >>> outputs[0].shape + torch.Size([2, 100, 256]) + """ + + def __init__( + self, + dim, + heads: int = 8, + dim_head: int = 64, + dropout: float = 0.0, + n_global_queries: int = 64, + use_te: bool = True, + context_dim: int = 0, + concrete_dropout: bool = False, + state_mixing_mode: str = "weighted", + ): + if use_te: + raise ValueError( + "GALE_FA does not support Transformer Engine backend. " + "Use use_te=False; TE disables FlashAttention for differing q/k sizes in FLARE attention." + ) + super().__init__() + self.use_te = use_te + self.heads = heads + self.dim_head = dim_head + self.scale = 1.0 + # It is recommended by the FLARE authors to use self.scale = 1 if self.dim_head <= 8 else (self.dim_head ** -0.5) + # but we use self.scale = 1.0 because the recommended scaling is not tested yet. + inner_dim = dim_head * heads + + linear_layer = te.Linear if self.use_te else nn.Linear + + # Global queries for FLARE self-attention + self.q_global = nn.Parameter(torch.randn(1, heads, n_global_queries, dim_head)) + + # Linear projections for self-attention + self.in_project_x = linear_layer(dim, inner_dim) + self.self_k = linear_layer(dim_head, dim_head) + self.self_v = linear_layer(dim_head, dim_head) + + if context_dim > 0: + _gale_cross_init(self, dim_head, context_dim, use_te, state_mixing_mode) + + # Linear projection for output + self.out_linear = linear_layer(inner_dim, dim) + if concrete_dropout: + self.out_dropout = ConcreteDropout( + in_features=dim, + init_p=max(dropout, 0.05), + ) else: - # Use only self-attention when no context is provided - out_slice_token = self_slice_token + self.out_dropout = nn.Dropout(dropout) - # Project attention outputs back to original space using slice weights - outputs = [ - self._project_attention_outputs(ost, sw) - for ost, sw in zip(out_slice_token, slice_weights) + def forward( + self, + x: tuple[Float[torch.Tensor, "batch tokens channels"], ...], + context: Float[torch.Tensor, "batch heads context_slices context_dim"] + | None = None, + ) -> list[Float[torch.Tensor, "batch tokens channels"]]: + r"""Forward pass of the GALE_FA module. + + Applies GALE_FA attention to the input features. + + Parameters + ---------- + x : tuple[torch.Tensor, ...] + Tuple of input tensors, each of shape :math:`(B, N, C)` where :math:`B` + is batch size, :math:`N` is number of tokens, and :math:`C` is number + of channels. + context : torch.Tensor | None, optional + Context tensor for cross-attention of shape :math:`(B, H, S_c, D_c)` + where :math:`H` is number of heads, :math:`S_c` is number of context + slices, and :math:`D_c` is context dimension. If ``None``, only + self-attention is applied. Default is ``None``. + + Returns + ------- + list[torch.Tensor] + List of output tensors, each of shape :math:`(B, N, C)``, same shape + as inputs. + """ + # Input projection: (B, N, C) -> (B, N, H, D) -> (B, H, N, D) + x_mid = [ + _project_input( + _x, self.in_project_x, self.heads, self.dim_head, + "B N (H D) -> B N H D", + ).permute(0, 2, 1, 3) + for _x in x + ] + + # FLARE self-attention per input + self_attention = [ + _flare_self_attention( + _x_mid, self.q_global, self.self_k, self.self_v, self.scale, + ) + for _x_mid in x_mid ] - return outputs + # Cross-attention with context and state mixing + if context is not None: + q = [self.cross_q(_x_mid) for _x_mid in x_mid] + k = self.cross_k(context) + v = self.cross_v(context) + cross_attention = [ + F.scaled_dot_product_attention(_q, k, v, scale=self.scale) + for _q in q + ] + outputs = [ + _mix_self_and_cross( + sa, ca, self.state_mixing_mode, + state_mixing=getattr(self, "state_mixing", None), + concat_project=getattr(self, "concat_project", None), + ) + for sa, ca in zip(self_attention, cross_attention) + ] + else: + outputs = self_attention + + # Back to token layout: (B, H, N, D) -> (B, N, H, D) + outputs = [_y.permute(0, 2, 1, 3) for _y in outputs] + outputs = [rearrange(_out, "b n h d -> b n (h d)") for _out in outputs] + outputs = [self.out_linear(_out) for _out in outputs] + return [self.out_dropout(_out) for _out in outputs] class GALE_block(nn.Module): @@ -354,11 +735,16 @@ class GALE_block(nn.Module): Whether to use Transolver++ features. Default is ``False``. context_dim : int, optional Dimension of the context vector for cross-attention. Default is 0. + spatial_shape : tuple[int, ...] | None, optional + If ``None``, uses irregular-mesh GALE. Length-2 tuple enables 2D Conv2d + projection; length-3 tuple enables 3D Conv3d projection (flattened + :math:`N = H \times W` or :math:`H \times W \times D`). Default is ``None``. attention_type : str, optional - attention_type is used to choose the attention type (GALE or GALE_FA). + Attention backend to use. ``"GALE"`` uses the standard physics-aware + slice attention; ``"GALE_FA"`` uses flash-attention variant. Default is ``"GALE"``. state_mixing_mode : str, optional - How to blend self-attention and cross-attention outputs. ``"weighted"`` uses + How to blend self-attention and cross-attention outputs. ``"weighted"`` uses a learnable sigmoid-gated weighted sum. ``"concat_project"`` concatenates the two along the head dimension and projects back with a linear layer. Default is ``"weighted"``. @@ -414,6 +800,7 @@ def __init__( use_te: bool = True, plus: bool = False, context_dim: int = 0, + spatial_shape: tuple[int, ...] | None = None, attention_type: str = "GALE", concrete_dropout: bool = False, state_mixing_mode: str = "weighted", @@ -434,26 +821,62 @@ def __init__( else: self.ln_1 = nn.LayerNorm(hidden_dim) - # Attention layer + dim_head = hidden_dim // num_heads + # First match on attention backend, then on spatial shape match attention_type: case 'GALE': - self.Attn = GALE( - hidden_dim, - heads=num_heads, - dim_head=hidden_dim // num_heads, - dropout=dropout, - slice_num=slice_num, - use_te=use_te, - plus=plus, - context_dim=context_dim, - concrete_dropout=concrete_dropout, - state_mixing_mode=state_mixing_mode, - ) + if spatial_shape is None: + self.Attn = GALE( + hidden_dim, + heads=num_heads, + dim_head=dim_head, + dropout=dropout, + slice_num=slice_num, + use_te=use_te, + plus=plus, + context_dim=context_dim, + concrete_dropout=concrete_dropout, + state_mixing_mode=state_mixing_mode, + ) + elif len(spatial_shape) == 2: + self.Attn = GALEStructuredMesh2D( + hidden_dim, + spatial_shape=(int(spatial_shape[0]), int(spatial_shape[1])), + heads=num_heads, + dim_head=dim_head, + dropout=dropout, + slice_num=slice_num, + use_te=use_te, + plus=plus, + context_dim=context_dim, + state_mixing_mode=state_mixing_mode, + ) + elif len(spatial_shape) == 3: + self.Attn = GALEStructuredMesh3D( + hidden_dim, + spatial_shape=( + int(spatial_shape[0]), + int(spatial_shape[1]), + int(spatial_shape[2]), + ), + heads=num_heads, + dim_head=dim_head, + dropout=dropout, + slice_num=slice_num, + use_te=use_te, + plus=plus, + context_dim=context_dim, + state_mixing_mode=state_mixing_mode, + ) + else: + raise ValueError( + f"spatial_shape must be None, length-2, or length-3; got {spatial_shape!r}" + ) case 'GALE_FA': self.Attn = GALE_FA( hidden_dim, heads=num_heads, - dim_head=hidden_dim // num_heads, + dim_head=dim_head, dropout=dropout, n_global_queries=slice_num, use_te=use_te, diff --git a/physicsnemo/experimental/models/geotransolver/gale_fa.py b/physicsnemo/experimental/models/geotransolver/gale_fa.py deleted file mode 100644 index eb9f8a3478..0000000000 --- a/physicsnemo/experimental/models/geotransolver/gale_fa.py +++ /dev/null @@ -1,285 +0,0 @@ -# SPDX-FileCopyrightText: Copyright (c) 2023 - 2026 NVIDIA CORPORATION & AFFILIATES. -# SPDX-FileCopyrightText: All rights reserved. -# SPDX-License-Identifier: Apache-2.0 -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -"""GALE_FA (Geometry-Aware Latent Embeddings with FLARE self-Attention) attention layer. - -This module provides the GALE_FA attention mechanism, -an alternative to the GALE attention mechanism of the GeoTransolver. -""" - -from __future__ import annotations - -import torch -import torch.nn as nn -import torch.nn.functional as F -from einops import rearrange -from jaxtyping import Float - -from physicsnemo.core.version_check import check_version_spec, OptionalImport -from physicsnemo.nn import ConcreteDropout - -# Check optional dependency availability -TE_AVAILABLE = check_version_spec("transformer_engine", "0.1.0", hard_fail=False) -te = OptionalImport("transformer_engine.pytorch", "0.1.0") - - -class GALE_FA(nn.Module): - r"""GALE_FA: Geometry-Aware Latent Embeddings with FLARE self-Attention attention layer. - Adopted: - - FLARE attention: Fast Low-rank Attention Routing Engine - paper: https://arxiv.org/abs/2508.12594 - - GeoTransolver context: - paper: https://arxiv.org/abs/2512.20399 - - GALE_FA is an alternative to the GALE attention mechanism of the GeoTransolver - It supports cross-attention with a context vector, built from geometry and global embeddings. - GALE_FA combines FLARE self-attention on learned physical state slices with cross-attention - to geometry-aware context, using a learnable mixing weight to blend the two. - - Parameters - ---------- - dim : int - Input dimension of the features. - heads : int, optional - Number of attention heads. Default is 8. - dim_head : int, optional - Dimension of each attention head. Default is 64. - dropout : float, optional - Dropout rate. Default is 0.0. - n_global_queries : int, optional - Number of learned global queries. Default is 64. - use_te : bool, optional - Whether to use Transformer Engine backend when available. Default is False. - context_dim : int, optional - Dimension of the context vector for cross-attention. Default is 0. - concrete_dropout : bool, optional - Whether to use learned concrete dropout instead of standard dropout. - Default is ``False``. - state_mixing_mode : str, optional - How to blend self-attention and cross-attention outputs. ``"weighted"`` uses - a learnable sigmoid-gated weighted sum. ``"concat_project"`` - concatenates the two along the head dimension and projects back with a - linear layer. Default is ``"weighted"``. - - Forward - ------- - x : tuple[torch.Tensor, ...] - Tuple of input tensors, each of shape :math:`(B, N, C)` where :math:`B` is - batch size, :math:`N` is number of tokens, and :math:`C` is number of channels. - context : tuple[torch.Tensor, ...] | None, optional - Context tensor for cross-attention of shape :math:`(B, H, S_c, D_c)` where - :math:`H` is number of heads, :math:`S_c` is number of context slices, and - :math:`D_c` is context dimension. If ``None``, only self-attention is applied. - Default is ``None``. - - Outputs - ------- - list[torch.Tensor] - List of output tensors, each of shape :math:`(B, N, C)`, same shape as inputs. - - Notes - ----- - The mixing between self-attention and cross-attention is controlled by a learnable - parameter ``state_mixing`` which is passed through a sigmoid function to ensure - the mixing weight stays in :math:`[0, 1]`. - - See Also - -------- - :class:`GALE` : Original GeoTransolver GALE attention class. - :class:`GALE_block` : Transformer block that calls GALE or GALE_FA attention. - - Examples - -------- - >>> import torch - >>> gale_fa = GALE_FA(dim=256, heads=8, dim_head=32, context_dim=32) - >>> x = (torch.randn(2, 100, 256),) # Single input tensor in tuple - >>> context = torch.randn(2, 8, 64, 32) # Context for cross-attention - >>> outputs = gale_fa(x, context) - >>> len(outputs) - 1 - >>> outputs[0].shape - torch.Size([2, 100, 256]) - """ - - def __init__( - self, - dim, - heads: int = 8, - dim_head: int = 64, - dropout: float = 0.0, - n_global_queries: int = 64, - use_te: bool = True, - context_dim: int = 0, - concrete_dropout: bool = False, - state_mixing_mode: str = "weighted", - ): - if use_te: - raise ValueError( - "GALE_FA does not support Transformer Engine backend. " - "Use use_te=False; TE disables FlashAttention for differing q/k sizes in FLARE attention." - ) - super().__init__() - if state_mixing_mode not in ("weighted", "concat_project"): - raise ValueError( - f"Invalid state_mixing_mode: {state_mixing_mode!r}. " - f"Expected 'weighted' or 'concat_project'." - ) - self.state_mixing_mode = state_mixing_mode - self.use_te = use_te - self.heads = heads - self.dim_head = dim_head - self.scale = 1.0 - # It is recommended by the FLARE authors to use self.scale = 1 if self.dim_head <= 8 else (self.dim_head ** -0.5) - # but we use self.scale = 1.0 because the recommended scaling is not tested yet. - inner_dim = dim_head * heads - - linear_layer = te.Linear if self.use_te else nn.Linear - - # Global queries for FLARE self-attention - self.q_global = nn.Parameter(torch.randn(1, heads, n_global_queries, dim_head)) - - # Linear projections for self-attention - self.in_project_x = linear_layer(dim, inner_dim) - self.self_k = linear_layer(dim_head, dim_head) - self.self_v = linear_layer(dim_head, dim_head) - - if context_dim > 0: - # Linear projections for cross-attention - self.cross_q = linear_layer(dim_head, dim_head) - self.cross_k = linear_layer(context_dim, dim_head) - self.cross_v = linear_layer(context_dim, dim_head) - - # Mixing layers for blending self-attention and cross-attention - if state_mixing_mode == "weighted": - # Learnable mixing weight between self and cross attention - self.state_mixing = nn.Parameter(torch.tensor(0.0)) - else: - # Concatenate self and cross attention and project back to dim_head - self.concat_project = nn.Sequential( - linear_layer(2 * dim_head, dim_head), - nn.GELU(), - ) - - # te attention - if self.use_te: - self.attn_fn = te.DotProductAttention( - num_attention_heads=self.heads, - kv_channels=self.dim_head, - attention_dropout=dropout, - qkv_format="bshd", - softmax_scale=self.scale - ) - - # Linear projection for output - self.out_linear = linear_layer(inner_dim, dim) - if concrete_dropout: - self.out_dropout = ConcreteDropout( - in_features=dim, - init_p=max(dropout, 0.05), - ) - else: - self.out_dropout = nn.Dropout(dropout) - - - def forward( - self, - x: tuple[Float[torch.Tensor, "batch tokens channels"], ...], - context: Float[torch.Tensor, "batch heads context_slices context_dim"] - | None = None, - ) -> list[Float[torch.Tensor, "batch tokens channels"]]: - r"""Forward pass of the GALE_FA module. - - Applies GALE_FA attention to the input features. - - Parameters - ---------- - x : tuple[torch.Tensor, ...] - Tuple of input tensors, each of shape :math:`(B, N, C)` where :math:`B` - is batch size, :math:`N` is number of tokens, and :math:`C` is number - of channels. - context : torch.Tensor | None, optional - Context tensor for cross-attention of shape :math:`(B, H, S_c, D_c)` - where :math:`H` is number of heads, :math:`S_c` is number of context - slices, and :math:`D_c` is context dimension. If ``None``, only - self-attention is applied. Default is ``None``. - - Returns - ------- - list[torch.Tensor] - List of output tensors, each of shape :math:`(B, N, C)``, same shape - as inputs. - """ - - # with record_function("forward"): - x_mid = [self.in_project_x(_x) for _x in x] - x_mid = [rearrange( - _x_mid, "B N (h d) -> B N h d", h=self.heads, d=self.dim_head - ) for _x_mid in x_mid] - x_mid = [_x_mid.permute(0, 2, 1, 3) for _x_mid in x_mid] # [B, H, N, D] - G = [self.q_global.to(dtype=x_mid[0].dtype).expand(x_mid[0].shape[0], -1, -1, -1)] * len(x) - k = [self.self_k(_x_mid) for _x_mid in x_mid] - v = [self.self_v(_x_mid) for _x_mid in x_mid] - - # FLARE: Self Attention - if self.use_te: - # Transformer Engine expects (B, S, H, D) format - G = [rearrange(_G, "b h s d -> b s h d") for _G in G] - k = [rearrange(_k, "b h s d -> b s h d") for _k in k] - v = [rearrange(_v, "b h s d -> b s h d") for _v in v] - z = [self.attn_fn(_G, _k, _v) for _G, _k, _v in zip(G, k, v)] - z = [rearrange( - _z, "b s (h d) -> b s h d", h=self.heads, d=self.dim_head - ) for _z in z] - self_attention = [self.attn_fn(_k, _G, _z) for _k, _G, _z in zip(k, G, z)] - self_attention = [rearrange( - _self_attention, "b s (h d) -> b h s d", h=self.heads, d=self.dim_head - ) for _self_attention in self_attention] - else: - # Use PyTorch's scaled dot-product attention - z = [F.scaled_dot_product_attention(_G, _k, _v, scale=self.scale) for _G, _k, _v in zip(G, k, v)] - self_attention = [F.scaled_dot_product_attention(_k, _G, _z, scale=self.scale) for _k, _G, _z in zip(k, G, z)] - - # apply cross-attention with physical states: - if context is not None: - q = [self.cross_q(_x_mid) for _x_mid in x_mid] - k = self.cross_k(context) - v = self.cross_v(context) - - if self.use_te: - q = [rearrange(_q, "b h s d -> b s h d") for _q in q] - k = rearrange(k, "b h s d -> b s h d") - v = rearrange(v, "b h s d -> b s h d") - cross_attention = [self.attn_fn(_q, k, v) for _q in q] - cross_attention = [rearrange( - _cross_attention, "b s (h d) -> b h s d", h=self.heads, d=self.dim_head - ) for _cross_attention in cross_attention] - else: - cross_attention = [F.scaled_dot_product_attention(_q, k, v, scale=self.scale) for _q in q] - - # Blend self-attention and cross-attention - if self.state_mixing_mode == "weighted": - mixing_weight = torch.sigmoid(self.state_mixing) - outputs = [mixing_weight * _ys + (1 - mixing_weight) * _yc for _ys, _yc in zip(self_attention, cross_attention)] - else: - outputs = [self.concat_project(torch.cat([_ys, _yc], dim=-1)) for _ys, _yc in zip(self_attention, cross_attention)] - else: - outputs = self_attention - - outputs = [_y.permute(0, 2, 1, 3) for _y in outputs] # [B, N, H, D] - outputs = [rearrange(_out, "b n h d -> b n (h d)") for _out in outputs] - outputs = [self.out_linear(_out) for _out in outputs] - return [self.out_dropout(_out) for _out in outputs] - diff --git a/physicsnemo/experimental/models/geotransolver/geotransolver.py b/physicsnemo/experimental/models/geotransolver/geotransolver.py index 36fd32a482..13689e345e 100644 --- a/physicsnemo/experimental/models/geotransolver/geotransolver.py +++ b/physicsnemo/experimental/models/geotransolver/geotransolver.py @@ -23,6 +23,7 @@ from __future__ import annotations +import math from collections.abc import Sequence from dataclasses import dataclass @@ -144,6 +145,47 @@ def _normalize_tensor( raise TypeError(f"Invalid tensor structure") +def _structured_num_tokens(spatial_shape: tuple[int, ...]) -> int: + return int(math.prod(spatial_shape)) + + +def _flatten_for_structured( + t: torch.Tensor, + spatial_shape: tuple[int, ...], + name: str, +) -> torch.Tensor: + """Flatten (B,H,W,C) or (B,H,W,D,C) to (B,N,C); pass through (B,N,C) if N matches. + + Mirrors Transolver's structured flatten/unflatten behavior so the rest of + GeoTransolver can assume a single token layout (B, N, C). + """ + n = _structured_num_tokens(spatial_shape) + if t.ndim == 3: + if not torch.compiler.is_compiling() and t.shape[1] != n: + raise ValueError( + f"{name} token count {t.shape[1]} != structured grid size {n}" + ) + return t + if len(spatial_shape) == 2 and t.ndim == 4: + B, H, W, C = t.shape + if (H, W) != spatial_shape: + raise ValueError( + f"{name} spatial dims {(H, W)} != structured_shape {spatial_shape}" + ) + return t.reshape(B, n, C) + if len(spatial_shape) == 3 and t.ndim == 5: + B, H, W, D, C = t.shape + if (H, W, D) != spatial_shape: + raise ValueError( + f"{name} spatial dims {(H, W, D)} != structured_shape {spatial_shape}" + ) + return t.reshape(B, n, C) + raise ValueError( + f"{name}: expected (B,N,C) with N={n}, or spatial layout matching " + f"structured_shape {spatial_shape}; got shape {tuple(t.shape)}" + ) + + class GeoTransolver(Module): r"""GeoTransolver: Geometry-Aware Physics Attention Transformer. @@ -205,6 +247,11 @@ class GeoTransolver(Module): Neighbors in radius for the local features. Default is ``[8, 32]``. n_hidden_local : int, optional Hidden dimension for the local features. Default is 32. + structured_shape : tuple[int, ...] | None, optional + If set to ``(H, W)`` or ``(H, W, D)``, enables structured 2D/3D paths + (Conv2d/Conv3d GALE; no ball-query local features). Inputs may be + flattened :math:`(B, N, C)` with :math:`N = H W` or :math:`H W D`, or + spatial :math:`(B, H, W, C)` / :math:`(B, H, W, D, C)`. Default is ``None``. guard_config : dict | None, optional Configuration for the embedded OOD guard (:class:`~physicsnemo.experimental.guardrails.embedded.OODGuard`). @@ -229,9 +276,9 @@ class GeoTransolver(Module): Forward ------- local_embedding : torch.Tensor | tuple[torch.Tensor, ...] - Local embedding of the input data of shape :math:`(B, N, C)` where :math:`B` - is batch size, :math:`N` is number of nodes/tokens, and :math:`C` is - ``functional_dim``. Can be a single tensor or tuple for multiple input types. + Local embedding: unstructured :math:`(B, N, C)`; structured 2D + :math:`(B, H, W, C)` or flattened :math:`(B, H W, C)`; structured 3D + :math:`(B, H, W, D, C)` or flattened. Can be a tuple for multiple input types. local_positions : torch.Tensor | tuple[torch.Tensor, ...] | None, optional Local positions for each input, each of shape :math:`(B, N, 3)`. Required if ``include_local_features=True``. Default is ``None``. @@ -248,14 +295,18 @@ class GeoTransolver(Module): Outputs ------- - Float[torch.Tensor, "batch tokens out_dim"] | tuple[Float[torch.Tensor, "batch tokens out_dim"], ...] - When ``return_embedding_states=False`` (default), output tensor of + torch.Tensor | tuple[torch.Tensor, ...] + When ``return_embedding_states=False`` (default): output tensor(s) of shape :math:`(B, N, C_{out})`. Returns a single tensor if input was - a single tensor, or a tuple if input was a tuple. - + a single tensor, or a tuple of tensors if input was a tuple + (multi-stream). For structured grids, output matches the input + layout—flattened :math:`(B, N, C_{out})` or spatial + :math:`(B, H, W, C_{out})` / :math:`(B, H, W, D, C_{out})` when + inputs were 4D/5D. + When ``return_embedding_states=True``, returns a 2-tuple - ``(output, embedding_states)`` where ``embedding_states`` is - ``Float[torch.Tensor, "batch heads slices context_dim"]`` of shape + ``(output, embedding_states)`` where ``output`` follows the same + rules above, and ``embedding_states`` is of shape :math:`(B, H, S, D_c)` (geometry/global context), or ``None`` if no context sources were provided. @@ -271,8 +322,9 @@ class GeoTransolver(Module): Notes ----- - GeoTransolver currently supports unstructured mesh input only. Enhancements for - image-based and voxel-based inputs may be available in the future. + Unstructured mesh uses linear GALE projection; structured ``structured_shape`` + uses the same Conv2d/Conv3d slice projection as :class:`~physicsnemo.models.transolver.Transolver`. + Ball-query local features are disabled when ``structured_shape`` is set. For more details on Transolver, see: @@ -321,6 +373,21 @@ class GeoTransolver(Module): >>> output.shape torch.Size([2, 1000, 3]) + Structured 2D grid: + + >>> model = GeoTransolver( + ... functional_dim=3, + ... out_dim=1, + ... structured_shape=(8, 8), + ... n_hidden=64, + ... n_head=4, + ... n_layers=2, + ... use_te=False, + ... ) + >>> y = model(torch.randn(2, 8, 8, 3)) + >>> y.shape + torch.Size([2, 8, 8, 1]) + To also retrieve the geometry/global context embeddings: >>> output, emb_states = model( @@ -353,6 +420,7 @@ def __init__( radii: list[float] | None = None, neighbors_in_radius: list[int] | None = None, n_hidden_local: int = 32, + structured_shape: tuple[int, ...] | None = None, guard_config: dict | None = None, attention_type: str = "GALE", concrete_dropout: bool = False, @@ -367,8 +435,22 @@ def __init__( if neighbors_in_radius is None: neighbors_in_radius = [8, 32] + if structured_shape is not None: + if include_local_features: + raise ValueError( + "include_local_features=True is not supported with structured_shape " + "(ball-query path is mesh-only)." + ) + if len(structured_shape) not in (2, 3): + raise ValueError( + f"structured_shape must have length 2 or 3, got {structured_shape!r}" + ) + if not all(int(s) > 0 for s in structured_shape): + raise ValueError(f"structured_shape must be positive ints, got {structured_shape!r}") + self.include_local_features = include_local_features self.use_te = use_te + self.structured_shape = structured_shape # Validate head dimension compatibility if not n_hidden % n_head == 0: @@ -399,6 +481,7 @@ def __init__( use_te=use_te, plus=plus, include_local_features=self.include_local_features, + structured_shape=structured_shape, concrete_dropout=concrete_dropout, ) context_dim = self.context_builder.get_context_dim() @@ -447,6 +530,7 @@ def __init__( use_te=use_te, plus=plus, context_dim=context_dim, + spatial_shape=structured_shape, attention_type=attention_type, concrete_dropout=concrete_dropout, state_mixing_mode=state_mixing_mode, @@ -590,6 +674,27 @@ def forward( if local_positions is not None: local_positions = _normalize_tensor(local_positions) + unflatten_output = False + if self.structured_shape is not None: + unflatten_output = any(le.ndim in (4, 5) for le in local_embedding) + local_embedding = tuple( + _flatten_for_structured( + le, self.structured_shape, f"local_embedding[{i}]" + ) + for i, le in enumerate(local_embedding) + ) + if geometry is not None: + geometry = _flatten_for_structured( + geometry, self.structured_shape, "geometry" + ) + n_tok = _structured_num_tokens(self.structured_shape) + for i, le in enumerate(local_embedding): + if le.shape[1] != n_tok: + raise ValueError( + f"structured GeoTransolver: all streams must have N={n_tok} tokens; " + f"local_embedding[{i}] has N={le.shape[1]}" + ) + ### Input validation if not torch.compiler.is_compiling(): if len(local_embedding) == 0: @@ -646,6 +751,16 @@ def forward( # Project to output dimensions: (B, N, n_hidden) -> (B, N, out_dim) x = [self.ln_mlp_out[i](x[i]) for i in range(len(x))] + if self.structured_shape is not None and unflatten_output: + B = x[0].shape[0] + for i in range(len(x)): + if len(self.structured_shape) == 2: + H, W = self.structured_shape + x[i] = x[i].reshape(B, H, W, -1) + else: + H, W, D_ = self.structured_shape + x[i] = x[i].reshape(B, H, W, D_, -1) + # Return same format as input (single tensor or tuple) if single_input: x = x[0] diff --git a/physicsnemo/experimental/nn/flare_attention.py b/physicsnemo/experimental/nn/flare_attention.py index 70a6c14f57..4ecc68a927 100644 --- a/physicsnemo/experimental/nn/flare_attention.py +++ b/physicsnemo/experimental/nn/flare_attention.py @@ -29,12 +29,50 @@ from jaxtyping import Float from physicsnemo.core.version_check import check_version_spec, OptionalImport +from physicsnemo.nn.module.physics_attention import _project_input # Check optional dependency availability TE_AVAILABLE = check_version_spec("transformer_engine", "0.1.0", hard_fail=False) te = OptionalImport("transformer_engine.pytorch", "0.1.0") +def _flare_self_attention( + x_mid: Float[torch.Tensor, "B H N D"], + q_global: nn.Parameter, + self_k: nn.Module, + self_v: nn.Module, + scale: float, +) -> Float[torch.Tensor, "B H N D"]: + r"""FLARE two-pass self-attention kernel. + + Computes low-rank attention via learned global queries: first aggregate + token values into global slots, then distribute back to tokens. + + Parameters + ---------- + x_mid : torch.Tensor + Projected input of shape :math:`(B, H, N, D)`. + q_global : nn.Parameter + Learned global queries of shape :math:`(1, H, S, D)`. + self_k : nn.Module + Key projection applied to ``x_mid``. + self_v : nn.Module + Value projection applied to ``x_mid``. + scale : float + Attention scale factor. + + Returns + ------- + torch.Tensor + Self-attended output of shape :math:`(B, H, N, D)`. + """ + G = q_global.to(dtype=x_mid.dtype).expand(x_mid.shape[0], -1, -1, -1) + k = self_k(x_mid) + v = self_v(x_mid) + z = F.scaled_dot_product_attention(G, k, v, scale=scale) + return F.scaled_dot_product_attention(k, G, z, scale=scale) + + class FLARE(nn.Module): r"""FLARE: Fast Low-rank Attention Routing Engine attention layer. Adopted: @@ -137,35 +175,17 @@ def forward(self, x: Float[torch.Tensor, "B N C"]) -> Float[torch.Tensor, "B N C Output tensor of shape :math:`(B, N, C)`, same shape as inputs. """ - x_mid = self.in_project_x(x) - x_mid = rearrange( - x_mid, "B N (h d) -> B N h d", h=self.heads, d=self.dim_head + x_mid = _project_input( + x, self.in_project_x, self.heads, self.dim_head, + "B N (H D) -> B N H D", ) - x_mid = x_mid.permute(0, 2, 1, 3) # [B, H, N, D] - G = self.q_global.to(dtype=x_mid.dtype).expand(x_mid.shape[0], -1, -1, -1) - k = self.self_k(x_mid) - v = self.self_v(x_mid) + x_mid = x_mid.permute(0, 2, 1, 3) # (B, N, H, D) -> (B, H, N, D) - # FLARE: Fast Low-rank Attention Routing Engine - if self.use_te: - # Transformer Engine expects (B, S, H, D) format - G = rearrange(G, "b h s d -> b s h d") - k = rearrange(k, "b h s d -> b s h d") - v = rearrange(v, "b h s d -> b s h d") - z = self.attn_fn(G, k, v) - z = rearrange( - z, "b s (h d) -> b s h d", h=self.heads, d=self.dim_head - ) - self_attention = self.attn_fn(k, G, z) - y = rearrange( - self_attention, "b s (h d) -> b h s d", h=self.heads, d=self.dim_head - ) - else: - # Use PyTorch's scaled dot-product attention - z = F.scaled_dot_product_attention(G, k, v, scale=self.scale) - y = F.scaled_dot_product_attention(k, G, z, scale=self.scale) + y = _flare_self_attention( + x_mid, self.q_global, self.self_k, self.self_v, self.scale, + ) - out_x = y.permute(0, 2, 1, 3) # [B, N, H, D] + out_x = y.permute(0, 2, 1, 3) # (B, H, N, D) -> (B, N, H, D) out_x = rearrange(out_x, "b n h d -> b n (h d)") out_x = self.out_linear(out_x) return self.out_dropout(out_x) diff --git a/physicsnemo/nn/module/physics_attention.py b/physicsnemo/nn/module/physics_attention.py index 2110a91331..22e0ff3ba8 100644 --- a/physicsnemo/nn/module/physics_attention.py +++ b/physicsnemo/nn/module/physics_attention.py @@ -35,7 +35,6 @@ SOFTWARE. """ -import importlib from abc import ABC, abstractmethod import torch @@ -45,19 +44,165 @@ from torch.autograd.profiler import record_function from torch.distributed.tensor.placement_types import Replicate -from physicsnemo.core.version_check import check_version_spec +from physicsnemo.core.version_check import OptionalImport from physicsnemo.nn import gumbel_softmax # Note: We use duck typing to check for ShardTensor instead of importing it # directly to avoid circular imports (domain_parallel imports from nn). # ShardTensor has a `redistribute` method that we check for. -TE_AVAILABLE = check_version_spec("transformer_engine", hard_fail=False) +te = OptionalImport("transformer_engine.pytorch") + + +def _project_input( + x: torch.Tensor, + project_x: nn.Module, + heads: int, + dim_head: int, + pattern: str, + project_fx: nn.Module | None = None, +) -> ( + Float[torch.Tensor, "B N H D"] + | tuple[ + Float[torch.Tensor, "B N H D"], + Float[torch.Tensor, "B N H D"], + ] +): + r"""Project input through one or two learned layers and rearrange to multi-head format. + + Universal building block for the ``(B, N, C) -> (B, N, H, D)`` projection + used by low attention mechanism in the repo: slice-based (Transolver / + GALE), FLARE, and future linear-attention variants. + + For structured grids the caller reshapes to spatial layout *before* calling + this function and passes the appropriate einops ``pattern``; the output is + always ``(B, N, H, D)`` with tokens in position 1, preserving + domain-parallel sharding over N. -if TE_AVAILABLE: - te = importlib.import_module("transformer_engine.pytorch") -else: - te = None + Parameters + ---------- + x : torch.Tensor + Input tensor. For irregular meshes this is ``(B, N, C)``; for + structured grids it has already been reshaped to conv-friendly layout + (e.g. ``(B, C, H_s, W_s)``). + project_x : nn.Module + Primary projection layer (``nn.Linear``, ``nn.Conv2d``, etc.). + heads : int + Number of attention heads ``H``. + dim_head : int + Dimension per head ``D``. + pattern : str + Einops rearrange pattern that maps the projection output to + ``(B, N, H, D)``. Examples: + + - ``"B N (H D) -> B N H D"`` for linear projection + - ``"B (H D) h w -> B (h w) H D"`` for 2-D convolution + - ``"B (H D) h w d -> B (h w d) H D"`` for 3-D convolution + project_fx : nn.Module or None, optional + If provided, a second projection is applied to ``x`` and the function + returns a ``(px, pfx)`` tuple. Omit for single-projection callers + (Transolver++ ``plus=True``, FLARE, etc.). Default is ``None``. + + Returns + ------- + torch.Tensor or tuple[torch.Tensor, torch.Tensor] + Single tensor of shape :math:`(B, N, H, D)` when ``project_fx`` is + ``None``; otherwise a tuple ``(px, pfx)`` both of shape + :math:`(B, N, H, D)`. + """ + px = rearrange(project_x(x), pattern, H=heads, D=dim_head) + if project_fx is None: + return px + return px, rearrange(project_fx(x), pattern, H=heads, D=dim_head) + + +def _compute_slices_from_projections( + slice_projections: Float[torch.Tensor, "B N H S"], + fx: Float[torch.Tensor, "B N H D"], + temperature: torch.Tensor, + plus: bool, + proj_temperature: nn.Module | None = None, +) -> tuple[ + Float[torch.Tensor, "B N H S"], + Float[torch.Tensor, "B H S D"], +]: + r"""Compute slice weights and slice tokens from input projections. + + Standalone implementation of the temperature-scaled softmax slice + aggregation used by :class:`PhysicsAttentionBase` and reusable by any + module that needs the same project-to-slices-then-aggregate pattern + (e.g. :class:`~physicsnemo.experimental.models.geotransolver.context_projector.ContextProjector`). + + In domain-parallel settings, this performs an implicit allreduce when + summing over the sharded token dimension. + + Parameters + ---------- + slice_projections : torch.Tensor + Projected input of shape :math:`(B, N, H, S)` where :math:`H` is + number of attention heads and :math:`S` is number of physics slices. + fx : torch.Tensor + Latent features of shape :math:`(B, N, H, D)` where :math:`D` is + dimension per head. + temperature : torch.Tensor + Scalar temperature for softmax/gumbel, shape broadcastable to + ``slice_projections`` (typically :math:`(1, 1, H, 1)`). + plus : bool + If ``True``, use Gumbel softmax with optional adaptive temperature. + proj_temperature : nn.Module or None, optional + If ``plus`` is ``True``, module mapping :math:`(B, N, H, D)` to + adaptive temperature; ignored otherwise. Default is ``None``. + + Returns + ------- + tuple[torch.Tensor, torch.Tensor] + - ``slice_weights``: Shape :math:`(B, N, H, S)`, normalized weights + for each slice per token. + - ``slice_token``: Shape :math:`(B, H, S, D)`, aggregated features + per slice. + """ + # Compute temperature-scaled softmax over slices + if plus and proj_temperature is not None: + # Transolver++ uses learned per-token temperature + temp = temperature + proj_temperature(fx) + clamped_temp = torch.clamp(temp, min=0.01).to(slice_projections.dtype) + slice_weights = gumbel_softmax(slice_projections, clamped_temp) # (B, N, H, S) + else: + # Standard Transolver uses global temperature + clamped_temp = torch.clamp(temperature, min=0.5, max=5).to( + slice_projections.dtype + ) + slice_weights = nn.functional.softmax( + slice_projections / clamped_temp, dim=-1 + ) # (B, N, H, S) + + # Cast to the computation type (since the parameter is probably fp32) + slice_weights = slice_weights.to(slice_projections.dtype) + + # Computing the slice tokens is a matmul followed by a normalization. + # It can, unfortunately, overflow in reduced precision, so normalize first: + slice_norm = slice_weights.sum(1) + 1e-2 # (B, H, S) + # Sharded note: slice_norm will be a partial sum at this point. + # That's because the we're summing over the tokens, which are distributed + normed_weights = slice_weights / (slice_norm[:, None, :, :]) + # Normed weights has shape (B, N, H, S) + + # Sharded note: normed_weights will resolve the partial slice_norm + # and the output normed_weights will be sharded. + # fx has shape (B, N, H, D) + # This matmul needs to contract over the tokens + # This should produce an output with shape (B, H, S, D) + + # Like the weight norm, this sum is a **partial** sum since we are summing + # over the tokens + + # Aggregate features: (B, N, H, S)^T @ (B, N, H, D) -> (B, H, S, D) + slice_token = torch.matmul( + normed_weights.permute(0, 2, 3, 1), fx.permute(0, 2, 1, 3) + ) + + # Return the original weights, not the normed weights: + return slice_weights, slice_token class PhysicsAttentionBase(nn.Module, ABC): @@ -225,6 +370,9 @@ def _compute_slices_from_projections( In domain-parallel settings, this performs an implicit allreduce when summing over the sharded token dimension. + Delegates to the module-level + :func:`_compute_slices_from_projections` free function. + Parameters ---------- slice_projections : torch.Tensor @@ -242,56 +390,11 @@ def _compute_slices_from_projections( - ``slice_token``: Shape :math:`(B, H, S, D)`, aggregated features per slice. """ - # Compute temperature-scaled softmax over slices - if self.plus: - # Transolver++ uses learned per-token temperature - temperature = self.temperature + self.proj_temperature(fx) - clamped_temp = torch.clamp(temperature, min=0.01).to( - slice_projections.dtype - ) - slice_weights = gumbel_softmax( - slice_projections, clamped_temp - ) # (B, N, H, S) - else: - # Standard Transolver uses global temperature - clamped_temp = torch.clamp(self.temperature, min=0.5, max=5).to( - slice_projections.dtype - ) - slice_weights = nn.functional.softmax( - slice_projections / clamped_temp, dim=-1 - ) # (B, N, H, S) - - # Cast to the computation type (since the parameter is probably fp32) - slice_weights = slice_weights.to(slice_projections.dtype) - - # This does the projection of the latent space fx by the weights: - - # Computing the slice tokens is a matmul followed by a normalization. - # It can, unfortunately, overflow in reduced precision, so normalize first: - slice_norm = slice_weights.sum(1) + 1e-2 # (B, H, S) - # Sharded note: slice_norm will be a partial sum at this point. - # That's because the we're summing over the tokens, which are distributed - normed_weights = slice_weights / (slice_norm[:, None, :, :]) - # Normed weights has shape (B, N, H, S) - - # Sharded note: normed_weights will resolve the partial slice_norm - # and the output normed_weights will be sharded. - # fx has shape (B, N, H, D) - # This matmul needs to contract over the tokens - # This should produce an output with shape (B, H, S, D) - - # Like the weight norm, this sum is a **partial** sum since we are summing - # over the tokens - - # Aggregate features: (B, N, H, S)^T @ (B, N, H, D) -> (B, H, S, D) - slice_token = torch.matmul( - normed_weights.permute(0, 2, 3, 1), fx.permute(0, 2, 1, 3) + proj_temp = getattr(self, "proj_temperature", None) if self.plus else None + return _compute_slices_from_projections( + slice_projections, fx, self.temperature, self.plus, proj_temp ) - # Return the original weights, not the normed weights: - - return slice_weights, slice_token - def _compute_slice_attention_te( self, slice_tokens: Float[torch.Tensor, "B H S D"] ) -> Float[torch.Tensor, "B H S D"]: @@ -544,22 +647,16 @@ def project_input_onto_slices( Projected tensors of shape :math:`(B, N, H, D)` where :math:`H` is number of attention heads and :math:`D` is dimension per head. """ - # Project and reshape to multi-head format - x_mid = rearrange( - self.in_project_x(x), "B N (H D) -> B N H D", H=self.heads, D=self.dim_head + fx = None if self.plus else self.in_project_fx + return _project_input( + x, + self.in_project_x, + self.heads, + self.dim_head, + "B N (H D) -> B N H D", + project_fx=fx, ) - if self.plus: - return x_mid - else: - fx_mid = rearrange( - self.in_project_fx(x), - "B N (H D) -> B N H D", - H=self.heads, - D=self.dim_head, - ) - return x_mid, fx_mid - class PhysicsAttentionStructuredMesh2D(PhysicsAttentionBase): r""" @@ -673,27 +770,16 @@ def project_input_onto_slices( x = x.view(B, self.H, self.W, C) x = x.permute(0, 3, 1, 2) - # Apply 2D convolution and reshape to multi-head format - input_projected_x = self.in_project_x(x) - input_projected_x = rearrange( - input_projected_x, + fx = None if self.plus else self.in_project_fx + return _project_input( + x, + self.in_project_x, + self.heads, + self.dim_head, "B (H D) h w -> B (h w) H D", - D=self.dim_head, - H=self.heads, + project_fx=fx, ) - if self.plus: - return input_projected_x - else: - input_projected_fx = self.in_project_fx(x) - input_projected_fx = rearrange( - input_projected_fx, - "B (H D) h w -> B (h w) H D", - D=self.dim_head, - H=self.heads, - ) - return input_projected_x, input_projected_fx - class PhysicsAttentionStructuredMesh3D(PhysicsAttentionBase): r""" @@ -808,23 +894,12 @@ def project_input_onto_slices( x = x.view(B, self.H, self.W, self.D, C) x = x.permute(0, 4, 1, 2, 3) - # Apply 3D convolution and reshape to multi-head format - input_projected_x = self.in_project_x(x) - input_projected_x = rearrange( - input_projected_x, - "B (H D) height width depth -> B (height width depth) H D", - D=self.dim_head, - H=self.heads, + fx = None if self.plus else self.in_project_fx + return _project_input( + x, + self.in_project_x, + self.heads, + self.dim_head, + "B (H D) h w d -> B (h w d) H D", + project_fx=fx, ) - - if self.plus: - return input_projected_x - else: - input_projected_fx = self.in_project_fx(x) - input_projected_fx = rearrange( - input_projected_fx, - "B (H D) height width depth -> B (height width depth) H D", - D=self.dim_head, - H=self.heads, - ) - return input_projected_x, input_projected_fx diff --git a/test/models/geotransolver/test_gale.py b/test/models/geotransolver/test_gale.py index 79ef54d8c5..63db8c4a93 100644 --- a/test/models/geotransolver/test_gale.py +++ b/test/models/geotransolver/test_gale.py @@ -19,9 +19,9 @@ from physicsnemo.experimental.models.geotransolver.gale import ( GALE, + GALE_FA, GALE_block, ) -from physicsnemo.experimental.models.geotransolver.gale_fa import GALE_FA # ============================================================================= # GALE (Geometry-Aware Latent Embeddings) Attention Tests diff --git a/test/models/geotransolver/test_geotransolver.py b/test/models/geotransolver/test_geotransolver.py index 4888623b0b..821d1e923b 100644 --- a/test/models/geotransolver/test_geotransolver.py +++ b/test/models/geotransolver/test_geotransolver.py @@ -558,6 +558,98 @@ def test_geotransolver_mismatched_functional_out_dims(): ) +def test_geotransolver_structured_rejects_local_features(): + """Ball-query local features are incompatible with structured_shape.""" + with pytest.raises(ValueError, match="include_local_features=True"): + GeoTransolver( + functional_dim=8, + out_dim=1, + structured_shape=(4, 4), + include_local_features=True, + geometry_dim=2, + use_te=False, + ) + + +def test_geotransolver_structured_2d_forward(device): + """Structured 2D: spatial input (B,H,W,C) and flattened (B,N,C); optional geometry.""" + torch.manual_seed(0) + H, W = 4, 4 + model = GeoTransolver( + functional_dim=3, + out_dim=2, + structured_shape=(H, W), + geometry_dim=2, + global_dim=None, + n_layers=2, + n_hidden=32, + n_head=4, + slice_num=8, + mlp_ratio=2, + use_te=False, + ).to(device) + B = 2 + x4 = torch.randn(B, H, W, 3, device=device) + g = torch.randn(B, H, W, 2, device=device) + y4 = model(x4, geometry=g) + assert y4.shape == (B, H, W, 2) + assert not torch.isnan(y4).any() + + x3 = x4.reshape(B, H * W, 3) + g3 = g.reshape(B, H * W, 2) + y3 = model(x3, geometry=g3) + assert y3.shape == (B, H * W, 2) + + y_none = model(x4) + assert y_none.shape == (B, H, W, 2) + + +def test_geotransolver_structured_3d_forward(device): + """Structured 3D voxel input (B,H,W,D,C).""" + torch.manual_seed(1) + H, W, Dg = 2, 2, 2 + model = GeoTransolver( + functional_dim=4, + out_dim=1, + structured_shape=(H, W, Dg), + n_layers=1, + n_hidden=32, + n_head=4, + slice_num=4, + mlp_ratio=2, + use_te=False, + ).to(device) + B = 1 + x = torch.randn(B, H, W, Dg, 4, device=device) + y = model(x) + assert y.shape == (B, H, W, Dg, 1) + + +def test_geotransolver_structured_global_context(device): + """Structured grid with global embedding context.""" + torch.manual_seed(2) + H, W = 4, 4 + model = GeoTransolver( + functional_dim=2, + out_dim=1, + structured_shape=(H, W), + geometry_dim=2, + global_dim=8, + n_layers=2, + n_hidden=32, + n_head=4, + slice_num=8, + mlp_ratio=2, + use_te=False, + ).to(device) + B = 2 + x = torch.randn(B, H, W, 2, device=device) + geo = torch.randn(B, H, W, 2, device=device) + glob = torch.randn(B, 3, 8, device=device) + y = model(x, geometry=geo, global_embedding=glob) + assert y.shape == (B, H, W, 1) + + # ============================================================================= # Activation Function Tests # =============================================================================