Skip to content
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
87 changes: 87 additions & 0 deletions configs/detector/_base_/jaxtpc_seg.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,87 @@
# Base dataset config for JAXTPC 3D seg data.
#
# Set JAXTPC_DATA_ROOT environment variable or override data_root in child config.
# Expected directory layout:
# {data_root}/seg/{split}/{dataset_name}_seg_NNNN.h5
# or: {data_root}/seg/{dataset_name}_seg_NNNN.h5 (flat, split ignored)

import os

_data_root = os.environ.get("JAXTPC_DATA_ROOT", "/path/to/jaxtpc/production")

# Coordinate normalization center and scale.
# Default is for SBND-scale dual-TPC: x in [-2160, 2160], y/z in [-2160, 2160] mm.
_center = [0.0, 0.0, 0.0]
_scale = 2160.0 * 3 ** 0.5 # ~3741 mm — normalizes to roughly [-1, 1]

grid_size = 0.001 # after normalization

transform = [
dict(type="PDGToSemantic", scheme="motif_5cls"),
dict(type="NormalizeCoord", center=_center, scale=_scale),
dict(type="LogTransform", min_val=0.01, max_val=20.0),
dict(
type="GridSample",
grid_size=grid_size,
hash_type="fnv",
mode="train",
return_grid_coord=True,
),
dict(type="RandomRotate", angle=[-1, 1], axis="z", center=[0, 0, 0], p=0.8),
dict(type="RandomRotate", angle=[-1, 1], axis="x", center=[0, 0, 0], p=0.8),
dict(type="RandomRotate", angle=[-1, 1], axis="y", center=[0, 0, 0], p=0.8),
dict(type="RandomFlip", p=0.5),
dict(type="Copy", keys_dict={"segment_motif": "segment"}),
dict(type="ToTensor"),
dict(
type="Collect",
keys=("coord", "grid_coord", "segment"),
feat_keys=("coord", "energy"),
),
]

test_transform = [
dict(type="PDGToSemantic", scheme="motif_5cls"),
dict(type="NormalizeCoord", center=_center, scale=_scale),
dict(type="LogTransform", min_val=0.01, max_val=20.0),
dict(
type="GridSample",
grid_size=grid_size,
hash_type="fnv",
mode="train",
return_grid_coord=True,
),
dict(type="Copy", keys_dict={"segment_motif": "segment"}),
dict(type="ToTensor"),
dict(
type="Collect",
keys=("coord", "grid_coord", "segment"),
feat_keys=("coord", "energy"),
),
]

data = dict(
num_classes=5,
ignore_index=-1,
names=["shower", "track", "michel", "delta", "led"],
train=dict(
type="LArTPCDataset",
data_root=_data_root,
split="train",
dataset_name="sim",
modalities=("seg",),
transform=transform,
min_deposits=1024,
max_len=-1,
),
val=dict(
type="LArTPCDataset",
data_root=_data_root,
split="val",
dataset_name="sim",
modalities=("seg",),
transform=test_transform,
min_deposits=1024,
max_len=1000,
),
)
106 changes: 106 additions & 0 deletions configs/detector/semseg/semseg-pt-v3m2-jaxtpc-5cls.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,106 @@
"""
PTv3 semantic segmentation on JAXTPC 3D data.

Drop-in replacement for PILArNet semseg — same model, different data source.
"""

_base_ = [
"../../../configs/_base_/default_runtime.py",
"../_base_/jaxtpc_seg.py",
]

# --- training ---
batch_size = 48
num_worker = 24
mix_prob = 0.0
clip_grad = None
empty_cache = False
enable_amp = True
amp_dtype = "bfloat16"
matmul_precision = "high"
seed = 0
evaluate = True

use_wandb = True
wandb_project = "SemSeg-JAXTPC"

class_weights = None # set from data statistics if needed

