⚙️ Development fork. This is the active development version of WaveDiT, maintained by Danilo Danese (@danesed). It may move ahead of or diverge from the official, stable release, which lives at sisinflab/WaveDiT; use that one to reproduce the paper.
🤗 Try WaveDiT in your browser: pick an age, generate a synthetic 3D brain MRI, and explore it interactively (triplane + 3D viewer with clip-plane slicing). No install needed → huggingface.co/spaces/danesed/WaveDiT-demo
Official PyTorch implementation of "WaveDiT: Distribution-Aware Wavelet Flow Matching for Efficient 3D Brain MRI Synthesis" (MICCAI 2026).
WaveDiT synthesises full resolution, high-fidelity, conditional 3D brain MRIs by performing flow matching in the 3D Haar wavelet domain with a slice-wise HDiT backbone, guided by Morpheus, a state-aware uncertainty scheduler that adaptively weights the loss and sampling across frequency bands.
Links: 🤗 Live demo · 🤗 Models · Project page · HF paper · arXiv
- Wavelet flow matching: operates on the 8-channel 3D Haar latent (1 LLL + 7 HF bands).
- Morpheus uncertainty scheduler: Bayesian heteroscedastic loss weighting + uncertainty-minimising sampling guidance.
- HDiT backbone: neighbourhood + spatio-depth factorised attention for efficient 3D modelling.
- Multiple flow formulations:
cfm,rectified,ot_fm. - Conditional synthesis: numeric and categorical metadata (e.g. age), with classifier-free guidance.
- Single-file configs: one YAML fully describes a run; checkpoints are self-contained for generation.
This fork differs from main in the token representation defaults. Two knobs in
model: control it:
| Knob | main | dev default | Cost | Effect |
|---|---|---|---|---|
patch_embed / patch_decode (TokenRep) |
linear / linear |
overlap_conv / pixel_shuffle |
free at train time | band-grouped overlapping conv stem + depth-to-space decode head, as zero-init residuals |
hf_decode (DeCo) |
none |
none (opt-in) |
~2x training + extra inference | per-sub-voxel coordinate-MLP that decodes the full-resolution input and adds an HF velocity residual |
- TokenRep is ON by default. Both
overlap_convandpixel_shuffleare added as zero-gated residual branches, so they cost nothing at init and are warm-start safe: alinearpaper-baseline checkpoint loads with an identical forward output (it just reports 4 benignmissingkeys for the zero residuals). Set both back tolinearfor the exact paper baseline. - DeCo is opt-in. Enable it with
model.hf_decode: deco.build_modellogs a warning that it roughly doubles training time (and adds inference cost). In a paired ablation (same seed/data/epochs, only DeCo on/off) DeCo gave a small but statistically real sharpness gain concentrated on high-frequency edges; for most uses TokenRep alone is enough, so DeCo is left off by default.
Wavelet subbands are not statistically equal: the low-frequency approximation stays close to Gaussian, while the high-frequency bands are sparse and heavy-tailed, and these statistics shift along the flow trajectory. Morpheus is a lightweight network that, at each step, reads the statistical signature of the current noisy state (per-band mean, standard deviation, max amplitude, L2 energy, skewness and kurtosis) and predicts a per-band log-variance. That prediction plays two roles:
- Weighting the loss: it forms a Bayesian heteroscedastic objective
(
0.5 * exp(-s) * ||v - v_target||^2 + 0.5 * s) that down-weights inherently unpredictable high-frequency content, while the0.5 * sterm prevents trivial variance inflation. The result is state-dependent precision instead of a uniform MSE. - Conditioning the backbone: the projected log-variances become a frequency hint, injected alongside the time, slice and age embeddings, so the transformer adapts its prediction to the current reliability of each band, during both training and sampling.
conda create -n wavedit_env python=3.11 && conda activate wavedit_env
pip install -r requirements.txt
# Optional but recommended: fused neighbourhood-attention CUDA kernels (match your build):
pip install natten -f https://whl.natten.org
# Optional, faster global attention:
# pip install -U xformersDeveloped for Python 3.11 and PyTorch 2.6 (CUDA recommended).
NATTEN is optional. It is the fastest, ground-truth implementation of the neighbourhood attention used in the default config, but WaveDiT ships an equivalent built-in pure-PyTorch fallback, so the model runs without NATTEN, including on CPU. The backend is chosen automatically; override with
WAVEDIT_NA_BACKEND=auto|natten|torch.
configs/ One YAML per experiment (cfm, rectified, ot_fm)
train.sh bash train.sh [config.yaml] -> launches training
generate.sh bash generate.sh <ckpt> [outdir] -> generates samples
scripts/
train.py config-driven training entry point
generate.py generation (specific condition sets or linear interpolation)
prepare_metadata.py build the metadata CSV from NIfTI folders
tools/
slim_checkpoint.py strip optimiser state for release/inference
wavedit/
config.py typed config loaded from YAML
data/ unified dataset (CSV / filename), augmentation, collation
wavelets/ differentiable 3D Haar DWT/IDWT
models/ WaveletFlowMatching, DiT3D backbone, Morpheus, sampling, hdit/
training/ Trainer + checkpoint I/O
generation/ sample generation
evaluation/ metrics + W&B visualisation
utils/ logging + seeding
See data/README.md. In short, build a catalog once:
python scripts/prepare_metadata.py --input-dirs /path/to/scans --output-csv ./data/dataset.csvthen point data.metadata_csv in your config at it. Raw scans and catalogs are
git-ignored and must be obtained from the original dataset providers.
Edit a config (data paths, architecture, hyper-parameters) and launch:
bash train.sh configs/cfm.yamlOr run the entry point directly:
PYTHONPATH=. python scripts/train.py configs/cfm.yamlEach run writes to <checkpoint_dir>/<run_name>/: best.pth, last.pth, a copy of the
resolved config.yaml, and logs. Set logging.wandb: true for W&B metrics and
visualisations. Switch the objective with model.flow (cfm | rectified | ot_fm).
Checkpoints are self-contained (they embed the config and condition metadata), so generation needs only the checkpoint and your sampling choices.
# Specific condition sets (N samples each)
# NOTE: global flags (--cfg-scale, --num-flow-steps, --sampler, --save-size, ...) go BEFORE the subcommand.
PYTHONPATH=. python scripts/generate.py checkpoints/WaveDiT_CFM/best.pth out/ \
--cfg-scale 1.5 --num-flow-steps 10 --sampler heun --save-size 182 218 182 \
specific --conditions "age=45.0" "age=70.5" --num-samples 10
# Linearly interpolate one condition (one sample per step)
PYTHONPATH=. python scripts/generate.py checkpoints/WaveDiT_CFM/best.pth out/ \
linear --condition age --min 6 --max 95 --num 100Or use the launcher: bash generate.sh checkpoints/WaveDiT_CFM/best.pth.
| Argument | Meaning |
|---|---|
--cfg-scale |
Classifier-free guidance scale (1.0 = none). |
--num-flow-steps |
ODE integration steps (overrides the checkpoint default). |
--sampler |
heun (2nd order) or euler. |
--morpheus-scale |
Uncertainty-guidance scale (0 disables it). |
--save-size |
Center-crop saved volumes to D H W (default: full model output). |
@misc{danese2026waveditdistributionawarewaveletflow,
title={WaveDiT: Distribution-Aware Wavelet Flow Matching for Efficient 3D Brain MRI Synthesis},
author={Danilo Danese and Angela Lombardi and Giuseppe Fasano and Matteo Attimonelli and Tommaso Di Noia},
year={2026},
eprint={2606.08670},
archivePrefix={arXiv},
primaryClass={cs.CV},
url={https://arxiv.org/abs/2606.08670},
}WaveDiT builds on the wavelet-domain analysis and multi-level evaluation protocol of our previous work, FlowLet.
The HDiT backbone is adapted from k-diffusion.
The invertible 3D wavelet transform builds on the great work of WDM
See LICENSE.
