Skip to content

Danesed/WaveDiT

Repository files navigation

WaveDiT: Distribution-Aware Wavelet Flow Matching for Efficient 3D Brain MRI Synthesis

⚙️ Development fork. This is the active development version of WaveDiT, maintained by Danilo Danese (@danesed). It may move ahead of or diverge from the official, stable release, which lives at sisinflab/WaveDiT; use that one to reproduce the paper.

Try the demo on Hugging Face Spaces WaveDiT Studio for macOS Models on Hugging Face Paper on Hugging Face arXiv

🤗 Try WaveDiT in your browser: pick an age, generate a synthetic 3D brain MRI, and explore it interactively (triplane + 3D viewer with clip-plane slicing). No install needed → huggingface.co/spaces/danesed/WaveDiT-demo

Official PyTorch implementation of "WaveDiT: Distribution-Aware Wavelet Flow Matching for Efficient 3D Brain MRI Synthesis" (MICCAI 2026).

WaveDiT synthesises full resolution, high-fidelity, conditional 3D brain MRIs by performing flow matching in the 3D Haar wavelet domain with a slice-wise HDiT backbone, guided by Morpheus, a state-aware uncertainty scheduler that adaptively weights the loss and sampling across frequency bands.

WaveDiT architecture

Links: 🤗 Live demo · 🤗 Models · Project page · HF paper · arXiv

Key features

  • Wavelet flow matching: operates on the 8-channel 3D Haar latent (1 LLL + 7 HF bands).
  • Morpheus uncertainty scheduler: Bayesian heteroscedastic loss weighting + uncertainty-minimising sampling guidance.
  • HDiT backbone: neighbourhood + spatio-depth factorised attention for efficient 3D modelling.
  • Multiple flow formulations: cfm, rectified, ot_fm.
  • Conditional synthesis: numeric and categorical metadata (e.g. age), with classifier-free guidance.
  • Single-file configs: one YAML fully describes a run; checkpoints are self-contained for generation.

Developing (dev fork)

This fork differs from main in the token representation defaults. Two knobs in model: control it:

Knob main dev default Cost Effect
patch_embed / patch_decode (TokenRep) linear / linear overlap_conv / pixel_shuffle free at train time band-grouped overlapping conv stem + depth-to-space decode head, as zero-init residuals
hf_decode (DeCo) none none (opt-in) ~2x training + extra inference per-sub-voxel coordinate-MLP that decodes the full-resolution input and adds an HF velocity residual
  • TokenRep is ON by default. Both overlap_conv and pixel_shuffle are added as zero-gated residual branches, so they cost nothing at init and are warm-start safe: a linear paper-baseline checkpoint loads with an identical forward output (it just reports 4 benign missing keys for the zero residuals). Set both back to linear for the exact paper baseline.
  • DeCo is opt-in. Enable it with model.hf_decode: deco. build_model logs a warning that it roughly doubles training time (and adds inference cost). In a paired ablation (same seed/data/epochs, only DeCo on/off) DeCo gave a small but statistically real sharpness gain concentrated on high-frequency edges; for most uses TokenRep alone is enough, so DeCo is left off by default.

Morpheus: state-aware uncertainty

Wavelet subbands are not statistically equal: the low-frequency approximation stays close to Gaussian, while the high-frequency bands are sparse and heavy-tailed, and these statistics shift along the flow trajectory. Morpheus is a lightweight network that, at each step, reads the statistical signature of the current noisy state (per-band mean, standard deviation, max amplitude, L2 energy, skewness and kurtosis) and predicts a per-band log-variance. That prediction plays two roles:

  • Weighting the loss: it forms a Bayesian heteroscedastic objective (0.5 * exp(-s) * ||v - v_target||^2 + 0.5 * s) that down-weights inherently unpredictable high-frequency content, while the 0.5 * s term prevents trivial variance inflation. The result is state-dependent precision instead of a uniform MSE.
  • Conditioning the backbone: the projected log-variances become a frequency hint, injected alongside the time, slice and age embeddings, so the transformer adapts its prediction to the current reliability of each band, during both training and sampling.

Installation

conda create -n wavedit_env python=3.11 && conda activate wavedit_env
pip install -r requirements.txt
# Optional but recommended: fused neighbourhood-attention CUDA kernels (match your build):
pip install natten -f https://whl.natten.org
# Optional, faster global attention:
# pip install -U xformers