# --- model ---
model = dict(
type="DefaultSegmentorV2",
num_classes=5,
backbone_out_channels=64,
backbone=dict(
type="PT-v3m2",
in_channels=4, # [xyz, energy]
order=("hilbert", "hilbert-trans", "z", "z-trans"),
stride=(2, 2, 2, 2),
enc_depths=(3, 3, 3, 9, 3),
enc_channels=(48, 96, 192, 384, 512),
enc_num_head=(3, 6, 12, 24, 32),
enc_patch_size=(256, 256, 256, 256, 256),
dec_depths=(2, 2, 2, 2),
dec_channels=(64, 96, 192, 384),
dec_num_head=(4, 6, 12, 24),
dec_patch_size=(256, 256, 256, 256),
mlp_ratio=4,
qkv_bias=True,
qk_scale=None,
layer_scale=0.0,
attn_drop=0.0,
proj_drop=0.0,
drop_path=0.3,
shuffle_orders=True,
pre_norm=True,
enable_rpe=False,
enable_flash=True,
upcast_attention=False,
upcast_softmax=False,
traceable=False,
mask_token=False,
enc_mode=False,
freeze_encoder=False,
),
criteria=[
dict(type="CrossEntropyLoss", loss_weight=1.0, ignore_index=-1),
dict(
type="LovaszLoss",
mode="multiclass",
loss_weight=1.0 / 20.0,
ignore_index=-1,
),
],
freeze_backbone=False,
)

# --- scheduler ---
epoch = 20
eval_epoch = 20
base_lr = 0.0026
optimizer = dict(type="AdamW", lr=base_lr, weight_decay=0.04)
param_dicts = None

scheduler = dict(
type="OneCycleLR",
max_lr=[base_lr],
pct_start=0.05,
anneal_strategy="cos",
div_factor=10.0,
final_div_factor=1000.0,
)

# --- hooks ---
hooks = [
dict(type="CheckpointLoader"),
dict(type="WeightDecayExclusion",
exclude_bias_from_wd=True, exclude_norm_from_wd=True,
exclude_gamma_from_wd=True, exclude_token_from_wd=True,
exclude_ndim_1_from_wd=True),
dict(type="GradientNormLogger", log_frequency=10),
dict(type="IterationTimer", warmup_iter=2),
dict(type="InformationWriter"),
dict(type="SemSegEvaluator", every_n_steps=1000, write_cls_iou=True),
dict(type="CheckpointSaver", save_freq=None, evaluator_every_n_steps=1000),
dict(type="PreciseEvaluator", test_last=False),
]
119 changes: 119 additions & 0 deletions docs/DETECTOR_DATASET.md
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

May be better to make the name of this specific data format less generic than "detector dataset"? Would be helpful to have a name for the dataset now, lol. Maybe call it a jaxtpc dataset? For now until we figure out the big name.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So the idea here is to have this for both JAXTPC, LUCiD, ...
Hence why I named it a generic detector dataset, but I could be more specific and change the lartpc one to jaxtpc if you think that's better. But a dataset name would be convenient, I agree

Copy link
Copy Markdown
Member

@youngsm youngsm Apr 16, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, that's understandable. But I think it would be a good idea to be more specific. We could try to persuade people to use this dataset format by naming this something like DefaultLArTPCDataset and throw it in datasets/defaults.py, but we shouldn't force it, if that makes sense.

The hope to me is that this repo will be usable with any already-made dataset anyone is bringing in without needing to remake it to fit this specific format. Instead, they would just be required that to make a dataset single object in a single file where the output of the dataloader is the same. I think this is slightly divergent from your instructions in the "adding a new detector" section of this doc.

Original file line number Diff line number Diff line change
@@ -0,0 +1,119 @@
# Detector Datasets

pimm supports multiple detector types through dedicated dataset classes. Each loads from co-indexed HDF5 files and produces flat dicts that flow through pimm's standard pipeline (transforms → Collect → collate → Point → model).

## LArTPCDataset

For Liquid Argon TPC detectors (JAXTPC production output).

### Data Layout
```
dataset_root/
├── seg/ sim_seg_0000.h5 — 3D truth deposits
├── resp/ sim_resp_0000.h5 — sparse wire signals per plane
├── corr/ sim_corr_0000.h5 — 3D→2D correspondence
└── labl/ sim_labl_0000.h5 — per-volume track_id→label lookup
```

### How coord is assigned

Who owns `coord`/`energy` depends on which modalities are loaded:

- **seg present** → coord is 3D (N,3) from deposits. Resp/corr available as `resp_coord`, `corr_coord`.
- **seg absent, corr+labl present** → coord is 2D (E,2) from corr entries with labels.
- **seg absent, resp present** → coord is 2D (M,2) from all planes merged.

When both resp and corr are loaded, each gets its own point cloud:
- `resp_coord`, `resp_energy`, `resp_plane_id` — from resp signal
- `corr_coord`, `corr_energy`, `corr_segment`, `corr_instance`, `corr_plane_id` — from correspondence

