Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
44 changes: 27 additions & 17 deletions wan/modules/model.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
import math
from functools import partial

import torch
import torch.nn as nn
Expand All @@ -11,6 +12,15 @@
__all__ = ['WanModel']


def _gc_autocast(dtype=torch.bfloat16):
"""Context function for gradient checkpointing that provides autocast
for both forward and recompute phases, ensuring dtype consistency."""
return (
torch.amp.autocast('cuda', dtype=dtype),
torch.amp.autocast('cuda', dtype=dtype),
)


def sinusoidal_embedding_1d(dim, position):
# preprocess
assert dim % 2 == 0
Expand Down Expand Up @@ -235,27 +245,21 @@ def forward(
freqs(Tensor): Rope freqs, shape [1024, C / num_heads / 2]
"""
assert e.dtype == torch.float32
with torch.amp.autocast('cuda', dtype=torch.float32):
e = (self.modulation.unsqueeze(0) + e).chunk(6, dim=2)
compute_dtype = x.dtype
e = (self.modulation.float().unsqueeze(0) + e).chunk(6, dim=2)
assert e[0].dtype == torch.float32

# self-attention
y = self.self_attn(
self.norm1(x).float() * (1 + e[1].squeeze(2)) + e[0].squeeze(2),
(self.norm1(x).float() * (1 + e[1].squeeze(2)) + e[0].squeeze(2)).to(dtype=compute_dtype),
seq_lens, grid_sizes, freqs)
with torch.amp.autocast('cuda', dtype=torch.float32):
x = x + y * e[2].squeeze(2)

# cross-attention & ffn function
def cross_attn_ffn(x, context, context_lens, e):
x = x + self.cross_attn(self.norm3(x), context, context_lens)
y = self.ffn(
self.norm2(x).float() * (1 + e[4].squeeze(2)) + e[3].squeeze(2))
with torch.amp.autocast('cuda', dtype=torch.float32):
x = x + y * e[5].squeeze(2)
return x

x = cross_attn_ffn(x, context, context_lens, e)
x = (x.float() + y.float() * e[2].squeeze(2)).to(dtype=compute_dtype)

# cross-attention & ffn
x = x + self.cross_attn(self.norm3(x), context, context_lens)
y = self.ffn(
(self.norm2(x).float() * (1 + e[4].squeeze(2)) + e[3].squeeze(2)).to(dtype=compute_dtype))
x = (x.float() + y.float() * e[5].squeeze(2)).to(dtype=compute_dtype)
return x


Expand Down Expand Up @@ -358,6 +362,7 @@ def __init__(self,

assert model_type in ['t2v', 'i2v', 'ti2v', 's2v']
self.model_type = model_type
self.gradient_checkpointing = False

self.patch_size = patch_size
self.text_len = text_len
Expand Down Expand Up @@ -487,7 +492,12 @@ def forward(
context_lens=context_lens)

for block in self.blocks:
x = block(x, **kwargs)
if self.gradient_checkpointing and self.training:
x = torch.utils.checkpoint.checkpoint(
block, x, use_reentrant=False,
context_fn=partial(_gc_autocast, dtype=x.dtype), **kwargs)
else:
x = block(x, **kwargs)

# head
x = self.head(x, e)
Expand Down