-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathtrain.py
More file actions
executable file
·421 lines (367 loc) · 22.7 KB
/
Copy pathtrain.py
File metadata and controls
executable file
·421 lines (367 loc) · 22.7 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
import argparse
import os
import torch
import wandb
import json
from torch.utils.data import DataLoader, random_split
# Import both dataset types
from flowlet.data import (
create_brain_dataset_and_split,
collate_fn,
train_transform,
val_transform,
TransformedSubset
)
from flowlet.data.dataset_csv import BrainMRIDatasetCSV
from flowlet.models import WaveletFlowMatching
from flowlet.training import train_wavelet_flow_matching
from flowlet.evaluation import visualize_flow_generation, visualize_multi_condition_samples
from flowlet.generation import generate_conditioned_brains
from flowlet.utils import setup_logging, set_seed, get_logger
logger = get_logger(__name__)
def parse_args():
parser = argparse.ArgumentParser(description="Train Wavelet Flow Matching (FlowLet) Model")
# --- Data Args ---
parser.add_argument("--data_folder", type=str, required=False,
help="Path to folder containing .nii.gz files (used *only* if --metadata_csv is NOT provided).")
parser.add_argument("--metadata_csv", type=str, default=None,
help="Path to the metadata CSV file. If provided, this overrides --data_folder and uses the CSV for file paths and conditions.")
parser.add_argument("--condition_vars", nargs="+", default=["Age"],
help="List of conditions to use. If using --metadata_csv, these must be column names in the CSV. If not, they are parsed from filenames.")
parser.add_argument("--require_conditions", action=argparse.BooleanOptionalAction, default=True,
help="If using filename parsing (no --metadata_csv), only use images where ALL specified condition_vars are found.")
parser.add_argument("--model_input_size", type=int, nargs=3, default=[112, 112, 112], metavar=('D', 'H', 'W'), help="Spatial size images are padded to before DWT.")
parser.add_argument("--val_split", type=float, default=0.2, help="Fraction of data for validation (0.0 to 1.0).")
# --- CSV Specific Data Args (Optional Filtering) ---
parser.add_argument("--csv_filter_col", type=str, default=None,
help="[CSV Mode Only] Column name in the CSV to filter by (e.g., 'Condition').")
parser.add_argument("--csv_filter_value", type=str, default=None,
help="[CSV Mode Only] Value to keep in the --csv_filter_col.")
# --- Flow Matching Training Args ---
parser.add_argument("--epochs", type=int, default=200, help="Number of epochs for Flow Matching training.")
parser.add_argument("--lr", type=float, default=3e-6, help="Learning rate for Flow Matching AdamW optimizer.")
parser.add_argument("--batch_size", type=int, default=16, help="Batch size for Flow Matching training.")
parser.add_argument("--early_stop_patience", type=int, default=50, help="Epochs with no val loss improvement before stopping. <= 0 disables.")
parser.add_argument("--grad_clip_norm", type=float, default=1.0, help="Maximum norm for gradient clipping.")
parser.add_argument("--num_flow_steps", type=int, default=100, help="Number of integration steps for sampling.")
parser.add_argument("--resume_from_checkpoint", type=str, default=None, help="Path to a specific checkpoint (.pth file) to resume training from.")
parser.add_argument("--flow_type", type=str, default="rectified",
choices=["rectified", "cfm", "trigonometric", "vp_diffusion"],
help="The type of flow matching to use for the training loss.")
### Additional VP Diffusion parameters ###
parser.add_argument("--vp_beta_min", type=float, default=0.1, help="[VP-Diffusion Only] Minimum beta value for the schedule.")
parser.add_argument("--vp_beta_max", type=float, default=20.0, help="[VP-Diffusion Only] Maximum beta value for the schedule.")
# --- U-Net Architecture Args ---
parser.add_argument("--unet_model_channels", type=int, default=128, help="Base number of channels in the U-Net.")
parser.add_argument("--unet_num_res_blocks", type=int, default=2, help="Number of residual blocks per U-Net level.")
parser.add_argument("--unet_channel_mult", type=str, default="1,2,3,4", help="Channel multipliers (comma-sep string, e.g., '1,2,3,4').")
parser.add_argument("--unet_attention_res", type=str, default="16,8", help="Resolutions (relative to initial feature map size) for attention (comma-sep string, e.g., '16,8').")
parser.add_argument("--unet_dropout", type=float, default=0.1, help="Dropout rate in U-Net ResBlocks/Attention.")
parser.add_argument("--condition_embedding_dim", type=int, default=512, help="Dimension of the projected condition embeddings.")
parser.add_argument("--unet_num_heads", type=int, default=8, help="Number of attention heads.")
parser.add_argument("--unet_num_head_channels", type=int, default=-1, help="Number of channels per head (-1 means calculate based on num_heads).")
parser.add_argument("--unet_norm_num_groups", type=int, default=32, help="Number of groups for GroupNorm.")
parser.add_argument("--use_checkpointing", action=argparse.BooleanOptionalAction, default=True, help="Enable gradient checkpointing in U-Net.")
parser.add_argument("--use_xformers", action=argparse.BooleanOptionalAction, default=True, help="Enable xformers memory-efficient attention if available.")
parser.add_argument("--unet_disable_cross_attn", action="store_true", help="Disable cross-attention in SpatialTransformer (model becomes unconditional to context).")
parser.add_argument("--lll_loss_weight", type=float, default=1, help="Weight for the LLL (approximation) subband loss. Default: 1")
parser.add_argument("--detail_loss_weight", type=float, default=1, help="Weight for the combined detail subbands (LH, HL, HH) loss. Default: 1")
# --- Generation Args (Optional post-training generation) ---
parser.add_argument("--generate_after_train", action=argparse.BooleanOptionalAction, default=False, help="Generate samples after training finishes.")
parser.add_argument("--num_synthetic", type=int, default=10, help="Number of synthetic samples per condition if generating.")
parser.add_argument("--generation_conditions", nargs='*', default=['age=45', 'age=75'], help="Conditions for generation ('key=value' strings).")
parser.add_argument("--save_size", type=int, nargs=3, default=[91, 109, 91], metavar=('D', 'H', 'W'), help="Spatial size to crop generated images to before saving.")
parser.add_argument("--generation_output_dir", type=str, default="generated_samples", help="Subdirectory within checkpoint_dir for generated samples.")
# --- System Args ---
parser.add_argument("--num_workers", type=int, default=4, help="Number of workers for DataLoader.")
parser.add_argument("--checkpoint_dir", type=str, default="checkpoints_flowlet", help="Directory for saving checkpoints and logs.")
parser.add_argument("--run_name", type=str, default="flowlet_run", help="A name for this training run (used for logging/checkpoints).")
parser.add_argument("--viz_every", type=int, default=1, help="Log validation/sample visualizations every N epochs (W&B only). Default: 1 (every epoch).")
parser.add_argument("--wandb", action=argparse.BooleanOptionalAction, default=True, help="Enable Weights & Biases logging.")
parser.add_argument("--wandb_project", type=str, default="FlowLet_training", help="Wandb project name.")
parser.add_argument("--wandb_entity", type=str, default=None, help="Wandb entity (username or team).")
parser.add_argument("--compile", action=argparse.BooleanOptionalAction, default=False, help="Enable torch.compile for the U-Net (experimental).")
parser.add_argument("--seed", type=int, default=42, help="Random seed for reproducibility.")
parser.add_argument("--device", type=str, default="cuda", help="Device to use ('cuda' or 'cpu').")
args = parser.parse_args()
# Validate argument dependencies
if args.metadata_csv is None and args.data_folder is None:
parser.error("Either --metadata_csv or --data_folder must be provided.")
if args.metadata_csv and args.data_folder:
logger.warning("Both --metadata_csv and --data_folder provided. --metadata_csv will be used.")
if args.csv_filter_col and args.csv_filter_value is None:
parser.error("--csv_filter_value must be provided if --csv_filter_col is set.")
if args.lll_loss_weight < 0 or args.detail_loss_weight < 0:
logger.warning("Loss weights should be non-negative. The model will use their absolute values.")
if args.lll_loss_weight == 0 and args.detail_loss_weight == 0:
logger.warning("Both lll_loss_weight and detail_loss_weight are 0. This will result in zero loss and no training. The model will internally default to 1 each if both are zero during loss calculation to prevent collapse.")
return args
def main():
args = parse_args()
# --- Setup ---
run_checkpoint_dir = os.path.join(args.checkpoint_dir, args.run_name)
os.makedirs(run_checkpoint_dir, exist_ok=True)
setup_logging(log_dir=run_checkpoint_dir, filename_prefix=args.run_name)
logger.info(f"Starting training run: {args.run_name}")
logger.info(f"Checkpoints and logs will be saved to: {run_checkpoint_dir}")
# Save configuration as JSON
config_save_path = os.path.join(run_checkpoint_dir, "config.json")
try:
config_to_save = vars(args)
# Convert tuples to lists for JSON compatibility
for key, value in config_to_save.items():
if isinstance(value, tuple):
config_to_save[key] = list(value)
with open(config_save_path, 'w') as f:
json.dump(config_to_save, f, indent=4, sort_keys=True)
logger.info(f"Configuration saved to {config_save_path}")
except Exception as e:
logger.error(f"Failed to save configuration to {config_save_path}: {e}", exc_info=True)
set_seed(args.seed)
device = torch.device(args.device if torch.cuda.is_available() and args.device == "cuda" else "cpu")
logger.info(f"Using device: {device}")
if args.wandb:
try:
wandb.login()
wandb.init(
project=args.wandb_project,
entity=args.wandb_entity,
name=args.run_name,
config=vars(args)
)
logger.info("Weights & Biases initialized.")
except ImportError:
logger.warning("wandb not installed, disabling logging.")
args.wandb = False
except Exception as e:
logger.error(f"Wandb initialization failed: {e}. Disabling wandb.", exc_info=True)
args.wandb = False
# --- Dataset & Dataloaders ---
logger.info("Creating and splitting dataset...")
train_dataset = None
val_dataset = None
condition_ranges = None
try:
# Validate input size vs U-Net structure
channel_mult_list = tuple(map(int, args.unet_channel_mult.split(',')))
num_downsamples = len(channel_mult_list) - 1
required_divisor = 2**num_downsamples
unet_input_size = tuple(s // 2 for s in args.model_input_size) # DWT halves spatial dims
if any(s % required_divisor != 0 for s in unet_input_size):
logger.warning(f"U-Net input size {unet_input_size} (from model_input_size {args.model_input_size}) "
f"is not divisible by {required_divisor} (required by channel_mult {args.unet_channel_mult}). "
f"Ensure padding handles this correctly.")
else:
logger.info(f"U-Net input size {unet_input_size} compatible with downsampling.")
# --- Choose Dataset Loading Method ---
if args.metadata_csv:
logger.info(f"Using CSV metadata from: {args.metadata_csv}")
if not os.path.exists(args.metadata_csv):
raise FileNotFoundError(f"Metadata CSV file not found: {args.metadata_csv}")
full_dataset = BrainMRIDatasetCSV(
metadata_path=args.metadata_csv,
transform=None,
model_input_size=tuple(args.model_input_size),
filepath_col="FilePath",
subject_id_col="SubjectID",
condition_cols=args.condition_vars,
filter_col=args.csv_filter_col,
filter_value=args.csv_filter_value
)
condition_ranges = full_dataset.condition_ranges
if len(full_dataset) == 0:
raise RuntimeError("Dataset is empty after loading from CSV (and filtering).")
val_split = max(0.0, min(1.0, args.val_split))
if val_split == 0.0 or val_split == 1.0:
logger.warning(f"Validation split is {val_split}, dataset will not be split.")
if val_split == 0.0:
train_dataset = TransformedSubset(full_dataset, train_transform)
val_dataset = torch.utils.data.TensorDataset(torch.empty(0))
else:
val_dataset = TransformedSubset(full_dataset, val_transform)
train_dataset = torch.utils.data.TensorDataset(torch.empty(0))
else:
train_size = int((1.0 - val_split) * len(full_dataset))
val_size = len(full_dataset) - train_size
logger.info(f"Splitting CSV dataset: {train_size} train, {val_size} validation samples.")
if train_size == 0 or val_size == 0:
raise ValueError("Dataset split resulted in zero samples for train or validation.")
generator = torch.Generator().manual_seed(args.seed)
train_subset, val_subset = random_split(full_dataset, [train_size, val_size], generator=generator)
train_dataset = TransformedSubset(train_subset, train_transform)
val_dataset = TransformedSubset(val_subset, val_transform)
else:
# --- Use Filename Parsing Method ---
logger.info(f"Using filename parsing from data folder: {args.data_folder}")
if args.data_folder is None or not os.path.isdir(args.data_folder):
raise FileNotFoundError(f"Data folder not found or not specified: {args.data_folder}. Required when not using --metadata_csv.")
train_dataset, val_dataset, condition_ranges = create_brain_dataset_and_split(
data_folder=args.data_folder,
metadata_path=None,
transform_train=train_transform,
transform_val=val_transform,
model_input_size=tuple(args.model_input_size),
filter_cognitive_status=None,
condition_vars=args.condition_vars,
require_conditions=args.require_conditions,
val_split=args.val_split,
seed=args.seed
)
if train_dataset is None or val_dataset is None:
raise RuntimeError("Dataset creation failed, train or validation dataset is None.")
logger.info(f"Dataset created. Train size: {len(train_dataset)}, Val size: {len(val_dataset)}")
logger.info(f"Condition ranges found: {condition_ranges}")
ranges_save_path = os.path.join(run_checkpoint_dir, "condition_ranges.json")
try:
with open(ranges_save_path, 'w') as f:
json.dump(condition_ranges, f, indent=4)
logger.info(f"Condition ranges saved to {ranges_save_path}")
except Exception as e:
logger.error(f"Failed to save condition ranges to {ranges_save_path}: {e}", exc_info=True)
if len(train_dataset) == 0:
logger.warning("Training dataset is empty after processing.")
if args.val_split != 1.0:
raise RuntimeError("Training dataset is empty.")
if len(val_dataset) == 0:
logger.warning("Validation dataset is empty after processing.")
if args.val_split != 0.0:
raise RuntimeError("Validation dataset is empty.")
except Exception as e:
logger.error(f"Failed during dataset creation or splitting: {e}", exc_info=True)
if args.wandb: wandb.finish(exit_code=1)
return
# --- Create DataLoaders ---
train_loader = DataLoader(
train_dataset, batch_size=args.batch_size, shuffle=True,
num_workers=args.num_workers, pin_memory=True, drop_last=True,
collate_fn=collate_fn, persistent_workers=args.num_workers > 0 and not isinstance(train_dataset, torch.utils.data.TensorDataset)
)
val_loader = DataLoader(
val_dataset, batch_size=args.batch_size, shuffle=False,
num_workers=args.num_workers, pin_memory=True,
collate_fn=collate_fn, persistent_workers=args.num_workers > 0 and not isinstance(val_dataset, torch.utils.data.TensorDataset)
)
logger.info("DataLoaders created.")
# --- Flow Matching Model ---
condition_dims_dict = {var: 1 for var in args.condition_vars} if args.condition_vars else {}
try:
attention_res = tuple(map(int, args.unet_attention_res.split(',')))
# channel_mult_list is already defined above
except ValueError as e:
logger.error(f"Invalid format for U-Net channel mult or attention res: {e}")
if args.wandb: wandb.finish(exit_code=1)
return
unet_args = {
"in_channels": 8, "model_channels": args.unet_model_channels, "out_channels": 8,
"num_res_blocks": args.unet_num_res_blocks,
"attention_resolutions": attention_res,
"dropout": args.unet_dropout,
"channel_mult": channel_mult_list,
"conv_resample": True, "dims": 3,
"use_checkpoint": args.use_checkpointing,
"num_heads": args.unet_num_heads,
"num_head_channels": args.unet_num_head_channels,
"use_scale_shift_norm": True,
"resblock_updown": True,
"condition_dims": condition_dims_dict,
"condition_embedding_dim": args.condition_embedding_dim,
"use_xformers": args.use_xformers,
"use_cross_attention": not args.unet_disable_cross_attn and bool(condition_dims_dict),
"norm_num_groups": args.unet_norm_num_groups,
"norm_eps": 1e-6,
}
wfm_model = WaveletFlowMatching(
u_net_args=unet_args,
num_flow_steps=args.num_flow_steps,
lll_loss_weight=args.lll_loss_weight,
detail_loss_weight=args.detail_loss_weight,
flow_type=args.flow_type,
vp_beta_min=args.vp_beta_min,
vp_beta_max=args.vp_beta_max
).to(device)
logger.info("FlowLet model initialized.")
logger.info(f"U-Net args: {unet_args}")
# --- Compile Flow Model (Optional) ---
if args.compile:
try:
logger.info("Attempting to compile FlowLet U-Net model...")
wfm_model.flow_net = torch.compile(wfm_model.flow_net, mode="reduce-overhead")
logger.info("FlowLet U-Net compiled successfully.")
except Exception as e:
logger.warning(f"Torch compile failed for Flowlet U-Net: {e}. Continuing without compilation.", exc_info=True)
# --- Flow Matching Training ---
logger.info("Starting training...")
try:
train_wavelet_flow_matching(
wfm_model=wfm_model,
train_loader=train_loader,
val_loader=val_loader,
num_epochs=args.epochs,
lr=args.lr,
use_wandb=args.wandb,
checkpoint_dir=run_checkpoint_dir,
resume_from_path_arg=args.resume_from_checkpoint,
early_stop_patience=args.early_stop_patience,
model_output_size=tuple(args.model_input_size),
condition_ranges=condition_ranges,
grad_clip_norm=args.grad_clip_norm,
device=device,
viz_every=args.viz_every
)
except Exception as e:
logger.error(f"An error occurred during training: {e}", exc_info=True)
if args.wandb: wandb.finish(exit_code=1)
return
logger.info("Training finished.")
torch.cuda.empty_cache()
# --- Final Visualization ---
logger.info("--- Generating Final Visualizations ---")
wfm_model.eval()
try:
if len(val_dataset) > 0:
visualize_flow_generation(wfm_model, val_loader, tuple(args.model_input_size), use_wandb=args.wandb, epoch_num=None)
visualize_multi_condition_samples(wfm_model, num_samples=1, model_output_size=tuple(args.model_input_size), wandb_log=args.wandb, condition_ranges=condition_ranges, epoch_num=None)
else:
logger.info("Skipping final visualization as validation dataset is empty.")
except Exception as e:
logger.error(f"Error during final visualization: {e}", exc_info=True)
# --- Optional Generation After Training ---
if args.generate_after_train and args.generation_conditions:
logger.info("--- Generating Samples Post-Training ---")
parsed_conditions_list = []
for cond_set_str in args.generation_conditions:
cond_dict = {}
try:
items = cond_set_str.split() if ' ' in cond_set_str else [cond_set_str]
for item in items:
if '=' not in item: raise ValueError(f"Condition item '{item}' missing '=' separator.")
key, value_str = item.split('=', 1); key = key.strip(); value_str = value_str.strip()
if not key: raise ValueError("Condition key cannot be empty.")
if key in args.condition_vars:
cond_dict[key] = float(value_str)
else:
logger.warning(f"Generation condition '{key}' provided but not in trained condition_vars ({args.condition_vars}). Ignoring.")
except Exception as e: logger.error(f"Invalid format in condition string: '{cond_set_str}'. Skipping. Error: {e}"); continue
if cond_dict: parsed_conditions_list.append(cond_dict)
if parsed_conditions_list:
gen_output_path = os.path.join(run_checkpoint_dir, args.generation_output_dir)
try:
generate_conditioned_brains(
wfm_model=wfm_model,
conditions_list=parsed_conditions_list,
num_samples_per_condition=args.num_synthetic,
output_dir=gen_output_path,
save_size=tuple(args.save_size),
model_output_size=tuple(args.model_input_size),
condition_ranges=condition_ranges
)
logger.info(f"Generated samples saved to: {gen_output_path}")
except Exception as e:
logger.error(f"Error during post-training generation: {e}", exc_info=True)
else:
logger.warning("No valid conditions parsed from --generation_conditions or conditions provided were not used during training, skipping post-training generation.")
if args.wandb:
wandb.finish()
logger.info("Script finished successfully!")
if __name__ == "__main__":
main()