Developed for Python 3.11 and PyTorch 2.6 (CUDA recommended).

NATTEN is optional. It is the fastest, ground-truth implementation of the neighbourhood attention used in the default config, but WaveDiT ships an equivalent built-in pure-PyTorch fallback, so the model runs without NATTEN, including on CPU. The backend is chosen automatically; override with WAVEDIT_NA_BACKEND=auto|natten|torch.

Repository layout

configs/            One YAML per experiment (cfm, rectified, ot_fm)
train.sh            bash train.sh [config.yaml]      -> launches training
generate.sh         bash generate.sh <ckpt> [outdir] -> generates samples
scripts/
  train.py          config-driven training entry point
  generate.py       generation (specific condition sets or linear interpolation)
  prepare_metadata.py  build the metadata CSV from NIfTI folders
tools/
  slim_checkpoint.py   strip optimiser state for release/inference
wavedit/
  config.py         typed config loaded from YAML
  data/             unified dataset (CSV / filename), augmentation, collation
  wavelets/         differentiable 3D Haar DWT/IDWT
  models/           WaveletFlowMatching, DiT3D backbone, Morpheus, sampling, hdit/
  training/         Trainer + checkpoint I/O
  generation/       sample generation
  evaluation/       metrics + W&B visualisation
  utils/            logging + seeding

Data

See data/README.md. In short, build a catalog once:

python scripts/prepare_metadata.py --input-dirs /path/to/scans --output-csv ./data/dataset.csv

then point data.metadata_csv in your config at it. Raw scans and catalogs are git-ignored and must be obtained from the original dataset providers.

Training

Edit a config (data paths, architecture, hyper-parameters) and launch:

bash train.sh configs/cfm.yaml

Or run the entry point directly:

PYTHONPATH=. python scripts/train.py configs/cfm.yaml

Each run writes to <checkpoint_dir>/<run_name>/: best.pth, last.pth, a copy of the resolved config.yaml, and logs. Set logging.wandb: true for W&B metrics and visualisations. Switch the objective with model.flow (cfm | rectified | ot_fm).

Generation

Checkpoints are self-contained (they embed the config and condition metadata), so generation needs only the checkpoint and your sampling choices.

# Specific condition sets (N samples each)
# NOTE: global flags (--cfg-scale, --num-flow-steps, --sampler, --save-size, ...) go BEFORE the subcommand.
PYTHONPATH=. python scripts/generate.py checkpoints/WaveDiT_CFM/best.pth out/ \
    --cfg-scale 1.5 --num-flow-steps 10 --sampler heun --save-size 182 218 182 \
    specific --conditions "age=45.0" "age=70.5" --num-samples 10

# Linearly interpolate one condition (one sample per step)
PYTHONPATH=. python scripts/generate.py checkpoints/WaveDiT_CFM/best.pth out/ \
    linear --condition age --min 6 --max 95 --num 100

Or use the launcher: bash generate.sh checkpoints/WaveDiT_CFM/best.pth.

Argument Meaning
--cfg-scale Classifier-free guidance scale (1.0 = none).
--num-flow-steps ODE integration steps (overrides the checkpoint default).
--sampler heun (2nd order) or euler.
--morpheus-scale Uncertainty-guidance scale (0 disables it).
--save-size Center-crop saved volumes to D H W (default: full model output).

Citation

@misc{danese2026waveditdistributionawarewaveletflow,
      title={WaveDiT: Distribution-Aware Wavelet Flow Matching for Efficient 3D Brain MRI Synthesis},
      author={Danilo Danese and Angela Lombardi and Giuseppe Fasano and Matteo Attimonelli and Tommaso Di Noia},
      year={2026},
      eprint={2606.08670},
      archivePrefix={arXiv},
      primaryClass={cs.CV},
      url={https://arxiv.org/abs/2606.08670},
}

Acknowledgements

WaveDiT builds on the wavelet-domain analysis and multi-level evaluation protocol of our previous work, FlowLet.

The HDiT backbone is adapted from k-diffusion.

The invertible 3D wavelet transform builds on the great work of WDM

See LICENSE.

About

[Development] - WaveDiT: Distribution-Aware Wavelet Flow Matching for Efficient 3D Brain MRI Synthesis - Accepted at MICCAI 2026

Topics

Resources

License

Stars

Watchers

Forks

Releases

No releases published

Packages

 
 
 

Contributors