Skip to content
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
116 changes: 63 additions & 53 deletions examples/generative/topodiff/inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,20 +14,23 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import torch
import torch.nn.functional as F
from tqdm import trange
import numpy as np
import matplotlib.pyplot as plt


import hydra
import matplotlib.pyplot as plt
import numpy as np
import torch
from omegaconf import DictConfig

from physicsnemo.models.topodiff import TopoDiff, Diffusion
from physicsnemo.models.topodiff import UNetEncoder
from utils import (
ClassifierGuidance,
DDPMLinearNoiseScheduler,
DDPMSolver,
load_data,
load_data_topodiff,
)

from physicsnemo.diffusion.guidance import DPSScorePredictor
from physicsnemo.diffusion.samplers import sample
from physicsnemo.models.topodiff import TopoDiff, UNetEncoder
from physicsnemo.utils.logging import PythonLogger
from utils import load_data_topodiff, load_data


@hydra.main(version_base="1.3", config_path="conf", config_name="config")
Expand Down Expand Up @@ -60,7 +63,9 @@ def main(cfg: DictConfig) -> None:
classifier.load_state_dict(torch.load(cfg.model_path_classifier))
classifier.to(device)

diffusion = Diffusion(n_steps=1000, device=device)
n_steps = cfg.diffusion_steps
scheduler = DDPMLinearNoiseScheduler(n_steps=n_steps)

batch_size = cfg.batch_size
data = load_data_topodiff(
topologies,
Expand All @@ -71,59 +76,64 @@ def main(cfg: DictConfig) -> None:
)

_, cons = next(data)

cons = cons.float().to(device)

n_steps = 1000
# Epsilon predictor (TopoDiff model with fixed conditions)
def eps_predictor(x, t):
with torch.no_grad():
return model(x, cons, t.long())

xt = torch.randn(batch_size, 1, 64, 64).to(device)
# X0 predictor (convert epsilon -> x0)
def x0_predictor(x, t):
eps = eps_predictor(x, t)
return scheduler.epsilon_to_x0(eps, x, t)

# Classifier guidance
floating_labels = torch.tensor([1] * batch_size).long().to(device)
guidance = ClassifierGuidance(classifier, floating_labels, scale=0.2)

for i in reversed(trange(n_steps)):
with torch.no_grad():
t = torch.tensor([i] * batch_size, device=device)
noisy = diffusion.p_sample(model, xt, t, cons)

with torch.enable_grad():
xt.requires_grad_(True)
logits = classifier(xt, time_steps=t)
loss = F.cross_entropy(logits, floating_labels)

grad = torch.autograd.grad(loss, xt)[0]

xt = (
1
/ diffusion.alphas[i].sqrt()
* (
xt
- noisy
* (1 - diffusion.alphas[i])
/ (1 - diffusion.alpha_bars[i]).sqrt()
)
)
# DPS guided score predictor (framework component)
dps_score = DPSScorePredictor(
x0_predictor=x0_predictor,
x0_to_score_fn=scheduler.x0_to_score,
guidances=guidance,
)

# DDPM solver (no stochastic noise, matching original)
solver = DDPMSolver(dps_score, scheduler, stochastic=False)

if i > 0:
z = torch.zeros_like(xt).to(device)
xt = xt + diffusion.betas[i].sqrt() * (z * 0.8 + 0.2 * grad.float())
# Generate samples
xt = torch.randn(batch_size, 1, 64, 64).to(device)

# Note: the denoiser arg is required by sample() but unused when a custom
# Solver is provided — the DDPMSolver uses its own score_predictor internally.
with torch.inference_mode(False):
xt = sample(
denoiser=scheduler.get_denoiser(score_predictor=dps_score),
xN=xt,
noise_scheduler=scheduler,
num_steps=n_steps,
solver=solver,
)

result = (xt.cpu().detach().numpy() + 1) * 2

np.save(cfg.generation_path + "results_topology.npy", result)

# plot images for the generated samples
fig, axes = plt.subplots(8, 8, figsize=(12, 6), dpi=300)

for i in range(8):
for j in range(8):
img = result[i * 4 + j][0]
axes[i, j].imshow(img, cmap="gray")
axes[i, j].set_xticks([])
axes[i, j].set_yticks([])

