-
Notifications
You must be signed in to change notification settings - Fork 2
Add multimodal detector datasets (LArTPC + Water Cherenkov) #2
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
youngsm
merged 4 commits into
DeepLearnPhysics:main
from
OmarAlterkait:multimodal-datasets
Apr 18, 2026
Merged
Changes from 2 commits
Commits
Show all changes
4 commits
Select commit
Hold shift + click to select a range
e62b33e
Add multimodal detector datasets (LArTPC + Water Cherenkov)
OmarAlterkait b1d5e0c
Delete pimm/datasets/readers/.gitignore
youngsm 995126f
Remove HEPDataset base class; dataset classes inherit torch Dataset d…
OmarAlterkait 715c84c
Rename dataset/reader classes to source-named (JAXTPC, LUCiD); move t…
OmarAlterkait File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,87 @@ | ||
| # Base dataset config for JAXTPC 3D seg data. | ||
| # | ||
| # Set JAXTPC_DATA_ROOT environment variable or override data_root in child config. | ||
| # Expected directory layout: | ||
| # {data_root}/seg/{split}/{dataset_name}_seg_NNNN.h5 | ||
| # or: {data_root}/seg/{dataset_name}_seg_NNNN.h5 (flat, split ignored) | ||
|
|
||
| import os | ||
|
|
||
| _data_root = os.environ.get("JAXTPC_DATA_ROOT", "/path/to/jaxtpc/production") | ||
|
|
||
| # Coordinate normalization center and scale. | ||
| # Default is for SBND-scale dual-TPC: x in [-2160, 2160], y/z in [-2160, 2160] mm. | ||
| _center = [0.0, 0.0, 0.0] | ||
| _scale = 2160.0 * 3 ** 0.5 # ~3741 mm — normalizes to roughly [-1, 1] | ||
|
|
||
| grid_size = 0.001 # after normalization | ||
|
|
||
| transform = [ | ||
| dict(type="PDGToSemantic", scheme="motif_5cls"), | ||
| dict(type="NormalizeCoord", center=_center, scale=_scale), | ||
| dict(type="LogTransform", min_val=0.01, max_val=20.0), | ||
| dict( | ||
| type="GridSample", | ||
| grid_size=grid_size, | ||
| hash_type="fnv", | ||
| mode="train", | ||
| return_grid_coord=True, | ||
| ), | ||
| dict(type="RandomRotate", angle=[-1, 1], axis="z", center=[0, 0, 0], p=0.8), | ||
| dict(type="RandomRotate", angle=[-1, 1], axis="x", center=[0, 0, 0], p=0.8), | ||
| dict(type="RandomRotate", angle=[-1, 1], axis="y", center=[0, 0, 0], p=0.8), | ||
| dict(type="RandomFlip", p=0.5), | ||
| dict(type="Copy", keys_dict={"segment_motif": "segment"}), | ||
| dict(type="ToTensor"), | ||
| dict( | ||
| type="Collect", | ||
| keys=("coord", "grid_coord", "segment"), | ||
| feat_keys=("coord", "energy"), | ||
| ), | ||
| ] | ||
|
|
||
| test_transform = [ | ||
| dict(type="PDGToSemantic", scheme="motif_5cls"), | ||
| dict(type="NormalizeCoord", center=_center, scale=_scale), | ||
| dict(type="LogTransform", min_val=0.01, max_val=20.0), | ||
| dict( | ||
| type="GridSample", | ||
| grid_size=grid_size, | ||
| hash_type="fnv", | ||
| mode="train", | ||
| return_grid_coord=True, | ||
| ), | ||
| dict(type="Copy", keys_dict={"segment_motif": "segment"}), | ||
| dict(type="ToTensor"), | ||
| dict( | ||
| type="Collect", | ||
| keys=("coord", "grid_coord", "segment"), | ||
| feat_keys=("coord", "energy"), | ||
| ), | ||
| ] | ||
|
|
||
| data = dict( | ||
| num_classes=5, | ||
| ignore_index=-1, | ||
| names=["shower", "track", "michel", "delta", "led"], | ||
| train=dict( | ||
| type="LArTPCDataset", | ||
| data_root=_data_root, | ||
| split="train", | ||
| dataset_name="sim", | ||
| modalities=("seg",), | ||
| transform=transform, | ||
| min_deposits=1024, | ||
| max_len=-1, | ||
| ), | ||
| val=dict( | ||
| type="LArTPCDataset", | ||
| data_root=_data_root, | ||
| split="val", | ||
| dataset_name="sim", | ||
| modalities=("seg",), | ||
| transform=test_transform, | ||
| min_deposits=1024, | ||
| max_len=1000, | ||
| ), | ||
| ) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,106 @@ | ||
| """ | ||
| PTv3 semantic segmentation on JAXTPC 3D data. | ||
|
|
||
| Drop-in replacement for PILArNet semseg — same model, different data source. | ||
| """ | ||
|
|
||
| _base_ = [ | ||
| "../../../configs/_base_/default_runtime.py", | ||
| "../_base_/jaxtpc_seg.py", | ||
| ] | ||
|
|
||
| # --- training --- | ||
| batch_size = 48 | ||
| num_worker = 24 | ||
| mix_prob = 0.0 | ||
| clip_grad = None | ||
| empty_cache = False | ||
| enable_amp = True | ||
| amp_dtype = "bfloat16" | ||
| matmul_precision = "high" | ||
| seed = 0 | ||
| evaluate = True | ||
|
|
||
| use_wandb = True | ||
| wandb_project = "SemSeg-JAXTPC" | ||
|
|
||
| class_weights = None # set from data statistics if needed | ||
|
|
||
| # --- model --- | ||
| model = dict( | ||
| type="DefaultSegmentorV2", | ||
| num_classes=5, | ||
| backbone_out_channels=64, | ||
| backbone=dict( | ||
| type="PT-v3m2", | ||
| in_channels=4, # [xyz, energy] | ||
| order=("hilbert", "hilbert-trans", "z", "z-trans"), | ||
| stride=(2, 2, 2, 2), | ||
| enc_depths=(3, 3, 3, 9, 3), | ||
| enc_channels=(48, 96, 192, 384, 512), | ||
| enc_num_head=(3, 6, 12, 24, 32), | ||
| enc_patch_size=(256, 256, 256, 256, 256), | ||
| dec_depths=(2, 2, 2, 2), | ||
| dec_channels=(64, 96, 192, 384), | ||
| dec_num_head=(4, 6, 12, 24), | ||
| dec_patch_size=(256, 256, 256, 256), | ||
| mlp_ratio=4, | ||
| qkv_bias=True, | ||
| qk_scale=None, | ||
| layer_scale=0.0, | ||
| attn_drop=0.0, | ||
| proj_drop=0.0, | ||
| drop_path=0.3, | ||
| shuffle_orders=True, | ||
| pre_norm=True, | ||
| enable_rpe=False, | ||
| enable_flash=True, | ||
| upcast_attention=False, | ||
| upcast_softmax=False, | ||
| traceable=False, | ||
| mask_token=False, | ||
| enc_mode=False, | ||
| freeze_encoder=False, | ||
| ), | ||
| criteria=[ | ||
| dict(type="CrossEntropyLoss", loss_weight=1.0, ignore_index=-1), | ||
| dict( | ||
| type="LovaszLoss", | ||
| mode="multiclass", | ||
| loss_weight=1.0 / 20.0, | ||
| ignore_index=-1, | ||
| ), | ||
| ], | ||
| freeze_backbone=False, | ||
| ) | ||
|
|
||
| # --- scheduler --- | ||
| epoch = 20 | ||
| eval_epoch = 20 | ||
| base_lr = 0.0026 | ||
| optimizer = dict(type="AdamW", lr=base_lr, weight_decay=0.04) | ||
| param_dicts = None | ||
|
|
||
| scheduler = dict( | ||
| type="OneCycleLR", | ||
| max_lr=[base_lr], | ||
| pct_start=0.05, | ||
| anneal_strategy="cos", | ||
| div_factor=10.0, | ||
| final_div_factor=1000.0, | ||
| ) | ||
|
|
||
| # --- hooks --- | ||
| hooks = [ | ||
| dict(type="CheckpointLoader"), | ||
| dict(type="WeightDecayExclusion", | ||
| exclude_bias_from_wd=True, exclude_norm_from_wd=True, | ||
| exclude_gamma_from_wd=True, exclude_token_from_wd=True, | ||
| exclude_ndim_1_from_wd=True), | ||
| dict(type="GradientNormLogger", log_frequency=10), | ||
| dict(type="IterationTimer", warmup_iter=2), | ||
| dict(type="InformationWriter"), | ||
| dict(type="SemSegEvaluator", every_n_steps=1000, write_cls_iou=True), | ||
| dict(type="CheckpointSaver", save_freq=None, evaluator_every_n_steps=1000), | ||
| dict(type="PreciseEvaluator", test_last=False), | ||
| ] |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,119 @@ | ||
| # Detector Datasets | ||
|
|
||
| pimm supports multiple detector types through dedicated dataset classes. Each loads from co-indexed HDF5 files and produces flat dicts that flow through pimm's standard pipeline (transforms → Collect → collate → Point → model). | ||
|
|
||
| ## LArTPCDataset | ||
|
|
||
| For Liquid Argon TPC detectors (JAXTPC production output). | ||
|
|
||
| ### Data Layout | ||
| ``` | ||
| dataset_root/ | ||
| ├── seg/ sim_seg_0000.h5 — 3D truth deposits | ||
| ├── resp/ sim_resp_0000.h5 — sparse wire signals per plane | ||
| ├── corr/ sim_corr_0000.h5 — 3D→2D correspondence | ||
| └── labl/ sim_labl_0000.h5 — per-volume track_id→label lookup | ||
| ``` | ||
|
|
||
| ### How coord is assigned | ||
|
|
||
| Who owns `coord`/`energy` depends on which modalities are loaded: | ||
|
|
||
| - **seg present** → coord is 3D (N,3) from deposits. Resp/corr available as `resp_coord`, `corr_coord`. | ||
| - **seg absent, corr+labl present** → coord is 2D (E,2) from corr entries with labels. | ||
| - **seg absent, resp present** → coord is 2D (M,2) from all planes merged. | ||
|
|
||
| When both resp and corr are loaded, each gets its own point cloud: | ||
| - `resp_coord`, `resp_energy`, `resp_plane_id` — from resp signal | ||
| - `corr_coord`, `corr_energy`, `corr_segment`, `corr_instance`, `corr_plane_id` — from correspondence | ||
|
|
||
| Raw per-plane keys (`plane.*`, `corr.*`) are always passed through for per-plane access. | ||
|
|
||
| ### Task → Config | ||
|
|
||
| | Task | `modalities` | What you get | | ||
| |------|-------------|-------------| | ||
| | 3D segmentation | `('seg', 'labl')` | `coord (N,3)`, `segment (N,)` | | ||
| | 3D seg (PDG fallback) | `('seg',)` | `coord (N,3)`, `pdg (N,)` — use PDGToSemantic | | ||
| | 2D segmentation | `('resp', 'corr', 'labl')` | `coord (E,2)` from corr + labels. `resp_coord` also available. | | ||
| | 2D self-supervised | `('resp',)` | `coord (M,2)` merged planes | | ||
| | Resp→corr denoising | `('resp', 'corr')` | `coord (M,2)` from resp. `corr.*` namespaced keys. | | ||
| | Everything | `('seg', 'resp', 'corr', 'labl')` | 3D `coord` + `resp_coord` + `corr_coord` + raw keys | | ||
|
|
||
| **Note:** `modalities=('resp', 'labl')` without `'corr'` will NOT produce labels — labl provides track_id→label tables but resp pixels can't be mapped to track_ids without corr. | ||
|
|
||
| ### Config Parameters | ||
| ```python | ||
| data = dict(train=dict( | ||
| type="LArTPCDataset", | ||
| data_root="/path/to/dataset", | ||
| split="", | ||
| dataset_name="sim", | ||
| modalities=("seg", "labl"), | ||
| volume=None, # None=all, 0=volume_0 only | ||
| label_key="particle", # 'particle', 'cluster', 'interaction' | ||
| min_deposits=1024, | ||
| transform=[...], | ||
| )) | ||
| ``` | ||
|
|
||
| ### Label Chain | ||
| - **3D**: `deposit.track_id → labl[track_id] → label` | ||
| - **2D**: `corr.group_id → g2t → track_id → labl[track_id] → label` | ||
|
|
||
| ### Transforms Safe for 2D | ||
| GridSample, ToTensor, Copy, Collect, RandomDropout, ShufflePoint, RandomJitter, RandomScale, RandomFlip, PositiveShift. | ||
|
|
||
| **3D-only** (crash on 2D coords): RandomRotate, NormalizeCoord, SphereCrop. | ||
|
|
||
| --- | ||
|
|
||
| ## WCDataset | ||
|
|
||
| For Water Cherenkov detectors (PMT-based). | ||
|
|
||
| ### Data Layout | ||
| ``` | ||
| dataset_root/ | ||
| ├── seg/ wc_seg_0000.h5 — 3D track segments | ||
| └── sensor/ wc_sensor_0000.h5 — PMT response + per-particle labels | ||
|
youngsm marked this conversation as resolved.
Outdated
|
||
| ``` | ||
|
|
||
| ### Task → Config | ||
|
|
||
| | Task | `modalities` | `output_mode` | Output | | ||
| |------|-------------|--------------|--------| | ||
| | Event classification | `('sensor',)` | `'response'` | `coord (N_pmt,3)`, `energy (N_pmt,1)`, `time (N_pmt,1)` | | ||
| | Per-sensor instance separation | `('sensor',)` | `'labels'` | `coord (E,3)`, `segment (E,)`, `instance (E,)` | | ||
| | 3D track reconstruction | `('seg',)` | any | `coord (N_seg,3)`, `energy (N_seg,1)`, `track_ids`, `pdg` | | ||
| | Joint 3D + sensor | `('seg', 'sensor')` | `'separate'` | `seg3d.*` + `pmt_*` + `pp_*` keys | | ||
|
|
||
| ### Config Parameters | ||
| ```python | ||
| data = dict(train=dict( | ||
| type="WCDataset", | ||
| data_root="/path/to/dataset_wc", | ||
| dataset_name="wc", | ||
| modalities=("sensor",), | ||
| output_mode="response", # 'response', 'labels', 'separate' | ||
| include_labels=True, | ||
| transform=[...], | ||
| )) | ||
| ``` | ||
|
|
||
| --- | ||
|
|
||
| ## Adding a New Detector | ||
|
|
||
| 1. Write reader(s) in `pimm/datasets/readers/` — implement `__init__`, `_find_files`, `_build_index`, `h5py_worker_init`, `read_event(idx) → dict`, `__len__`, `close`. | ||
| 2. Write a dataset class in `pimm/datasets/` — inherit `HEPDataset`, register via `@DATASETS.register_module()`. | ||
| 3. The dataset's `get_data()` calls readers and maps output to `coord`/`energy`/`segment`. | ||
| 4. Add imports in `pimm/datasets/__init__.py` and `pimm/datasets/readers/__init__.py`. | ||
|
|
||
| No changes needed to transforms, collation, models, or training infrastructure. | ||
|
|
||
| ## Running Tests | ||
| ```bash | ||
| /usr/bin/python3 tools/test_detector_dataset.py # LArTPC (38 tests) | ||
| /usr/bin/python3 tools/test_wc_dataset.py # Water Cherenkov (37 tests) | ||
| ``` | ||
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
May be better to make the name of this specific data format less generic than "detector dataset"? Would be helpful to have a name for the dataset now, lol. Maybe call it a jaxtpc dataset? For now until we figure out the big name.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
So the idea here is to have this for both JAXTPC, LUCiD, ...
Hence why I named it a generic detector dataset, but I could be more specific and change the lartpc one to jaxtpc if you think that's better. But a dataset name would be convenient, I agree
Uh oh!
There was an error while loading. Please reload this page.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah, that's understandable. But I think it would be a good idea to be more specific. We could try to persuade people to use this dataset format by naming this something like
DefaultLArTPCDatasetand throw it indatasets/defaults.py, but we shouldn't force it, if that makes sense.The hope to me is that this repo will be usable with any already-made dataset anyone is bringing in without needing to remake it to fit this specific format. Instead, they would just be required that to make a dataset single object in a single file where the output of the dataloader is the same. I think this is slightly divergent from your instructions in the "adding a new detector" section of this doc.