Raw per-plane keys (`plane.*`, `corr.*`) are always passed through for per-plane access.

### Task → Config

| Task | `modalities` | What you get |
|------|-------------|-------------|
| 3D segmentation | `('seg', 'labl')` | `coord (N,3)`, `segment (N,)` |
| 3D seg (PDG fallback) | `('seg',)` | `coord (N,3)`, `pdg (N,)` — use PDGToSemantic |
| 2D segmentation | `('resp', 'corr', 'labl')` | `coord (E,2)` from corr + labels. `resp_coord` also available. |
| 2D self-supervised | `('resp',)` | `coord (M,2)` merged planes |
| Resp→corr denoising | `('resp', 'corr')` | `coord (M,2)` from resp. `corr.*` namespaced keys. |
| Everything | `('seg', 'resp', 'corr', 'labl')` | 3D `coord` + `resp_coord` + `corr_coord` + raw keys |

**Note:** `modalities=('resp', 'labl')` without `'corr'` will NOT produce labels — labl provides track_id→label tables but resp pixels can't be mapped to track_ids without corr.

### Config Parameters
```python
data = dict(train=dict(
type="LArTPCDataset",
data_root="/path/to/dataset",
split="",
dataset_name="sim",
modalities=("seg", "labl"),
volume=None, # None=all, 0=volume_0 only
label_key="particle", # 'particle', 'cluster', 'interaction'
min_deposits=1024,
transform=[...],
))
```

### Label Chain
- **3D**: `deposit.track_id → labl[track_id] → label`
- **2D**: `corr.group_id → g2t → track_id → labl[track_id] → label`

### Transforms Safe for 2D
GridSample, ToTensor, Copy, Collect, RandomDropout, ShufflePoint, RandomJitter, RandomScale, RandomFlip, PositiveShift.

**3D-only** (crash on 2D coords): RandomRotate, NormalizeCoord, SphereCrop.

---

## WCDataset

For Water Cherenkov detectors (PMT-based).

### Data Layout
```
dataset_root/
├── seg/ wc_seg_0000.h5 — 3D track segments
└── sensor/ wc_sensor_0000.h5 — PMT response + per-particle labels
Comment thread
youngsm marked this conversation as resolved.
Outdated
```

### Task → Config

| Task | `modalities` | `output_mode` | Output |
|------|-------------|--------------|--------|
| Event classification | `('sensor',)` | `'response'` | `coord (N_pmt,3)`, `energy (N_pmt,1)`, `time (N_pmt,1)` |
| Per-sensor instance separation | `('sensor',)` | `'labels'` | `coord (E,3)`, `segment (E,)`, `instance (E,)` |
| 3D track reconstruction | `('seg',)` | any | `coord (N_seg,3)`, `energy (N_seg,1)`, `track_ids`, `pdg` |
| Joint 3D + sensor | `('seg', 'sensor')` | `'separate'` | `seg3d.*` + `pmt_*` + `pp_*` keys |

### Config Parameters
```python
data = dict(train=dict(
type="WCDataset",
data_root="/path/to/dataset_wc",
dataset_name="wc",
modalities=("sensor",),
output_mode="response", # 'response', 'labels', 'separate'
include_labels=True,
transform=[...],
))
```

---

## Adding a New Detector

1. Write reader(s) in `pimm/datasets/readers/` — implement `__init__`, `_find_files`, `_build_index`, `h5py_worker_init`, `read_event(idx) → dict`, `__len__`, `close`.
2. Write a dataset class in `pimm/datasets/` — inherit `HEPDataset`, register via `@DATASETS.register_module()`.
3. The dataset's `get_data()` calls readers and maps output to `coord`/`energy`/`segment`.
4. Add imports in `pimm/datasets/__init__.py` and `pimm/datasets/readers/__init__.py`.

No changes needed to transforms, collation, models, or training infrastructure.

## Running Tests
```bash
/usr/bin/python3 tools/test_detector_dataset.py # LArTPC (38 tests)
/usr/bin/python3 tools/test_wc_dataset.py # Water Cherenkov (37 tests)
```
3 changes: 3 additions & 0 deletions pimm/datasets/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,5 +5,8 @@

# physics
from .pilarnet import PILArNetH5Dataset
from .lartpc_dataset import LArTPCDataset
from .wc_dataset import WCDataset
from . import detector_transforms # register PDGToSemantic
# dataloader
from .dataloader import MultiDatasetDataloader
Loading