diff --git a/CHANGELOG.md b/CHANGELOG.md index b2bc6486ff..b262fdef69 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -167,6 +167,12 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 inversion example to use the consolidated `physicsnemo.diffusion` API (preconditioners, samplers, losses, DPS guidance) and removed the recipe-local copies of these utilities under `utils/`. +- Refactored the `examples/generative/topodiff` recipe to use the + consolidated `physicsnemo.diffusion` API (`MSEDSMLoss` with + `prediction_type="epsilon"`, `sample()`, `DPSScorePredictor`) plus a + recipe-local DDPM scheduler, solver, and classifier guidance. Removed + the now-unused `Diffusion`, `DatasetTopoDiff`, and `load_data_topodiff` + abstractions from `physicsnemo.models.topodiff`. - Significantly expanded CI test coverage for `physicsnemo.diffusion`, including new tests for samplers, solvers, preconditioners, losses, DPS guidance, multi-diffusion, and patching utilities, plus diff --git a/examples/generative/topodiff/inference.py b/examples/generative/topodiff/inference.py index 93c52b2184..e3c220e27b 100644 --- a/examples/generative/topodiff/inference.py +++ b/examples/generative/topodiff/inference.py @@ -14,20 +14,23 @@ # See the License for the specific language governing permissions and # limitations under the License. -import torch -import torch.nn.functional as F -from tqdm import trange -import numpy as np -import matplotlib.pyplot as plt - - import hydra +import matplotlib.pyplot as plt +import numpy as np +import torch from omegaconf import DictConfig - -from physicsnemo.models.topodiff import TopoDiff, Diffusion -from physicsnemo.models.topodiff import UNetEncoder +from utils import ( + ClassifierGuidance, + DDPMLinearNoiseScheduler, + DDPMSolver, + load_data, + load_data_topodiff, +) + +from physicsnemo.diffusion.guidance import DPSScorePredictor +from physicsnemo.diffusion.samplers import sample +from physicsnemo.models.topodiff import TopoDiff, UNetEncoder from physicsnemo.utils.logging import PythonLogger -from utils import load_data_topodiff, load_data @hydra.main(version_base="1.3", config_path="conf", config_name="config") @@ -60,7 +63,9 @@ def main(cfg: DictConfig) -> None: classifier.load_state_dict(torch.load(cfg.model_path_classifier)) classifier.to(device) - diffusion = Diffusion(n_steps=1000, device=device) + n_steps = cfg.diffusion_steps + scheduler = DDPMLinearNoiseScheduler(n_steps=n_steps) + batch_size = cfg.batch_size data = load_data_topodiff( topologies, @@ -71,59 +76,64 @@ def main(cfg: DictConfig) -> None: ) _, cons = next(data) - cons = cons.float().to(device) - n_steps = 1000 + # Epsilon predictor (TopoDiff model with fixed conditions) + def eps_predictor(x, t): + with torch.no_grad(): + return model(x, cons, t.long()) - xt = torch.randn(batch_size, 1, 64, 64).to(device) + # X0 predictor (convert epsilon -> x0) + def x0_predictor(x, t): + eps = eps_predictor(x, t) + return scheduler.epsilon_to_x0(eps, x, t) + + # Classifier guidance floating_labels = torch.tensor([1] * batch_size).long().to(device) + guidance = ClassifierGuidance(classifier, floating_labels, scale=0.2) - for i in reversed(trange(n_steps)): - with torch.no_grad(): - t = torch.tensor([i] * batch_size, device=device) - noisy = diffusion.p_sample(model, xt, t, cons) - - with torch.enable_grad(): - xt.requires_grad_(True) - logits = classifier(xt, time_steps=t) - loss = F.cross_entropy(logits, floating_labels) - - grad = torch.autograd.grad(loss, xt)[0] - - xt = ( - 1 - / diffusion.alphas[i].sqrt() - * ( - xt - - noisy - * (1 - diffusion.alphas[i]) - / (1 - diffusion.alpha_bars[i]).sqrt() - ) - ) + # DPS guided score predictor (framework component) + dps_score = DPSScorePredictor( + x0_predictor=x0_predictor, + x0_to_score_fn=scheduler.x0_to_score, + guidances=guidance, + ) + + # DDPM solver (no stochastic noise, matching original) + solver = DDPMSolver(dps_score, scheduler, stochastic=False) - if i > 0: - z = torch.zeros_like(xt).to(device) - xt = xt + diffusion.betas[i].sqrt() * (z * 0.8 + 0.2 * grad.float()) + # Generate samples + xt = torch.randn(batch_size, 1, 64, 64).to(device) + + # Note: the denoiser arg is required by sample() but unused when a custom + # Solver is provided — the DDPMSolver uses its own score_predictor internally. + with torch.inference_mode(False): + xt = sample( + denoiser=scheduler.get_denoiser(score_predictor=dps_score), + xN=xt, + noise_scheduler=scheduler, + num_steps=n_steps, + solver=solver, + ) result = (xt.cpu().detach().numpy() + 1) * 2 np.save(cfg.generation_path + "results_topology.npy", result) # plot images for the generated samples - fig, axes = plt.subplots(8, 8, figsize=(12, 6), dpi=300) - - for i in range(8): - for j in range(8): - img = result[i * 4 + j][0] - axes[i, j].imshow(img, cmap="gray") - axes[i, j].set_xticks([]) - axes[i, j].set_yticks([]) - - plt.xticks([]) # Remove x-axis ticks - plt.yticks([]) # Remove y-axis ticks - plt.gca().xaxis.set_visible(False) # Optionally hide x-axis - plt.gca().yaxis.set_visible(False) # Optionally hide y-axis + n_samples = result.shape[0] + ncols = min(8, n_samples) + nrows = min(8, (n_samples + ncols - 1) // ncols) + fig, axes = plt.subplots(nrows, ncols, figsize=(12, 6), dpi=300, squeeze=False) + + for idx in range(min(nrows * ncols, n_samples)): + r, c = divmod(idx, ncols) + axes[r, c].imshow(result[idx][0], cmap="gray") + axes[r, c].set_xticks([]) + axes[r, c].set_yticks([]) + for idx in range(n_samples, nrows * ncols): + r, c = divmod(idx, ncols) + axes[r, c].axis("off") plt.savefig( cfg.generation_path + "grid_topology.png", bbox_inches="tight", pad_inches=0 diff --git a/examples/generative/topodiff/train.py b/examples/generative/topodiff/train.py index 3705708c74..ffa0521d9c 100644 --- a/examples/generative/topodiff/train.py +++ b/examples/generative/topodiff/train.py @@ -14,17 +14,16 @@ # See the License for the specific language governing permissions and # limitations under the License.cd .. +import hydra import torch +from omegaconf import DictConfig from torch.optim import AdamW from tqdm import trange +from utils import DDPMLinearNoiseScheduler, load_data, load_data_topodiff - -import hydra -from omegaconf import DictConfig - -from physicsnemo.models.topodiff import TopoDiff, Diffusion +from physicsnemo.diffusion.metrics.losses import MSEDSMLoss +from physicsnemo.models.topodiff import TopoDiff from physicsnemo.utils.logging import PythonLogger -from utils import load_data_topodiff, load_data @hydra.main(version_base="1.3", config_path="conf", config_name="config") @@ -34,7 +33,18 @@ def main(cfg: DictConfig) -> None: device = torch.device("cuda:0") model = TopoDiff(64, 6, 1, model_channels=128, attn_resolutions=[16, 8]).to(device) - diffusion = Diffusion(n_steps=1000, device=device) + scheduler = DDPMLinearNoiseScheduler(n_steps=1000) + + # Adapt TopoDiff to DiffusionModel protocol (epsilon-predictor) + def diffusion_model(x, t, condition=None, **kwargs): + return model(x, condition, t.long()) + + loss_fn = MSEDSMLoss( + model=diffusion_model, + noise_scheduler=scheduler, + prediction_type="epsilon", + epsilon_to_x0_fn=scheduler.epsilon_to_x0, + ) topologies = load_data( cfg.path_training_data_diffusion, cfg.prefix_topology_file, ".png", 0, 30000 @@ -67,7 +77,7 @@ def main(cfg: DictConfig) -> None: tops = tops.float().to(device) cons = cons.float().to(device) - losses = diffusion.train_loss(model, tops, cons) + losses = loss_fn(x0=tops, condition=cons) optimizer.zero_grad() losses.backward() diff --git a/examples/generative/topodiff/train_classifier.py b/examples/generative/topodiff/train_classifier.py index 7cba9057d6..34295b0059 100644 --- a/examples/generative/topodiff/train_classifier.py +++ b/examples/generative/topodiff/train_classifier.py @@ -14,20 +14,17 @@ # See the License for the specific language governing permissions and # limitations under the License. +import hydra +import numpy as np import torch import torch.nn.functional as F +from omegaconf import DictConfig from torch.optim import AdamW from torch.optim.lr_scheduler import LinearLR -import numpy as np - - -import hydra -from omegaconf import DictConfig +from utils import DDPMLinearNoiseScheduler, load_data_classifier -from physicsnemo.models.topodiff import Diffusion from physicsnemo.models.topodiff import UNetEncoder from physicsnemo.utils.logging import PythonLogger -from utils import load_data_classifier @hydra.main(version_base="1.3", config_path="conf", config_name="config") @@ -44,7 +41,7 @@ def main(cfg: DictConfig) -> None: classifier = UNetEncoder(in_channels=1, out_channels=2).to(device) - diffusion = Diffusion(n_steps=cfg.diffusion_steps, device=device) + noise_scheduler = DDPMLinearNoiseScheduler(n_steps=cfg.diffusion_steps) batch_size = cfg.batch_size @@ -64,8 +61,8 @@ def main(cfg: DictConfig) -> None: batch = torch.tensor(train_img[idx]).float().unsqueeze(1).to(device) * 2 - 1 batch_labels = torch.tensor(train_labels[idx]).long().to(device) - t = torch.randint(0, cfg.diffusion_steps, (batch.shape[0],)).to(device) - batch = diffusion.q_sample(batch, t) + t = noise_scheduler.sample_time(batch.shape[0], device=device) + batch = noise_scheduler.add_noise(batch, t) logits = classifier(batch, time_steps=t) loss = F.cross_entropy(logits, batch_labels) @@ -83,8 +80,8 @@ def main(cfg: DictConfig) -> None: batch_labels = torch.tensor(valid_labels[idx]).long().to(device) # Sample diffusion steps and get noised images - t = torch.randint(0, cfg.diffusion_steps, (batch.shape[0],)).to(device) - batch = diffusion.q_sample(batch, t) + t = noise_scheduler.sample_time(batch.shape[0], device=device) + batch = noise_scheduler.add_noise(batch, t) # Forward pass logits = classifier(batch, time_steps=t) diff --git a/examples/generative/topodiff/train_regressor.py b/examples/generative/topodiff/train_regressor.py index 3c7aa9e807..cb9e37a659 100644 --- a/examples/generative/topodiff/train_regressor.py +++ b/examples/generative/topodiff/train_regressor.py @@ -14,20 +14,17 @@ # See the License for the specific language governing permissions and # limitations under the License. +import hydra +import numpy as np import torch import torch.nn as nn +from omegaconf import DictConfig from torch.optim import AdamW from torch.optim.lr_scheduler import LinearLR -import numpy as np - - -import hydra -from omegaconf import DictConfig +from utils import DDPMLinearNoiseScheduler, load_data_regressor -from physicsnemo.models.topodiff import Diffusion from physicsnemo.models.topodiff import UNetEncoder from physicsnemo.utils.logging import PythonLogger -from utils import load_data_regressor @hydra.main(version_base="1.3", config_path="conf", config_name="config") @@ -54,7 +51,7 @@ def main(cfg: DictConfig) -> None: in_channels = 6 regressor = UNetEncoder(in_channels=in_channels, out_channels=1).to(device) - diffusion = Diffusion(n_steps=cfg.diffusion_steps, device=device) + noise_scheduler = DDPMLinearNoiseScheduler(n_steps=cfg.diffusion_steps) batch_size = cfg.batch_size """ @@ -85,8 +82,8 @@ def main(cfg: DictConfig) -> None: batch_labels = torch.tensor(labels[idx]).float().to(device).unsqueeze(1) - t = torch.randint(0, cfg.diffusion_steps, (batch.shape[0],)).to(device) - batch = diffusion.q_sample(batch, t) + t = noise_scheduler.sample_time(batch.shape[0], device=device) + batch = noise_scheduler.add_noise(batch, t) batch = torch.cat((batch, batch_pf, batch_load), dim=1) diff --git a/examples/generative/topodiff/utils.py b/examples/generative/topodiff/utils.py index 62ee814329..ebbcb6dc09 100644 --- a/examples/generative/topodiff/utils.py +++ b/examples/generative/topodiff/utils.py @@ -14,10 +14,17 @@ # See the License for the specific language governing permissions and # limitations under the License. +import os + import numpy as np +import torch +import torch.nn.functional as F +from jaxtyping import Float from PIL import Image -from torch.utils.data import Dataset, DataLoader -import os +from torch import Tensor +from torch.utils.data import DataLoader, Dataset + +from physicsnemo.diffusion.noise_schedulers import LinearGaussianNoiseScheduler class DiffusionDataset_topodiff(Dataset): @@ -127,3 +134,240 @@ def load_data_classifier(root): return np.array(image_list).astype(np.float64), np.array(label_list).astype( np.float64 ) + + +# ============================================================================= +# Diffusion Framework Components +# ============================================================================= + + +class DDPMLinearNoiseScheduler(LinearGaussianNoiseScheduler): + r"""Discrete DDPM noise scheduler with a linear beta schedule. + + Implements the classic DDPM forward process with + :math:`\beta_t = \text{linspace}(\beta_{\min}, \beta_{\max}, T)` and + :math:`\bar\alpha_t = \prod_{s=1}^{t}(1-\beta_s)`. + + Time values are discrete integer indices in :math:`[0, T)` represented as + float tensors. + + Parameters + ---------- + n_steps : int + Number of diffusion steps :math:`T`. + min_beta : float + Minimum beta in the linear schedule. + max_beta : float + Maximum beta in the linear schedule. + """ + + def __init__( + self, + n_steps: int = 1000, + min_beta: float = 1e-4, + max_beta: float = 0.02, + ) -> None: + self.n_steps = n_steps + self.min_beta = min_beta + self.max_beta = max_beta + + betas = torch.linspace(min_beta, max_beta, n_steps) + alphas_individual = 1.0 - betas + alpha_bars = torch.cumprod(alphas_individual, dim=0) + + # Store precomputed tables (registered as plain tensors, not parameters) + self._betas = betas + self._alphas_individual = alphas_individual + self._alpha_bars = alpha_bars + self._sqrt_alpha_bars = alpha_bars.sqrt() + self._sqrt_one_minus_alpha_bars = (1.0 - alpha_bars).sqrt() + + def _index(self, t: Tensor) -> Tensor: + """Clamp and convert time to integer indices.""" + return t.long().clamp(0, self.n_steps - 1) + + def alpha(self, t: Float[Tensor, " *shape"]) -> Float[Tensor, " *shape"]: + r"""Signal coefficient :math:`\sqrt{\bar\alpha_t}`.""" + return self._sqrt_alpha_bars.to(t.device)[self._index(t)] + + def alpha_dot(self, t: Float[Tensor, " *shape"]) -> Float[Tensor, " *shape"]: + r"""Finite-difference approximation of :math:`\dot\alpha(t)`.""" + idx = self._index(t) + table = self._sqrt_alpha_bars.to(t.device) + cur = table[idx] + prev = table[(idx - 1).clamp(min=0)] + return cur - prev + + def sigma(self, t: Float[Tensor, " *shape"]) -> Float[Tensor, " *shape"]: + r"""Noise level :math:`\sqrt{1-\bar\alpha_t}`.""" + return self._sqrt_one_minus_alpha_bars.to(t.device)[self._index(t)] + + def sigma_dot(self, t: Float[Tensor, " *shape"]) -> Float[Tensor, " *shape"]: + r"""Finite-difference approximation of :math:`\dot\sigma(t)`.""" + idx = self._index(t) + table = self._sqrt_one_minus_alpha_bars.to(t.device) + cur = table[idx] + prev = table[(idx - 1).clamp(min=0)] + return cur - prev + + def sigma_inv(self, sigma: Float[Tensor, " *shape"]) -> Float[Tensor, " *shape"]: + r"""Inverse mapping: find closest discrete timestep for a given sigma.""" + table = self._sqrt_one_minus_alpha_bars.to(sigma.device) + # For each sigma value, find the closest index + diffs = (table.unsqueeze(0) - sigma.reshape(-1, 1)).abs() + indices = diffs.argmin(dim=1) + return indices.to(sigma.dtype).reshape(sigma.shape) + + def timesteps( + self, + num_steps: int, + *, + device: torch.device | None = None, + dtype: torch.dtype | None = None, + ) -> Float[Tensor, " N+1"]: + r"""Generate discrete DDPM timesteps from ``n_steps-1`` down to ``0``.""" + if num_steps >= self.n_steps: + indices = torch.arange(self.n_steps - 1, -1, -1, device=device, dtype=dtype) + else: + step_indices = torch.arange(num_steps, device=device, dtype=dtype) + scale = (self.n_steps - 1) / (num_steps - 1) + indices = (scale * (num_steps - 1 - step_indices)).round() + if dtype is not None: + indices = indices.to(dtype) + zero = torch.zeros(1, device=device, dtype=dtype) + return torch.cat([indices, zero]) + + def sample_time( + self, + N: int, + *, + device: torch.device | None = None, + dtype: torch.dtype | None = None, + ) -> Float[Tensor, " N"]: + r"""Sample N random discrete timesteps uniformly in :math:`[0, T)`.""" + t = torch.randint(0, self.n_steps, (N,), device=device) + if dtype is not None: + t = t.to(dtype) + else: + t = t.float() + return t + + def loss_weight(self, t: Float[Tensor, " N"]) -> Float[Tensor, " N"]: + r"""Loss weight: :math:`\alpha(t)^2/\sigma(t)^2`. + + This ensures that ``loss_weight * ||x0_hat - x0||^2`` is equivalent to + the uniform-weighted epsilon-MSE loss ``||eps_hat - eps||^2``. + """ + a = self.alpha(t) + s = self.sigma(t) + return (a / s) ** 2 + + def betas_at(self, t: Tensor) -> Tensor: + """Return individual beta values at discrete timestep ``t``.""" + return self._betas.to(t.device)[self._index(t)] + + def alphas_individual_at(self, t: Tensor) -> Tensor: + """Return individual alpha values (1 - beta) at timestep ``t``.""" + return self._alphas_individual.to(t.device)[self._index(t)] + + +class DDPMSolver: + r"""Discrete DDPM reverse-step solver. + + Implements the DDPM posterior mean update: + + .. math:: + \mu_\theta(\mathbf{x}_t, t) + = \frac{1}{\sqrt{\alpha_t}} + \left(\mathbf{x}_t + - \frac{\beta_t}{\sqrt{1-\bar\alpha_t}}\hat\epsilon_\theta + \right) + + with optional posterior noise: + + .. math:: + \mathbf{x}_{t-1} = \mu_\theta + \sqrt{\tilde\beta_t}\,\mathbf{z} + + Parameters + ---------- + score_predictor : callable + A score predictor ``(x, t) -> score``. The solver converts score to + epsilon internally via ``scheduler.score_to_epsilon``. + noise_scheduler : DDPMLinearNoiseScheduler + Scheduler providing alpha/sigma/beta values. + stochastic : bool + If True, inject posterior noise at each step (except the final one). + """ + + def __init__(self, score_predictor, noise_scheduler, stochastic=True): + self.score_predictor = score_predictor + self.scheduler = noise_scheduler + self.stochastic = stochastic + + def step( + self, + x: Float[Tensor, " B *dims"], + t_cur: Float[Tensor, " B"], + t_next: Float[Tensor, " B"], + ) -> Float[Tensor, " B *dims"]: + """Perform one discrete DDPM reverse step from ``t_cur`` to ``t_next``.""" + score = self.score_predictor(x, t_cur) + eps = self.scheduler.score_to_epsilon(score, t_cur) + + # Individual (non-cumulative) alpha and beta at t_cur + alpha_ind = self.scheduler.alphas_individual_at(t_cur) + beta = self.scheduler.betas_at(t_cur) + sigma_t = self.scheduler.sigma(t_cur) + + # Reshape for broadcasting + def _bc(v): + return v.reshape(-1, *([1] * (x.ndim - 1))) + + # DDPM posterior mean + mu = _bc(1.0 / alpha_ind.sqrt()) * (x - _bc(beta / sigma_t) * eps) + + # Posterior noise (skip at t=0) + if self.stochastic: + noise_mask = (t_next > 0).float() + if noise_mask.any(): + mu = mu + _bc(noise_mask) * _bc(beta.sqrt()) * torch.randn_like(x) + + return mu + + +class ClassifierGuidance: + r"""DPS-compatible classifier guidance on noisy samples. + + Computes :math:`\gamma\,\nabla_{\mathbf{x}}\log p(y|\mathbf{x}_t, t)` by + running a time-aware classifier on the noisy sample and differentiating. + + Implements the :class:`~physicsnemo.diffusion.guidance.DPSGuidance` protocol. + + Parameters + ---------- + classifier : callable + Classifier ``(x, time_steps=t) -> logits``. + labels : Tensor + Target class labels of shape ``(B,)``. + scale : float + Guidance scale :math:`\gamma`. + """ + + def __init__(self, classifier, labels, scale=1.0): + self.classifier = classifier + self.labels = labels + self.scale = scale + + def __call__( + self, + x: Float[Tensor, " B *dims"], + t: Float[Tensor, " B"], + x_0: Float[Tensor, " B *dims"], + ) -> Float[Tensor, " B *dims"]: + with torch.enable_grad(): + x_grad = x.detach().requires_grad_(True) + logits = self.classifier(x_grad, time_steps=t.long()) + loss = F.cross_entropy(logits, self.labels[: x.shape[0]]) + grad = torch.autograd.grad(loss, x_grad)[0] + # Negate: grad(CE) = -grad(log p(y|x)), and we want +grad(log p(y|x)) + return -self.scale * grad diff --git a/physicsnemo/models/topodiff/__init__.py b/physicsnemo/models/topodiff/__init__.py index f0fb9890e7..19d297b482 100644 --- a/physicsnemo/models/topodiff/__init__.py +++ b/physicsnemo/models/topodiff/__init__.py @@ -14,16 +14,11 @@ # See the License for the specific language governing permissions and # limitations under the License. -"""TopoDiff models and utilities.""" +"""TopoDiff models.""" -from .diffusion import Diffusion from .topodiff import TopoDiff, UNetEncoder -from .utils import DatasetTopoDiff, load_data_topodiff __all__ = [ "TopoDiff", "UNetEncoder", - "Diffusion", - "DatasetTopoDiff", - "load_data_topodiff", ] diff --git a/physicsnemo/models/topodiff/diffusion.py b/physicsnemo/models/topodiff/diffusion.py deleted file mode 100644 index 060e8986ff..0000000000 --- a/physicsnemo/models/topodiff/diffusion.py +++ /dev/null @@ -1,145 +0,0 @@ -# SPDX-FileCopyrightText: Copyright (c) 2023 - 2026 NVIDIA CORPORATION & AFFILIATES. -# SPDX-FileCopyrightText: All rights reserved. -# SPDX-License-Identifier: Apache-2.0 -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import torch -import torch.nn as nn -import torch.nn.functional as F - - -class Diffusion: - r"""Diffusion scheduler for TopoDiff.""" - - def __init__( - self, - n_steps: int = 1000, - min_beta: float = 10**-4, - max_beta: float = 0.02, - device: str = "cpu", - ): - r"""Initialize the diffusion schedule. - - Parameters - ---------- - n_steps : int, optional, default=1000 - Number of diffusion steps. - min_beta : float, optional, default=1e-4 - Minimum beta in the linear schedule. - max_beta : float, optional, default=0.02 - Maximum beta in the linear schedule. - device : str, optional, default="cpu" - Target device string for tensors. - """ - self.n_steps = n_steps - self.device = device - - self.betas = torch.linspace(min_beta, max_beta, self.n_steps).to(device) - - self.alphas = 1 - self.betas - - self.alpha_bars = torch.cumprod(self.alphas, 0).to(device) - - self.alpha_bars_prev = F.pad(self.alpha_bars[:-1], [1, 0], "constant", 0) - - self.posterior_variance = ( - self.betas * (1.0 - self.alpha_bars_prev) / (1.0 - self.alpha_bars) - ) - - self.loss = nn.MSELoss() - - def q_sample( - self, x0: torch.Tensor, t: torch.Tensor, noise: torch.Tensor | None = None - ) -> torch.Tensor: - r"""Diffuse the input data (forward process). - - Parameters - ---------- - x0 : torch.Tensor - Clean samples :math:`(N, C, H, W)`. - t : torch.Tensor - Timestep indices :math:`(N,)`. - noise : torch.Tensor, optional - Optional noise tensor; if ``None`` sampled from standard normal. - - Returns - ------- - torch.Tensor - Noised samples :math:`x_t`. - """ - - if noise is None: - noise = torch.rand_like(x0).to(self.device) - - alpha_bars = self.alpha_bars[t] - - x = ( - alpha_bars.sqrt()[:, None, None, None] * x0 - + (1 - alpha_bars).sqrt()[:, None, None, None] * noise - ) - - return x - - def p_sample( - self, model, xt: torch.Tensor, t: torch.Tensor, cons: torch.Tensor - ) -> torch.Tensor: - r"""Predict noise using the model (reverse process). - - Parameters - ---------- - model : torch.nn.Module - Denoiser that predicts noise given ``(x_t, cons, t)``. - xt : torch.Tensor - Noised samples :math:`x_t`. - t : torch.Tensor - Timestep indices :math:`(N,)`. - cons : torch.Tensor - Constraint tensor concatenated in the model. - - Returns - ------- - torch.Tensor - Predicted noise tensor. - """ - - return model(xt, cons, t) - - def train_loss(self, model, x0: torch.Tensor, cons: torch.Tensor) -> torch.Tensor: - r"""Compute training loss for diffusion denoiser. - - Parameters - ---------- - model : torch.nn.Module - Denoiser model. - x0 : torch.Tensor - Clean inputs :math:`(N, C, H, W)`. - cons : torch.Tensor - Constraint tensor :math:`(N, C_{cons}, H, W)`. - - Returns - ------- - torch.Tensor - Scalar loss tensor. - """ - - b, c, w, h = x0.shape - noise = torch.randn_like(x0).to(self.device) - - t = torch.randint(0, self.n_steps, (b,)).to(self.device) - - xt = self.q_sample(x0, t, noise) - - pred_noise = self.p_sample(model, xt, t, cons) - - return self.loss(pred_noise, noise) diff --git a/physicsnemo/models/topodiff/utils.py b/physicsnemo/models/topodiff/utils.py deleted file mode 100644 index 79902f8ba6..0000000000 --- a/physicsnemo/models/topodiff/utils.py +++ /dev/null @@ -1,124 +0,0 @@ -# SPDX-FileCopyrightText: Copyright (c) 2023 - 2026 NVIDIA CORPORATION & AFFILIATES. -# SPDX-FileCopyrightText: All rights reserved. -# SPDX-License-Identifier: Apache-2.0 -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from typing import Generator, Tuple - -import numpy as np -from torch.utils.data import DataLoader, Dataset - - -class DatasetTopoDiff(Dataset): - r"""Dataset wrapper for TopoDiff training. - - Parameters - ---------- - topologies : np.ndarray - Array of binary topology images of shape ``(N, H, W)`` in ``[0,1]``. - stress : np.ndarray - Array of scalar stress fields of shape ``(N, H, W)``. - strain : np.ndarray - Array of scalar strain fields of shape ``(N, H, W)``. - load_im : np.ndarray - Array of load images of shape ``(N, H, W, 2)`` representing load vectors. - constraints : list[dict] - List of dictionaries containing per-sample constraints such as - ``{"VOL_FRAC": float}``. - """ - - def __init__(self, topologies, stress, strain, load_im, constraints): - self.topologies = topologies - self.constraints = constraints - self.image_size = topologies.shape[1] - - self.stress = stress - self.strain = strain - self.load_im = load_im - - def __len__(self) -> int: - return self.topologies.shape[0] - - def __getitem__(self, idx) -> Tuple[np.ndarray, np.ndarray]: - r"""Return a single sample for TopoDiff training. - - Returns - ------- - np.ndarray - Topology tensor of shape ``(1, H, W)`` scaled to ``[-1, 1]``. - np.ndarray - Constraint tensor of shape ``(5, H, W)`` composed of - ``[stress, strain, load_x, load_y, vol_frac]``. - """ - cons = self.constraints[idx] - - vol_frac = cons["VOL_FRAC"] - - cons = np.zeros((5, self.image_size, self.image_size)) - - cons[0] = self.stress[idx] - cons[1] = self.strain[idx] - cons[2] = self.load_im[idx][:, :, 0] - cons[3] = self.load_im[idx][:, :, 1] - cons[4] = np.ones((self.image_size, self.image_size)) * vol_frac - - return np.expand_dims(self.topologies[idx], 0) * 2 - 1, cons - - -def load_data_topodiff( - topologies, - constraints, - stress, - strain, - load_img, - batch_size: int, - deterministic: bool = False, -) -> Generator[Tuple[np.ndarray, np.ndarray], None, None]: - r"""Build an iterator over the TopoDiff dataset. - - Parameters - ---------- - topologies : np.ndarray - Topology images ``(N, H, W)``. - constraints : list[dict] - Per-sample constraints dicts (expects key ``"VOL_FRAC"``). - stress : np.ndarray - Stress fields ``(N, H, W)``. - strain : np.ndarray - Strain fields ``(N, H, W)``. - load_img : np.ndarray - Load images ``(N, H, W, 2)``. - batch_size : int - Mini-batch size. - deterministic : bool, optional, default=False - If ``True``, disables shuffling. - - Returns - ------- - Iterator[Tuple[np.ndarray, np.ndarray]] - Iterator over batches of ``(topology, constraints)`` for training. - """ - dataset = DatasetTopoDiff(topologies, stress, strain, load_img, constraints) - - if deterministic: - loader = DataLoader( - dataset, batch_size=batch_size, shuffle=False, num_workers=1, drop_last=True - ) - else: - loader = DataLoader( - dataset, batch_size=batch_size, shuffle=True, num_workers=1, drop_last=True - ) - - while True: - yield from loader diff --git a/pyproject.toml b/pyproject.toml index f4e06a2bda..2ce187394d 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -319,6 +319,7 @@ exclude = ["docs", "physicsnemo/experimental"] # Ignore `S101` (assertions) in all `test` files. "test/*.py" = ["S101"] +"examples/**/test_*.py" = ["S101"]