plt.xticks([]) # Remove x-axis ticks
plt.yticks([]) # Remove y-axis ticks
plt.gca().xaxis.set_visible(False) # Optionally hide x-axis
plt.gca().yaxis.set_visible(False) # Optionally hide y-axis
n_samples = result.shape[0]
ncols = min(8, n_samples)
nrows = min(8, (n_samples + ncols - 1) // ncols)
fig, axes = plt.subplots(nrows, ncols, figsize=(12, 6), dpi=300, squeeze=False)

for idx in range(min(nrows * ncols, n_samples)):
r, c = divmod(idx, ncols)
axes[r, c].imshow(result[idx][0], cmap="gray")
axes[r, c].set_xticks([])
axes[r, c].set_yticks([])
for idx in range(n_samples, nrows * ncols):
r, c = divmod(idx, ncols)
axes[r, c].axis("off")

plt.savefig(
cfg.generation_path + "grid_topology.png", bbox_inches="tight", pad_inches=0
Expand Down
26 changes: 18 additions & 8 deletions examples/generative/topodiff/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,17 +14,16 @@
# See the License for the specific language governing permissions and
# limitations under the License.cd ..

import hydra
import torch
from omegaconf import DictConfig
from torch.optim import AdamW
from tqdm import trange
from utils import DDPMLinearNoiseScheduler, load_data, load_data_topodiff


import hydra
from omegaconf import DictConfig

from physicsnemo.models.topodiff import TopoDiff, Diffusion
from physicsnemo.diffusion.metrics.losses import MSEDSMLoss
from physicsnemo.models.topodiff import TopoDiff
from physicsnemo.utils.logging import PythonLogger
from utils import load_data_topodiff, load_data


@hydra.main(version_base="1.3", config_path="conf", config_name="config")
Expand All @@ -34,7 +33,18 @@ def main(cfg: DictConfig) -> None:

device = torch.device("cuda:0")
model = TopoDiff(64, 6, 1, model_channels=128, attn_resolutions=[16, 8]).to(device)
diffusion = Diffusion(n_steps=1000, device=device)
scheduler = DDPMLinearNoiseScheduler(n_steps=1000)

# Adapt TopoDiff to DiffusionModel protocol (epsilon-predictor)
def diffusion_model(x, t, condition=None, **kwargs):
return model(x, condition, t.long())

loss_fn = MSEDSMLoss(
model=diffusion_model,
noise_scheduler=scheduler,
prediction_type="epsilon",
epsilon_to_x0_fn=scheduler.epsilon_to_x0,
)

topologies = load_data(
cfg.path_training_data_diffusion, cfg.prefix_topology_file, ".png", 0, 30000
Expand Down Expand Up @@ -67,7 +77,7 @@ def main(cfg: DictConfig) -> None:
tops = tops.float().to(device)
cons = cons.float().to(device)

losses = diffusion.train_loss(model, tops, cons)
losses = loss_fn(x0=tops, condition=cons)

optimizer.zero_grad()
losses.backward()
Expand Down
21 changes: 9 additions & 12 deletions examples/generative/topodiff/train_classifier.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,20 +14,17 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import hydra
import numpy as np
import torch
import torch.nn.functional as F
from omegaconf import DictConfig
from torch.optim import AdamW
from torch.optim.lr_scheduler import LinearLR
import numpy as np


import hydra
from omegaconf import DictConfig
from utils import DDPMLinearNoiseScheduler, load_data_classifier

from physicsnemo.models.topodiff import Diffusion
from physicsnemo.models.topodiff import UNetEncoder
from physicsnemo.utils.logging import PythonLogger
from utils import load_data_classifier


@hydra.main(version_base="1.3", config_path="conf", config_name="config")
Expand All @@ -44,7 +41,7 @@ def main(cfg: DictConfig) -> None:

classifier = UNetEncoder(in_channels=1, out_channels=2).to(device)

diffusion = Diffusion(n_steps=cfg.diffusion_steps, device=device)
noise_scheduler = DDPMLinearNoiseScheduler(n_steps=cfg.diffusion_steps)

batch_size = cfg.batch_size

Expand All @@ -64,8 +61,8 @@ def main(cfg: DictConfig) -> None:
batch = torch.tensor(train_img[idx]).float().unsqueeze(1).to(device) * 2 - 1
Comment thread
CharlelieLrt marked this conversation as resolved.
batch_labels = torch.tensor(train_labels[idx]).long().to(device)

t = torch.randint(0, cfg.diffusion_steps, (batch.shape[0],)).to(device)
batch = diffusion.q_sample(batch, t)
t = noise_scheduler.sample_time(batch.shape[0], device=device)
batch = noise_scheduler.add_noise(batch, t)
logits = classifier(batch, time_steps=t)

loss = F.cross_entropy(logits, batch_labels)
Expand All @@ -83,8 +80,8 @@ def main(cfg: DictConfig) -> None:
batch_labels = torch.tensor(valid_labels[idx]).long().to(device)

# Sample diffusion steps and get noised images
t = torch.randint(0, cfg.diffusion_steps, (batch.shape[0],)).to(device)
batch = diffusion.q_sample(batch, t)
t = noise_scheduler.sample_time(batch.shape[0], device=device)
batch = noise_scheduler.add_noise(batch, t)

# Forward pass
logits = classifier(batch, time_steps=t)
Expand Down
17 changes: 7 additions & 10 deletions examples/generative/topodiff/train_regressor.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,20 +14,17 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import hydra
import numpy as np
import torch
import torch.nn as nn
from omegaconf import DictConfig
from torch.optim import AdamW
from torch.optim.lr_scheduler import LinearLR
import numpy as np


import hydra
from omegaconf import DictConfig
from utils import DDPMLinearNoiseScheduler, load_data_regressor

from physicsnemo.models.topodiff import Diffusion
from physicsnemo.models.topodiff import UNetEncoder
from physicsnemo.utils.logging import PythonLogger
from utils import load_data_regressor


@hydra.main(version_base="1.3", config_path="conf", config_name="config")
Expand All @@ -54,7 +51,7 @@ def main(cfg: DictConfig) -> None:
in_channels = 6
regressor = UNetEncoder(in_channels=in_channels, out_channels=1).to(device)

diffusion = Diffusion(n_steps=cfg.diffusion_steps, device=device)
noise_scheduler = DDPMLinearNoiseScheduler(n_steps=cfg.diffusion_steps)

batch_size = cfg.batch_size
"""
Expand Down Expand Up @@ -85,8 +82,8 @@ def main(cfg: DictConfig) -> None:

batch_labels = torch.tensor(labels[idx]).float().to(device).unsqueeze(1)

t = torch.randint(0, cfg.diffusion_steps, (batch.shape[0],)).to(device)
batch = diffusion.q_sample(batch, t)
t = noise_scheduler.sample_time(batch.shape[0], device=device)
batch = noise_scheduler.add_noise(batch, t)

batch = torch.cat((batch, batch_pf, batch_load), dim=1)

Expand Down
Loading
Loading