diff --git a/examples/cfd/darcy_transolver/README.md b/examples/cfd/darcy_transolver/README.md index 941eac3c96..fcfa96da19 100644 --- a/examples/cfd/darcy_transolver/README.md +++ b/examples/cfd/darcy_transolver/README.md @@ -89,15 +89,6 @@ To train Transolver on your own physics problem, modify the `dataloader` in `train_transolver_darcy.py` to use your own pre-computed data or on-the-fly solver. -#### Other models - -You can also use GeoTransolver, Flare, or GeoTransolver with flare attention -for the darcy example with the "fixed" usage. Change the `model` parameter -to update it! Note that the inputs to geotransolver here uses the positions -and the `x` diffusivity to both the model trunk and encoder paths. For these -problems, GeoTransolver is not yet extensively tested, and your results -may vary. - ## Additional Information More components are added for convenience. `Validators` calculate the loss diff --git a/examples/cfd/darcy_transolver/config_fix.yaml b/examples/cfd/darcy_transolver/config_fix.yaml index a9c4f64f93..ee90d84855 100644 --- a/examples/cfd/darcy_transolver/config_fix.yaml +++ b/examples/cfd/darcy_transolver/config_fix.yaml @@ -20,31 +20,33 @@ # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE # SOFTWARE. -defaults: - - model: transolver # transolver | flare | geotransolver | geotransolver_fa - - _self_ - -hydra: - output_subdir: null - run: - dir: . - -output_dir: ./output/ -run_id: ${hydra:runtime.choices.model}-muon_${precision}_r${resolution}_b${data.batch_size}_s${model.slice_num} - -resolution: 85 #421, 211, 141, 106, 85 all viable +output_dir: ./output/darcy_transolver_fix +run_id: bf16_dev_r85_b8_s64 data: - train_path: //lustre/fsw/portfolios/coreai/users/coreya/datasets/darcy_fix/example_data/piececonst_r421_N1024_smooth1.npz - test_path: //lustre/fsw/portfolios/coreai/users/coreya/datasets/darcy_fix/example_data/piececonst_r421_N1024_smooth2.npz - resolution: ${resolution} - batch_size: 64 # This is the GLOBAL batch size + train_path: /user_data/datasets/darcy_fix/example_data/piececonst_r421_N1024_smooth1.npz + test_path: /user_data/datasets/darcy_fix/example_data/piececonst_r421_N1024_smooth2.npz + resolution: 85 #421, 211, 141, 106, 85 all viable + batch_size: 8 # This is the GLOBAL batch size +model: + functional_dim: 1 + out_dim: 1 + embedding_dim: 2 + n_layers: 4 + n_hidden: 128 + dropout: 0.0 + n_head: 4 + act: gelu + mlp_ratio: 4 + unified_pos: False + ref: 8 + slice_num: 64 + use_te: False + time_input: False + precision: bf16 -optimizer: - type: muon # adamw | muon (muon requires PyTorch >= 2.9) - normaliser: permeability: mean: 1.25 @@ -67,4 +69,3 @@ training: validation: sample_size: 200 validation_pseudo_epochs: 1 - save_plots: false diff --git a/examples/cfd/darcy_transolver/model/flare.yaml b/examples/cfd/darcy_transolver/model/flare.yaml deleted file mode 100644 index 31e9c78dfd..0000000000 --- a/examples/cfd/darcy_transolver/model/flare.yaml +++ /dev/null @@ -1,39 +0,0 @@ -# SPDX-FileCopyrightText: Copyright (c) 2023 - 2026 NVIDIA CORPORATION & AFFILIATES. -# SPDX-FileCopyrightText: All rights reserved. -# SPDX-License-Identifier: Apache-2.0 -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -# FLARE inherits from Transolver and replaces physics attention blocks with -# FLARE (Fast Low-rank Attention Routing Engine) blocks. -# use_te and plus are forced off internally by FLARE. - -_target_: physicsnemo.experimental.models.flare.FLARE -_convert_: all - -functional_dim: 1 -out_dim: 1 -embedding_dim: 2 -n_layers: 4 -n_hidden: 128 -dropout: 0.0 -n_head: 4 -act: gelu -mlp_ratio: 4 -slice_num: 64 -unified_pos: False -ref: 8 -structured_shape: - - ${resolution} - - ${resolution} -time_input: False diff --git a/examples/cfd/darcy_transolver/model/geotransolver.yaml b/examples/cfd/darcy_transolver/model/geotransolver.yaml deleted file mode 100644 index db585b8829..0000000000 --- a/examples/cfd/darcy_transolver/model/geotransolver.yaml +++ /dev/null @@ -1,40 +0,0 @@ -# SPDX-FileCopyrightText: Copyright (c) 2023 - 2026 NVIDIA CORPORATION & AFFILIATES. -# SPDX-FileCopyrightText: All rights reserved. -# SPDX-License-Identifier: Apache-2.0 -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -# GeoTransolver with GALE attention (default). -# Uses geometry-aware cross-attention instead of Transolver's structured grid. - -_target_: physicsnemo.experimental.models.geotransolver.GeoTransolver -_convert_: all - -functional_dim: 3 -out_dim: 1 -geometry_dim: 3 -n_layers: 4 -n_hidden: 128 -dropout: 0.0 -n_head: 4 -act: gelu -mlp_ratio: 4 -slice_num: 64 -use_te: False -time_input: False -attention_type: GALE -include_local_features: False -state_mixing_mode: weighted -structured_shape: - - ${resolution} - - ${resolution} diff --git a/examples/cfd/darcy_transolver/model/geotransolver_fa.yaml b/examples/cfd/darcy_transolver/model/geotransolver_fa.yaml deleted file mode 100644 index 305a411489..0000000000 --- a/examples/cfd/darcy_transolver/model/geotransolver_fa.yaml +++ /dev/null @@ -1,41 +0,0 @@ -# SPDX-FileCopyrightText: Copyright (c) 2023 - 2026 NVIDIA CORPORATION & AFFILIATES. -# SPDX-FileCopyrightText: All rights reserved. -# SPDX-License-Identifier: Apache-2.0 -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -# GeoTransolver with GALE_FA (FLARE attention backend). -# Combines FLARE-style global-query self-attention with GALE geometry/context -# cross-attention. use_te must be False for GALE_FA. - -_target_: physicsnemo.experimental.models.geotransolver.GeoTransolver -_convert_: all - -functional_dim: 3 -out_dim: 1 -geometry_dim: 3 -n_layers: 4 -n_hidden: 128 -dropout: 0.0 -n_head: 4 -act: gelu -mlp_ratio: 4 -slice_num: 64 -use_te: False -time_input: False -attention_type: GALE_FA -include_local_features: False -state_mixing_mode: concat_project -structured_shape: - - ${resolution} - - ${resolution} diff --git a/examples/cfd/darcy_transolver/model/transolver.yaml b/examples/cfd/darcy_transolver/model/transolver.yaml deleted file mode 100644 index 6e01da1b3f..0000000000 --- a/examples/cfd/darcy_transolver/model/transolver.yaml +++ /dev/null @@ -1,36 +0,0 @@ -# SPDX-FileCopyrightText: Copyright (c) 2023 - 2026 NVIDIA CORPORATION & AFFILIATES. -# SPDX-FileCopyrightText: All rights reserved. -# SPDX-License-Identifier: Apache-2.0 -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -_target_: physicsnemo.models.transolver.Transolver -_convert_: all - -functional_dim: 1 -out_dim: 1 -embedding_dim: 2 -n_layers: 4 -n_hidden: 128 -dropout: 0.0 -n_head: 4 -act: gelu -mlp_ratio: 4 -slice_num: 64 -unified_pos: False -ref: 8 -structured_shape: - - ${resolution} - - ${resolution} -use_te: False -time_input: False diff --git a/examples/cfd/darcy_transolver/train_transolver_darcy_fix.py b/examples/cfd/darcy_transolver/train_transolver_darcy_fix.py index 8c8b38a4e7..0f528f6205 100644 --- a/examples/cfd/darcy_transolver/train_transolver_darcy_fix.py +++ b/examples/cfd/darcy_transolver/train_transolver_darcy_fix.py @@ -16,12 +16,9 @@ # Configuration imports: import hydra -from hydra.utils import instantiate from omegaconf import DictConfig, OmegaConf import json -import os import time -from datetime import datetime, timezone from math import ceil # Base PyTorch imports: @@ -41,8 +38,8 @@ from utils.testloss import TestLoss # Model imports from PhysicsNeMo +from physicsnemo.models.transolver import Transolver from physicsnemo.distributed import DistributedManager -from physicsnemo.optim import CombinedOptimizer from physicsnemo.utils import load_checkpoint, save_checkpoint from physicsnemo.utils.logging import PythonLogger, RankZeroLoggingWrapper @@ -57,114 +54,8 @@ prof = Profiler() -_GEOTRANSOLVER_TARGETS = { - "physicsnemo.experimental.models.geotransolver.GeoTransolver", -} - - -def make_model_forward(cfg: DictConfig) -> callable: - """ - Return a forward callable that uses the right keyword arguments for the - configured model. - - GeoTransolver uses (local_embedding, geometry) while Transolver/FLARE - use (fx, embedding). The decision is made once at startup from the Hydra - config, avoiding fragile isinstance checks through DDP/compile wrappers. - - Args: - cfg (DictConfig): Full Hydra config (reads model._target_). - - Returns: - callable: ``fn(model, pos, x) -> Tensor`` - """ - if cfg.model._target_ in _GEOTRANSOLVER_TARGETS: - - def _forward(model, pos, x): - combined_inputs = torch.cat([pos, x.unsqueeze(-1)], dim=-1) - return model( - local_embedding=combined_inputs, geometry=combined_inputs - ).squeeze(-1) - - else: - - def _forward(model, pos, x): - return model(embedding=pos, fx=x.unsqueeze(-1)).squeeze(-1) - - return _forward - - -def build_optimizer( - model: torch.nn.Module, - cfg: DictConfig, -) -> torch.optim.Optimizer: - """ - Build optimizer based on config. Supports AdamW and Muon. - - Muon is applied to 2D weight matrices; remaining parameters (biases, norms, - embeddings) are handled by AdamW. When both groups exist they are wrapped in - ``CombinedOptimizer``. - - Args: - model (torch.nn.Module): The model (possibly DDP-wrapped). - cfg (DictConfig): Full Hydra config (reads optimizer.type, scheduler.initial_lr, - scheduler.weight_decay). - - Returns: - torch.optim.Optimizer: The configured optimizer. - """ - opt_type = cfg.optimizer.type - lr = cfg.scheduler.initial_lr - weight_decay = cfg.scheduler.weight_decay - - if opt_type == "adamw": - return AdamW(model.parameters(), lr=lr, weight_decay=weight_decay) - - if opt_type == "muon": - if not hasattr(torch.optim, "Muon"): - raise ImportError( - "Muon optimizer requires PyTorch >= 2.9. " - "Install a newer PyTorch or use optimizer.type=adamw." - ) - base_model = model.module if hasattr(model, "module") else model - muon_params = [p for p in base_model.parameters() if p.ndim == 2] - other_params = [p for p in base_model.parameters() if p.ndim != 2] - - if muon_params and other_params: - return CombinedOptimizer( - [ - torch.optim.Muon( - muon_params, - lr=lr, - weight_decay=weight_decay, - adjust_lr_fn="match_rms_adamw", - ), - AdamW( - other_params, - lr=lr, - weight_decay=weight_decay, - betas=(0.9, 0.999), - eps=1.0e-8, - ), - ] - ) - elif muon_params: - return torch.optim.Muon( - muon_params, - lr=lr, - weight_decay=weight_decay, - adjust_lr_fn="match_rms_adamw", - ) - else: - return AdamW(other_params, lr=lr, weight_decay=weight_decay) - - raise ValueError( - f"Unsupported optimizer type: {opt_type!r}. Use 'adamw' or 'muon'." - ) - - def forward_train_full_loop( model: torch.nn.Module, - model_forward: callable, loss_fun: callable, optimizer: torch.optim.Optimizer, pos: torch.Tensor, @@ -179,7 +70,6 @@ def forward_train_full_loop( Args: model (torch.nn.Module): The model to train. - model_forward (callable): Forward callable from ``make_model_forward``. loss_fun (callable): Loss function. optimizer (torch.optim.Optimizer): Optimizer. pos (torch.Tensor): Position tensor (embedding). @@ -194,7 +84,7 @@ def forward_train_full_loop( """ dm = DistributedManager() with precision_context: - pred = model_forward(model, pos, x) + pred = model(embedding=pos, fx=x.unsqueeze(-1)).squeeze(-1) pred = y_normalizer.decode(pred) loss = loss_fun(pred, y) if scaler is not None: @@ -210,7 +100,6 @@ def forward_train_full_loop( def train_epoch( model: torch.nn.Module, - model_forward: callable, optimizer: torch.optim.Optimizer, scheduler: torch.optim.lr_scheduler._LRScheduler, train_dataloader: DataLoader, @@ -224,7 +113,6 @@ def train_epoch( Args: model (torch.nn.Module): The model to train. - model_forward (callable): Forward callable from ``make_model_forward``. optimizer (torch.optim.Optimizer): Optimizer. scheduler (torch.optim.lr_scheduler._LRScheduler): Learning rate scheduler. train_dataloader (DataLoader): Training data loader. @@ -240,7 +128,6 @@ def train_epoch( pos, x, y = batch loss = forward_train_full_loop( model, - model_forward, loss_fun, optimizer, pos, @@ -263,7 +150,6 @@ def train_epoch( def val_epoch( model: torch.nn.Module, - model_forward: callable, test_dataloader: DataLoader, loss_fun: callable, y_normalizer, @@ -273,7 +159,6 @@ def val_epoch( Args: model (torch.nn.Module): The model to validate. - model_forward (callable): Forward callable from ``make_model_forward``. test_dataloader (DataLoader): Validation data loader. loss_fun (callable): Loss function. y_normalizer: Normalizer for the target tensor. @@ -290,7 +175,7 @@ def val_epoch( for i, batch in enumerate(test_dataloader): pos, x, y = batch with torch.no_grad(): - pred = model_forward(model, pos, x) + pred = model(embedding=pos, fx=x.unsqueeze(-1)).squeeze(-1) pred = y_normalizer.decode(pred) loss = loss_fun(pred, y) @@ -342,20 +227,12 @@ def darcy_trainer(cfg: DictConfig) -> None: logger = RankZeroLoggingWrapper(PythonLogger(name="darcy_transolver"), dm) logger.file_logging() - # === TensorBoard SummaryWriters === - # Separate train/val writers so TensorBoard can overlay matching scalars - train_writer = None - val_writer = None - metrics_file = None + # === TensorBoard SummaryWriter === + # Only rank 0 writes logs to avoid duplication in DDP + writer = None if dm.rank == 0: log_dir = f"{cfg.output_dir}/runs/{cfg.run_id}" - train_writer = SummaryWriter(log_dir=f"{log_dir}/train") - val_writer = SummaryWriter(log_dir=f"{log_dir}/val") - - # === JSONL metrics log (append-safe for resumed runs) === - metrics_path = os.path.join(log_dir, "metrics.jsonl") - os.makedirs(log_dir, exist_ok=True) - metrics_file = open(metrics_path, "a") + writer = SummaryWriter(log_dir=log_dir) ######################################################################## # Print the configuration to log @@ -365,8 +242,23 @@ def darcy_trainer(cfg: DictConfig) -> None: ######################################################################## # define model ######################################################################## - model = instantiate(cfg.model).to(dm.device) - model_forward = make_model_forward(cfg) + model = Transolver( + functional_dim=cfg.model.functional_dim, + out_dim=cfg.model.out_dim, + embedding_dim=cfg.model.embedding_dim, + n_layers=cfg.model.n_layers, + n_hidden=cfg.model.n_hidden, + dropout=cfg.model.dropout, + n_head=cfg.model.n_head, + act=cfg.model.act, + mlp_ratio=cfg.model.mlp_ratio, + slice_num=cfg.model.slice_num, + unified_pos=cfg.model.unified_pos, + ref=cfg.model.ref, + structured_shape=[cfg.data.resolution, cfg.data.resolution], + use_te=cfg.model.use_te, + time_input=cfg.model.time_input, + ).to(dm.device) logger.info(f"\n{torchinfo.summary(model, verbose=0)}") @@ -377,7 +269,11 @@ def darcy_trainer(cfg: DictConfig) -> None: # define loss and optimizer ######################################################################## loss_fun = TestLoss(size_average=True) - optimizer = build_optimizer(model, cfg) + optimizer = AdamW( + model.parameters(), + lr=cfg.scheduler.initial_lr, + weight_decay=cfg.scheduler.weight_decay, + ) ######################################################################## # Create the data pipes and samplers @@ -427,30 +323,12 @@ def darcy_trainer(cfg: DictConfig) -> None: cfg.training.pseudo_epoch_sample_size / cfg.data.batch_size ) - total_steps = steps_per_pseudo_epoch * cfg.training.max_pseudo_epochs - if cfg.optimizer.type == "muon": - warmup_steps = steps_per_pseudo_epoch * 2 - scheduler = lr_scheduler.SequentialLR( - optimizer, - schedulers=[ - lr_scheduler.LinearLR( - optimizer, start_factor=1e-2, total_iters=warmup_steps - ), - lr_scheduler.CosineAnnealingLR( - optimizer, - T_max=total_steps - warmup_steps, - eta_min=cfg.scheduler.initial_lr * 0.1, - ), - ], - milestones=[warmup_steps], - ) - else: - scheduler = lr_scheduler.OneCycleLR( - optimizer, - max_lr=cfg.scheduler.initial_lr, - steps_per_epoch=steps_per_pseudo_epoch, - epochs=cfg.training.max_pseudo_epochs, - ) + scheduler = lr_scheduler.OneCycleLR( + optimizer, + max_lr=cfg.scheduler.initial_lr, + steps_per_epoch=steps_per_pseudo_epoch, + epochs=cfg.training.max_pseudo_epochs, + ) validator = GridValidator(output_dir=f"{cfg.output_dir}/runs/{cfg.run_id}/plots") @@ -462,9 +340,6 @@ def darcy_trainer(cfg: DictConfig) -> None: } loaded_pseudo_epoch = load_checkpoint(device=dm.device, **ckpt_args) - # Compile after checkpoint loading to avoid triggering recompilation - model = torch.compile(model) - validation_iters = ceil(cfg.validation.sample_size / cfg.data.batch_size) if cfg.training.pseudo_epoch_sample_size % cfg.data.batch_size != 0: @@ -506,7 +381,6 @@ def darcy_trainer(cfg: DictConfig) -> None: train_start = time.time() loss = train_epoch( model, - model_forward, optimizer, scheduler, train_dataloader, @@ -532,7 +406,7 @@ def darcy_trainer(cfg: DictConfig) -> None: logger.info(log_string) # --- TensorBoard logging (only on rank 0) --- - if dm.rank == 0 and train_writer is not None: + if dm.rank == 0 and writer is not None: # Images/sec/GPU: (num images processed in train_epoch) / train_time / num_gpus # Each batch processes batch_size // world_size images, for steps_per_pseudo_epoch steps images_per_epoch = len(train_dataloader) * ( @@ -540,28 +414,12 @@ def darcy_trainer(cfg: DictConfig) -> None: ) images_per_sec_per_gpu = images_per_epoch / train_time - train_writer.add_scalar("loss", loss.item(), pseudo_epoch) - train_writer.add_scalar("time_per_epoch", train_time, pseudo_epoch) - train_writer.add_scalar( - "images_per_sec_per_gpu", images_per_sec_per_gpu, pseudo_epoch - ) - train_writer.add_scalar("learning_rate", lr, pseudo_epoch) - - # --- JSONL metrics record (training fields) --- - metrics_record = None - if dm.rank == 0 and metrics_file is not None: - images_per_epoch = len(train_dataloader) * ( - cfg.data.batch_size // dm.world_size + writer.add_scalar("loss/train", loss.item(), pseudo_epoch) + writer.add_scalar("time_per_epoch/train", train_time, pseudo_epoch) + writer.add_scalar( + "images_per_sec_per_gpu/train", images_per_sec_per_gpu, pseudo_epoch ) - metrics_record = { - "pseudo_epoch": pseudo_epoch, - "timestamp": datetime.now(timezone.utc).isoformat(), - "train_loss": loss.item(), - "train_time_s": train_time, - "learning_rate": lr, - "images_per_sec_per_gpu": images_per_epoch / train_time, - "gpu_mem_reserved_gb": gpu_mem_reserved, - } + writer.add_scalar("learning_rate/train", lr, pseudo_epoch) # save checkpoint if pseudo_epoch % cfg.training.rec_results_freq == 0 and dm.rank == 0: @@ -571,7 +429,7 @@ def darcy_trainer(cfg: DictConfig) -> None: if pseudo_epoch % cfg.validation.validation_pseudo_epochs == 0: val_start = time.time() val_loss, pred, y, RL2 = val_epoch( - model, model_forward, test_dataloader, loss_fun, y_normalizer + model, test_dataloader, loss_fun, y_normalizer ) val_time = time.time() - val_start @@ -582,49 +440,29 @@ def darcy_trainer(cfg: DictConfig) -> None: logger.info(log_string) # --- TensorBoard logging (only on rank 0) --- - if dm.rank == 0 and val_writer is not None: + if dm.rank == 0 and writer is not None: # Validation images/sec/GPU val_images = validation_iters * ( cfg.data.batch_size // dm.world_size ) val_images_per_sec_per_gpu = val_images / val_time - val_writer.add_scalar("loss", val_loss.item(), pseudo_epoch) - val_writer.add_scalar("RL2", RL2.item(), pseudo_epoch) - val_writer.add_scalar("time_per_epoch", val_time, pseudo_epoch) - val_writer.add_scalar( - "images_per_sec_per_gpu", + writer.add_scalar("loss/val", val_loss.item(), pseudo_epoch) + writer.add_scalar("RL2/val", RL2.item(), pseudo_epoch) + writer.add_scalar("time_per_epoch/val", val_time, pseudo_epoch) + writer.add_scalar( + "images_per_sec_per_gpu/val", val_images_per_sec_per_gpu, pseudo_epoch, ) - # --- JSONL metrics record (validation fields) --- - if metrics_record is not None: - val_images = validation_iters * ( - cfg.data.batch_size // dm.world_size - ) - metrics_record["val_loss"] = val_loss.item() - metrics_record["val_rl2"] = RL2.item() - metrics_record["val_time_s"] = val_time - metrics_record["val_images_per_sec_per_gpu"] = val_images / val_time - - if dm.rank == 0 and cfg.validation.save_plots: + if dm.rank == 0: validator.make_plot(pred, y, pseudo_epoch, test_datapipe.s) - # --- Flush JSONL record for this pseudo-epoch --- - if metrics_record is not None: - metrics_file.write(json.dumps(metrics_record) + "\n") - metrics_file.flush() - # update learning rate # if pseudo_epoch % cfg.scheduler.decay_pseudo_epochs == 0: - if dm.rank == 0: - if train_writer is not None: - train_writer.close() - if val_writer is not None: - val_writer.close() - if metrics_file is not None: - metrics_file.close() + if dm.rank == 0 and writer is not None: + writer.close() logger.success("Training completed *yay*") diff --git a/physicsnemo/experimental/models/geotransolver/__init__.py b/physicsnemo/experimental/models/geotransolver/__init__.py index 3fed9686b8..21638bdc6e 100644 --- a/physicsnemo/experimental/models/geotransolver/__init__.py +++ b/physicsnemo/experimental/models/geotransolver/__init__.py @@ -23,36 +23,15 @@ ------- GeoTransolver Main model class combining GALE attention with geometry and global context. -GeoTransolverMetaData - Data class for storing essential meta data needed for the GeoTransolver model. GALE Geometry-Aware Latent Embeddings attention layer. -GALE_FA - GALE with FLARE self-attention backend. GALE_block - Transformer block using GALE or GALE_FA attention. -GALEStructuredMesh2D - GALE with Conv2d slice projection for 2D structured grids. -GALEStructuredMesh3D - GALE with Conv3d slice projection for 3D structured grids. + Transformer block using GALE attention. ContextProjector Projects context features onto physical state slices. -StructuredContextProjector - Context projector with Conv2d/Conv3d geometry encoding on structured grids. -GeometricFeatureProcessor - Processes geometric features at a single spatial scale using BQWarp. -MultiScaleFeatureExtractor - Multi-scale geometric feature extraction over multiple radii. GlobalContextBuilder Orchestrates context construction for the model. -Functions ---------- -collect_concrete_dropout_losses - Collect concrete dropout regularization losses from a model. -get_concrete_dropout_rates - Get concrete dropout rates from a model. - Examples -------- Basic usage: @@ -77,35 +56,17 @@ collect_concrete_dropout_losses, get_concrete_dropout_rates, ) -from .context_projector import ( - ContextProjector, - GeometricFeatureProcessor, - GlobalContextBuilder, - MultiScaleFeatureExtractor, - StructuredContextProjector, -) -from .gale import ( - GALE, - GALE_FA, - GALE_block, - GALEStructuredMesh2D, - GALEStructuredMesh3D, -) +from .context_projector import ContextProjector, GlobalContextBuilder +from .gale import GALE, GALE_block from .geotransolver import GeoTransolver, GeoTransolverMetaData __all__ = [ "GeoTransolver", "GeoTransolverMetaData", "GALE", - "GALE_FA", "GALE_block", - "GALEStructuredMesh2D", - "GALEStructuredMesh3D", "ContextProjector", - "GeometricFeatureProcessor", "GlobalContextBuilder", - "MultiScaleFeatureExtractor", - "StructuredContextProjector", "ConcreteDropout", "collect_concrete_dropout_losses", "get_concrete_dropout_rates", diff --git a/physicsnemo/experimental/models/geotransolver/context_projector.py b/physicsnemo/experimental/models/geotransolver/context_projector.py index 15f7d67ffc..a57b7ecfc1 100644 --- a/physicsnemo/experimental/models/geotransolver/context_projector.py +++ b/physicsnemo/experimental/models/geotransolver/context_projector.py @@ -23,8 +23,6 @@ ------- ContextProjector Projects context features onto physical state slices. -StructuredContextProjector - Context projector with Conv2d/Conv3d geometry encoding on structured grids. GeometricFeatureProcessor Processes geometric features at a single spatial scale using BQWarp. MultiScaleFeatureExtractor @@ -41,12 +39,9 @@ from jaxtyping import Float from physicsnemo.core.version_check import check_version_spec +from physicsnemo.nn import gumbel_softmax from physicsnemo.nn import BQWarp from physicsnemo.nn import Mlp -from physicsnemo.nn.module.physics_attention import ( - _compute_slices_from_projections, - _project_input, -) from physicsnemo.nn import ConcreteDropout @@ -56,148 +51,7 @@ import transformer_engine.pytorch as te -def _structured_grid_to_conv_input( - x: Float[torch.Tensor, "batch tokens channels"], - batch: int, - tokens: int, - channels: int, - ndim: int, - spatial_shape: tuple[int, ...], -) -> Float[torch.Tensor, "batch channels ..."]: - r"""Reshape flat token tensor to spatial layout for Conv2d/Conv3d. - - Converts :math:`(B, N, C)` to :math:`(B, C, H, W)` for 2D or - :math:`(B, C, H, W, D)` for 3D so that structured context projectors - can apply spatial convolutions. Validates that :math:`N` matches the - grid size. - - Parameters - ---------- - x : torch.Tensor - Input tensor of shape :math:`(B, N, C)` (batch, tokens, channels). - batch : int - Batch size :math:`B`. - tokens : int - Number of tokens :math:`N` (must equal :math:`H \\times W` or - :math:`H \\times W \\times D`). - channels : int - Channel dimension :math:`C`. - ndim : int - Number of spatial dimensions; must be 2 or 3. - spatial_shape : tuple[int, ...] - :math:`(H, W)` for 2D or :math:`(H, W, D)` for 3D. - - Returns - ------- - torch.Tensor - Reshaped tensor of shape :math:`(B, C, H, W)` or - :math:`(B, C, H, W, D)` for use as conv input. - - Raises - ------ - ValueError - If ``tokens`` does not match the product of ``spatial_shape``. - """ - if ndim == 2: - H, W = spatial_shape - if tokens != H * W: - raise ValueError( - f"Expected N={H * W} tokens for 2D grid, got N={tokens}" - ) - return x.view(batch, H, W, channels).permute(0, 3, 1, 2) - H, W, D = spatial_shape - if tokens != H * W * D: - raise ValueError( - f"Expected N={H * W * D} tokens for 3D grid, got N={tokens}" - ) - return x.view(batch, H, W, D, channels).permute(0, 4, 1, 2, 3) - - -class _SliceToContextMixin: - r"""Internal mixin providing shared slice-to-context init and slice aggregation. - - Used by :class:`ContextProjector` and :class:`StructuredContextProjector` to - avoid duplicating ``in_project_slice``, ``temperature``, ``proj_temperature``, - and the call to - :func:`~physicsnemo.nn.module.physics_attention._compute_slices_from_projections`. - """ - - def _init_slice_components( - self, - dim_head: int, - slice_num: int, - heads: int, - use_te: bool, - plus: bool, - ) -> None: - r"""Initialize slice projection, temperature, and optional adaptive temperature. - - Sets ``in_project_slice``, ``temperature``, and (when ``plus`` is True) - ``proj_temperature`` on this instance. Uses Transformer Engine linear - when ``use_te`` is True and TE is available. - - Parameters - ---------- - dim_head : int - Head dimension for the slice projection input. - slice_num : int - Number of slices (output dimension of ``in_project_slice``). - heads : int - Number of heads (used for temperature shape). - use_te : bool - Whether to prefer Transformer Engine for linear layers. - plus : bool - If True, add ``proj_temperature`` for Transolver++. - """ - linear_layer = te.Linear if (use_te and TE_AVAILABLE) else nn.Linear - self.in_project_slice = linear_layer(dim_head, slice_num) - self.temperature = nn.Parameter(torch.ones([1, 1, heads, 1]) * 0.5) - if plus: - self.proj_temperature = nn.Sequential( - linear_layer(dim_head, slice_num), - nn.GELU(), - linear_layer(slice_num, 1), - nn.GELU(), - ) - - def _compute_slices( - self, - slice_projections: Float[torch.Tensor, "batch tokens heads slices"], - fx: Float[torch.Tensor, "batch tokens heads dim"], - ) -> tuple[ - Float[torch.Tensor, "batch tokens heads slices"], - Float[torch.Tensor, "batch heads slices dim"], - ]: - r"""Compute slice weights and slice tokens from projections and latent features. - - Delegates to :func:`~physicsnemo.nn.module.physics_attention._compute_slices_from_projections`, - the shared free function that also backs - :meth:`~physicsnemo.nn.module.physics_attention.PhysicsAttentionBase._compute_slices_from_projections`. - - Parameters - ---------- - slice_projections : torch.Tensor - Shape :math:`(B, N, H, S)`. - fx : torch.Tensor - Shape :math:`(B, N, H, D)`. - - Returns - ------- - tuple[torch.Tensor, torch.Tensor] - ``(slice_weights, slice_token)`` with shapes :math:`(B, N, H, S)` - and :math:`(B, H, S, D)`. - """ - proj_temp = getattr(self, "proj_temperature", None) if self.plus else None - return _compute_slices_from_projections( - slice_projections, - fx, - self.temperature, - self.plus, - proj_temperature=proj_temp, - ) - - -class ContextProjector(_SliceToContextMixin, nn.Module): +class ContextProjector(nn.Module): r"""Projects context features onto physical state space. This context projector is conceptually similar to half of a GALE attention layer. @@ -284,8 +138,19 @@ def __init__( # Attention components self.softmax = nn.Softmax(dim=-1) + self.temperature = nn.Parameter(torch.ones([1, heads, 1, 1]) * 0.5) + + # Transolver++ adaptive temperature projection + if plus: + self.proj_temperature = nn.Sequential( + linear_layer(self.dim_head, slice_num), + nn.GELU(), + linear_layer(slice_num, 1), + nn.GELU(), + ) - self._init_slice_components(dim_head, slice_num, heads, use_te, plus) + # Slice projection layer maps from head dimension to slice space + self.in_project_slice = linear_layer(dim_head, slice_num) # Concrete dropout on the output slice tokens if concrete_dropout: @@ -299,10 +164,10 @@ def __init__( def project_input_onto_slices( self, x: Float[torch.Tensor, "batch tokens channels"] ) -> ( - Float[torch.Tensor, "batch tokens heads dim"] + Float[torch.Tensor, "batch heads tokens dim"] | tuple[ - Float[torch.Tensor, "batch tokens heads dim"], - Float[torch.Tensor, "batch tokens heads dim"], + Float[torch.Tensor, "batch heads tokens dim"], + Float[torch.Tensor, "batch heads tokens dim"], ] ): r"""Project the input onto the slice space. @@ -316,17 +181,95 @@ def project_input_onto_slices( Returns ------- torch.Tensor or tuple[torch.Tensor, torch.Tensor] - If ``plus=True``, returns single tensor of shape :math:`(B, N, H, D)` where + If ``plus=True``, returns single tensor of shape :math:`(B, H, N, D)` where :math:`H` is number of heads and :math:`D` is head dimension. If ``plus=False``, - returns tuple of two tensors both of shape :math:`(B, N, H, D)`, representing + returns tuple of two tensors both of shape :math:`(B, H, N, D)`, representing the query and key projections respectively. """ - fx = None if self.plus else self.in_project_fx - return _project_input( - x, self.in_project_x, self.heads, self.dim_head, - "B N (H D) -> B N H D", project_fx=fx, + # Project input to multi-head representation: (B, N, C) -> (B, H, N, D) + projected_x = rearrange( + self.in_project_x(x), "B N (h d) -> B h N d", h=self.heads, d=self.dim_head ) + if self.plus: + # Transolver++ uses single projection for both paths + return projected_x + else: + # Standard Transolver uses separate query and key projections + feature_projection = rearrange( + self.in_project_fx(x), + "B N (h d) -> B h N d", + h=self.heads, + d=self.dim_head, + ) + return projected_x, feature_projection + + def compute_slices_from_projections( + self, + slice_projections: Float[torch.Tensor, "batch heads tokens slices"], + fx: Float[torch.Tensor, "batch heads tokens dim"], + ) -> tuple[ + Float[torch.Tensor, "batch heads tokens slices"], + Float[torch.Tensor, "batch heads slices dim"], + ]: + r"""Compute slice weights and slice tokens from input projections and latent features. + + Parameters + ---------- + slice_projections : torch.Tensor + Projected input tensor of shape :math:`(B, H, N, S)` where :math:`B` is batch size, + :math:`H` is number of heads, :math:`N` is number of tokens, and :math:`S` is number of + slices, representing the projection of each token onto each slice for each + attention head. + fx : torch.Tensor + Latent feature tensor of shape :math:`(B, H, N, D)` where :math:`D` is head dimension, + representing the learned states to be aggregated by the slice weights. + + Returns + ------- + tuple[torch.Tensor, torch.Tensor] + - ``slice_weights``: Tensor of shape :math:`(B, H, N, S)`, normalized weights for + each slice per token and head. + - ``slice_token``: Tensor of shape :math:`(B, H, S, D)`, aggregated latent features + for each slice, head, and batch. + + Notes + ----- + The function computes a temperature-scaled softmax over the slice projections to + obtain slice weights, then aggregates the latent features for each slice using + these weights. The aggregated features are normalized by the sum of weights for + numerical stability. + """ + # Compute temperature-adjusted softmax weights + if self.plus: + # Transolver++ uses adaptive temperature with Gumbel softmax + temperature = self.temperature + self.proj_temperature(fx) + clamped_temp = torch.clamp(temperature, min=0.01).to( + slice_projections.dtype + ) + slice_weights = gumbel_softmax(slice_projections, clamped_temp) + else: + # Standard Transolver uses fixed temperature with regular softmax + clamped_temp = torch.clamp(self.temperature, min=0.5, max=5).to( + slice_projections.dtype + ) + slice_weights = nn.functional.softmax( + slice_projections / clamped_temp, dim=-1 + ) + + # Ensure weights match the computation dtype + slice_weights = slice_weights.to(slice_projections.dtype) + + # Aggregate features by slice weights with normalization + # Normalize first to prevent overflow in reduced precision + slice_norm = slice_weights.sum(2) # Sum over tokens: (B, H, S) + normed_weights = slice_weights / (slice_norm[:, :, None, :] + 1e-2) + + # Weighted aggregation: (B, H, S, N) @ (B, H, N, D) -> (B, H, S, D) + slice_token = torch.matmul(normed_weights.transpose(2, 3), fx) + + return slice_weights, slice_token + def forward( self, x: Float[torch.Tensor, "batch tokens channels"] ) -> Float[torch.Tensor, "batch heads slices dim"]: @@ -370,120 +313,11 @@ def forward( else: projected_x, feature_projection = self.project_input_onto_slices(x) - # Project latent representations onto physical state slices: (B, N, H, D) -> (B, N, H, S) + # Project latent representations onto physical state slices: (B, H, N, D) -> (B, H, N, S) slice_projections = self.in_project_slice(projected_x) # Compute weighted aggregation of features into slice tokens - _, slice_tokens = self._compute_slices( - slice_projections, feature_projection - ) - - # Apply concrete dropout to output slice tokens - if self.output_dropout is not None: - slice_tokens = self.output_dropout(slice_tokens) - - return slice_tokens - - -class StructuredContextProjector(_SliceToContextMixin, nn.Module): - r"""Context projector with Conv2d/Conv3d geometry encoding on structured grids. - - Same output interface as :class:`ContextProjector`—slice tokens - :math:`(B, H, S, D)`—but projects per-cell geometry via spatial convolutions - aligned with structured GALE attention. - """ - - def __init__( - self, - dim: int, - spatial_shape: tuple[int, ...], - heads: int = 8, - dim_head: int = 64, - dropout: float = 0.0, - slice_num: int = 64, - kernel: int = 3, - use_te: bool = True, - plus: bool = False, - concrete_dropout: bool = False, - ) -> None: - super().__init__() - if len(spatial_shape) not in (2, 3): - raise ValueError( - f"StructuredContextProjector expects spatial_shape of length 2 or 3, got {spatial_shape!r}" - ) - inner_dim = dim_head * heads - self.dim_head = dim_head - self.heads = heads - self.plus = plus - self.use_te = use_te - self.spatial_shape = tuple(int(s) for s in spatial_shape) - self._nd = len(self.spatial_shape) - pad = kernel // 2 - if self._nd == 2: - H, W = self.spatial_shape - self.H, self.W = H, W - self.in_project_x = nn.Conv2d(dim, inner_dim, kernel, 1, pad) - if not plus: - self.in_project_fx = nn.Conv2d(dim, inner_dim, kernel, 1, pad) - else: - H, W, D_ = self.spatial_shape - self.H, self.W, self.D = H, W, D_ - self.in_project_x = nn.Conv3d(dim, inner_dim, kernel, 1, pad) - if not plus: - self.in_project_fx = nn.Conv3d(dim, inner_dim, kernel, 1, pad) - - self.softmax = nn.Softmax(dim=-1) - self.dropout = nn.Dropout(dropout) - self._init_slice_components(dim_head, slice_num, heads, use_te, plus) - - # Concrete dropout on the output slice tokens - if concrete_dropout: - self.output_dropout = ConcreteDropout( - in_features=dim_head, - init_p=max(dropout, 0.05), - ) - else: - self.output_dropout = None - - def _grid_project( - self, x: Float[torch.Tensor, "batch tokens channels"] - ) -> ( - Float[torch.Tensor, "batch tokens heads dim"] - | tuple[ - Float[torch.Tensor, "batch tokens heads dim"], - Float[torch.Tensor, "batch tokens heads dim"], - ] - ): - B, N, C = x.shape - grid = _structured_grid_to_conv_input( - x, B, N, C, self._nd, self.spatial_shape - ) - pattern = ( - "B (H D) h w -> B (h w) H D" - if self._nd == 2 - else "B (H D) h w d -> B (h w d) H D" - ) - fx = None if self.plus else self.in_project_fx - return _project_input( - grid, self.in_project_x, self.heads, self.dim_head, - pattern, project_fx=fx, - ) - - def forward( - self, x: Float[torch.Tensor, "batch tokens channels"] - ) -> Float[torch.Tensor, "batch heads slices dim"]: - if not torch.compiler.is_compiling(): - if x.ndim != 3: - raise ValueError( - f"Expected 3D input (B, N, C), got {x.ndim}D shape {tuple(x.shape)}" - ) - if self.plus: - projected_x = self._grid_project(x) - feature_projection = projected_x - else: - projected_x, feature_projection = self._grid_project(x) - slice_projections = self.in_project_slice(projected_x) - _, slice_tokens = self._compute_slices( + _, slice_tokens = self.compute_slices_from_projections( slice_projections, feature_projection ) @@ -805,10 +639,6 @@ class GlobalContextBuilder(nn.Module): Whether to use Transolver++ features. Default is ``False``. include_local_features : bool, optional Enable local feature extraction. Default is ``False``. - structured_shape : tuple[int, ...] | None, optional - If set, disables ball-query extractors and uses - :class:`StructuredContextProjector` for geometry when ``geometry_dim`` - is set. Default is ``None``. Forward ------- @@ -857,7 +687,6 @@ def __init__( use_te: bool = True, plus: bool = False, include_local_features: bool = False, - structured_shape: tuple[int, ...] | None = None, concrete_dropout: bool = False, ) -> None: super().__init__() @@ -870,17 +699,9 @@ def __init__( dim_head = n_hidden // n_head context_dim = 0 - self.structured_shape = structured_shape - - # Ball-query local features are not used on structured grids - use_local_bq = ( - geometry_dim is not None - and include_local_features - and structured_shape is None - ) # Multi-scale extractors for local features (one per functional dim) - if use_local_bq: + if geometry_dim is not None and include_local_features: self.local_extractors = nn.ModuleList( [ MultiScaleFeatureExtractor( @@ -905,23 +726,16 @@ def __init__( # Geometry tokenizer for global geometry context if geometry_dim is not None: - if structured_shape is not None: - self.geometry_tokenizer = StructuredContextProjector( - geometry_dim, - structured_shape, - n_head, - dim_head, - dropout, - slice_num, - use_te=use_te, - plus=plus, - concrete_dropout=concrete_dropout, - ) - else: - self.geometry_tokenizer = ContextProjector( - geometry_dim, n_head, dim_head, dropout, slice_num, use_te, plus=plus, - concrete_dropout=concrete_dropout, - ) + self.geometry_tokenizer = ContextProjector( + geometry_dim, + n_head, + dim_head, + dropout, + slice_num, + use_te, + plus, + concrete_dropout=concrete_dropout, + ) context_dim += dim_head else: self.geometry_tokenizer = None diff --git a/physicsnemo/experimental/models/geotransolver/gale.py b/physicsnemo/experimental/models/geotransolver/gale.py index 8074c958eb..e87eed41b2 100644 --- a/physicsnemo/experimental/models/geotransolver/gale.py +++ b/physicsnemo/experimental/models/geotransolver/gale.py @@ -25,7 +25,6 @@ import torch import torch.nn as nn -import torch.nn.functional as F from einops import rearrange from jaxtyping import Float @@ -34,12 +33,9 @@ from physicsnemo.nn import Mlp from physicsnemo.nn.module.physics_attention import ( PhysicsAttentionIrregularMesh, - PhysicsAttentionStructuredMesh2D, - PhysicsAttentionStructuredMesh3D, - _project_input, ) -from physicsnemo.experimental.nn.flare_attention import _flare_self_attention +from physicsnemo.experimental.models.geotransolver.gale_fa import GALE_FA from physicsnemo.nn import ConcreteDropout # Check optional dependency availability @@ -47,191 +43,6 @@ te = OptionalImport("transformer_engine.pytorch", "0.1.0") -def _mix_self_and_cross( - self_attn: torch.Tensor, - cross_attn: torch.Tensor, - mode: str, - state_mixing: nn.Parameter | None = None, - concat_project: nn.Module | None = None, -) -> torch.Tensor: - r"""Blend self-attention and cross-attention outputs. - - Parameters - ---------- - self_attn : torch.Tensor - Self-attention output. - cross_attn : torch.Tensor - Cross-attention output (same shape as ``self_attn``). - mode : str - ``"weighted"`` for sigmoid-gated sum, ``"concat_project"`` for - concatenation followed by a learned projection. - state_mixing : nn.Parameter or None - Learnable scalar for ``"weighted"`` mode. - concat_project : nn.Module or None - Projection module for ``"concat_project"`` mode. - - Returns - ------- - torch.Tensor - Blended output, same shape as inputs. - """ - match mode: - case "weighted": - w = torch.sigmoid(state_mixing) - return w * self_attn + (1 - w) * cross_attn - case "concat_project": - return concat_project(torch.cat([self_attn, cross_attn], dim=-1)) - case _: - raise ValueError(f"Invalid state_mixing_mode: {mode!r}") - - -def _gale_compute_slice_attention_cross( - module: nn.Module, - slice_tokens: list[Float[torch.Tensor, "batch heads slices dim"]], - context: Float[torch.Tensor, "batch heads context_slices context_dim"], -) -> list[Float[torch.Tensor, "batch heads slices dim"]]: - r"""Shared cross-attention between slice tokens and context. - - Used by :class:`GALE` and :class:`_GALEStructuredForwardMixin` so the - cross-attention implementation lives in one place. Projects queries from - concatenated slice tokens, keys and values from context; runs Transformer - Engine or SDPA attention; splits the result back to one tensor per input. - - Parameters - ---------- - module : nn.Module - Module with ``cross_q``, ``cross_k``, ``cross_v``, ``use_te``, - ``heads``, ``dim_head``, and (if ``use_te``) ``attn_fn``. - slice_tokens : list[torch.Tensor] - One tensor per input, each of shape :math:`(B, H, S, D)`. - context : torch.Tensor - Context tensor of shape :math:`(B, H, S_c, D_c)`. - - Returns - ------- - list[torch.Tensor] - One cross-attention output per element of ``slice_tokens``, each - of shape :math:`(B, H, S, D)`. - """ - q_input = torch.cat(slice_tokens, dim=-2) - q = module.cross_q(q_input) - k = module.cross_k(context) - v = module.cross_v(context) - if module.use_te: - q = rearrange(q, "b h s d -> b s h d") - k = rearrange(k, "b h s d -> b s h d") - v = rearrange(v, "b h s d -> b s h d") - cross_attention = module.attn_fn(q, k, v) - cross_attention = rearrange( - cross_attention, - "b s (h d) -> b h s d", - h=module.heads, - d=module.dim_head, - ) - else: - cross_attention = torch.nn.functional.scaled_dot_product_attention( - q, k, v, is_causal=False - ) - cross_attention = torch.split( - cross_attention, slice_tokens[0].shape[-2], dim=-2 - ) - return list(cross_attention) - - -def _gale_forward_impl( - module: nn.Module, - x: tuple[Float[torch.Tensor, "batch tokens channels"], ...], - context: Float[torch.Tensor, "batch heads context_slices context_dim"] - | None, -) -> list[Float[torch.Tensor, "batch tokens channels"]]: - r"""Single implementation of the GALE forward pipeline. - - Shared by :class:`GALE` and :class:`_GALEStructuredForwardMixin`. Steps: - validate inputs; project onto slices; compute slice weights and tokens; - apply self-attention on slices; optionally cross-attend to context and - mix with ``state_mixing``; project attention outputs back to token space. - - Parameters - ---------- - module : nn.Module - GALE-like module with ``project_input_onto_slices``, - ``in_project_slice``, ``_compute_slices_from_projections``, - ``_compute_slice_attention_te``, ``_compute_slice_attention_sdpa``, - ``compute_slice_attention_cross``, ``_project_attention_outputs``, - plus attributes ``use_te``, ``plus``, ``state_mixing_mode``, and - ``state_mixing`` (if weighted) or ``concat_project`` (if concat). - x : tuple[torch.Tensor, ...] - Input tensors, each of shape :math:`(B, N, C)`; must be non-empty. - context : torch.Tensor or None - Optional context of shape :math:`(B, H, S_c, D_c)` for cross-attention. - If ``None``, only self-attention is applied. - - Returns - ------- - list[torch.Tensor] - One output tensor per input, each of shape :math:`(B, N, C)`. - - Raises - ------ - ValueError - If ``x`` is empty or any element is not 3D. - """ - if not torch.compiler.is_compiling(): - if len(x) == 0: - raise ValueError("Expected non-empty tuple of input tensors") - for i, tensor in enumerate(x): - if tensor.ndim != 3: - raise ValueError( - f"Expected 3D input tensor (B, N, C) at index {i}, " - f"got {tensor.ndim}D tensor with shape {tuple(tensor.shape)}" - ) - if module.plus: - x_mid = [module.project_input_onto_slices(_x) for _x in x] - fx_mid = [_x_mid for _x_mid in x_mid] - else: - x_mid, fx_mid = zip( - *[module.project_input_onto_slices(_x) for _x in x] - ) - slice_projections = [module.in_project_slice(_x_mid) for _x_mid in x_mid] - slice_weights, slice_tokens = zip( - *[ - module._compute_slices_from_projections(proj, _fx_mid) - for proj, _fx_mid in zip(slice_projections, fx_mid) - ] - ) - if module.use_te: - self_slice_token = [ - module._compute_slice_attention_te(_slice_token) - for _slice_token in slice_tokens - ] - else: - self_slice_token = [ - module._compute_slice_attention_sdpa(_slice_token) - for _slice_token in slice_tokens - ] - if context is not None: - cross_slice_token = [ - module.compute_slice_attention_cross([_slice_token], context)[0] - for _slice_token in slice_tokens - ] - out_slice_token = [ - _mix_self_and_cross( - sst, cst, module.state_mixing_mode, - state_mixing=getattr(module, "state_mixing", None), - concat_project=getattr(module, "concat_project", None), - ) - for sst, cst in zip(self_slice_token, cross_slice_token) - ] - else: - # Use only self-attention when no context is provided - out_slice_token = self_slice_token - outputs = [ - module._project_attention_outputs(ost, sw) - for ost, sw in zip(out_slice_token, slice_weights) - ] - return outputs - - class GALE(PhysicsAttentionIrregularMesh): r"""Geometry-Aware Latent Embeddings (GALE) attention layer. @@ -258,10 +69,8 @@ class GALE(PhysicsAttentionIrregularMesh): Whether to use Transolver++ features. Default is False. context_dim : int, optional Dimension of the context vector for cross-attention. Default is 0. - concrete_dropout : bool, optional - Whether to use ConcreteDropout instead of standard dropout. Default is False. state_mixing_mode : str, optional - How to blend self-attention and cross-attention outputs. ``"weighted"`` uses + How to blend self-attention and cross-attention outputs. ``"weighted"`` uses a learnable sigmoid-gated weighted sum. ``"concat_project"`` concatenates the two along the head dimension and projects back with a linear layer. Default is ``"weighted"``. @@ -320,7 +129,32 @@ def __init__( state_mixing_mode: str = "weighted", ) -> None: super().__init__(dim, heads, dim_head, dropout, slice_num, use_te, plus) - _gale_cross_init(self, dim_head, context_dim, use_te, state_mixing_mode) + + if state_mixing_mode not in ("weighted", "concat_project"): + raise ValueError( + f"Invalid state_mixing_mode: {state_mixing_mode!r}. " + f"Expected 'weighted' or 'concat_project'." + ) + self.state_mixing_mode = state_mixing_mode + + linear_layer = te.Linear if self.use_te else nn.Linear + + # Cross-attention projection layers for context integration + self.cross_q = linear_layer(dim_head, dim_head) + self.cross_k = linear_layer(context_dim, dim_head) + self.cross_v = linear_layer(context_dim, dim_head) + + # Mixing layers for blending self-attention and cross-attention + if state_mixing_mode == "weighted": + # Learnable mixing weight between self and cross attention + # Initialize near 0.0 since sigmoid(0) = 0.5, giving balanced initial mixing + self.state_mixing = nn.Parameter(torch.tensor(0.0)) + else: + # Concatenate self and cross attention and project back to dim_head + self.concat_project = nn.Sequential( + linear_layer(2 * dim_head, dim_head), + nn.GELU(), + ) # Replace inherited out_dropout with ConcreteDropout when enabled if concrete_dropout: @@ -351,10 +185,39 @@ def compute_slice_attention_cross( list[torch.Tensor] List of cross-attention outputs, each of shape :math:`(B, H, S, D)`. """ - return _gale_compute_slice_attention_cross( - self, slice_tokens, context + # Concatenate all slice tokens for batched projection + q_input = torch.cat(slice_tokens, dim=-2) # (B, H, total_slices, D) + + # Project queries from slice tokens + q = self.cross_q(q_input) # (B, H, total_slices, D) + + # Project keys and values from context + k = self.cross_k(context) # (B, H, S_c, D) + v = self.cross_v(context) # (B, H, S_c, D) + + # Compute cross-attention using appropriate backend + if self.use_te: + # Transformer Engine expects (B, S, H, D) format + q = rearrange(q, "b h s d -> b s h d") + k = rearrange(k, "b h s d -> b s h d") + v = rearrange(v, "b h s d -> b s h d") + cross_attention = self.attn_fn(q, k, v) + cross_attention = rearrange( + cross_attention, "b s (h d) -> b h s d", h=self.heads, d=self.dim_head + ) + else: + # Use PyTorch's scaled dot-product attention + cross_attention = torch.nn.functional.scaled_dot_product_attention( + q, k, v, is_causal=False + ) + + # Split back into individual slice token outputs + cross_attention = torch.split( + cross_attention, slice_tokens[0].shape[-2], dim=-2 ) + return list(cross_attention) + def forward( self, x: tuple[Float[torch.Tensor, "batch tokens channels"], ...], @@ -384,324 +247,80 @@ def forward( List of output tensors, each of shape :math:`(B, N, C)``, same shape as inputs. """ - return _gale_forward_impl(self, x, context) - - -def _gale_cross_init( - self: nn.Module, - dim_head: int, - context_dim: int, - use_te: bool, - state_mixing_mode: str = "weighted", -) -> None: - # Match GALE: TE linear only when TE is installed (GALE_block already errors if use_te without TE) - linear_layer = te.Linear if (use_te and TE_AVAILABLE) else nn.Linear - self.cross_q = linear_layer(dim_head, dim_head) - self.cross_k = linear_layer(context_dim, dim_head) - self.cross_v = linear_layer(context_dim, dim_head) - - self.state_mixing_mode = state_mixing_mode - - match state_mixing_mode: - case "weighted": - # Learnable mixing weight between self and cross attention - # Initialize near 0.0 since sigmoid(0) = 0.5, giving balanced initial mixing - self.state_mixing = nn.Parameter(torch.tensor(0.0)) - case "concat_project": - # Concatenate self and cross attention and project back to dim_head - self.concat_project = nn.Sequential( - linear_layer(2 * dim_head, dim_head), - nn.GELU(), - ) - case _: - raise ValueError( - f"Invalid state_mixing_mode: {state_mixing_mode!r}. " - f"Expected 'weighted' or 'concat_project'." - ) - - -class _GALEStructuredForwardMixin: - """Shared cross-attention and forward for structured GALE (2D/3D conv projection).""" - - def compute_slice_attention_cross( - self, - slice_tokens: list[Float[torch.Tensor, "batch heads slices dim"]], - context: Float[torch.Tensor, "batch heads context_slices context_dim"], - ) -> list[Float[torch.Tensor, "batch heads slices dim"]]: - return _gale_compute_slice_attention_cross( - self, slice_tokens, context - ) - - def forward( - self, - x: tuple[Float[torch.Tensor, "batch tokens channels"], ...], - context: Float[torch.Tensor, "batch heads context_slices context_dim"] - | None = None, - ) -> list[Float[torch.Tensor, "batch tokens channels"]]: - return _gale_forward_impl(self, x, context) - - -class GALEStructuredMesh2D(_GALEStructuredForwardMixin, PhysicsAttentionStructuredMesh2D): - r"""GALE with Conv2d slice projection for 2D structured grids (see :class:`GALE`).""" - - def __init__( - self, - dim: int, - spatial_shape: tuple[int, int], - heads: int = 8, - dim_head: int = 64, - dropout: float = 0.0, - slice_num: int = 64, - kernel: int = 3, - use_te: bool = True, - plus: bool = False, - context_dim: int = 0, - state_mixing_mode: str = "weighted", - ) -> None: - super().__init__( - dim, - spatial_shape, - heads, - dim_head, - dropout, - slice_num, - kernel, - use_te, - plus, - ) - _gale_cross_init(self, dim_head, context_dim, use_te, state_mixing_mode) - - -class GALEStructuredMesh3D(_GALEStructuredForwardMixin, PhysicsAttentionStructuredMesh3D): - r"""GALE with Conv3d slice projection for 3D structured grids (see :class:`GALE`).""" - - def __init__( - self, - dim: int, - spatial_shape: tuple[int, int, int], - heads: int = 8, - dim_head: int = 64, - dropout: float = 0.0, - slice_num: int = 64, - kernel: int = 3, - use_te: bool = True, - plus: bool = False, - context_dim: int = 0, - state_mixing_mode: str = "weighted", - ) -> None: - super().__init__( - dim, - spatial_shape, - heads, - dim_head, - dropout, - slice_num, - kernel, - use_te, - plus, - ) - _gale_cross_init(self, dim_head, context_dim, use_te, state_mixing_mode) - - -class GALE_FA(nn.Module): - r"""GALE_FA: Geometry-Aware Latent Embeddings with FLARE self-Attention attention layer. - - Adopted: - - - FLARE attention: Fast Low-rank Attention Routing Engine - paper: https://arxiv.org/abs/2508.12594 - - GeoTransolver context: - paper: https://arxiv.org/abs/2512.20399 - - GALE_FA is an alternative to the GALE attention mechanism of the GeoTransolver. - It supports cross-attention with a context vector, built from geometry and global embeddings. - GALE_FA combines FLARE self-attention on learned physical state slices with cross-attention - to geometry-aware context, using a learnable mixing weight to blend the two. - - Parameters - ---------- - dim : int - Input dimension of the features. - heads : int, optional - Number of attention heads. Default is 8. - dim_head : int, optional - Dimension of each attention head. Default is 64. - dropout : float, optional - Dropout rate. Default is 0.0. - n_global_queries : int, optional - Number of learned global queries. Default is 64. - use_te : bool, optional - Whether to use Transformer Engine backend when available. Default is False. - context_dim : int, optional - Dimension of the context vector for cross-attention. Default is 0. - concrete_dropout : bool, optional - Whether to use learned concrete dropout instead of standard dropout. - Default is ``False``. - state_mixing_mode : str, optional - How to blend self-attention and cross-attention outputs. ``"weighted"`` uses - a learnable sigmoid-gated weighted sum. ``"concat_project"`` - concatenates the two along the head dimension and projects back with a - linear layer. Default is ``"weighted"``. - - Forward - ------- - x : tuple[torch.Tensor, ...] - Tuple of input tensors, each of shape :math:`(B, N, C)` where :math:`B` is - batch size, :math:`N` is number of tokens, and :math:`C` is number of channels. - context : tuple[torch.Tensor, ...] | None, optional - Context tensor for cross-attention of shape :math:`(B, H, S_c, D_c)` where - :math:`H` is number of heads, :math:`S_c` is number of context slices, and - :math:`D_c` is context dimension. If ``None``, only self-attention is applied. - Default is ``None``. - - Outputs - ------- - list[torch.Tensor] - List of output tensors, each of shape :math:`(B, N, C)`, same shape as inputs. - - Notes - ----- - The mixing between self-attention and cross-attention is controlled by a learnable - parameter ``state_mixing`` which is passed through a sigmoid function to ensure - the mixing weight stays in :math:`[0, 1]`. - - See Also - -------- - :class:`GALE` : Original GeoTransolver GALE attention class. - :class:`GALE_block` : Transformer block that calls GALE or GALE_FA attention. - - Examples - -------- - >>> import torch - >>> gale_fa = GALE_FA(dim=256, heads=8, dim_head=32, context_dim=32) - >>> x = (torch.randn(2, 100, 256),) # Single input tensor in tuple - >>> context = torch.randn(2, 8, 64, 32) # Context for cross-attention - >>> outputs = gale_fa(x, context) - >>> len(outputs) - 1 - >>> outputs[0].shape - torch.Size([2, 100, 256]) - """ + ### Input validation + if not torch.compiler.is_compiling(): + if len(x) == 0: + raise ValueError("Expected non-empty tuple of input tensors") + for i, tensor in enumerate(x): + if tensor.ndim != 3: + raise ValueError( + f"Expected 3D input tensor (B, N, C) at index {i}, " + f"got {tensor.ndim}D tensor with shape {tuple(tensor.shape)}" + ) - def __init__( - self, - dim, - heads: int = 8, - dim_head: int = 64, - dropout: float = 0.0, - n_global_queries: int = 64, - use_te: bool = True, - context_dim: int = 0, - concrete_dropout: bool = False, - state_mixing_mode: str = "weighted", - ): - if use_te: - raise ValueError( - "GALE_FA does not support Transformer Engine backend. " - "Use use_te=False; TE disables FlashAttention for differing q/k sizes in FLARE attention." + # Project inputs onto learned latent spaces + if self.plus: + x_mid = [self.project_input_onto_slices(_x) for _x in x] + # In Transolver++, x_mid is reused for both projections + fx_mid = [_x_mid for _x_mid in x_mid] + else: + x_mid, fx_mid = zip( + *[self.project_input_onto_slices(_x) for _x in x] ) - super().__init__() - self.use_te = use_te - self.heads = heads - self.dim_head = dim_head - self.scale = 1.0 - # It is recommended by the FLARE authors to use self.scale = 1 if self.dim_head <= 8 else (self.dim_head ** -0.5) - # but we use self.scale = 1.0 because the recommended scaling is not tested yet. - inner_dim = dim_head * heads - linear_layer = te.Linear if self.use_te else nn.Linear + # Project latent representations onto physical state slices + slice_projections = [self.in_project_slice(_x_mid) for _x_mid in x_mid] - # Global queries for FLARE self-attention - self.q_global = nn.Parameter(torch.randn(1, heads, n_global_queries, dim_head)) - - # Linear projections for self-attention - self.in_project_x = linear_layer(dim, inner_dim) - self.self_k = linear_layer(dim_head, dim_head) - self.self_v = linear_layer(dim_head, dim_head) - - if context_dim > 0: - _gale_cross_init(self, dim_head, context_dim, use_te, state_mixing_mode) + # Compute slice weights and aggregated slice tokens + slice_weights, slice_tokens = zip( + *[ + self._compute_slices_from_projections(proj, _fx_mid) + for proj, _fx_mid in zip(slice_projections, fx_mid) + ] + ) - # Linear projection for output - self.out_linear = linear_layer(inner_dim, dim) - if concrete_dropout: - self.out_dropout = ConcreteDropout( - in_features=dim, - init_p=max(dropout, 0.05), - ) + # Apply self-attention to slice tokens + if self.use_te: + self_slice_token = [ + self._compute_slice_attention_te(_slice_token) + for _slice_token in slice_tokens + ] else: - self.out_dropout = nn.Dropout(dropout) - - def forward( - self, - x: tuple[Float[torch.Tensor, "batch tokens channels"], ...], - context: Float[torch.Tensor, "batch heads context_slices context_dim"] - | None = None, - ) -> list[Float[torch.Tensor, "batch tokens channels"]]: - r"""Forward pass of the GALE_FA module. - - Applies GALE_FA attention to the input features. - - Parameters - ---------- - x : tuple[torch.Tensor, ...] - Tuple of input tensors, each of shape :math:`(B, N, C)` where :math:`B` - is batch size, :math:`N` is number of tokens, and :math:`C` is number - of channels. - context : torch.Tensor | None, optional - Context tensor for cross-attention of shape :math:`(B, H, S_c, D_c)` - where :math:`H` is number of heads, :math:`S_c` is number of context - slices, and :math:`D_c` is context dimension. If ``None``, only - self-attention is applied. Default is ``None``. - - Returns - ------- - list[torch.Tensor] - List of output tensors, each of shape :math:`(B, N, C)``, same shape - as inputs. - """ - # Input projection: (B, N, C) -> (B, N, H, D) -> (B, H, N, D) - x_mid = [ - _project_input( - _x, self.in_project_x, self.heads, self.dim_head, - "B N (H D) -> B N H D", - ).permute(0, 2, 1, 3) - for _x in x - ] - - # FLARE self-attention per input - self_attention = [ - _flare_self_attention( - _x_mid, self.q_global, self.self_k, self.self_v, self.scale, - ) - for _x_mid in x_mid - ] + self_slice_token = [ + self._compute_slice_attention_sdpa(_slice_token) + for _slice_token in slice_tokens + ] - # Cross-attention with context and state mixing + # Apply cross-attention with context if provided if context is not None: - q = [self.cross_q(_x_mid) for _x_mid in x_mid] - k = self.cross_k(context) - v = self.cross_v(context) - cross_attention = [ - F.scaled_dot_product_attention(_q, k, v, scale=self.scale) - for _q in q - ] - outputs = [ - _mix_self_and_cross( - sa, ca, self.state_mixing_mode, - state_mixing=getattr(self, "state_mixing", None), - concat_project=getattr(self, "concat_project", None), - ) - for sa, ca in zip(self_attention, cross_attention) + cross_slice_token = [ + self.compute_slice_attention_cross([_slice_token], context)[0] + for _slice_token in slice_tokens ] + + # Blend self-attention and cross-attention + if self.state_mixing_mode == "weighted": + mixing_weight = torch.sigmoid(self.state_mixing) + out_slice_token = [ + mixing_weight * sst + (1 - mixing_weight) * cst + for sst, cst in zip(self_slice_token, cross_slice_token) + ] + else: + out_slice_token = [ + self.concat_project(torch.cat([sst, cst], dim=-1)) + for sst, cst in zip(self_slice_token, cross_slice_token) + ] else: - outputs = self_attention + # Use only self-attention when no context is provided + out_slice_token = self_slice_token + + # Project attention outputs back to original space using slice weights + outputs = [ + self._project_attention_outputs(ost, sw) + for ost, sw in zip(out_slice_token, slice_weights) + ] - # Back to token layout: (B, H, N, D) -> (B, N, H, D) - outputs = [_y.permute(0, 2, 1, 3) for _y in outputs] - outputs = [rearrange(_out, "b n h d -> b n (h d)") for _out in outputs] - outputs = [self.out_linear(_out) for _out in outputs] - return [self.out_dropout(_out) for _out in outputs] + return outputs class GALE_block(nn.Module): @@ -735,16 +354,11 @@ class GALE_block(nn.Module): Whether to use Transolver++ features. Default is ``False``. context_dim : int, optional Dimension of the context vector for cross-attention. Default is 0. - spatial_shape : tuple[int, ...] | None, optional - If ``None``, uses irregular-mesh GALE. Length-2 tuple enables 2D Conv2d - projection; length-3 tuple enables 3D Conv3d projection (flattened - :math:`N = H \times W` or :math:`H \times W \times D`). Default is ``None``. attention_type : str, optional - Attention backend to use. ``"GALE"`` uses the standard physics-aware - slice attention; ``"GALE_FA"`` uses flash-attention variant. + attention_type is used to choose the attention type (GALE or GALE_FA). Default is ``"GALE"``. state_mixing_mode : str, optional - How to blend self-attention and cross-attention outputs. ``"weighted"`` uses + How to blend self-attention and cross-attention outputs. ``"weighted"`` uses a learnable sigmoid-gated weighted sum. ``"concat_project"`` concatenates the two along the head dimension and projects back with a linear layer. Default is ``"weighted"``. @@ -800,7 +414,6 @@ def __init__( use_te: bool = True, plus: bool = False, context_dim: int = 0, - spatial_shape: tuple[int, ...] | None = None, attention_type: str = "GALE", concrete_dropout: bool = False, state_mixing_mode: str = "weighted", @@ -821,62 +434,26 @@ def __init__( else: self.ln_1 = nn.LayerNorm(hidden_dim) - dim_head = hidden_dim // num_heads - # First match on attention backend, then on spatial shape + # Attention layer match attention_type: case 'GALE': - if spatial_shape is None: - self.Attn = GALE( - hidden_dim, - heads=num_heads, - dim_head=dim_head, - dropout=dropout, - slice_num=slice_num, - use_te=use_te, - plus=plus, - context_dim=context_dim, - concrete_dropout=concrete_dropout, - state_mixing_mode=state_mixing_mode, - ) - elif len(spatial_shape) == 2: - self.Attn = GALEStructuredMesh2D( - hidden_dim, - spatial_shape=(int(spatial_shape[0]), int(spatial_shape[1])), - heads=num_heads, - dim_head=dim_head, - dropout=dropout, - slice_num=slice_num, - use_te=use_te, - plus=plus, - context_dim=context_dim, - state_mixing_mode=state_mixing_mode, - ) - elif len(spatial_shape) == 3: - self.Attn = GALEStructuredMesh3D( - hidden_dim, - spatial_shape=( - int(spatial_shape[0]), - int(spatial_shape[1]), - int(spatial_shape[2]), - ), - heads=num_heads, - dim_head=dim_head, - dropout=dropout, - slice_num=slice_num, - use_te=use_te, - plus=plus, - context_dim=context_dim, - state_mixing_mode=state_mixing_mode, - ) - else: - raise ValueError( - f"spatial_shape must be None, length-2, or length-3; got {spatial_shape!r}" - ) + self.Attn = GALE( + hidden_dim, + heads=num_heads, + dim_head=hidden_dim // num_heads, + dropout=dropout, + slice_num=slice_num, + use_te=use_te, + plus=plus, + context_dim=context_dim, + concrete_dropout=concrete_dropout, + state_mixing_mode=state_mixing_mode, + ) case 'GALE_FA': self.Attn = GALE_FA( hidden_dim, heads=num_heads, - dim_head=dim_head, + dim_head=hidden_dim // num_heads, dropout=dropout, n_global_queries=slice_num, use_te=use_te, diff --git a/physicsnemo/experimental/models/geotransolver/gale_fa.py b/physicsnemo/experimental/models/geotransolver/gale_fa.py new file mode 100644 index 0000000000..eb9f8a3478 --- /dev/null +++ b/physicsnemo/experimental/models/geotransolver/gale_fa.py @@ -0,0 +1,285 @@ +# SPDX-FileCopyrightText: Copyright (c) 2023 - 2026 NVIDIA CORPORATION & AFFILIATES. +# SPDX-FileCopyrightText: All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""GALE_FA (Geometry-Aware Latent Embeddings with FLARE self-Attention) attention layer. + +This module provides the GALE_FA attention mechanism, +an alternative to the GALE attention mechanism of the GeoTransolver. +""" + +from __future__ import annotations + +import torch +import torch.nn as nn +import torch.nn.functional as F +from einops import rearrange +from jaxtyping import Float + +from physicsnemo.core.version_check import check_version_spec, OptionalImport +from physicsnemo.nn import ConcreteDropout + +# Check optional dependency availability +TE_AVAILABLE = check_version_spec("transformer_engine", "0.1.0", hard_fail=False) +te = OptionalImport("transformer_engine.pytorch", "0.1.0") + + +class GALE_FA(nn.Module): + r"""GALE_FA: Geometry-Aware Latent Embeddings with FLARE self-Attention attention layer. + Adopted: + - FLARE attention: Fast Low-rank Attention Routing Engine + paper: https://arxiv.org/abs/2508.12594 + - GeoTransolver context: + paper: https://arxiv.org/abs/2512.20399 + + GALE_FA is an alternative to the GALE attention mechanism of the GeoTransolver + It supports cross-attention with a context vector, built from geometry and global embeddings. + GALE_FA combines FLARE self-attention on learned physical state slices with cross-attention + to geometry-aware context, using a learnable mixing weight to blend the two. + + Parameters + ---------- + dim : int + Input dimension of the features. + heads : int, optional + Number of attention heads. Default is 8. + dim_head : int, optional + Dimension of each attention head. Default is 64. + dropout : float, optional + Dropout rate. Default is 0.0. + n_global_queries : int, optional + Number of learned global queries. Default is 64. + use_te : bool, optional + Whether to use Transformer Engine backend when available. Default is False. + context_dim : int, optional + Dimension of the context vector for cross-attention. Default is 0. + concrete_dropout : bool, optional + Whether to use learned concrete dropout instead of standard dropout. + Default is ``False``. + state_mixing_mode : str, optional + How to blend self-attention and cross-attention outputs. ``"weighted"`` uses + a learnable sigmoid-gated weighted sum. ``"concat_project"`` + concatenates the two along the head dimension and projects back with a + linear layer. Default is ``"weighted"``. + + Forward + ------- + x : tuple[torch.Tensor, ...] + Tuple of input tensors, each of shape :math:`(B, N, C)` where :math:`B` is + batch size, :math:`N` is number of tokens, and :math:`C` is number of channels. + context : tuple[torch.Tensor, ...] | None, optional + Context tensor for cross-attention of shape :math:`(B, H, S_c, D_c)` where + :math:`H` is number of heads, :math:`S_c` is number of context slices, and + :math:`D_c` is context dimension. If ``None``, only self-attention is applied. + Default is ``None``. + + Outputs + ------- + list[torch.Tensor] + List of output tensors, each of shape :math:`(B, N, C)`, same shape as inputs. + + Notes + ----- + The mixing between self-attention and cross-attention is controlled by a learnable + parameter ``state_mixing`` which is passed through a sigmoid function to ensure + the mixing weight stays in :math:`[0, 1]`. + + See Also + -------- + :class:`GALE` : Original GeoTransolver GALE attention class. + :class:`GALE_block` : Transformer block that calls GALE or GALE_FA attention. + + Examples + -------- + >>> import torch + >>> gale_fa = GALE_FA(dim=256, heads=8, dim_head=32, context_dim=32) + >>> x = (torch.randn(2, 100, 256),) # Single input tensor in tuple + >>> context = torch.randn(2, 8, 64, 32) # Context for cross-attention + >>> outputs = gale_fa(x, context) + >>> len(outputs) + 1 + >>> outputs[0].shape + torch.Size([2, 100, 256]) + """ + + def __init__( + self, + dim, + heads: int = 8, + dim_head: int = 64, + dropout: float = 0.0, + n_global_queries: int = 64, + use_te: bool = True, + context_dim: int = 0, + concrete_dropout: bool = False, + state_mixing_mode: str = "weighted", + ): + if use_te: + raise ValueError( + "GALE_FA does not support Transformer Engine backend. " + "Use use_te=False; TE disables FlashAttention for differing q/k sizes in FLARE attention." + ) + super().__init__() + if state_mixing_mode not in ("weighted", "concat_project"): + raise ValueError( + f"Invalid state_mixing_mode: {state_mixing_mode!r}. " + f"Expected 'weighted' or 'concat_project'." + ) + self.state_mixing_mode = state_mixing_mode + self.use_te = use_te + self.heads = heads + self.dim_head = dim_head + self.scale = 1.0 + # It is recommended by the FLARE authors to use self.scale = 1 if self.dim_head <= 8 else (self.dim_head ** -0.5) + # but we use self.scale = 1.0 because the recommended scaling is not tested yet. + inner_dim = dim_head * heads + + linear_layer = te.Linear if self.use_te else nn.Linear + + # Global queries for FLARE self-attention + self.q_global = nn.Parameter(torch.randn(1, heads, n_global_queries, dim_head)) + + # Linear projections for self-attention + self.in_project_x = linear_layer(dim, inner_dim) + self.self_k = linear_layer(dim_head, dim_head) + self.self_v = linear_layer(dim_head, dim_head) + + if context_dim > 0: + # Linear projections for cross-attention + self.cross_q = linear_layer(dim_head, dim_head) + self.cross_k = linear_layer(context_dim, dim_head) + self.cross_v = linear_layer(context_dim, dim_head) + + # Mixing layers for blending self-attention and cross-attention + if state_mixing_mode == "weighted": + # Learnable mixing weight between self and cross attention + self.state_mixing = nn.Parameter(torch.tensor(0.0)) + else: + # Concatenate self and cross attention and project back to dim_head + self.concat_project = nn.Sequential( + linear_layer(2 * dim_head, dim_head), + nn.GELU(), + ) + + # te attention + if self.use_te: + self.attn_fn = te.DotProductAttention( + num_attention_heads=self.heads, + kv_channels=self.dim_head, + attention_dropout=dropout, + qkv_format="bshd", + softmax_scale=self.scale + ) + + # Linear projection for output + self.out_linear = linear_layer(inner_dim, dim) + if concrete_dropout: + self.out_dropout = ConcreteDropout( + in_features=dim, + init_p=max(dropout, 0.05), + ) + else: + self.out_dropout = nn.Dropout(dropout) + + + def forward( + self, + x: tuple[Float[torch.Tensor, "batch tokens channels"], ...], + context: Float[torch.Tensor, "batch heads context_slices context_dim"] + | None = None, + ) -> list[Float[torch.Tensor, "batch tokens channels"]]: + r"""Forward pass of the GALE_FA module. + + Applies GALE_FA attention to the input features. + + Parameters + ---------- + x : tuple[torch.Tensor, ...] + Tuple of input tensors, each of shape :math:`(B, N, C)` where :math:`B` + is batch size, :math:`N` is number of tokens, and :math:`C` is number + of channels. + context : torch.Tensor | None, optional + Context tensor for cross-attention of shape :math:`(B, H, S_c, D_c)` + where :math:`H` is number of heads, :math:`S_c` is number of context + slices, and :math:`D_c` is context dimension. If ``None``, only + self-attention is applied. Default is ``None``. + + Returns + ------- + list[torch.Tensor] + List of output tensors, each of shape :math:`(B, N, C)``, same shape + as inputs. + """ + + # with record_function("forward"): + x_mid = [self.in_project_x(_x) for _x in x] + x_mid = [rearrange( + _x_mid, "B N (h d) -> B N h d", h=self.heads, d=self.dim_head + ) for _x_mid in x_mid] + x_mid = [_x_mid.permute(0, 2, 1, 3) for _x_mid in x_mid] # [B, H, N, D] + G = [self.q_global.to(dtype=x_mid[0].dtype).expand(x_mid[0].shape[0], -1, -1, -1)] * len(x) + k = [self.self_k(_x_mid) for _x_mid in x_mid] + v = [self.self_v(_x_mid) for _x_mid in x_mid] + + # FLARE: Self Attention + if self.use_te: + # Transformer Engine expects (B, S, H, D) format + G = [rearrange(_G, "b h s d -> b s h d") for _G in G] + k = [rearrange(_k, "b h s d -> b s h d") for _k in k] + v = [rearrange(_v, "b h s d -> b s h d") for _v in v] + z = [self.attn_fn(_G, _k, _v) for _G, _k, _v in zip(G, k, v)] + z = [rearrange( + _z, "b s (h d) -> b s h d", h=self.heads, d=self.dim_head + ) for _z in z] + self_attention = [self.attn_fn(_k, _G, _z) for _k, _G, _z in zip(k, G, z)] + self_attention = [rearrange( + _self_attention, "b s (h d) -> b h s d", h=self.heads, d=self.dim_head + ) for _self_attention in self_attention] + else: + # Use PyTorch's scaled dot-product attention + z = [F.scaled_dot_product_attention(_G, _k, _v, scale=self.scale) for _G, _k, _v in zip(G, k, v)] + self_attention = [F.scaled_dot_product_attention(_k, _G, _z, scale=self.scale) for _k, _G, _z in zip(k, G, z)] + + # apply cross-attention with physical states: + if context is not None: + q = [self.cross_q(_x_mid) for _x_mid in x_mid] + k = self.cross_k(context) + v = self.cross_v(context) + + if self.use_te: + q = [rearrange(_q, "b h s d -> b s h d") for _q in q] + k = rearrange(k, "b h s d -> b s h d") + v = rearrange(v, "b h s d -> b s h d") + cross_attention = [self.attn_fn(_q, k, v) for _q in q] + cross_attention = [rearrange( + _cross_attention, "b s (h d) -> b h s d", h=self.heads, d=self.dim_head + ) for _cross_attention in cross_attention] + else: + cross_attention = [F.scaled_dot_product_attention(_q, k, v, scale=self.scale) for _q in q] + + # Blend self-attention and cross-attention + if self.state_mixing_mode == "weighted": + mixing_weight = torch.sigmoid(self.state_mixing) + outputs = [mixing_weight * _ys + (1 - mixing_weight) * _yc for _ys, _yc in zip(self_attention, cross_attention)] + else: + outputs = [self.concat_project(torch.cat([_ys, _yc], dim=-1)) for _ys, _yc in zip(self_attention, cross_attention)] + else: + outputs = self_attention + + outputs = [_y.permute(0, 2, 1, 3) for _y in outputs] # [B, N, H, D] + outputs = [rearrange(_out, "b n h d -> b n (h d)") for _out in outputs] + outputs = [self.out_linear(_out) for _out in outputs] + return [self.out_dropout(_out) for _out in outputs] + diff --git a/physicsnemo/experimental/models/geotransolver/geotransolver.py b/physicsnemo/experimental/models/geotransolver/geotransolver.py index 13689e345e..36fd32a482 100644 --- a/physicsnemo/experimental/models/geotransolver/geotransolver.py +++ b/physicsnemo/experimental/models/geotransolver/geotransolver.py @@ -23,7 +23,6 @@ from __future__ import annotations -import math from collections.abc import Sequence from dataclasses import dataclass @@ -145,47 +144,6 @@ def _normalize_tensor( raise TypeError(f"Invalid tensor structure") -def _structured_num_tokens(spatial_shape: tuple[int, ...]) -> int: - return int(math.prod(spatial_shape)) - - -def _flatten_for_structured( - t: torch.Tensor, - spatial_shape: tuple[int, ...], - name: str, -) -> torch.Tensor: - """Flatten (B,H,W,C) or (B,H,W,D,C) to (B,N,C); pass through (B,N,C) if N matches. - - Mirrors Transolver's structured flatten/unflatten behavior so the rest of - GeoTransolver can assume a single token layout (B, N, C). - """ - n = _structured_num_tokens(spatial_shape) - if t.ndim == 3: - if not torch.compiler.is_compiling() and t.shape[1] != n: - raise ValueError( - f"{name} token count {t.shape[1]} != structured grid size {n}" - ) - return t - if len(spatial_shape) == 2 and t.ndim == 4: - B, H, W, C = t.shape - if (H, W) != spatial_shape: - raise ValueError( - f"{name} spatial dims {(H, W)} != structured_shape {spatial_shape}" - ) - return t.reshape(B, n, C) - if len(spatial_shape) == 3 and t.ndim == 5: - B, H, W, D, C = t.shape - if (H, W, D) != spatial_shape: - raise ValueError( - f"{name} spatial dims {(H, W, D)} != structured_shape {spatial_shape}" - ) - return t.reshape(B, n, C) - raise ValueError( - f"{name}: expected (B,N,C) with N={n}, or spatial layout matching " - f"structured_shape {spatial_shape}; got shape {tuple(t.shape)}" - ) - - class GeoTransolver(Module): r"""GeoTransolver: Geometry-Aware Physics Attention Transformer. @@ -247,11 +205,6 @@ class GeoTransolver(Module): Neighbors in radius for the local features. Default is ``[8, 32]``. n_hidden_local : int, optional Hidden dimension for the local features. Default is 32. - structured_shape : tuple[int, ...] | None, optional - If set to ``(H, W)`` or ``(H, W, D)``, enables structured 2D/3D paths - (Conv2d/Conv3d GALE; no ball-query local features). Inputs may be - flattened :math:`(B, N, C)` with :math:`N = H W` or :math:`H W D`, or - spatial :math:`(B, H, W, C)` / :math:`(B, H, W, D, C)`. Default is ``None``. guard_config : dict | None, optional Configuration for the embedded OOD guard (:class:`~physicsnemo.experimental.guardrails.embedded.OODGuard`). @@ -276,9 +229,9 @@ class GeoTransolver(Module): Forward ------- local_embedding : torch.Tensor | tuple[torch.Tensor, ...] - Local embedding: unstructured :math:`(B, N, C)`; structured 2D - :math:`(B, H, W, C)` or flattened :math:`(B, H W, C)`; structured 3D - :math:`(B, H, W, D, C)` or flattened. Can be a tuple for multiple input types. + Local embedding of the input data of shape :math:`(B, N, C)` where :math:`B` + is batch size, :math:`N` is number of nodes/tokens, and :math:`C` is + ``functional_dim``. Can be a single tensor or tuple for multiple input types. local_positions : torch.Tensor | tuple[torch.Tensor, ...] | None, optional Local positions for each input, each of shape :math:`(B, N, 3)`. Required if ``include_local_features=True``. Default is ``None``. @@ -295,18 +248,14 @@ class GeoTransolver(Module): Outputs ------- - torch.Tensor | tuple[torch.Tensor, ...] - When ``return_embedding_states=False`` (default): output tensor(s) of + Float[torch.Tensor, "batch tokens out_dim"] | tuple[Float[torch.Tensor, "batch tokens out_dim"], ...] + When ``return_embedding_states=False`` (default), output tensor of shape :math:`(B, N, C_{out})`. Returns a single tensor if input was - a single tensor, or a tuple of tensors if input was a tuple - (multi-stream). For structured grids, output matches the input - layout—flattened :math:`(B, N, C_{out})` or spatial - :math:`(B, H, W, C_{out})` / :math:`(B, H, W, D, C_{out})` when - inputs were 4D/5D. - + a single tensor, or a tuple if input was a tuple. + When ``return_embedding_states=True``, returns a 2-tuple - ``(output, embedding_states)`` where ``output`` follows the same - rules above, and ``embedding_states`` is of shape + ``(output, embedding_states)`` where ``embedding_states`` is + ``Float[torch.Tensor, "batch heads slices context_dim"]`` of shape :math:`(B, H, S, D_c)` (geometry/global context), or ``None`` if no context sources were provided. @@ -322,9 +271,8 @@ class GeoTransolver(Module): Notes ----- - Unstructured mesh uses linear GALE projection; structured ``structured_shape`` - uses the same Conv2d/Conv3d slice projection as :class:`~physicsnemo.models.transolver.Transolver`. - Ball-query local features are disabled when ``structured_shape`` is set. + GeoTransolver currently supports unstructured mesh input only. Enhancements for + image-based and voxel-based inputs may be available in the future. For more details on Transolver, see: @@ -373,21 +321,6 @@ class GeoTransolver(Module): >>> output.shape torch.Size([2, 1000, 3]) - Structured 2D grid: - - >>> model = GeoTransolver( - ... functional_dim=3, - ... out_dim=1, - ... structured_shape=(8, 8), - ... n_hidden=64, - ... n_head=4, - ... n_layers=2, - ... use_te=False, - ... ) - >>> y = model(torch.randn(2, 8, 8, 3)) - >>> y.shape - torch.Size([2, 8, 8, 1]) - To also retrieve the geometry/global context embeddings: >>> output, emb_states = model( @@ -420,7 +353,6 @@ def __init__( radii: list[float] | None = None, neighbors_in_radius: list[int] | None = None, n_hidden_local: int = 32, - structured_shape: tuple[int, ...] | None = None, guard_config: dict | None = None, attention_type: str = "GALE", concrete_dropout: bool = False, @@ -435,22 +367,8 @@ def __init__( if neighbors_in_radius is None: neighbors_in_radius = [8, 32] - if structured_shape is not None: - if include_local_features: - raise ValueError( - "include_local_features=True is not supported with structured_shape " - "(ball-query path is mesh-only)." - ) - if len(structured_shape) not in (2, 3): - raise ValueError( - f"structured_shape must have length 2 or 3, got {structured_shape!r}" - ) - if not all(int(s) > 0 for s in structured_shape): - raise ValueError(f"structured_shape must be positive ints, got {structured_shape!r}") - self.include_local_features = include_local_features self.use_te = use_te - self.structured_shape = structured_shape # Validate head dimension compatibility if not n_hidden % n_head == 0: @@ -481,7 +399,6 @@ def __init__( use_te=use_te, plus=plus, include_local_features=self.include_local_features, - structured_shape=structured_shape, concrete_dropout=concrete_dropout, ) context_dim = self.context_builder.get_context_dim() @@ -530,7 +447,6 @@ def __init__( use_te=use_te, plus=plus, context_dim=context_dim, - spatial_shape=structured_shape, attention_type=attention_type, concrete_dropout=concrete_dropout, state_mixing_mode=state_mixing_mode, @@ -674,27 +590,6 @@ def forward( if local_positions is not None: local_positions = _normalize_tensor(local_positions) - unflatten_output = False - if self.structured_shape is not None: - unflatten_output = any(le.ndim in (4, 5) for le in local_embedding) - local_embedding = tuple( - _flatten_for_structured( - le, self.structured_shape, f"local_embedding[{i}]" - ) - for i, le in enumerate(local_embedding) - ) - if geometry is not None: - geometry = _flatten_for_structured( - geometry, self.structured_shape, "geometry" - ) - n_tok = _structured_num_tokens(self.structured_shape) - for i, le in enumerate(local_embedding): - if le.shape[1] != n_tok: - raise ValueError( - f"structured GeoTransolver: all streams must have N={n_tok} tokens; " - f"local_embedding[{i}] has N={le.shape[1]}" - ) - ### Input validation if not torch.compiler.is_compiling(): if len(local_embedding) == 0: @@ -751,16 +646,6 @@ def forward( # Project to output dimensions: (B, N, n_hidden) -> (B, N, out_dim) x = [self.ln_mlp_out[i](x[i]) for i in range(len(x))] - if self.structured_shape is not None and unflatten_output: - B = x[0].shape[0] - for i in range(len(x)): - if len(self.structured_shape) == 2: - H, W = self.structured_shape - x[i] = x[i].reshape(B, H, W, -1) - else: - H, W, D_ = self.structured_shape - x[i] = x[i].reshape(B, H, W, D_, -1) - # Return same format as input (single tensor or tuple) if single_input: x = x[0] diff --git a/physicsnemo/experimental/nn/flare_attention.py b/physicsnemo/experimental/nn/flare_attention.py index 4ecc68a927..70a6c14f57 100644 --- a/physicsnemo/experimental/nn/flare_attention.py +++ b/physicsnemo/experimental/nn/flare_attention.py @@ -29,50 +29,12 @@ from jaxtyping import Float from physicsnemo.core.version_check import check_version_spec, OptionalImport -from physicsnemo.nn.module.physics_attention import _project_input # Check optional dependency availability TE_AVAILABLE = check_version_spec("transformer_engine", "0.1.0", hard_fail=False) te = OptionalImport("transformer_engine.pytorch", "0.1.0") -def _flare_self_attention( - x_mid: Float[torch.Tensor, "B H N D"], - q_global: nn.Parameter, - self_k: nn.Module, - self_v: nn.Module, - scale: float, -) -> Float[torch.Tensor, "B H N D"]: - r"""FLARE two-pass self-attention kernel. - - Computes low-rank attention via learned global queries: first aggregate - token values into global slots, then distribute back to tokens. - - Parameters - ---------- - x_mid : torch.Tensor - Projected input of shape :math:`(B, H, N, D)`. - q_global : nn.Parameter - Learned global queries of shape :math:`(1, H, S, D)`. - self_k : nn.Module - Key projection applied to ``x_mid``. - self_v : nn.Module - Value projection applied to ``x_mid``. - scale : float - Attention scale factor. - - Returns - ------- - torch.Tensor - Self-attended output of shape :math:`(B, H, N, D)`. - """ - G = q_global.to(dtype=x_mid.dtype).expand(x_mid.shape[0], -1, -1, -1) - k = self_k(x_mid) - v = self_v(x_mid) - z = F.scaled_dot_product_attention(G, k, v, scale=scale) - return F.scaled_dot_product_attention(k, G, z, scale=scale) - - class FLARE(nn.Module): r"""FLARE: Fast Low-rank Attention Routing Engine attention layer. Adopted: @@ -175,17 +137,35 @@ def forward(self, x: Float[torch.Tensor, "B N C"]) -> Float[torch.Tensor, "B N C Output tensor of shape :math:`(B, N, C)`, same shape as inputs. """ - x_mid = _project_input( - x, self.in_project_x, self.heads, self.dim_head, - "B N (H D) -> B N H D", + x_mid = self.in_project_x(x) + x_mid = rearrange( + x_mid, "B N (h d) -> B N h d", h=self.heads, d=self.dim_head ) - x_mid = x_mid.permute(0, 2, 1, 3) # (B, N, H, D) -> (B, H, N, D) + x_mid = x_mid.permute(0, 2, 1, 3) # [B, H, N, D] + G = self.q_global.to(dtype=x_mid.dtype).expand(x_mid.shape[0], -1, -1, -1) + k = self.self_k(x_mid) + v = self.self_v(x_mid) - y = _flare_self_attention( - x_mid, self.q_global, self.self_k, self.self_v, self.scale, - ) + # FLARE: Fast Low-rank Attention Routing Engine + if self.use_te: + # Transformer Engine expects (B, S, H, D) format + G = rearrange(G, "b h s d -> b s h d") + k = rearrange(k, "b h s d -> b s h d") + v = rearrange(v, "b h s d -> b s h d") + z = self.attn_fn(G, k, v) + z = rearrange( + z, "b s (h d) -> b s h d", h=self.heads, d=self.dim_head + ) + self_attention = self.attn_fn(k, G, z) + y = rearrange( + self_attention, "b s (h d) -> b h s d", h=self.heads, d=self.dim_head + ) + else: + # Use PyTorch's scaled dot-product attention + z = F.scaled_dot_product_attention(G, k, v, scale=self.scale) + y = F.scaled_dot_product_attention(k, G, z, scale=self.scale) - out_x = y.permute(0, 2, 1, 3) # (B, H, N, D) -> (B, N, H, D) + out_x = y.permute(0, 2, 1, 3) # [B, N, H, D] out_x = rearrange(out_x, "b n h d -> b n (h d)") out_x = self.out_linear(out_x) return self.out_dropout(out_x) diff --git a/physicsnemo/nn/module/physics_attention.py b/physicsnemo/nn/module/physics_attention.py index 22e0ff3ba8..2110a91331 100644 --- a/physicsnemo/nn/module/physics_attention.py +++ b/physicsnemo/nn/module/physics_attention.py @@ -35,6 +35,7 @@ SOFTWARE. """ +import importlib from abc import ABC, abstractmethod import torch @@ -44,165 +45,19 @@ from torch.autograd.profiler import record_function from torch.distributed.tensor.placement_types import Replicate -from physicsnemo.core.version_check import OptionalImport +from physicsnemo.core.version_check import check_version_spec from physicsnemo.nn import gumbel_softmax # Note: We use duck typing to check for ShardTensor instead of importing it # directly to avoid circular imports (domain_parallel imports from nn). # ShardTensor has a `redistribute` method that we check for. -te = OptionalImport("transformer_engine.pytorch") - - -def _project_input( - x: torch.Tensor, - project_x: nn.Module, - heads: int, - dim_head: int, - pattern: str, - project_fx: nn.Module | None = None, -) -> ( - Float[torch.Tensor, "B N H D"] - | tuple[ - Float[torch.Tensor, "B N H D"], - Float[torch.Tensor, "B N H D"], - ] -): - r"""Project input through one or two learned layers and rearrange to multi-head format. - - Universal building block for the ``(B, N, C) -> (B, N, H, D)`` projection - used by low attention mechanism in the repo: slice-based (Transolver / - GALE), FLARE, and future linear-attention variants. - - For structured grids the caller reshapes to spatial layout *before* calling - this function and passes the appropriate einops ``pattern``; the output is - always ``(B, N, H, D)`` with tokens in position 1, preserving - domain-parallel sharding over N. +TE_AVAILABLE = check_version_spec("transformer_engine", hard_fail=False) - Parameters - ---------- - x : torch.Tensor - Input tensor. For irregular meshes this is ``(B, N, C)``; for - structured grids it has already been reshaped to conv-friendly layout - (e.g. ``(B, C, H_s, W_s)``). - project_x : nn.Module - Primary projection layer (``nn.Linear``, ``nn.Conv2d``, etc.). - heads : int - Number of attention heads ``H``. - dim_head : int - Dimension per head ``D``. - pattern : str - Einops rearrange pattern that maps the projection output to - ``(B, N, H, D)``. Examples: - - - ``"B N (H D) -> B N H D"`` for linear projection - - ``"B (H D) h w -> B (h w) H D"`` for 2-D convolution - - ``"B (H D) h w d -> B (h w d) H D"`` for 3-D convolution - project_fx : nn.Module or None, optional - If provided, a second projection is applied to ``x`` and the function - returns a ``(px, pfx)`` tuple. Omit for single-projection callers - (Transolver++ ``plus=True``, FLARE, etc.). Default is ``None``. - - Returns - ------- - torch.Tensor or tuple[torch.Tensor, torch.Tensor] - Single tensor of shape :math:`(B, N, H, D)` when ``project_fx`` is - ``None``; otherwise a tuple ``(px, pfx)`` both of shape - :math:`(B, N, H, D)`. - """ - px = rearrange(project_x(x), pattern, H=heads, D=dim_head) - if project_fx is None: - return px - return px, rearrange(project_fx(x), pattern, H=heads, D=dim_head) - - -def _compute_slices_from_projections( - slice_projections: Float[torch.Tensor, "B N H S"], - fx: Float[torch.Tensor, "B N H D"], - temperature: torch.Tensor, - plus: bool, - proj_temperature: nn.Module | None = None, -) -> tuple[ - Float[torch.Tensor, "B N H S"], - Float[torch.Tensor, "B H S D"], -]: - r"""Compute slice weights and slice tokens from input projections. - - Standalone implementation of the temperature-scaled softmax slice - aggregation used by :class:`PhysicsAttentionBase` and reusable by any - module that needs the same project-to-slices-then-aggregate pattern - (e.g. :class:`~physicsnemo.experimental.models.geotransolver.context_projector.ContextProjector`). - - In domain-parallel settings, this performs an implicit allreduce when - summing over the sharded token dimension. - - Parameters - ---------- - slice_projections : torch.Tensor - Projected input of shape :math:`(B, N, H, S)` where :math:`H` is - number of attention heads and :math:`S` is number of physics slices. - fx : torch.Tensor - Latent features of shape :math:`(B, N, H, D)` where :math:`D` is - dimension per head. - temperature : torch.Tensor - Scalar temperature for softmax/gumbel, shape broadcastable to - ``slice_projections`` (typically :math:`(1, 1, H, 1)`). - plus : bool - If ``True``, use Gumbel softmax with optional adaptive temperature. - proj_temperature : nn.Module or None, optional - If ``plus`` is ``True``, module mapping :math:`(B, N, H, D)` to - adaptive temperature; ignored otherwise. Default is ``None``. - - Returns - ------- - tuple[torch.Tensor, torch.Tensor] - - ``slice_weights``: Shape :math:`(B, N, H, S)`, normalized weights - for each slice per token. - - ``slice_token``: Shape :math:`(B, H, S, D)`, aggregated features - per slice. - """ - # Compute temperature-scaled softmax over slices - if plus and proj_temperature is not None: - # Transolver++ uses learned per-token temperature - temp = temperature + proj_temperature(fx) - clamped_temp = torch.clamp(temp, min=0.01).to(slice_projections.dtype) - slice_weights = gumbel_softmax(slice_projections, clamped_temp) # (B, N, H, S) - else: - # Standard Transolver uses global temperature - clamped_temp = torch.clamp(temperature, min=0.5, max=5).to( - slice_projections.dtype - ) - slice_weights = nn.functional.softmax( - slice_projections / clamped_temp, dim=-1 - ) # (B, N, H, S) - - # Cast to the computation type (since the parameter is probably fp32) - slice_weights = slice_weights.to(slice_projections.dtype) - - # Computing the slice tokens is a matmul followed by a normalization. - # It can, unfortunately, overflow in reduced precision, so normalize first: - slice_norm = slice_weights.sum(1) + 1e-2 # (B, H, S) - # Sharded note: slice_norm will be a partial sum at this point. - # That's because the we're summing over the tokens, which are distributed - normed_weights = slice_weights / (slice_norm[:, None, :, :]) - # Normed weights has shape (B, N, H, S) - - # Sharded note: normed_weights will resolve the partial slice_norm - # and the output normed_weights will be sharded. - # fx has shape (B, N, H, D) - # This matmul needs to contract over the tokens - # This should produce an output with shape (B, H, S, D) - - # Like the weight norm, this sum is a **partial** sum since we are summing - # over the tokens - - # Aggregate features: (B, N, H, S)^T @ (B, N, H, D) -> (B, H, S, D) - slice_token = torch.matmul( - normed_weights.permute(0, 2, 3, 1), fx.permute(0, 2, 1, 3) - ) - - # Return the original weights, not the normed weights: - return slice_weights, slice_token +if TE_AVAILABLE: + te = importlib.import_module("transformer_engine.pytorch") +else: + te = None class PhysicsAttentionBase(nn.Module, ABC): @@ -370,9 +225,6 @@ def _compute_slices_from_projections( In domain-parallel settings, this performs an implicit allreduce when summing over the sharded token dimension. - Delegates to the module-level - :func:`_compute_slices_from_projections` free function. - Parameters ---------- slice_projections : torch.Tensor @@ -390,11 +242,56 @@ def _compute_slices_from_projections( - ``slice_token``: Shape :math:`(B, H, S, D)`, aggregated features per slice. """ - proj_temp = getattr(self, "proj_temperature", None) if self.plus else None - return _compute_slices_from_projections( - slice_projections, fx, self.temperature, self.plus, proj_temp + # Compute temperature-scaled softmax over slices + if self.plus: + # Transolver++ uses learned per-token temperature + temperature = self.temperature + self.proj_temperature(fx) + clamped_temp = torch.clamp(temperature, min=0.01).to( + slice_projections.dtype + ) + slice_weights = gumbel_softmax( + slice_projections, clamped_temp + ) # (B, N, H, S) + else: + # Standard Transolver uses global temperature + clamped_temp = torch.clamp(self.temperature, min=0.5, max=5).to( + slice_projections.dtype + ) + slice_weights = nn.functional.softmax( + slice_projections / clamped_temp, dim=-1 + ) # (B, N, H, S) + + # Cast to the computation type (since the parameter is probably fp32) + slice_weights = slice_weights.to(slice_projections.dtype) + + # This does the projection of the latent space fx by the weights: + + # Computing the slice tokens is a matmul followed by a normalization. + # It can, unfortunately, overflow in reduced precision, so normalize first: + slice_norm = slice_weights.sum(1) + 1e-2 # (B, H, S) + # Sharded note: slice_norm will be a partial sum at this point. + # That's because the we're summing over the tokens, which are distributed + normed_weights = slice_weights / (slice_norm[:, None, :, :]) + # Normed weights has shape (B, N, H, S) + + # Sharded note: normed_weights will resolve the partial slice_norm + # and the output normed_weights will be sharded. + # fx has shape (B, N, H, D) + # This matmul needs to contract over the tokens + # This should produce an output with shape (B, H, S, D) + + # Like the weight norm, this sum is a **partial** sum since we are summing + # over the tokens + + # Aggregate features: (B, N, H, S)^T @ (B, N, H, D) -> (B, H, S, D) + slice_token = torch.matmul( + normed_weights.permute(0, 2, 3, 1), fx.permute(0, 2, 1, 3) ) + # Return the original weights, not the normed weights: + + return slice_weights, slice_token + def _compute_slice_attention_te( self, slice_tokens: Float[torch.Tensor, "B H S D"] ) -> Float[torch.Tensor, "B H S D"]: @@ -647,16 +544,22 @@ def project_input_onto_slices( Projected tensors of shape :math:`(B, N, H, D)` where :math:`H` is number of attention heads and :math:`D` is dimension per head. """ - fx = None if self.plus else self.in_project_fx - return _project_input( - x, - self.in_project_x, - self.heads, - self.dim_head, - "B N (H D) -> B N H D", - project_fx=fx, + # Project and reshape to multi-head format + x_mid = rearrange( + self.in_project_x(x), "B N (H D) -> B N H D", H=self.heads, D=self.dim_head ) + if self.plus: + return x_mid + else: + fx_mid = rearrange( + self.in_project_fx(x), + "B N (H D) -> B N H D", + H=self.heads, + D=self.dim_head, + ) + return x_mid, fx_mid + class PhysicsAttentionStructuredMesh2D(PhysicsAttentionBase): r""" @@ -770,16 +673,27 @@ def project_input_onto_slices( x = x.view(B, self.H, self.W, C) x = x.permute(0, 3, 1, 2) - fx = None if self.plus else self.in_project_fx - return _project_input( - x, - self.in_project_x, - self.heads, - self.dim_head, + # Apply 2D convolution and reshape to multi-head format + input_projected_x = self.in_project_x(x) + input_projected_x = rearrange( + input_projected_x, "B (H D) h w -> B (h w) H D", - project_fx=fx, + D=self.dim_head, + H=self.heads, ) + if self.plus: + return input_projected_x + else: + input_projected_fx = self.in_project_fx(x) + input_projected_fx = rearrange( + input_projected_fx, + "B (H D) h w -> B (h w) H D", + D=self.dim_head, + H=self.heads, + ) + return input_projected_x, input_projected_fx + class PhysicsAttentionStructuredMesh3D(PhysicsAttentionBase): r""" @@ -894,12 +808,23 @@ def project_input_onto_slices( x = x.view(B, self.H, self.W, self.D, C) x = x.permute(0, 4, 1, 2, 3) - fx = None if self.plus else self.in_project_fx - return _project_input( - x, - self.in_project_x, - self.heads, - self.dim_head, - "B (H D) h w d -> B (h w d) H D", - project_fx=fx, + # Apply 3D convolution and reshape to multi-head format + input_projected_x = self.in_project_x(x) + input_projected_x = rearrange( + input_projected_x, + "B (H D) height width depth -> B (height width depth) H D", + D=self.dim_head, + H=self.heads, ) + + if self.plus: + return input_projected_x + else: + input_projected_fx = self.in_project_fx(x) + input_projected_fx = rearrange( + input_projected_fx, + "B (H D) height width depth -> B (height width depth) H D", + D=self.dim_head, + H=self.heads, + ) + return input_projected_x, input_projected_fx diff --git a/test/models/geotransolver/test_gale.py b/test/models/geotransolver/test_gale.py index 63db8c4a93..79ef54d8c5 100644 --- a/test/models/geotransolver/test_gale.py +++ b/test/models/geotransolver/test_gale.py @@ -19,9 +19,9 @@ from physicsnemo.experimental.models.geotransolver.gale import ( GALE, - GALE_FA, GALE_block, ) +from physicsnemo.experimental.models.geotransolver.gale_fa import GALE_FA # ============================================================================= # GALE (Geometry-Aware Latent Embeddings) Attention Tests diff --git a/test/models/geotransolver/test_geotransolver.py b/test/models/geotransolver/test_geotransolver.py index 821d1e923b..4888623b0b 100644 --- a/test/models/geotransolver/test_geotransolver.py +++ b/test/models/geotransolver/test_geotransolver.py @@ -558,98 +558,6 @@ def test_geotransolver_mismatched_functional_out_dims(): ) -def test_geotransolver_structured_rejects_local_features(): - """Ball-query local features are incompatible with structured_shape.""" - with pytest.raises(ValueError, match="include_local_features=True"): - GeoTransolver( - functional_dim=8, - out_dim=1, - structured_shape=(4, 4), - include_local_features=True, - geometry_dim=2, - use_te=False, - ) - - -def test_geotransolver_structured_2d_forward(device): - """Structured 2D: spatial input (B,H,W,C) and flattened (B,N,C); optional geometry.""" - torch.manual_seed(0) - H, W = 4, 4 - model = GeoTransolver( - functional_dim=3, - out_dim=2, - structured_shape=(H, W), - geometry_dim=2, - global_dim=None, - n_layers=2, - n_hidden=32, - n_head=4, - slice_num=8, - mlp_ratio=2, - use_te=False, - ).to(device) - B = 2 - x4 = torch.randn(B, H, W, 3, device=device) - g = torch.randn(B, H, W, 2, device=device) - y4 = model(x4, geometry=g) - assert y4.shape == (B, H, W, 2) - assert not torch.isnan(y4).any() - - x3 = x4.reshape(B, H * W, 3) - g3 = g.reshape(B, H * W, 2) - y3 = model(x3, geometry=g3) - assert y3.shape == (B, H * W, 2) - - y_none = model(x4) - assert y_none.shape == (B, H, W, 2) - - -def test_geotransolver_structured_3d_forward(device): - """Structured 3D voxel input (B,H,W,D,C).""" - torch.manual_seed(1) - H, W, Dg = 2, 2, 2 - model = GeoTransolver( - functional_dim=4, - out_dim=1, - structured_shape=(H, W, Dg), - n_layers=1, - n_hidden=32, - n_head=4, - slice_num=4, - mlp_ratio=2, - use_te=False, - ).to(device) - B = 1 - x = torch.randn(B, H, W, Dg, 4, device=device) - y = model(x) - assert y.shape == (B, H, W, Dg, 1) - - -def test_geotransolver_structured_global_context(device): - """Structured grid with global embedding context.""" - torch.manual_seed(2) - H, W = 4, 4 - model = GeoTransolver( - functional_dim=2, - out_dim=1, - structured_shape=(H, W), - geometry_dim=2, - global_dim=8, - n_layers=2, - n_hidden=32, - n_head=4, - slice_num=8, - mlp_ratio=2, - use_te=False, - ).to(device) - B = 2 - x = torch.randn(B, H, W, 2, device=device) - geo = torch.randn(B, H, W, 2, device=device) - glob = torch.randn(B, 3, 8, device=device) - y = model(x, geometry=geo, global_embedding=glob) - assert y.shape == (B, H, W, 1) - - # ============================================================================= # Activation Function Tests # =============================================================================