From 218f206407614e0206b18c557929688af6220527 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E6=9D=9C=E5=BC=BC?= Date: Tue, 21 Apr 2026 15:22:42 +0800 Subject: [PATCH] fix: WanAttentionBlock gradient checkpointing + FSDP2 mixed precision dtype mismatch Replace torch.amp.autocast('cuda', dtype=torch.float32) context managers inside WanAttentionBlock.forward() with explicit .float() upcasts and .to(dtype=compute_dtype) downcasts. The autocast approach is incompatible with gradient checkpointing under FSDP2 mixed precision: during recompute, the outer bf16 autocast (provided by FSDP2) is not replayed, causing saved tensors (bf16) to diverge from recomputed tensors (fp32), triggering CheckpointError. The explicit cast approach is autocast-agnostic and produces numerically identical results (bitwise identical in fp32, <1e-2 tolerance in bf16). Also adds gradient checkpointing support to WanModel via a _gc_autocast() context_fn that infers the autocast dtype from x.dtype rather than hardcoding bf16. Fixes Wan-Video/Wan2.2#353 --- wan/modules/model.py | 44 +++++++++++++++++++++++++++----------------- 1 file changed, 27 insertions(+), 17 deletions(-) diff --git a/wan/modules/model.py b/wan/modules/model.py index 6982fa15..843a177a 100644 --- a/wan/modules/model.py +++ b/wan/modules/model.py @@ -1,5 +1,6 @@ # Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved. import math +from functools import partial import torch import torch.nn as nn @@ -11,6 +12,15 @@ __all__ = ['WanModel'] +def _gc_autocast(dtype=torch.bfloat16): + """Context function for gradient checkpointing that provides autocast + for both forward and recompute phases, ensuring dtype consistency.""" + return ( + torch.amp.autocast('cuda', dtype=dtype), + torch.amp.autocast('cuda', dtype=dtype), + ) + + def sinusoidal_embedding_1d(dim, position): # preprocess assert dim % 2 == 0 @@ -235,27 +245,21 @@ def forward( freqs(Tensor): Rope freqs, shape [1024, C / num_heads / 2] """ assert e.dtype == torch.float32 - with torch.amp.autocast('cuda', dtype=torch.float32): - e = (self.modulation.unsqueeze(0) + e).chunk(6, dim=2) + compute_dtype = x.dtype + e = (self.modulation.float().unsqueeze(0) + e).chunk(6, dim=2) assert e[0].dtype == torch.float32 # self-attention y = self.self_attn( - self.norm1(x).float() * (1 + e[1].squeeze(2)) + e[0].squeeze(2), + (self.norm1(x).float() * (1 + e[1].squeeze(2)) + e[0].squeeze(2)).to(dtype=compute_dtype), seq_lens, grid_sizes, freqs) - with torch.amp.autocast('cuda', dtype=torch.float32): - x = x + y * e[2].squeeze(2) - - # cross-attention & ffn function - def cross_attn_ffn(x, context, context_lens, e): - x = x + self.cross_attn(self.norm3(x), context, context_lens) - y = self.ffn( - self.norm2(x).float() * (1 + e[4].squeeze(2)) + e[3].squeeze(2)) - with torch.amp.autocast('cuda', dtype=torch.float32): - x = x + y * e[5].squeeze(2) - return x - - x = cross_attn_ffn(x, context, context_lens, e) + x = (x.float() + y.float() * e[2].squeeze(2)).to(dtype=compute_dtype) + + # cross-attention & ffn + x = x + self.cross_attn(self.norm3(x), context, context_lens) + y = self.ffn( + (self.norm2(x).float() * (1 + e[4].squeeze(2)) + e[3].squeeze(2)).to(dtype=compute_dtype)) + x = (x.float() + y.float() * e[5].squeeze(2)).to(dtype=compute_dtype) return x @@ -358,6 +362,7 @@ def __init__(self, assert model_type in ['t2v', 'i2v', 'ti2v', 's2v'] self.model_type = model_type + self.gradient_checkpointing = False self.patch_size = patch_size self.text_len = text_len @@ -487,7 +492,12 @@ def forward( context_lens=context_lens) for block in self.blocks: - x = block(x, **kwargs) + if self.gradient_checkpointing and self.training: + x = torch.utils.checkpoint.checkpoint( + block, x, use_reentrant=False, + context_fn=partial(_gc_autocast, dtype=x.dtype), **kwargs) + else: + x = block(x, **kwargs) # head x = self.head(x, e)