From e62b33e8cd1540fd1005e307babd068671240b86 Mon Sep 17 00:00:00 2001 From: Omar Alterkait Date: Tue, 14 Apr 2026 01:33:57 +0900 Subject: [PATCH 1/4] Add multimodal detector datasets (LArTPC + Water Cherenkov) Add LArTPCDataset and WCDataset for loading multimodal detector simulation data through pimm's existing pipeline. Each detector type has dedicated readers that produce flat dicts consumed by the standard transform/collation/Point infrastructure. LArTPC (JAXTPC production): - LArTPCSegReader: 3D truth deposits from seg files - LArTPCRespReader: sparse wire signals from resp files - LArTPCLablReader: per-volume track_id->label lookup from labl files - LArTPCCorrReader: 3D->2D correspondence from corr files (vectorized) - Modality-driven coord ownership: seg->3D, corr+labl->2D labeled, resp->2D merged - Both resp and corr available as separate point clouds (resp_coord, corr_coord) - Volume filter for single-volume loading Water Cherenkov (PMT-based): - WCSegReader: 3D track segments from flat CSR format - WCSensorReader: PMT response + per-particle sparse decomposition - Output modes: response (per-sensor), labels (per-particle sparse), separate Minimal changes to existing pimm code (3 files, 19 lines): - index_valid_keys extended for LArTPC keys - collate_fn skips _-prefixed metadata keys - Dataset imports added 70 tests across both detector types verify all modality combinations, transform pipelines, collation, DataLoader workers, and toy model forward/backward passes. --- configs/detector/_base_/jaxtpc_seg.py | 87 ++++ .../semseg/semseg-pt-v3m2-jaxtpc-5cls.py | 106 ++++ docs/DETECTOR_DATASET.md | 119 +++++ pimm/datasets/__init__.py | 3 + pimm/datasets/detector_transforms.py | 112 +++++ pimm/datasets/lartpc_dataset.py | 455 ++++++++++++++++++ pimm/datasets/readers/.gitignore | 1 + pimm/datasets/readers/__init__.py | 6 + pimm/datasets/readers/lartpc_corr_reader.py | 190 ++++++++ pimm/datasets/readers/lartpc_labl_reader.py | 150 ++++++ pimm/datasets/readers/lartpc_resp_reader.py | 180 +++++++ pimm/datasets/readers/lartpc_seg_reader.py | 281 +++++++++++ pimm/datasets/readers/wc_seg_reader.py | 201 ++++++++ pimm/datasets/readers/wc_sensor_reader.py | 223 +++++++++ pimm/datasets/transform.py | 15 + pimm/datasets/utils.py | 1 + pimm/datasets/wc_dataset.py | 245 ++++++++++ tools/test_detector_dataset.py | 231 +++++++++ tools/test_wc_dataset.py | 197 ++++++++ 19 files changed, 2803 insertions(+) create mode 100644 configs/detector/_base_/jaxtpc_seg.py create mode 100644 configs/detector/semseg/semseg-pt-v3m2-jaxtpc-5cls.py create mode 100644 docs/DETECTOR_DATASET.md create mode 100644 pimm/datasets/detector_transforms.py create mode 100644 pimm/datasets/lartpc_dataset.py create mode 100644 pimm/datasets/readers/.gitignore create mode 100644 pimm/datasets/readers/__init__.py create mode 100644 pimm/datasets/readers/lartpc_corr_reader.py create mode 100644 pimm/datasets/readers/lartpc_labl_reader.py create mode 100644 pimm/datasets/readers/lartpc_resp_reader.py create mode 100644 pimm/datasets/readers/lartpc_seg_reader.py create mode 100644 pimm/datasets/readers/wc_seg_reader.py create mode 100644 pimm/datasets/readers/wc_sensor_reader.py create mode 100644 pimm/datasets/wc_dataset.py create mode 100644 tools/test_detector_dataset.py create mode 100644 tools/test_wc_dataset.py diff --git a/configs/detector/_base_/jaxtpc_seg.py b/configs/detector/_base_/jaxtpc_seg.py new file mode 100644 index 0000000..f8021a5 --- /dev/null +++ b/configs/detector/_base_/jaxtpc_seg.py @@ -0,0 +1,87 @@ +# Base dataset config for JAXTPC 3D seg data. +# +# Set JAXTPC_DATA_ROOT environment variable or override data_root in child config. +# Expected directory layout: +# {data_root}/seg/{split}/{dataset_name}_seg_NNNN.h5 +# or: {data_root}/seg/{dataset_name}_seg_NNNN.h5 (flat, split ignored) + +import os + +_data_root = os.environ.get("JAXTPC_DATA_ROOT", "/path/to/jaxtpc/production") + +# Coordinate normalization center and scale. +# Default is for SBND-scale dual-TPC: x in [-2160, 2160], y/z in [-2160, 2160] mm. +_center = [0.0, 0.0, 0.0] +_scale = 2160.0 * 3 ** 0.5 # ~3741 mm — normalizes to roughly [-1, 1] + +grid_size = 0.001 # after normalization + +transform = [ + dict(type="PDGToSemantic", scheme="motif_5cls"), + dict(type="NormalizeCoord", center=_center, scale=_scale), + dict(type="LogTransform", min_val=0.01, max_val=20.0), + dict( + type="GridSample", + grid_size=grid_size, + hash_type="fnv", + mode="train", + return_grid_coord=True, + ), + dict(type="RandomRotate", angle=[-1, 1], axis="z", center=[0, 0, 0], p=0.8), + dict(type="RandomRotate", angle=[-1, 1], axis="x", center=[0, 0, 0], p=0.8), + dict(type="RandomRotate", angle=[-1, 1], axis="y", center=[0, 0, 0], p=0.8), + dict(type="RandomFlip", p=0.5), + dict(type="Copy", keys_dict={"segment_motif": "segment"}), + dict(type="ToTensor"), + dict( + type="Collect", + keys=("coord", "grid_coord", "segment"), + feat_keys=("coord", "energy"), + ), +] + +test_transform = [ + dict(type="PDGToSemantic", scheme="motif_5cls"), + dict(type="NormalizeCoord", center=_center, scale=_scale), + dict(type="LogTransform", min_val=0.01, max_val=20.0), + dict( + type="GridSample", + grid_size=grid_size, + hash_type="fnv", + mode="train", + return_grid_coord=True, + ), + dict(type="Copy", keys_dict={"segment_motif": "segment"}), + dict(type="ToTensor"), + dict( + type="Collect", + keys=("coord", "grid_coord", "segment"), + feat_keys=("coord", "energy"), + ), +] + +data = dict( + num_classes=5, + ignore_index=-1, + names=["shower", "track", "michel", "delta", "led"], + train=dict( + type="LArTPCDataset", + data_root=_data_root, + split="train", + dataset_name="sim", + modalities=("seg",), + transform=transform, + min_deposits=1024, + max_len=-1, + ), + val=dict( + type="LArTPCDataset", + data_root=_data_root, + split="val", + dataset_name="sim", + modalities=("seg",), + transform=test_transform, + min_deposits=1024, + max_len=1000, + ), +) diff --git a/configs/detector/semseg/semseg-pt-v3m2-jaxtpc-5cls.py b/configs/detector/semseg/semseg-pt-v3m2-jaxtpc-5cls.py new file mode 100644 index 0000000..6f1be17 --- /dev/null +++ b/configs/detector/semseg/semseg-pt-v3m2-jaxtpc-5cls.py @@ -0,0 +1,106 @@ +""" +PTv3 semantic segmentation on JAXTPC 3D data. + +Drop-in replacement for PILArNet semseg — same model, different data source. +""" + +_base_ = [ + "../../../configs/_base_/default_runtime.py", + "../_base_/jaxtpc_seg.py", +] + +# --- training --- +batch_size = 48 +num_worker = 24 +mix_prob = 0.0 +clip_grad = None +empty_cache = False +enable_amp = True +amp_dtype = "bfloat16" +matmul_precision = "high" +seed = 0 +evaluate = True + +use_wandb = True +wandb_project = "SemSeg-JAXTPC" + +class_weights = None # set from data statistics if needed + +# --- model --- +model = dict( + type="DefaultSegmentorV2", + num_classes=5, + backbone_out_channels=64, + backbone=dict( + type="PT-v3m2", + in_channels=4, # [xyz, energy] + order=("hilbert", "hilbert-trans", "z", "z-trans"), + stride=(2, 2, 2, 2), + enc_depths=(3, 3, 3, 9, 3), + enc_channels=(48, 96, 192, 384, 512), + enc_num_head=(3, 6, 12, 24, 32), + enc_patch_size=(256, 256, 256, 256, 256), + dec_depths=(2, 2, 2, 2), + dec_channels=(64, 96, 192, 384), + dec_num_head=(4, 6, 12, 24), + dec_patch_size=(256, 256, 256, 256), + mlp_ratio=4, + qkv_bias=True, + qk_scale=None, + layer_scale=0.0, + attn_drop=0.0, + proj_drop=0.0, + drop_path=0.3, + shuffle_orders=True, + pre_norm=True, + enable_rpe=False, + enable_flash=True, + upcast_attention=False, + upcast_softmax=False, + traceable=False, + mask_token=False, + enc_mode=False, + freeze_encoder=False, + ), + criteria=[ + dict(type="CrossEntropyLoss", loss_weight=1.0, ignore_index=-1), + dict( + type="LovaszLoss", + mode="multiclass", + loss_weight=1.0 / 20.0, + ignore_index=-1, + ), + ], + freeze_backbone=False, +) + +# --- scheduler --- +epoch = 20 +eval_epoch = 20 +base_lr = 0.0026 +optimizer = dict(type="AdamW", lr=base_lr, weight_decay=0.04) +param_dicts = None + +scheduler = dict( + type="OneCycleLR", + max_lr=[base_lr], + pct_start=0.05, + anneal_strategy="cos", + div_factor=10.0, + final_div_factor=1000.0, +) + +# --- hooks --- +hooks = [ + dict(type="CheckpointLoader"), + dict(type="WeightDecayExclusion", + exclude_bias_from_wd=True, exclude_norm_from_wd=True, + exclude_gamma_from_wd=True, exclude_token_from_wd=True, + exclude_ndim_1_from_wd=True), + dict(type="GradientNormLogger", log_frequency=10), + dict(type="IterationTimer", warmup_iter=2), + dict(type="InformationWriter"), + dict(type="SemSegEvaluator", every_n_steps=1000, write_cls_iou=True), + dict(type="CheckpointSaver", save_freq=None, evaluator_every_n_steps=1000), + dict(type="PreciseEvaluator", test_last=False), +] diff --git a/docs/DETECTOR_DATASET.md b/docs/DETECTOR_DATASET.md new file mode 100644 index 0000000..8b9545f --- /dev/null +++ b/docs/DETECTOR_DATASET.md @@ -0,0 +1,119 @@ +# Detector Datasets + +pimm supports multiple detector types through dedicated dataset classes. Each loads from co-indexed HDF5 files and produces flat dicts that flow through pimm's standard pipeline (transforms → Collect → collate → Point → model). + +## LArTPCDataset + +For Liquid Argon TPC detectors (JAXTPC production output). + +### Data Layout +``` +dataset_root/ +├── seg/ sim_seg_0000.h5 — 3D truth deposits +├── resp/ sim_resp_0000.h5 — sparse wire signals per plane +├── corr/ sim_corr_0000.h5 — 3D→2D correspondence +└── labl/ sim_labl_0000.h5 — per-volume track_id→label lookup +``` + +### How coord is assigned + +Who owns `coord`/`energy` depends on which modalities are loaded: + +- **seg present** → coord is 3D (N,3) from deposits. Resp/corr available as `resp_coord`, `corr_coord`. +- **seg absent, corr+labl present** → coord is 2D (E,2) from corr entries with labels. +- **seg absent, resp present** → coord is 2D (M,2) from all planes merged. + +When both resp and corr are loaded, each gets its own point cloud: +- `resp_coord`, `resp_energy`, `resp_plane_id` — from resp signal +- `corr_coord`, `corr_energy`, `corr_segment`, `corr_instance`, `corr_plane_id` — from correspondence + +Raw per-plane keys (`plane.*`, `corr.*`) are always passed through for per-plane access. + +### Task → Config + +| Task | `modalities` | What you get | +|------|-------------|-------------| +| 3D segmentation | `('seg', 'labl')` | `coord (N,3)`, `segment (N,)` | +| 3D seg (PDG fallback) | `('seg',)` | `coord (N,3)`, `pdg (N,)` — use PDGToSemantic | +| 2D segmentation | `('resp', 'corr', 'labl')` | `coord (E,2)` from corr + labels. `resp_coord` also available. | +| 2D self-supervised | `('resp',)` | `coord (M,2)` merged planes | +| Resp→corr denoising | `('resp', 'corr')` | `coord (M,2)` from resp. `corr.*` namespaced keys. | +| Everything | `('seg', 'resp', 'corr', 'labl')` | 3D `coord` + `resp_coord` + `corr_coord` + raw keys | + +**Note:** `modalities=('resp', 'labl')` without `'corr'` will NOT produce labels — labl provides track_id→label tables but resp pixels can't be mapped to track_ids without corr. + +### Config Parameters +```python +data = dict(train=dict( + type="LArTPCDataset", + data_root="/path/to/dataset", + split="", + dataset_name="sim", + modalities=("seg", "labl"), + volume=None, # None=all, 0=volume_0 only + label_key="particle", # 'particle', 'cluster', 'interaction' + min_deposits=1024, + transform=[...], +)) +``` + +### Label Chain +- **3D**: `deposit.track_id → labl[track_id] → label` +- **2D**: `corr.group_id → g2t → track_id → labl[track_id] → label` + +### Transforms Safe for 2D +GridSample, ToTensor, Copy, Collect, RandomDropout, ShufflePoint, RandomJitter, RandomScale, RandomFlip, PositiveShift. + +**3D-only** (crash on 2D coords): RandomRotate, NormalizeCoord, SphereCrop. + +--- + +## WCDataset + +For Water Cherenkov detectors (PMT-based). + +### Data Layout +``` +dataset_root/ +├── seg/ wc_seg_0000.h5 — 3D track segments +└── sensor/ wc_sensor_0000.h5 — PMT response + per-particle labels +``` + +### Task → Config + +| Task | `modalities` | `output_mode` | Output | +|------|-------------|--------------|--------| +| Event classification | `('sensor',)` | `'response'` | `coord (N_pmt,3)`, `energy (N_pmt,1)`, `time (N_pmt,1)` | +| Per-sensor instance separation | `('sensor',)` | `'labels'` | `coord (E,3)`, `segment (E,)`, `instance (E,)` | +| 3D track reconstruction | `('seg',)` | any | `coord (N_seg,3)`, `energy (N_seg,1)`, `track_ids`, `pdg` | +| Joint 3D + sensor | `('seg', 'sensor')` | `'separate'` | `seg3d.*` + `pmt_*` + `pp_*` keys | + +### Config Parameters +```python +data = dict(train=dict( + type="WCDataset", + data_root="/path/to/dataset_wc", + dataset_name="wc", + modalities=("sensor",), + output_mode="response", # 'response', 'labels', 'separate' + include_labels=True, + transform=[...], +)) +``` + +--- + +## Adding a New Detector + +1. Write reader(s) in `pimm/datasets/readers/` — implement `__init__`, `_find_files`, `_build_index`, `h5py_worker_init`, `read_event(idx) → dict`, `__len__`, `close`. +2. Write a dataset class in `pimm/datasets/` — inherit `HEPDataset`, register via `@DATASETS.register_module()`. +3. The dataset's `get_data()` calls readers and maps output to `coord`/`energy`/`segment`. +4. Add imports in `pimm/datasets/__init__.py` and `pimm/datasets/readers/__init__.py`. + +No changes needed to transforms, collation, models, or training infrastructure. + +## Running Tests +```bash +/usr/bin/python3 tools/test_detector_dataset.py # LArTPC (38 tests) +/usr/bin/python3 tools/test_wc_dataset.py # Water Cherenkov (37 tests) +``` diff --git a/pimm/datasets/__init__.py b/pimm/datasets/__init__.py index f6e2807..78419e3 100644 --- a/pimm/datasets/__init__.py +++ b/pimm/datasets/__init__.py @@ -5,5 +5,8 @@ # physics from .pilarnet import PILArNetH5Dataset +from .lartpc_dataset import LArTPCDataset +from .wc_dataset import WCDataset +from . import detector_transforms # register PDGToSemantic # dataloader from .dataloader import MultiDatasetDataloader diff --git a/pimm/datasets/detector_transforms.py b/pimm/datasets/detector_transforms.py new file mode 100644 index 0000000..d321b5b --- /dev/null +++ b/pimm/datasets/detector_transforms.py @@ -0,0 +1,112 @@ +""" +Detector-specific transforms for multimodal detector data. + +PDGToSemantic: derives semantic labels from PDG codes in seg data (3D tasks). + Used when no label file is available (fallback). + For production training, labels come from the labl file via + LArTPCDataset._apply_labl_to_3d() or _build_corr_pointcloud(). +""" + +import numpy as np +from .transform import TRANSFORMS + + +@TRANSFORMS.register_module() +class PDGToSemantic: + """Fallback: derive approximate semantic labels from PDG codes. + + Use this only when no labl file is available. For production training, + use modalities=('seg', 'labl') which applies labels via LArTPCLablReader. + + Schemes + ------- + motif_5cls : shower(0), track(1), michel(2), delta(3), led(4) + pid_6cls : photon(0), electron(1), muon(2), pion(3), proton(4), other(5) + custom : user-provided {pdg_code: class_index} dict + """ + + _MOTIF = { + 22: 0, 11: 0, -11: 0, # shower + 13: 1, -13: 1, # track (muon) + 211: 1, -211: 1, # track (pion) + 2212: 1, # track (proton) + 321: 1, -321: 1, # track (kaon) + } + + _PID = { + 22: 0, # photon + 11: 1, -11: 1, # electron + 13: 2, -13: 2, # muon + 211: 3, -211: 3, # pion + 2212: 4, # proton + } + + def __init__(self, scheme='motif_5cls', custom_map=None): + self.scheme = scheme + if scheme == 'motif_5cls': + self.mapping = self._MOTIF + self.default = 4 + elif scheme == 'pid_6cls': + self.mapping = self._PID + self.default = 5 + elif scheme == 'custom': + assert custom_map is not None + self.mapping = custom_map + self.default = max(custom_map.values()) + 1 + elif scheme == 'none': + self.mapping = None + self.default = -1 + else: + raise ValueError(f"Unknown label scheme: {scheme}") + + def __call__(self, data_dict): + if self.mapping is None or 'pdg' not in data_dict: + return data_dict + + # Skip if labels already loaded from label file + if 'segment' in data_dict or 'segment_motif' in data_dict: + return data_dict + + pdg = data_dict['pdg'] + n = len(pdg) + labels = np.full(n, self.default, dtype=np.int32) + for code, cls in self.mapping.items(): + labels[pdg == code] = cls + + data_dict['segment_motif'] = labels[:, None] + + # Also produce PID labels + if self.scheme == 'motif_5cls': + pid = np.full(n, 5, dtype=np.int32) + for code, cls in self._PID.items(): + pid[pdg == code] = cls + data_dict['segment_pid'] = pid[:, None] + elif self.scheme == 'pid_6cls': + data_dict['segment_pid'] = labels[:, None] + + # Derive instance from track_ids (simple contiguous remapping) + if 'instance_particle' not in data_dict and 'track_ids' in data_dict: + track_ids = data_dict['track_ids'] + mask = track_ids >= 0 + if mask.any(): + _, inverse = np.unique(track_ids[mask], return_inverse=True) + out = np.full(n, -1, dtype=np.int32) + out[mask] = inverse + data_dict['instance_particle'] = out[:, None] + else: + data_dict['instance_particle'] = np.full((n, 1), -1, dtype=np.int32) + + if 'instance_interaction' not in data_dict and 'interaction_ids' in data_dict: + iids = data_dict['interaction_ids'] + mask = iids >= 0 + if mask.any(): + _, inverse = np.unique(iids[mask], return_inverse=True) + out = np.full(n, -1, dtype=np.int32) + out[mask] = inverse + data_dict['instance_interaction'] = out[:, None] + else: + data_dict['instance_interaction'] = np.full((n, 1), -1, dtype=np.int32) + + data_dict['segment_interaction'] = (iids[:, None] != -1).astype(np.int32) + + return data_dict diff --git a/pimm/datasets/lartpc_dataset.py b/pimm/datasets/lartpc_dataset.py new file mode 100644 index 0000000..84a9083 --- /dev/null +++ b/pimm/datasets/lartpc_dataset.py @@ -0,0 +1,455 @@ +""" +LArTPCDataset — multimodal dataset for LArTPC detector simulation output. + +Loads from co-indexed HDF5 files produced by JAXTPC's production pipeline: +seg (3D deposits), resp (2D wire signals), corr (3D→2D correspondence), +labl (track_id→label lookup tables). + +Who owns ``coord``/``energy`` is determined by which modalities are loaded: + +- seg present → coord is 3D (N,3) from deposits. Resp/corr stay namespaced. +- seg absent, resp present → all planes merged into coord (M,2) with plane_id. +- seg absent, corr+labl present → corr entries become coord (E,2) with labels. + +Example configs:: + + # 3D segmentation + data = dict(train=dict(type="LArTPCDataset", + modalities=("seg", "labl"), label_key="particle", ...)) + + # 2D segmentation (all planes) + data = dict(train=dict(type="LArTPCDataset", + modalities=("resp", "corr", "labl"), label_key="particle", ...)) + + # Mixed 3D + 2D + data = dict(train=dict(type="LArTPCDataset", + modalities=("seg", "resp", "corr", "labl"), ...)) +""" + +import os +import numpy as np +from copy import deepcopy + +from pimm.utils.logger import get_root_logger +from .builder import DATASETS +from .transform import Compose, TRANSFORMS +from .hepdataset import HEPDataset +from .readers.lartpc_seg_reader import LArTPCSegReader +from .readers.lartpc_resp_reader import LArTPCRespReader +from .readers.lartpc_labl_reader import LArTPCLablReader +from .readers.lartpc_corr_reader import LArTPCCorrReader + + +@DATASETS.register_module() +class LArTPCDataset(HEPDataset): + """Multimodal dataset for LArTPC detector simulation output. + + Parameters + ---------- + data_root : str + Root directory with seg/, resp/, corr/, labl/ subdirectories. + split : str + Split name for file discovery. + transform : list[dict] + Transform pipeline config. + modalities : tuple[str] + Which to load: 'seg', 'resp', 'labl', 'corr'. + dataset_name : str + File prefix (e.g., 'sim' for 'sim_seg_0000.h5'). + volume : int or None + Load only this volume index. None = all volumes. + label_key : str + Which label to use as 'segment': 'particle', 'cluster', 'interaction'. + min_deposits : int + Minimum 3D deposits per event (seg reader filter). + max_len : int + Cap on dataset length (-1 = unlimited). + loop : int + Dataset repetition per epoch. + include_physics : bool + Whether seg reader loads dx, theta, phi, charge, photons, etc. + label_keys : list or None + Which label datasets to load from labl files. + test_mode : bool + Enable test-time transforms. + test_cfg : object + Test config (voxelize, crop, post_transform, aug_transform). + """ + + def __init__( + self, + data_root, + split='train', + transform=None, + modalities=('seg',), + dataset_name='sim', + volume=None, + label_key='particle', + min_deposits=0, + max_len=-1, + loop=1, + include_physics=True, + label_keys=None, + test_mode=False, + test_cfg=None, + ignore_index=-1, + ): + super().__init__() + self.data_root = data_root + self.split = split + self.modalities = tuple(modalities) + self.dataset_name = dataset_name + self.volume = volume + self.label_key = label_key + self.min_deposits = min_deposits + self.max_len = max_len + self.loop = loop if not test_mode else 1 + self.ignore_index = ignore_index + self.test_mode = test_mode + self.test_cfg = test_cfg if test_mode else None + + self.transform = Compose(transform) + if test_mode and test_cfg is not None: + self.test_voxelize = TRANSFORMS.build(self.test_cfg.voxelize) + self.test_crop = ( + TRANSFORMS.build(self.test_cfg.crop) + if self.test_cfg.crop else None) + self.post_transform = Compose(self.test_cfg.post_transform) + self.aug_transform = [ + Compose(aug) for aug in self.test_cfg.aug_transform] + + # Build readers + self.seg_reader = None + self.resp_reader = None + self.labl_reader = None + self.corr_reader = None + + # Plane filter: if volume is set, only load that volume's planes + planes = 'all' + if volume is not None: + planes = [f'volume_{volume}_U', f'volume_{volume}_V', + f'volume_{volume}_Y'] + + if 'seg' in self.modalities: + self.seg_reader = LArTPCSegReader( + data_root=self._modality_root('seg'), split=split, + dataset_name=dataset_name, min_deposits=min_deposits, + include_physics=include_physics, volume=volume) + + if 'resp' in self.modalities: + self.resp_reader = LArTPCRespReader( + data_root=self._modality_root('resp'), split=split, + dataset_name=dataset_name, planes=planes) + + if 'labl' in self.modalities: + self.labl_reader = LArTPCLablReader( + data_root=self._modality_root('labl'), split=split, + dataset_name=dataset_name, label_keys=label_keys) + + if 'corr' in self.modalities: + self.corr_reader = LArTPCCorrReader( + data_root=self._modality_root('corr'), split=split, + dataset_name=dataset_name, planes=planes) + + # Canonical reader and length + active_readers = [r for r in (self.seg_reader, self.resp_reader, + self.labl_reader, self.corr_reader) + if r is not None] + if not active_readers: + raise ValueError(f"Need at least one modality, got {self.modalities}") + self._canonical_reader = (self.seg_reader or self.resp_reader + or self.corr_reader or self.labl_reader) + self._n_events = min(len(r) for r in active_readers) + + logger = get_root_logger() + + # Warn about modality combinations that won't produce labels + if (self.resp_reader and self.labl_reader + and not self.corr_reader and not self.seg_reader): + logger.warning( + "modalities=('resp','labl') without 'corr': labl provides " + "track_id→label tables but resp pixels can't be mapped to " + "track_ids without corr. No 'segment' will be produced. " + "Add 'corr' for 2D labels or 'seg' for 3D labels.") + + logger.info( + f"LArTPCDataset: {self._n_events} events, " + f"modalities={self.modalities}, " + f"volume={volume}, split={split}") + + def _modality_root(self, modality): + """Resolve root directory for a modality.""" + mod_dir = os.path.join(self.data_root, modality) + if os.path.isdir(mod_dir): + return mod_dir + split_dir = os.path.join(self.data_root, self.split) + if os.path.isdir(split_dir): + return self.data_root + return self.data_root + + def get_data(self, idx): + """Load one event. Who owns coord depends on modalities: + + - seg present: coord = 3D deposits. Resp/corr as namespaced keys. + - seg absent, corr+labl present: coord = 2D corr entries with labels. + - seg absent, resp present (no corr): coord = 2D resp merged. + """ + data_dict = {} + + # --- Seg (3D point cloud) → owns coord if present --- + if self.seg_reader is not None: + data_dict.update(self.seg_reader.read_event(idx)) + + # --- Labl (track_id → label lookup) --- + labl_data = {} + if self.labl_reader is not None: + labl_data = self.labl_reader.read_event(idx) + + # Apply labels to 3D seg data + if self.seg_reader is not None and labl_data: + self._apply_labl_to_3d(data_dict, labl_data) + + # --- Resp (2D wire planes) --- + resp_data = {} + if self.resp_reader is not None: + resp_data = self.resp_reader.read_event(idx) + + # --- Corr (correspondence) --- + corr_data = {} + if self.corr_reader is not None: + corr_data = self.corr_reader.read_event(idx) + + # --- Build point clouds for each spatial modality --- + # Each gets its own prefixed keys. When only one spatial source + # exists, its keys are also copied to the standard coord/energy + # so the default pipeline (GridSample, Collect, etc.) works. + + has_seg = self.seg_reader is not None + has_resp = bool(resp_data) + has_corr = bool(corr_data) + + # Resp → resp_coord/resp_energy/resp_plane_id + if has_resp: + self._merge_resp_planes(data_dict, resp_data, prefix='resp_') + # Also keep raw namespaced keys for per-plane access + data_dict.update(resp_data) + + # Corr+labl → corr_coord/corr_energy/corr_segment/corr_instance + if has_corr and labl_data: + self._build_corr_pointcloud(data_dict, corr_data, labl_data, prefix='corr_') + elif has_corr: + # corr without labl — keep as namespaced keys only + data_dict.update(corr_data) + + # --- Set standard coord/energy from the primary spatial source --- + if has_seg: + # seg already set coord/energy + pass + elif has_corr and labl_data: + # corr is primary (has labels) + data_dict['coord'] = data_dict['corr_coord'] + data_dict['energy'] = data_dict['corr_energy'] + data_dict['segment'] = data_dict['corr_segment'] + data_dict['instance'] = data_dict['corr_instance'] + data_dict['plane_id'] = data_dict['corr_plane_id'] + elif has_resp: + # resp is primary + data_dict['coord'] = data_dict['resp_coord'] + data_dict['energy'] = data_dict['resp_energy'] + data_dict['plane_id'] = data_dict['resp_plane_id'] + + # Pass through labl lookup tables (for downstream use) + if labl_data: + for k, v in labl_data.items(): + if k not in data_dict: + data_dict[k] = v + + # Metadata + data_dict['name'] = self.get_data_name(idx) + data_dict['split'] = self.split if isinstance(self.split, str) else 'custom' + return data_dict + + def _apply_labl_to_3d(self, data_dict, labl_data): + """Map 3D deposits' track_ids to labels via labl lookup. Vectorized.""" + track_ids = data_dict.get('track_ids') + volume_ids = data_dict.get('volume_id') + if track_ids is None: + return + + n = len(track_ids) + labels = np.full(n, -1, dtype=np.int32) + + vol_indices = sorted(set( + k.split('_')[1] for k in labl_data + if k.startswith('labl_v') and k.endswith('_track_ids') + )) + + for vi in vol_indices: + tids_key = f'labl_{vi}_track_ids' + label_key = f'labl_{vi}_{self.label_key}' + if tids_key not in labl_data or label_key not in labl_data: + continue + + vol_tids = labl_data[tids_key] + vol_labels = labl_data[label_key] + vol_num = int(vi[1:]) + + if volume_ids is not None: + vol_mask = volume_ids.ravel() == vol_num + else: + vol_mask = np.ones(n, dtype=bool) + + sort_idx = np.argsort(vol_tids) + sorted_tids = vol_tids[sort_idx] + sorted_labels = vol_labels[sort_idx] + + deposit_tids = track_ids[vol_mask] + insert_pos = np.searchsorted(sorted_tids, deposit_tids) + insert_pos = np.clip(insert_pos, 0, len(sorted_tids) - 1) + matched = sorted_tids[insert_pos] == deposit_tids + labels[vol_mask] = np.where(matched, sorted_labels[insert_pos], -1) + + data_dict['segment'] = labels + + def _merge_resp_planes(self, data_dict, resp_data, prefix=''): + """Merge all planes into {prefix}coord (M,2), {prefix}energy (M,1), {prefix}plane_id (M,1).""" + planes = sorted(set( + k.split('.')[1] for k in resp_data if k.endswith('.wire') + )) + + all_coord, all_energy, all_plane_id = [], [], [] + for i, plane in enumerate(planes): + wire = resp_data[f'plane.{plane}.wire'] + time = resp_data[f'plane.{plane}.time'] + value = resp_data[f'plane.{plane}.value'] + n = len(wire) + all_coord.append(np.stack([wire, time], axis=1).astype(np.float32)) + all_energy.append(value[:, None].astype(np.float32)) + all_plane_id.append(np.full((n, 1), i, dtype=np.int32)) + + data_dict[f'{prefix}coord'] = np.concatenate(all_coord, axis=0) + data_dict[f'{prefix}energy'] = np.concatenate(all_energy, axis=0) + data_dict[f'{prefix}plane_id'] = np.concatenate(all_plane_id, axis=0) + + def _build_corr_pointcloud(self, data_dict, corr_data, labl_data, prefix=''): + """Build 2D labeled point cloud from corr + labl. + + Each corr entry is a point: coord=(wire,time), feature=charge, + instance=group_id, segment from g2t+labl chain. + Overlapping instances at the same pixel are separate points. + """ + planes = sorted(set( + k.split('.')[1] for k in corr_data if k.endswith('.wire') + )) + + all_coord, all_charge, all_gid, all_segment, all_plane_id = [], [], [], [], [] + + for pi, plane in enumerate(planes): + wire_key = f'corr.{plane}.wire' + if wire_key not in corr_data: + continue + + wire = corr_data[f'corr.{plane}.wire'] + time = corr_data[f'corr.{plane}.time'] + gid = corr_data[f'corr.{plane}.group_id'] + charge = corr_data[f'corr.{plane}.charge'] + n = len(wire) + + all_coord.append(np.stack([wire, time], axis=1).astype(np.float32)) + all_charge.append(charge[:, None].astype(np.float32)) + all_gid.append(gid.astype(np.int32)) + all_plane_id.append(np.full((n, 1), pi, dtype=np.int32)) + + # group_id → g2t → track_id → labl → label + vol_idx = plane.split('_')[1] # 'volume_0_U' → '0' + g2t = corr_data.get(f'g2t_v{vol_idx}') + + labels = np.full(n, -1, dtype=np.int32) + if g2t is not None: + valid_gid = (gid >= 0) & (gid < len(g2t)) + track_ids = np.where(valid_gid, g2t[gid], -1) + + tids_key = f'labl_v{vol_idx}_track_ids' + lbl_key = f'labl_v{vol_idx}_{self.label_key}' + if tids_key in labl_data and lbl_key in labl_data: + labl_tids = labl_data[tids_key] + labl_vals = labl_data[lbl_key] + sort_idx = np.argsort(labl_tids) + sorted_tids = labl_tids[sort_idx] + sorted_vals = labl_vals[sort_idx] + insert_pos = np.searchsorted(sorted_tids, track_ids) + insert_pos = np.clip(insert_pos, 0, len(sorted_tids) - 1) + matched = sorted_tids[insert_pos] == track_ids + labels[matched] = sorted_vals[insert_pos[matched]] + + all_segment.append(labels) + + if not all_coord: + return + + data_dict[f'{prefix}coord'] = np.concatenate(all_coord, axis=0) + data_dict[f'{prefix}energy'] = np.concatenate(all_charge, axis=0) + data_dict[f'{prefix}instance'] = np.concatenate(all_gid, axis=0) + data_dict[f'{prefix}segment'] = np.concatenate(all_segment, axis=0) + data_dict[f'{prefix}plane_id'] = np.concatenate(all_plane_id, axis=0) + + def get_data_name(self, idx): + reader = self._canonical_reader + file_idx = int(np.searchsorted(reader.cumulative_lengths, idx, side='right')) + local = idx - (int(reader.cumulative_lengths[file_idx - 1]) + if file_idx > 0 else 0) + event_num = reader.indices[file_idx][local] + fname = os.path.basename(reader.h5_files[file_idx]) + return f"{fname}_evt{event_num:03d}" + + def prepare_train_data(self, idx): + return self.transform(self.get_data(idx % len(self))) + + def prepare_test_data(self, idx): + data_dict = self.get_data(idx % len(self)) + if self.transform is not None: + data_dict = self.transform(data_dict) + result_dict = dict(name=data_dict.pop("name")) + if "segment" in data_dict: + result_dict["segment"] = data_dict.pop("segment") + if "origin_segment" in data_dict: + assert "inverse" in data_dict + result_dict["origin_segment"] = data_dict.pop("origin_segment") + result_dict["inverse"] = data_dict.pop("inverse") + data_dict_list = [] + for aug in self.aug_transform: + data_dict_list.append(aug(deepcopy(data_dict))) + fragment_list = [] + for data in data_dict_list: + if self.test_voxelize is not None: + data_part_list = self.test_voxelize(data) + else: + data_part_list = [data] + for data_part in data_part_list: + if self.test_crop is not None: + data_part = self.test_crop(data_part) + else: + data_part = [data_part] + fragment_list += data_part + for i in range(len(fragment_list)): + fragment_list[i] = self.post_transform(fragment_list[i]) + result_dict["fragment_list"] = fragment_list + return result_dict + + def __getitem__(self, idx): + real_idx = idx % len(self) + if self.test_mode: + return self.prepare_test_data(real_idx) + return self.prepare_train_data(real_idx) + + def __len__(self): + n = self._n_events + if self.max_len > 0: + n = min(n, self.max_len) + return n * self.loop + + def __del__(self): + for attr in ('seg_reader', 'resp_reader', 'labl_reader', 'corr_reader'): + reader = getattr(self, attr, None) + if reader is not None: + reader.close() diff --git a/pimm/datasets/readers/.gitignore b/pimm/datasets/readers/.gitignore new file mode 100644 index 0000000..c18dd8d --- /dev/null +++ b/pimm/datasets/readers/.gitignore @@ -0,0 +1 @@ +__pycache__/ diff --git a/pimm/datasets/readers/__init__.py b/pimm/datasets/readers/__init__.py new file mode 100644 index 0000000..4227466 --- /dev/null +++ b/pimm/datasets/readers/__init__.py @@ -0,0 +1,6 @@ +from .lartpc_seg_reader import LArTPCSegReader +from .lartpc_resp_reader import LArTPCRespReader +from .lartpc_labl_reader import LArTPCLablReader +from .lartpc_corr_reader import LArTPCCorrReader +from .wc_seg_reader import WCSegReader +from .wc_sensor_reader import WCSensorReader diff --git a/pimm/datasets/readers/lartpc_corr_reader.py b/pimm/datasets/readers/lartpc_corr_reader.py new file mode 100644 index 0000000..1286189 --- /dev/null +++ b/pimm/datasets/readers/lartpc_corr_reader.py @@ -0,0 +1,190 @@ +""" +LArTPCCorrReader — reads 3D→2D correspondence from JAXTPC corr files. + +Decodes CSR-encoded per-plane correspondence into flat arrays: +per pixel entry: (wire, time, group_id, charge). + +Also loads per-volume group_to_track lookup tables. + +All decoding is fully vectorized (no Python loops over groups). +""" + +import os +import glob +import numpy as np +import h5py +from pimm.utils.logger import get_root_logger + + +class LArTPCCorrReader: + """Reads 3D→2D correspondence from JAXTPC corr HDF5 files. + + Parameters + ---------- + data_root : str + Directory containing corr shard files. + split : str + Split name. + dataset_name : str + File prefix (e.g., 'sim' matches 'sim_corr_0000.h5'). + planes : str or list + Which planes to load: 'all' or list like ['volume_0_U']. + """ + + def __init__(self, data_root, split='train', dataset_name='sim', + planes='all', **kwargs): + self.data_root = data_root + self.split = split + self.dataset_name = dataset_name + self.planes = planes + + self.h5_files = self._find_files() + assert len(self.h5_files) > 0, ( + f"No corr files found for '{dataset_name}' in {data_root}/{split}") + + self._initted = False + self._h5data = [] + + self._build_index() + + def _find_files(self): + pattern = os.path.join( + self.data_root, self.split, + f'{self.dataset_name}_corr_*.h5') + files = sorted(glob.glob(pattern)) + if not files: + pattern = os.path.join( + self.data_root, f'{self.dataset_name}_corr_*.h5') + files = sorted(glob.glob(pattern)) + return files + + def _build_index(self): + log = get_root_logger() + self.cumulative_lengths = [] + self.indices = [] + + for h5_path in self.h5_files: + try: + with h5py.File(h5_path, 'r', libver='latest', swmr=True) as f: + n_events = int(f['config'].attrs['n_events']) + index = np.arange(n_events, dtype=np.int64) + except Exception as e: + log.warning(f"Error processing {h5_path}: {e}") + index = np.array([], dtype=np.int64) + + self.cumulative_lengths.append(len(index)) + self.indices.append(index) + + self.cumulative_lengths = np.cumsum(self.cumulative_lengths) + log.info(f"LArTPCCorrReader: {self.cumulative_lengths[-1]} events " + f"from {len(self.h5_files)} files") + + def h5py_worker_init(self): + self._h5data = [ + h5py.File(p, 'r', libver='latest', swmr=True) + for p in self.h5_files + ] + self._initted = True + + def _locate_event(self, idx): + file_idx = int(np.searchsorted(self.cumulative_lengths, idx, side='right')) + local_idx = idx - (int(self.cumulative_lengths[file_idx - 1]) if file_idx > 0 else 0) + event_num = self.indices[file_idx][local_idx] + event_key = f'event_{event_num:03d}' + f = self._h5data[file_idx] + return f, event_key + + @staticmethod + def _decode_plane_vectorized(g): + """Decode one plane's CSR correspondence — fully vectorized. + + Returns (wire, time, group_id, charge) arrays, all shape (E,). + """ + group_ids = g['group_ids'][:] + group_sizes = g['group_sizes'][:].astype(np.int32) + center_wires = g['center_wires'][:] + center_times = g['center_times'][:] + peak_charges = g['peak_charges'][:] + delta_wires = g['delta_wires'][:] + delta_times = g['delta_times'][:] + charges_u16 = g['charges_u16'][:] + + G = len(group_ids) + if G == 0: + empty = np.array([], dtype=np.int32) + return empty, empty, empty, np.array([], dtype=np.float32) + + # Broadcast group-level arrays to per-entry arrays + wires = (np.repeat(center_wires, group_sizes).astype(np.int32) + + delta_wires.astype(np.int32)) + times = (np.repeat(center_times, group_sizes).astype(np.int32) + + delta_times.astype(np.int32)) + gids = np.repeat(group_ids, group_sizes) + charges = (np.repeat(peak_charges, group_sizes) + * charges_u16.astype(np.float32) / 65535.0) + + return wires, times, gids, charges + + def read_event(self, idx): + """Read one event's correspondence data. + + Returns dict with: + corr.{vol_plane}.wire: (E,) int32 — wire index per entry + corr.{vol_plane}.time: (E,) int32 — time index per entry + corr.{vol_plane}.group_id: (E,) int32 — group ID per entry + corr.{vol_plane}.charge: (E,) float32 — charge per entry + g2t_v{N}: (G,) int32 — group_to_track per volume + """ + if not self._initted: + self.h5py_worker_init() + + f, event_key = self._locate_event(idx) + evt = f[event_key] + + data_dict = {} + + for vol_key in evt: + vol = evt[vol_key] + if not isinstance(vol, h5py.Group): + continue + if not vol_key.startswith('volume_'): + continue + + vol_idx = vol_key.replace('volume_', '') + + # group_to_track lookup for this volume + if 'group_to_track' in vol: + data_dict[f'g2t_v{vol_idx}'] = vol['group_to_track'][:].astype(np.int32) + + # Per-plane correspondence + for plane_key in vol: + pg = vol[plane_key] + if not isinstance(pg, h5py.Group) or 'group_ids' not in pg: + continue + + plane_label = f'volume_{vol_idx}_{plane_key}' + if self.planes != 'all' and plane_label not in self.planes: + continue + + wires, times, gids, charges = self._decode_plane_vectorized(pg) + + prefix = f'corr.{plane_label}' + data_dict[f'{prefix}.wire'] = wires + data_dict[f'{prefix}.time'] = times + data_dict[f'{prefix}.group_id'] = gids + data_dict[f'{prefix}.charge'] = charges + + return data_dict + + def __len__(self): + return int(self.cumulative_lengths[-1]) if len(self.cumulative_lengths) > 0 else 0 + + def close(self): + if self._initted: + for f in self._h5data: + try: + f.close() + except Exception: + pass + self._h5data = [] + self._initted = False diff --git a/pimm/datasets/readers/lartpc_labl_reader.py b/pimm/datasets/readers/lartpc_labl_reader.py new file mode 100644 index 0000000..10ac08f --- /dev/null +++ b/pimm/datasets/readers/lartpc_labl_reader.py @@ -0,0 +1,150 @@ +""" +LArTPCLablReader — reads per-volume track_id → label lookup tables. + +The labl file stores a mapping from track_id to labels (particle, cluster, +interaction) per volume. This is used for: + - 3D tasks: deposit's track_id → look up label directly + - 2D tasks: pixel → group_id → group_to_track → track_id → look up label + +Output: a dict of numpy arrays per volume, keyed by label name. + labl_v0_track_ids: (T,) int32 — unique track IDs + labl_v0_particle: (T,) int32 — particle type per track + labl_v0_cluster: (T,) int32 — cluster ID per track + labl_v0_interaction: (T,) int32 — interaction ID per track +""" + +import os +import glob +import numpy as np +import h5py +from pimm.utils.logger import get_root_logger + + +class LArTPCLablReader: + """Reads per-volume track_id → label lookup tables. + + Parameters + ---------- + data_root : str + Directory containing labl shard files. + split : str + Split name. + dataset_name : str + File prefix (e.g., 'sim' matches 'sim_labl_0000.h5'). + label_keys : list of str or None + Which label datasets to load (default: all available). + """ + + def __init__(self, data_root, split='train', dataset_name='sim', + label_keys=None): + self.data_root = data_root + self.split = split + self.dataset_name = dataset_name + self.label_keys = label_keys # None = load all + + self.h5_files = self._find_files() + assert len(self.h5_files) > 0, ( + f"No labl files found for '{dataset_name}' in {data_root}/{split}") + + self._initted = False + self._h5data = [] + + self._build_index() + + def _find_files(self): + pattern = os.path.join( + self.data_root, self.split, + f'{self.dataset_name}_labl_*.h5') + files = sorted(glob.glob(pattern)) + if not files: + pattern = os.path.join( + self.data_root, f'{self.dataset_name}_labl_*.h5') + files = sorted(glob.glob(pattern)) + return files + + def _build_index(self): + log = get_root_logger() + self.cumulative_lengths = [] + self.indices = [] + + for h5_path in self.h5_files: + try: + with h5py.File(h5_path, 'r', libver='latest', swmr=True) as f: + n_events = int(f['config'].attrs['n_events']) + index = np.arange(n_events, dtype=np.int64) + except Exception as e: + log.warning(f"Error processing {h5_path}: {e}") + index = np.array([], dtype=np.int64) + + self.cumulative_lengths.append(len(index)) + self.indices.append(index) + + self.cumulative_lengths = np.cumsum(self.cumulative_lengths) + log.info(f"LArTPCLablReader: {self.cumulative_lengths[-1]} events " + f"from {len(self.h5_files)} files") + + def h5py_worker_init(self): + self._h5data = [ + h5py.File(p, 'r', libver='latest', swmr=True) + for p in self.h5_files + ] + self._initted = True + + def _locate_event(self, idx): + file_idx = int(np.searchsorted(self.cumulative_lengths, idx, side='right')) + local_idx = idx - (int(self.cumulative_lengths[file_idx - 1]) if file_idx > 0 else 0) + event_num = self.indices[file_idx][local_idx] + event_key = f'event_{event_num:03d}' + f = self._h5data[file_idx] + return f, event_key + + def read_event(self, idx): + """Read one event's per-volume label lookup tables. + + Returns dict with keys like: + labl_v0_track_ids: (T,) int32 + labl_v0_particle: (T,) int32 + labl_v0_cluster: (T,) int32 + labl_v0_interaction: (T,) int32 + """ + if not self._initted: + self.h5py_worker_init() + + f, event_key = self._locate_event(idx) + evt = f[event_key] + + data_dict = {} + for vol_key in evt: + vol = evt[vol_key] + if not isinstance(vol, h5py.Group): + continue + if 'track_ids' not in vol: + continue + + # Volume index from key (e.g., 'volume_0' → 0) + vol_idx = vol_key.replace('volume_', '') + prefix = f'labl_v{vol_idx}' + + data_dict[f'{prefix}_track_ids'] = vol['track_ids'][:].astype(np.int32) + + for lk in vol: + if lk == 'track_ids': + continue + if self.label_keys is not None and lk not in self.label_keys: + continue + data_dict[f'{prefix}_{lk}'] = vol[lk][:].astype(np.int32) + + return data_dict + + def __len__(self): + return int(self.cumulative_lengths[-1]) if len(self.cumulative_lengths) > 0 else 0 + + def close(self): + if self._initted: + for f in self._h5data: + try: + f.close() + except Exception: + pass + self._h5data = [] + self._initted = False diff --git a/pimm/datasets/readers/lartpc_resp_reader.py b/pimm/datasets/readers/lartpc_resp_reader.py new file mode 100644 index 0000000..33f6ea9 --- /dev/null +++ b/pimm/datasets/readers/lartpc_resp_reader.py @@ -0,0 +1,180 @@ +""" +LArTPCRespReader — reads sparse wire signals from JAXTPC resp files. + +Decodes delta-encoded (wire, time, value) triples per plane. +Output keys are dot-namespaced: plane.{plane_label}.wire/time/value + +Handles both old format (planes directly under event) and new format +(planes under volume_N/ subgroups). +""" + +import os +import glob +import numpy as np +import h5py +from pimm.utils.logger import get_root_logger + + +class LArTPCRespReader: + """Reads sparse wire signals from JAXTPC resp HDF5 files. + + Parameters + ---------- + data_root : str + Directory containing resp shard files. + split : str + Split name — used as subdirectory or glob pattern. + dataset_name : str + File prefix (e.g., 'sim' matches 'sim_resp_0000.h5'). + planes : str or list + Which planes to load: 'all' or list like ['east_U', 'east_V']. + decode_digitization : bool + If True, subtract pedestal from uint16 values. + """ + + def __init__(self, data_root, split='train', dataset_name='sim', + planes='all', decode_digitization=True): + self.data_root = data_root + self.split = split + self.dataset_name = dataset_name + self.planes = planes + self.decode_digitization = decode_digitization + + self.h5_files = self._find_files() + assert len(self.h5_files) > 0, ( + f"No resp files found for '{dataset_name}' in {data_root}/{split}") + + self._initted = False + self._h5data = [] + + self._build_index() + + def _find_files(self): + """Discover resp shard files.""" + pattern = os.path.join( + self.data_root, self.split, + f'{self.dataset_name}_resp_*.h5') + files = sorted(glob.glob(pattern)) + if not files: + pattern = os.path.join( + self.data_root, f'{self.dataset_name}_resp_*.h5') + files = sorted(glob.glob(pattern)) + return files + + def _build_index(self): + """Scan files, count events, build cumulative index.""" + log = get_root_logger() + self.cumulative_lengths = [] + self.indices = [] + + for h5_path in self.h5_files: + try: + with h5py.File(h5_path, 'r', libver='latest', swmr=True) as f: + n_events = int(f['config'].attrs['n_events']) + index = np.arange(n_events, dtype=np.int64) + except Exception as e: + log.warning(f"Error processing {h5_path}: {e}") + index = np.array([], dtype=np.int64) + + self.cumulative_lengths.append(len(index)) + self.indices.append(index) + + self.cumulative_lengths = np.cumsum(self.cumulative_lengths) + log.info(f"LArTPCRespReader: {self.cumulative_lengths[-1]} events " + f"from {len(self.h5_files)} files") + + def h5py_worker_init(self): + """Lazily open file handles (called after DataLoader fork).""" + self._h5data = [ + h5py.File(p, 'r', libver='latest', swmr=True) + for p in self.h5_files + ] + self._initted = True + + def _locate_event(self, idx): + """Map global index -> (file_handle, event_key).""" + file_idx = int(np.searchsorted(self.cumulative_lengths, idx, side='right')) + local_idx = idx - (int(self.cumulative_lengths[file_idx - 1]) if file_idx > 0 else 0) + event_num = self.indices[file_idx][local_idx] + event_key = f'event_{event_num:03d}' + f = self._h5data[file_idx] + return f, event_key + + def _iter_planes(self, evt): + """Yield (plane_label, h5py.Group) for each plane in an event. + + Handles both formats: + - Old: planes directly under event (east_U, east_V, ...) + - New: planes under volume_N/ subgroups + """ + for key in evt: + obj = evt[key] + if not isinstance(obj, h5py.Group): + continue + if key.startswith('volume_'): + # New format: volume_0/U, volume_0/V, ... + vol_label = key # e.g., 'volume_0' + for plane_key in obj: + pg = obj[plane_key] + if isinstance(pg, h5py.Group) and 'delta_wire' in pg: + yield f'{vol_label}_{plane_key}', pg + elif 'delta_wire' in obj: + # Old format: east_U, east_V, ... + yield key, obj + + def _decode_plane(self, g): + """Decode one plane's delta-encoded sparse data.""" + wire_start = int(g.attrs['wire_start']) + time_start = int(g.attrs['time_start']) + + wire = wire_start + np.cumsum(g['delta_wire'][:]).astype(np.int32) + time = time_start + np.cumsum(g['delta_time'][:]).astype(np.int32) + + raw_values = g['values'][:] + if self.decode_digitization and raw_values.dtype == np.uint16: + ped = int(g.attrs.get('pedestal', 0)) + values = raw_values.astype(np.float32) - ped + else: + values = raw_values.astype(np.float32) + + return wire, time, values + + def read_event(self, idx): + """Read one event, return dict with plane-namespaced sparse arrays. + + Returns keys like: + plane.east_U.wire: (M,) int32 + plane.east_U.time: (M,) int32 + plane.east_U.value: (M,) float32 + """ + if not self._initted: + self.h5py_worker_init() + + f, event_key = self._locate_event(idx) + evt = f[event_key] + + data_dict = {} + for plane_label, pg in self._iter_planes(evt): + if self.planes != 'all' and plane_label not in self.planes: + continue + + wire, time, values = self._decode_plane(pg) + prefix = f'plane.{plane_label}' + data_dict[f'{prefix}.wire'] = wire + data_dict[f'{prefix}.time'] = time + data_dict[f'{prefix}.value'] = values + + return data_dict + + def __len__(self): + return int(self.cumulative_lengths[-1]) if len(self.cumulative_lengths) > 0 else 0 + + def close(self): + if self._initted: + for f in self._h5data: + try: + f.close() + except Exception: + pass + self._h5data = [] + self._initted = False diff --git a/pimm/datasets/readers/lartpc_seg_reader.py b/pimm/datasets/readers/lartpc_seg_reader.py new file mode 100644 index 0000000..f5dbec4 --- /dev/null +++ b/pimm/datasets/readers/lartpc_seg_reader.py @@ -0,0 +1,281 @@ +""" +LArTPCSegReader — reads 3D truth deposits from JAXTPC seg files. + +Produces raw geometry, physics, and IDs. Labels come from the labl file +via LArTPCLablReader + LArTPCDataset._apply_labl_to_3d(), or from +PDGToSemantic transform as a fallback. + +Output dict: + coord (N,3), energy (N,1), volume_id (N,1), + track_ids (N,), group_ids (N,), pdg (N,), interaction_ids (N,), + ancestor_track_ids (N,), + and optionally: dx, theta, phi, t0_us, charge, photons, qs_fractions +""" + +import os +import glob +import numpy as np +import h5py +from pimm.utils.logger import get_root_logger + + +class LArTPCSegReader: + """Reads 3D truth deposits from JAXTPC seg HDF5 files. + + Concatenates volumes into a single point cloud with a volume_id feature. + No label computation — just raw data. + + Parameters + ---------- + data_root : str + Directory containing seg shard files. + split : str + Split name — used as subdirectory or glob pattern. + dataset_name : str + File prefix (e.g., 'sim' matches 'sim_seg_0000.h5'). + min_deposits : int + Minimum deposits per event to include in index. + include_physics : bool + Whether to load dx, theta, phi, charge, photons, etc. + """ + + def __init__(self, data_root, split='train', dataset_name='sim', + min_deposits=0, include_physics=True, volume=None): + self.data_root = data_root + self.split = split + self.dataset_name = dataset_name + self.min_deposits = min_deposits + self.include_physics = include_physics + self.volume = volume # None = all volumes, int = single volume + + self.h5_files = self._find_files() + assert len(self.h5_files) > 0, ( + f"No seg files found for '{dataset_name}' in {data_root}/{split}") + + self._initted = False + self._h5data = [] + + self._build_index() + + def _find_files(self): + """Discover seg shard files.""" + pattern = os.path.join( + self.data_root, self.split, + f'{self.dataset_name}_seg_*.h5') + files = sorted(glob.glob(pattern)) + if not files: + pattern = os.path.join( + self.data_root, f'{self.dataset_name}_seg_*.h5') + files = sorted(glob.glob(pattern)) + return files + + def _build_index(self): + """Scan files, count events, build cumulative index.""" + log = get_root_logger() + self.cumulative_lengths = [] + self.indices = [] + + for h5_path in self.h5_files: + try: + with h5py.File(h5_path, 'r', libver='latest', swmr=True) as f: + n_events = int(f['config'].attrs['n_events']) + n_volumes = int(f['config'].attrs.get('n_volumes', 1)) + + if self.min_deposits > 0: + valid = [] + for i in range(n_events): + evt_key = f'event_{i:03d}' + if evt_key not in f: + continue + evt = f[evt_key] + total = sum( + int(evt[f'volume_{v}'].attrs.get('n_actual', 0)) + for v in range(n_volumes) + if f'volume_{v}' in evt + ) if n_volumes > 1 else ( + evt['positions'].shape[0] if 'positions' in evt else 0 + ) + if total >= self.min_deposits: + valid.append(i) + index = np.array(valid, dtype=np.int64) + else: + index = np.arange(n_events, dtype=np.int64) + + except Exception as e: + log.warning(f"Error processing {h5_path}: {e}") + index = np.array([], dtype=np.int64) + + self.cumulative_lengths.append(len(index)) + self.indices.append(index) + + self.cumulative_lengths = np.cumsum(self.cumulative_lengths) + log.info(f"LArTPCSegReader: {self.cumulative_lengths[-1]} events " + f"from {len(self.h5_files)} files " + f"(min_deposits={self.min_deposits})") + + def h5py_worker_init(self): + """Lazily open file handles (called after DataLoader fork).""" + self._h5data = [ + h5py.File(p, 'r', libver='latest', swmr=True) + for p in self.h5_files + ] + self._initted = True + + def _locate_event(self, idx): + """Map global index → (file_handle, event_key, n_volumes).""" + file_idx = int(np.searchsorted(self.cumulative_lengths, idx, side='right')) + local_idx = idx - (int(self.cumulative_lengths[file_idx - 1]) if file_idx > 0 else 0) + event_num = self.indices[file_idx][local_idx] + event_key = f'event_{event_num:03d}' + f = self._h5data[file_idx] + n_volumes = int(f['config'].attrs.get('n_volumes', 1)) + return f, event_key, n_volumes + + def read_event(self, idx): + """Read one event, return flat dict of numpy arrays. + + No label computation — just raw geometry, physics, and IDs. + """ + if not self._initted: + self.h5py_worker_init() + + f, event_key, n_volumes = self._locate_event(idx) + evt = f[event_key] + + vol_arrays = [] + + if n_volumes > 1: + for v in range(n_volumes): + if self.volume is not None and v != self.volume: + continue + vk = f'volume_{v}' + if vk not in evt: + continue + vg = evt[vk] + n = int(vg.attrs.get('n_actual', 0)) + if n == 0: + continue + vol_arrays.append(self._read_volume(vg, n, v)) + else: + # Legacy flat format + if 'positions' in evt: + n = evt['positions'].shape[0] + vol_arrays.append(self._read_volume_flat(evt, n, 0)) + + if not vol_arrays: + return self._empty_dict() + + return self._concat_volumes(vol_arrays) + + def _read_volume(self, vg, n, vol_idx): + """Read arrays from a volume group.""" + step = float(vg.attrs['pos_step_mm']) + origin = np.array([vg.attrs['pos_origin_x'], + vg.attrs['pos_origin_y'], + vg.attrs['pos_origin_z']], dtype=np.float32) + + d = { + 'coord': vg['positions'][:].astype(np.float32) * step + origin, + 'energy': vg['de'][:].astype(np.float32), + 'volume_id': np.full(n, vol_idx, dtype=np.int32), + 'track_ids': vg['track_ids'][:].astype(np.int32), + 'group_ids': vg['group_ids'][:].astype(np.int32), + } + + # Optional ID fields + for key, dtype in [('pdg', np.int32), ('interaction_ids', np.int32), + ('ancestor_track_ids', np.int32)]: + if key in vg: + d[key] = vg[key][:].astype(dtype) + else: + d[key] = np.full(n, -1, dtype=dtype) + + # Optional physics + if self.include_physics: + for key in ('dx', 'theta', 'phi', 't0_us'): + if key in vg: + d[key] = vg[key][:].astype(np.float32) + for key in ('charge', 'photons', 'qs_fractions'): + if key in vg: + d[key] = vg[key][:].astype(np.float32) + + return d + + def _read_volume_flat(self, evt, n, vol_idx): + """Read from legacy flat event format (no volume subgroups).""" + step = float(evt.attrs['pos_step_mm']) + origin = np.array([evt.attrs['pos_origin_x'], + evt.attrs['pos_origin_y'], + evt.attrs['pos_origin_z']], dtype=np.float32) + + d = { + 'coord': evt['positions'][:].astype(np.float32) * step + origin, + 'energy': evt['de'][:].astype(np.float32), + 'volume_id': np.full(n, vol_idx, dtype=np.int32), + 'track_ids': evt['track_ids'][:].astype(np.int32), + 'group_ids': evt['group_ids'][:].astype(np.int32), + } + + for key, dtype in [('pdg', np.int32), ('interaction_ids', np.int32), + ('ancestor_track_ids', np.int32)]: + if key in evt: + d[key] = evt[key][:].astype(dtype) + else: + d[key] = np.full(n, -1, dtype=dtype) + + if self.include_physics: + for key in ('dx', 'theta', 'phi', 't0_us'): + if key in evt: + d[key] = evt[key][:].astype(np.float32) + for key in ('charge', 'photons', 'qs_fractions'): + if key in evt: + d[key] = evt[key][:].astype(np.float32) + + return d + + def _concat_volumes(self, vol_arrays): + """Concatenate per-volume dicts into a single flat dict.""" + keys = vol_arrays[0].keys() + data_dict = {} + for k in keys: + arrays = [v[k] for v in vol_arrays if k in v] + combined = np.concatenate(arrays, axis=0) + if k == 'coord': + data_dict[k] = combined # already (N,3) float32 + elif k in ('energy', 'dx', 'theta', 'phi', 't0_us', + 'charge', 'photons', 'qs_fractions'): + data_dict[k] = combined[:, None] # (N,1) + elif k == 'volume_id': + data_dict[k] = combined[:, None] # (N,1) + else: + data_dict[k] = combined # (N,) for IDs + + return data_dict + + def _empty_dict(self): + """Minimal valid dict for empty events.""" + d = { + 'coord': np.zeros((0, 3), dtype=np.float32), + 'energy': np.zeros((0, 1), dtype=np.float32), + 'volume_id': np.zeros((0, 1), dtype=np.int32), + 'track_ids': np.zeros((0,), dtype=np.int32), + 'group_ids': np.zeros((0,), dtype=np.int32), + 'pdg': np.zeros((0,), dtype=np.int32), + 'interaction_ids': np.zeros((0,), dtype=np.int32), + 'ancestor_track_ids': np.zeros((0,), dtype=np.int32), + } + return d + + def __len__(self): + return int(self.cumulative_lengths[-1]) if len(self.cumulative_lengths) > 0 else 0 + + def close(self): + """Close open file handles.""" + if self._initted: + for f in self._h5data: + try: + f.close() + except Exception: + pass + self._h5data = [] + self._initted = False diff --git a/pimm/datasets/readers/wc_seg_reader.py b/pimm/datasets/readers/wc_seg_reader.py new file mode 100644 index 0000000..0e762c3 --- /dev/null +++ b/pimm/datasets/readers/wc_seg_reader.py @@ -0,0 +1,201 @@ +""" +WCSegReader — reads 3D track segments from Water Cherenkov segment files. + +Format: flat CSR arrays (no per-event groups). + - track_event_offset (n_events+1,) — CSR into track arrays + - segment_offset (n_tracks+1,) — CSR into segment arrays + - start_x/y/z, end_x/y/z, edep, time (total_segments,) — per-segment + - track_id, pdg, parent_id, initial_energy (n_tracks,) — per-track + +Output dict: + coord (N,3), energy (N,1), time (N,1), + track_ids (N,), pdg (N,), parent_ids (N,) +""" + +import os +import glob +import numpy as np +import h5py +from pimm.utils.logger import get_root_logger + + +class WCSegReader: + """Reads 3D track segments from WC segment HDF5 files. + + Parameters + ---------- + data_root : str + Directory containing segment shard files. + split : str + Split name. + dataset_name : str + File prefix — matches files like '{dataset_name}_*segment*_*.h5' + or 'segment_events_*.h5'. + min_segments : int + Minimum segments per event to include. + """ + + def __init__(self, data_root, split='', dataset_name='wc', + min_segments=0, **kwargs): + self.data_root = data_root + self.split = split + self.dataset_name = dataset_name + self.min_segments = min_segments + + self.h5_files = self._find_files() + assert len(self.h5_files) > 0, ( + f"No WC seg files found in {data_root}/{split}") + + self._initted = False + self._h5data = [] + self._build_index() + + def _find_files(self): + """Find segment HDF5 files. Tries multiple naming patterns.""" + for pattern in [ + os.path.join(self.data_root, self.split, f'{self.dataset_name}_seg_*.h5'), + os.path.join(self.data_root, f'{self.dataset_name}_seg_*.h5'), + os.path.join(self.data_root, self.split, 'segment_events_*.h5'), + os.path.join(self.data_root, 'segment_events_*.h5'), + os.path.join(self.data_root, self.split, '*segment*.h5'), + os.path.join(self.data_root, '*segment*.h5'), + ]: + files = sorted(glob.glob(pattern)) + if files: + return files + return [] + + def _build_index(self): + log = get_root_logger() + self.cumulative_lengths = [] + self.indices = [] + self._file_n_events = [] + + for h5_path in self.h5_files: + try: + with h5py.File(h5_path, 'r', libver='latest', swmr=True) as f: + n_events = int(f.attrs.get('n_events', 0)) + if n_events == 0 and 'event_number' in f: + n_events = f['event_number'].shape[0] + + if self.min_segments > 0 and 'segment_offset' in f and 'track_event_offset' in f: + track_offsets = f['track_event_offset'][:] + seg_offsets = f['segment_offset'][:] + valid = [] + for i in range(n_events): + t0 = track_offsets[i] + t1 = track_offsets[i + 1] + if t1 > t0: + n_seg = int(seg_offsets[t1] - seg_offsets[t0]) + else: + n_seg = 0 + if n_seg >= self.min_segments: + valid.append(i) + index = np.array(valid, dtype=np.int64) + else: + index = np.arange(n_events, dtype=np.int64) + self._file_n_events.append(n_events) + except Exception as e: + log.warning(f"Error processing {h5_path}: {e}") + index = np.array([], dtype=np.int64) + self._file_n_events.append(0) + + self.cumulative_lengths.append(len(index)) + self.indices.append(index) + + self.cumulative_lengths = np.cumsum(self.cumulative_lengths) + log.info(f"WCSegReader: {self.cumulative_lengths[-1]} events " + f"from {len(self.h5_files)} files") + + def h5py_worker_init(self): + self._h5data = [ + h5py.File(p, 'r', libver='latest', swmr=True) + for p in self.h5_files + ] + self._initted = True + + def _locate_event(self, idx): + file_idx = int(np.searchsorted(self.cumulative_lengths, idx, side='right')) + local_idx = idx - (int(self.cumulative_lengths[file_idx - 1]) if file_idx > 0 else 0) + event_num = self.indices[file_idx][local_idx] + return self._h5data[file_idx], event_num + + def read_event(self, idx): + if not self._initted: + self.h5py_worker_init() + + f, event_num = self._locate_event(idx) + + # Track range for this event + track_offsets = f['track_event_offset'] + t0 = int(track_offsets[event_num]) + t1 = int(track_offsets[event_num + 1]) + n_tracks = t1 - t0 + + if n_tracks == 0: + return self._empty_dict() + + # Segment range for these tracks + seg_offsets = f['segment_offset'] + s0 = int(seg_offsets[t0]) + s1 = int(seg_offsets[t1]) + n_seg = s1 - s0 + + if n_seg == 0: + return self._empty_dict() + + # Segment data + sx = f['start_x'][s0:s1] + sy = f['start_y'][s0:s1] + sz = f['start_z'][s0:s1] + ex = f['end_x'][s0:s1] + ey = f['end_y'][s0:s1] + ez = f['end_z'][s0:s1] + + mid_x = (sx + ex) / 2 + mid_y = (sy + ey) / 2 + mid_z = (sz + ez) / 2 + + # Per-track data, expanded to per-segment + track_ids = f['track_id'][t0:t1].astype(np.int32) + pdg = f['pdg'][t0:t1].astype(np.int32) + parent_ids = f['parent_id'][t0:t1].astype(np.int32) + + # Number of segments per track (from segment_offset) + n_segs_per_track = np.diff(seg_offsets[t0:t1 + 1]).astype(np.int32) + + seg_track_ids = np.repeat(track_ids, n_segs_per_track) + seg_pdg = np.repeat(pdg, n_segs_per_track) + seg_parent_ids = np.repeat(parent_ids, n_segs_per_track) + + return { + 'coord': np.stack([mid_x, mid_y, mid_z], axis=1).astype(np.float32), + 'energy': f['edep'][s0:s1].astype(np.float32)[:, None], + 'time': f['time'][s0:s1].astype(np.float32)[:, None], + 'track_ids': seg_track_ids, + 'pdg': seg_pdg, + 'parent_ids': seg_parent_ids, + } + + def _empty_dict(self): + return { + 'coord': np.zeros((0, 3), dtype=np.float32), + 'energy': np.zeros((0, 1), dtype=np.float32), + 'time': np.zeros((0, 1), dtype=np.float32), + 'track_ids': np.zeros((0,), dtype=np.int32), + 'pdg': np.zeros((0,), dtype=np.int32), + 'parent_ids': np.zeros((0,), dtype=np.int32), + } + + def __len__(self): + return int(self.cumulative_lengths[-1]) if len(self.cumulative_lengths) > 0 else 0 + + def close(self): + if self._initted: + for fh in self._h5data: + try: + fh.close() + except Exception: + pass + self._h5data = [] + self._initted = False diff --git a/pimm/datasets/readers/wc_sensor_reader.py b/pimm/datasets/readers/wc_sensor_reader.py new file mode 100644 index 0000000..aa308f0 --- /dev/null +++ b/pimm/datasets/readers/wc_sensor_reader.py @@ -0,0 +1,223 @@ +""" +WCSensorReader — reads PMT sensor data from Water Cherenkov sensor files. + +Format: flat CSR arrays (no per-event groups). + - event_hit_offsets (n_events+1,) — CSR into hit arrays + - event_hit_sensor_idx, event_hit_PE, event_hit_T (total_hits,) — per-hit + - particle_event_offset (n_events+1,) — CSR into particle arrays + - particle_hit_offsets (n_particles+1,) — CSR into per-particle hits + - particle_hit_sensor_idx, particle_hit_PE, particle_hit_T (total_pp_hits,) + - particle_category (n_particles,) — semantic labels + +PMT 3D positions are NOT stored in the file — must be provided separately +(via pmt_positions_file or pmt_positions array). + +Output (sensor response): + pmt_pe (N_sensors,), pmt_t (N_sensors,) + If pmt_positions provided: pmt_coord (N_sensors, 3) + +Output (per-particle sparse, when include_labels=True): + pp_sensor_idx (E,), pp_particle_idx (E,), pp_pe (E,), + pp_t (E,), pp_category (E,) +""" + +import os +import glob +import numpy as np +import h5py +from pimm.utils.logger import get_root_logger + + +class WCSensorReader: + """Reads PMT sensor data from WC sensor HDF5 files. + + Parameters + ---------- + data_root : str + Directory containing sensor shard files. + split : str + Split name. + dataset_name : str + File prefix. + include_labels : bool + Whether to load per-particle hit decomposition. + pe_threshold : float + Minimum PE for per-particle hits (already sparse in file, this + applies additional filtering). + pmt_positions : ndarray or None + (N_sensors, 3) PMT positions. If None, coord won't be produced + (sensor_idx is still available). + pmt_positions_file : str or None + Path to .npy file with PMT positions. Alternative to pmt_positions. + """ + + def __init__(self, data_root, split='', dataset_name='wc', + include_labels=True, pe_threshold=0.0, + pmt_positions=None, pmt_positions_file=None, **kwargs): + self.data_root = data_root + self.split = split + self.dataset_name = dataset_name + self.include_labels = include_labels + self.pe_threshold = pe_threshold + + # PMT positions (optional) + if pmt_positions is not None: + self._pmt_positions = np.asarray(pmt_positions, dtype=np.float32) + elif pmt_positions_file is not None: + self._pmt_positions = np.load(pmt_positions_file).astype(np.float32) + else: + self._pmt_positions = None + + self.h5_files = self._find_files() + assert len(self.h5_files) > 0, ( + f"No WC sensor files found in {data_root}/{split}") + + self._initted = False + self._h5data = [] + self._n_sensors = None + self._build_index() + + def _find_files(self): + for pattern in [ + os.path.join(self.data_root, self.split, f'{self.dataset_name}_sensor_*.h5'), + os.path.join(self.data_root, f'{self.dataset_name}_sensor_*.h5'), + os.path.join(self.data_root, self.split, 'sensor_events_*.h5'), + os.path.join(self.data_root, 'sensor_events_*.h5'), + os.path.join(self.data_root, self.split, '*sensor*.h5'), + os.path.join(self.data_root, '*sensor*.h5'), + ]: + files = sorted(glob.glob(pattern)) + if files: + return files + return [] + + def _build_index(self): + log = get_root_logger() + self.cumulative_lengths = [] + self.indices = [] + + for h5_path in self.h5_files: + try: + with h5py.File(h5_path, 'r', libver='latest', swmr=True) as f: + n_events = int(f.attrs.get('n_events', 0)) + if n_events == 0 and 'event_number' in f: + n_events = f['event_number'].shape[0] + self._n_sensors = int(f.attrs.get('n_sensors', 0)) + index = np.arange(n_events, dtype=np.int64) + except Exception as e: + log.warning(f"Error processing {h5_path}: {e}") + index = np.array([], dtype=np.int64) + + self.cumulative_lengths.append(len(index)) + self.indices.append(index) + + self.cumulative_lengths = np.cumsum(self.cumulative_lengths) + log.info(f"WCSensorReader: {self.cumulative_lengths[-1]} events, " + f"{self._n_sensors} sensors from {len(self.h5_files)} files") + + def h5py_worker_init(self): + self._h5data = [ + h5py.File(p, 'r', libver='latest', swmr=True) + for p in self.h5_files + ] + # Try to load PMT positions from file config if not provided + if self._pmt_positions is None: + f = self._h5data[0] + if 'config' in f and 'pmt_positions' in f['config']: + self._pmt_positions = f['config']['pmt_positions'][:].astype(np.float32) + self._initted = True + + def _locate_event(self, idx): + file_idx = int(np.searchsorted(self.cumulative_lengths, idx, side='right')) + local_idx = idx - (int(self.cumulative_lengths[file_idx - 1]) if file_idx > 0 else 0) + event_num = self.indices[file_idx][local_idx] + return self._h5data[file_idx], event_num + + def read_event(self, idx): + if not self._initted: + self.h5py_worker_init() + + f, event_num = self._locate_event(idx) + data_dict = {} + + # --- Event-level sensor hits --- + hit_offsets = f['event_hit_offsets'] + h0 = int(hit_offsets[event_num]) + h1 = int(hit_offsets[event_num + 1]) + + sensor_idx = f['event_hit_sensor_idx'][h0:h1].astype(np.int32) + pe = f['event_hit_PE'][h0:h1].astype(np.float32) + t = f['event_hit_T'][h0:h1].astype(np.float32) + + # Build dense per-sensor arrays (sum PE, min T for sensors with hits) + n_sensors = self._n_sensors + pmt_pe = np.zeros(n_sensors, dtype=np.float32) + pmt_t = np.full(n_sensors, -1.0, dtype=np.float32) + np.add.at(pmt_pe, sensor_idx, pe) + # First-hit time: use scatter-min + for i in range(len(sensor_idx)): + s = sensor_idx[i] + if pmt_t[s] < 0 or t[i] < pmt_t[s]: + pmt_t[s] = t[i] + + data_dict['pmt_pe'] = pmt_pe + data_dict['pmt_t'] = pmt_t + + if self._pmt_positions is not None: + data_dict['pmt_coord'] = self._pmt_positions.copy() + + # --- Per-particle hits (sparse) --- + if self.include_labels and 'particle_hit_offsets' in f: + # Which particles belong to this event + p_offsets = f['particle_event_offset'] + p0 = int(p_offsets[event_num]) + p1 = int(p_offsets[event_num + 1]) + n_particles = p1 - p0 + + if n_particles > 0: + categories = f['particle_category'][p0:p1].astype(np.int32) + + # Per-particle hit ranges + pp_hit_offsets = f['particle_hit_offsets'] + pp_h0 = int(pp_hit_offsets[p0]) + pp_h1 = int(pp_hit_offsets[p1]) + + pp_sensor = f['particle_hit_sensor_idx'][pp_h0:pp_h1].astype(np.int32) + pp_pe = f['particle_hit_PE'][pp_h0:pp_h1].astype(np.float32) + pp_t = f['particle_hit_T'][pp_h0:pp_h1].astype(np.float32) + + # Build particle_idx for each hit + hits_per_particle = np.diff(pp_hit_offsets[p0:p1 + 1]).astype(np.int32) + pp_particle_idx = np.repeat(np.arange(n_particles, dtype=np.int32), + hits_per_particle) + pp_category = np.repeat(categories, hits_per_particle) + + # Optional threshold filter + if self.pe_threshold > 0: + mask = pp_pe > self.pe_threshold + pp_sensor = pp_sensor[mask] + pp_pe = pp_pe[mask] + pp_t = pp_t[mask] + pp_particle_idx = pp_particle_idx[mask] + pp_category = pp_category[mask] + + data_dict['pp_sensor_idx'] = pp_sensor + data_dict['pp_particle_idx'] = pp_particle_idx + data_dict['pp_pe'] = pp_pe + data_dict['pp_t'] = pp_t + data_dict['pp_category'] = pp_category + + return data_dict + + def __len__(self): + return int(self.cumulative_lengths[-1]) if len(self.cumulative_lengths) > 0 else 0 + + def close(self): + if self._initted: + for fh in self._h5data: + try: + fh.close() + except Exception: + pass + self._h5data = [] + self._initted = False diff --git a/pimm/datasets/transform.py b/pimm/datasets/transform.py index 7b5ebcf..3e63f42 100644 --- a/pimm/datasets/transform.py +++ b/pimm/datasets/transform.py @@ -53,6 +53,21 @@ def index_operator(data_dict, index, duplicate=False): "instance_particle", "instance_interaction", "momentum", + # LArTPCDataset keys (JAXTPC) + "track_ids", + "group_ids", + "pdg", + "volume_id", + "interaction_ids", + "ancestor_track_ids", + "charge", + "photons", + "qs_fractions", + "t0_us", + "dx", + "theta", + "phi", + "segment_interaction", ] if not duplicate: for key in data_dict["index_valid_keys"]: diff --git a/pimm/datasets/utils.py b/pimm/datasets/utils.py index b717eb4..3f7ea45 100644 --- a/pimm/datasets/utils.py +++ b/pimm/datasets/utils.py @@ -43,6 +43,7 @@ def collate_fn(batch, mix_prob=0): ) ) for key in batch[0] + if not key.startswith("_") # skip non-tensor metadata } return batch else: diff --git a/pimm/datasets/wc_dataset.py b/pimm/datasets/wc_dataset.py new file mode 100644 index 0000000..bd258ef --- /dev/null +++ b/pimm/datasets/wc_dataset.py @@ -0,0 +1,245 @@ +""" +WCDataset — dataset for Water Cherenkov detector simulation output. + +Loads PMT sensor data and/or 3D track segments from co-indexed HDF5 files. +Produces flat dicts compatible with pimm's transform/collation pipeline. + +Example configs: + + # PMT event classification (sensor response as fixed-geometry point cloud) + data = dict(train=dict(type="WCDataset", data_root="dataset_wc", + modalities=("sensor",), dataset_name="wc", ...)) + + # Per-sensor instance separation (sparse per-particle entries) + data = dict(train=dict(type="WCDataset", data_root="dataset_wc", + modalities=("sensor",), include_labels=True, ...)) + + # 3D track reconstruction + data = dict(train=dict(type="WCDataset", data_root="dataset_wc", + modalities=("seg",), ...)) +""" + +import os +import numpy as np +from copy import deepcopy + +from pimm.utils.logger import get_root_logger +from .builder import DATASETS +from .transform import Compose, TRANSFORMS +from .hepdataset import HEPDataset +from .readers.wc_seg_reader import WCSegReader +from .readers.wc_sensor_reader import WCSensorReader + + +@DATASETS.register_module() +class WCDataset(HEPDataset): + """Water Cherenkov detector dataset. + + Parameters + ---------- + data_root : str + Root directory with seg/ and/or sensor/ subdirectories. + split : str + Split name for file discovery. + transform : list[dict] + Transform pipeline config. + modalities : tuple[str] + Which to load: 'seg', 'sensor'. + dataset_name : str + File prefix (e.g., 'wc' for 'wc_seg_0000.h5'). + output_mode : str + How to format sensor data for the model: + - 'response': PMT point cloud with total PE/T features + - 'labels': sparse per-particle entries with instance/semantic labels + - 'separate': keep raw reader keys (pmt_coord, pmt_pe, pp_* keys) + include_labels : bool + Whether sensor reader loads per-particle decomposition. + pe_threshold : float + Minimum PE for sparsifying PE_per_particle. + min_segments : int + Minimum segments per event (seg reader filter). + max_len : int + Cap on dataset length. + loop : int + Dataset repetition per epoch. + """ + + def __init__( + self, + data_root, + split='', + transform=None, + modalities=('sensor',), + dataset_name='wc', + output_mode='response', + include_labels=True, + pe_threshold=0.0, + min_segments=0, + max_len=-1, + loop=1, + test_mode=False, + test_cfg=None, + ignore_index=-1, + ): + super().__init__() + self.data_root = data_root + self.split = split + self.modalities = tuple(modalities) + self.dataset_name = dataset_name + self.output_mode = output_mode + self.max_len = max_len + self.loop = loop if not test_mode else 1 + self.ignore_index = ignore_index + self.test_mode = test_mode + self.test_cfg = test_cfg if test_mode else None + + self.transform = Compose(transform) + if test_mode and test_cfg is not None: + self.test_voxelize = TRANSFORMS.build(self.test_cfg.voxelize) + self.test_crop = ( + TRANSFORMS.build(self.test_cfg.crop) + if self.test_cfg.crop else None) + self.post_transform = Compose(self.test_cfg.post_transform) + self.aug_transform = [ + Compose(aug) for aug in self.test_cfg.aug_transform] + + # Build readers + self.seg_reader = None + self.sensor_reader = None + + if 'seg' in self.modalities: + seg_root = self._modality_root('seg') + self.seg_reader = WCSegReader( + data_root=seg_root, split=split, + dataset_name=dataset_name, min_segments=min_segments) + + if 'sensor' in self.modalities: + sensor_root = self._modality_root('sensor') + self.sensor_reader = WCSensorReader( + data_root=sensor_root, split=split, + dataset_name=dataset_name, + include_labels=include_labels, + pe_threshold=pe_threshold) + + # Canonical reader and length + active_readers = [r for r in (self.seg_reader, self.sensor_reader) + if r is not None] + if not active_readers: + raise ValueError(f"Need 'seg' or 'sensor' in modalities, got {self.modalities}") + self._canonical_reader = active_readers[0] + self._n_events = min(len(r) for r in active_readers) + + logger = get_root_logger() + logger.info(f"WCDataset: {self._n_events} events, " + f"modalities={self.modalities}, output_mode={output_mode}") + + def _modality_root(self, modality): + mod_dir = os.path.join(self.data_root, modality) + if os.path.isdir(mod_dir): + return mod_dir + return self.data_root + + def get_data(self, idx): + data_dict = {} + + # --- Seg (3D track segments) --- + if self.seg_reader is not None: + seg_data = self.seg_reader.read_event(idx) + if self.sensor_reader is not None and self.output_mode == 'separate': + # Prefix 3D keys to avoid collision with sensor coord + for k, v in seg_data.items(): + data_dict[f'seg3d.{k}'] = v + else: + data_dict.update(seg_data) + + # --- Sensor (PMT response + optional per-particle labels) --- + if self.sensor_reader is not None: + sensor_data = self.sensor_reader.read_event(idx) + + if self.output_mode == 'response': + # PMT response — one entry per sensor + n = len(sensor_data['pmt_pe']) + if 'pmt_coord' in sensor_data: + data_dict['coord'] = sensor_data['pmt_coord'] + else: + # No 3D positions — use sensor index as 1D coord + data_dict['coord'] = np.arange(n, dtype=np.float32)[:, None] + data_dict['energy'] = sensor_data['pmt_pe'][:, None] + data_dict['time'] = sensor_data['pmt_t'][:, None] + + elif self.output_mode == 'labels': + # Sparse per-particle entries + if 'pp_sensor_idx' in sensor_data: + sidx = sensor_data['pp_sensor_idx'] + if 'pmt_coord' in sensor_data: + data_dict['coord'] = sensor_data['pmt_coord'][sidx] + else: + data_dict['coord'] = sidx.astype(np.float32)[:, None] + data_dict['energy'] = sensor_data['pp_pe'][:, None] + data_dict['segment'] = sensor_data['pp_category'] + data_dict['instance'] = sensor_data['pp_particle_idx'] + if 'pp_t' in sensor_data: + data_dict['time'] = sensor_data['pp_t'][:, None] + + elif self.output_mode == 'separate': + data_dict.update(sensor_data) + + # Metadata + data_dict['name'] = self.get_data_name(idx) + data_dict['split'] = self.split if isinstance(self.split, str) else 'custom' + return data_dict + + def get_data_name(self, idx): + reader = self._canonical_reader + file_idx = int(np.searchsorted(reader.cumulative_lengths, idx, side='right')) + local = idx - (int(reader.cumulative_lengths[file_idx - 1]) if file_idx > 0 else 0) + event_num = reader.indices[file_idx][local] + fname = os.path.basename(reader.h5_files[file_idx]) + return f"{fname}_evt{event_num:03d}" + + def prepare_train_data(self, idx): + return self.transform(self.get_data(idx % len(self))) + + def prepare_test_data(self, idx): + data_dict = self.get_data(idx % len(self)) + if self.transform is not None: + data_dict = self.transform(data_dict) + result_dict = dict(name=data_dict.pop("name")) + if "segment" in data_dict: + result_dict["segment"] = data_dict.pop("segment") + data_dict_list = [] + for aug in self.aug_transform: + data_dict_list.append(aug(deepcopy(data_dict))) + fragment_list = [] + for data in data_dict_list: + if self.test_voxelize is not None: + data_part_list = self.test_voxelize(data) + else: + data_part_list = [data] + for data_part in data_part_list: + if self.test_crop is not None: + data_part = self.test_crop(data_part) + else: + data_part = [data_part] + fragment_list += data_part + for i in range(len(fragment_list)): + fragment_list[i] = self.post_transform(fragment_list[i]) + result_dict["fragment_list"] = fragment_list + return result_dict + + def __getitem__(self, idx): + real_idx = idx % len(self) + if self.test_mode: + return self.prepare_test_data(real_idx) + return self.prepare_train_data(real_idx) + + def __len__(self): + n = self._n_events + if self.max_len > 0: + n = min(n, self.max_len) + return n * self.loop + + def __del__(self): + for reader in (self.seg_reader, self.sensor_reader): + if reader is not None: + reader.close() diff --git a/tools/test_detector_dataset.py b/tools/test_detector_dataset.py new file mode 100644 index 0000000..45be779 --- /dev/null +++ b/tools/test_detector_dataset.py @@ -0,0 +1,231 @@ +""" +Verification script for LArTPCDataset — all modality combinations. + +Run: /usr/bin/python3 tools/test_detector_dataset.py +""" + +import sys +import os +sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) + +import numpy as np +import torch +import torch.nn as nn +from pimm.datasets.lartpc_dataset import LArTPCDataset +from pimm.datasets.utils import collate_fn +from pimm.datasets.transform import Compose + +DATA_ROOT = os.environ.get( + 'JAXTPC_DATA_ROOT', + '/home/oalterka/desktop_linux/JAXTPC/dataset_1') + +MAX_LEN = 4 +PASSED = 0 +FAILED = 0 + + +def check(condition, msg): + global PASSED, FAILED + if condition: + print(f" OK: {msg}") + PASSED += 1 + else: + print(f" FAIL: {msg}") + FAILED += 1 + + +def make_ds(**kwargs): + defaults = dict(data_root=DATA_ROOT, split='', dataset_name='sim', max_len=MAX_LEN) + defaults.update(kwargs) + return LArTPCDataset(**defaults) + + +def test_seg_only(): + """seg only — 3D point cloud, no labels.""" + print("\n=== seg only ===") + ds = make_ds(modalities=('seg',)) + d = ds.get_data(0) + check(d['coord'].shape[1] == 3, f"coord 3D: {d['coord'].shape}") + check(d['energy'].shape[1] == 1, f"energy: {d['energy'].shape}") + check('segment' not in d, "no segment without labl") + + +def test_seg_labl(): + """seg + labl — 3D with labels from lookup.""" + print("\n=== seg + labl ===") + ds = make_ds(modalities=('seg', 'labl'), label_key='particle') + d = ds.get_data(0) + check(d['coord'].shape[1] == 3, f"coord 3D: {d['coord'].shape}") + check('segment' in d, "segment present") + check(d['segment'].shape[0] == d['coord'].shape[0], "segment matches coord") + + +def test_resp_only(): + """resp only — all planes merged into 2D point cloud, no labels.""" + print("\n=== resp only ===") + ds = make_ds(modalities=('resp',)) + d = ds.get_data(0) + check(d['coord'].shape[1] == 2, f"coord 2D: {d['coord'].shape}") + check('plane_id' in d, "plane_id present") + check('segment' not in d, "no segment") + n_planes = len(np.unique(d['plane_id'])) + check(n_planes > 1, f"multiple planes: {n_planes}") + + +def test_resp_corr_labl(): + """resp + corr + labl — 2D labeled point cloud from corr chain.""" + print("\n=== resp + corr + labl ===") + ds = make_ds(modalities=('resp', 'corr', 'labl'), label_key='particle') + d = ds.get_data(0) + check(d['coord'].shape[1] == 2, f"coord 2D: {d['coord'].shape}") + check('segment' in d, "segment present") + check('instance' in d, "instance present") + check('plane_id' in d, "plane_id present") + # Resp signal also available as namespaced keys + resp_keys = [k for k in d if k.startswith('plane.')] + check(len(resp_keys) > 0, f"resp namespaced keys: {len(resp_keys)}") + # Overlapping instances + _, counts = np.unique(d['coord'], axis=0, return_counts=True) + check(np.sum(counts > 1) > 0, f"overlapping pixels: {np.sum(counts > 1)}") + + +def test_seg_resp_corr_labl(): + """All modalities — seg owns coord, resp/corr as separate point clouds.""" + print("\n=== seg + resp + corr + labl ===") + ds = make_ds(modalities=('seg', 'resp', 'corr', 'labl'), label_key='particle') + d = ds.get_data(0) + check(d['coord'].shape[1] == 3, f"3D coord: {d['coord'].shape}") + check('segment' in d, "3D segment from labl") + # Resp as separate point cloud + check('resp_coord' in d, f"resp_coord present: {d.get('resp_coord', 'MISSING')}") + check(d['resp_coord'].shape[1] == 2, f"resp_coord 2D: {d['resp_coord'].shape}") + # Corr as separate point cloud + check('corr_coord' in d, "corr_coord present") + check('corr_segment' in d, "corr_segment present") + check('corr_instance' in d, "corr_instance present") + # Raw plane keys also available + plane_keys = [k for k in d if k.startswith('plane.')] + check(len(plane_keys) > 0, f"raw plane keys: {len(plane_keys)}") + + +def test_resp_corr(): + """resp + corr (no labl) — resp merged, corr namespaced.""" + print("\n=== resp + corr (no labl) ===") + ds = make_ds(modalities=('resp', 'corr')) + d = ds.get_data(0) + check(d['coord'].shape[1] == 2, f"coord 2D from resp: {d['coord'].shape}") + check('segment' not in d, "no segment without labl") + corr_keys = [k for k in d if k.startswith('corr.')] + check(len(corr_keys) > 0, f"corr namespaced: {len(corr_keys)}") + + +def test_volume_filter(): + """volume=0 — only volume 0 data (fewer points than all volumes).""" + print("\n=== volume filter ===") + ds_all = make_ds(modalities=('resp',)) + ds_v0 = make_ds(modalities=('resp',), volume=0) + d_all = ds_all.get_data(0) + d_v0 = ds_v0.get_data(0) + check(d_v0['coord'].shape[0] < d_all['coord'].shape[0], + f"volume_0 ({d_v0['coord'].shape[0]}) < all ({d_all['coord'].shape[0]})") + + +def test_different_label_keys(): + """All label_key options.""" + print("\n=== different label_keys ===") + for lk in ['particle', 'cluster', 'interaction']: + ds = make_ds(modalities=('seg', 'labl'), label_key=lk) + d = ds.get_data(0) + n = len(np.unique(d['segment'])) + check(n > 1, f"label_key={lk}: {n} classes") + + +def test_pipeline_3d(): + """Full 3D pipeline: transforms → collate.""" + print("\n=== 3D pipeline ===") + transform = [ + dict(type='NormalizeCoord', center=[0, 0, 0], scale=4000.0), + dict(type='GridSample', grid_size=0.001, hash_type='fnv', + mode='train', return_grid_coord=True), + dict(type='ToTensor'), + dict(type='Collect', keys=('coord', 'grid_coord', 'segment'), + feat_keys=('coord', 'energy')), + ] + ds = make_ds(modalities=('seg', 'labl'), label_key='particle', + min_deposits=1024, transform=transform) + batch = collate_fn([ds[0], ds[1]]) + check(batch['coord'].shape[1] == 3, f"3D: {batch['coord'].shape}") + check(len(batch['offset']) == 2, "offset correct") + + +def test_pipeline_2d(): + """Full 2D pipeline: transforms → collate → DataLoader.""" + print("\n=== 2D pipeline ===") + transform = [ + dict(type='GridSample', grid_size=1.0, hash_type='fnv', + mode='train', return_grid_coord=True), + dict(type='ToTensor'), + dict(type='Collect', keys=('coord', 'grid_coord', 'segment', 'instance'), + feat_keys=('coord', 'energy')), + ] + ds = make_ds(modalities=('resp', 'corr', 'labl'), + label_key='particle', transform=transform) + batch = collate_fn([ds[0], ds[1]]) + check(batch['coord'].shape[1] == 2, f"2D: {batch['coord'].shape}") + check(len(batch['offset']) == 2, "offset correct") + + # DataLoader + loader = torch.utils.data.DataLoader( + ds, batch_size=2, shuffle=False, num_workers=2, + collate_fn=collate_fn, persistent_workers=False) + for i, b in enumerate(loader): + if i >= 1: + break + check(b['coord'].shape[1] == 2, f"DataLoader: {b['coord'].shape}") + + +def test_toy_model(): + """Toy model forward+backward.""" + print("\n=== toy model ===") + transform = [ + dict(type='GridSample', grid_size=1.0, hash_type='fnv', + mode='train', return_grid_coord=True), + dict(type='ToTensor'), + dict(type='Collect', keys=('coord', 'grid_coord', 'segment'), + feat_keys=('coord', 'energy')), + ] + ds = make_ds(modalities=('resp', 'corr', 'labl'), + label_key='particle', transform=transform) + device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') + batch = collate_fn([ds[0], ds[1]]) + for k in batch: + if isinstance(batch[k], torch.Tensor): + batch[k] = batch[k].to(device) + model = nn.Linear(batch['feat'].shape[1], 5).to(device) + logits = model(batch['feat']) + loss = nn.CrossEntropyLoss(ignore_index=-1)(logits, batch['segment'].long()) + loss.backward() + check(logits.shape[1] == 5, f"logits: {logits.shape}") + check(model.weight.grad is not None, "gradients computed") + + +if __name__ == '__main__': + print(f"Testing LArTPCDataset\nData root: {DATA_ROOT}") + + test_seg_only() + test_seg_labl() + test_resp_only() + test_resp_corr_labl() + test_seg_resp_corr_labl() + test_resp_corr() + test_volume_filter() + test_different_label_keys() + test_pipeline_3d() + test_pipeline_2d() + test_toy_model() + + print(f"\n{'='*50}") + print(f"PASSED: {PASSED}, FAILED: {FAILED}") + if FAILED > 0: + sys.exit(1) + print("ALL TESTS PASSED") diff --git a/tools/test_wc_dataset.py b/tools/test_wc_dataset.py new file mode 100644 index 0000000..b17092c --- /dev/null +++ b/tools/test_wc_dataset.py @@ -0,0 +1,197 @@ +""" +Verification script for WCDataset — all output modes. + +Run: /usr/bin/python3 tools/test_wc_dataset.py +""" + +import sys +import os +sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) + +import numpy as np +import torch +import torch.nn as nn +from pimm.datasets.wc_dataset import WCDataset +from pimm.datasets.utils import collate_fn +from pimm.datasets.transform import Compose + +DATA_ROOT = '/home/oalterka/desktop_linux/JAXTPC/dataset_wc' +PASSED = 0 +FAILED = 0 + + +def check(condition, msg): + global PASSED, FAILED + if condition: + print(f" OK: {msg}") + PASSED += 1 + else: + print(f" FAIL: {msg}") + FAILED += 1 + + +def make_ds(**kwargs): + defaults = dict(data_root=DATA_ROOT, split='', dataset_name='wc', max_len=4) + defaults.update(kwargs) + return WCDataset(**defaults) + + +def test_sensor_response(): + """Sensor response — one entry per sensor.""" + print("\n=== Sensor response ===") + ds = make_ds(modalities=('sensor',), output_mode='response', include_labels=False) + d = ds.get_data(0) + check('coord' in d, "coord present") + check('energy' in d, f"energy: {d['energy'].shape}") + check('time' in d, "time present") + check('segment' not in d, "no segment (SSL)") + n = d['coord'].shape[0] + check(n > 1000, f"many sensors: {n}") + + +def test_sensor_labels(): + """Sensor with labels — sparse per-particle entries.""" + print("\n=== Sensor labels ===") + ds = make_ds(modalities=('sensor',), output_mode='labels', include_labels=True) + d = ds.get_data(0) + check('coord' in d, "coord present") + check('segment' in d, "segment present") + check('instance' in d, "instance present") + n = d['coord'].shape[0] + check(n > 0, f"sparse entries: {n}") + n_inst = len(np.unique(d['instance'])) + check(n_inst > 1, f"multiple instances: {n_inst}") + cats = np.unique(d['segment']) + check(len(cats) >= 1, f"categories: {cats}") + + +def test_sensor_separate(): + """Sensor separate — raw reader keys.""" + print("\n=== Sensor separate ===") + ds = make_ds(modalities=('sensor',), output_mode='separate') + d = ds.get_data(0) + check('pmt_pe' in d, "pmt_pe present") + check('pmt_t' in d, "pmt_t present") + check('pp_sensor_idx' in d, "pp_sensor_idx present") + check('pp_category' in d, "pp_category present") + check('coord' not in d, "no top-level coord") + + +def test_seg_only(): + """3D track segments.""" + print("\n=== Seg only ===") + ds = make_ds(modalities=('seg',)) + d = ds.get_data(0) + check(d['coord'].shape[1] == 3, f"coord 3D: {d['coord'].shape}") + check(d['energy'].shape[1] == 1, f"energy: {d['energy'].shape}") + check('track_ids' in d, "track_ids present") + check('pdg' in d, "pdg present") + + +def test_mixed_separate(): + """Seg + sensor separate.""" + print("\n=== Mixed separate ===") + ds = make_ds(modalities=('seg', 'sensor'), output_mode='separate') + d = ds.get_data(0) + seg_keys = [k for k in d if k.startswith('seg3d.')] + check(len(seg_keys) > 0, f"seg3d keys: {len(seg_keys)}") + check('pmt_pe' in d, "pmt_pe present") + + +def test_pipeline_response(): + """Pipeline: response → transforms → collate.""" + print("\n=== Pipeline response ===") + transform = [ + dict(type='ToTensor'), + dict(type='Collect', keys=('coord',), feat_keys=('coord', 'energy', 'time')), + ] + ds = make_ds(modalities=('sensor',), output_mode='response', + include_labels=False, transform=transform) + s0 = ds[0] + coord_dim = s0['coord'].shape[1] + feat_dim = s0['feat'].shape[1] + check(feat_dim == coord_dim + 2, f"feat={s0['feat'].shape} (coord_dim+2)") + check('offset' in s0, "offset present") + + batch = collate_fn([ds[0], ds[1]]) + n0 = ds.get_data(0)['coord'].shape[0] + check(batch['coord'].shape[0] > n0, f"batch: {batch['coord'].shape}") + check(len(batch['offset']) == 2, f"offset: {batch['offset'].tolist()}") + + +def test_pipeline_labels(): + """Pipeline: labels → transforms → collate → toy model.""" + print("\n=== Pipeline labels ===") + transform = [ + dict(type='ToTensor'), + dict(type='Collect', keys=('coord', 'segment', 'instance'), + feat_keys=('coord', 'energy')), + ] + ds = make_ds(modalities=('sensor',), output_mode='labels', + include_labels=True, transform=transform) + + batch = collate_fn([ds[0], ds[1]]) + check('segment' in batch, "segment in batch") + check(len(batch['offset']) == 2, "offset correct") + + device = torch.device('cpu') + for k in batch: + if isinstance(batch[k], torch.Tensor): + batch[k] = batch[k].to(device) + n_classes = len(torch.unique(batch['segment'])) + model = nn.Linear(batch['feat'].shape[1], max(4, n_classes)).to(device) + logits = model(batch['feat']) + loss = nn.CrossEntropyLoss(ignore_index=-1)(logits, batch['segment'].long()) + loss.backward() + check(model.weight.grad is not None, "gradients computed") + + +def test_pipeline_seg(): + """Pipeline: 3D segments → transforms → collate.""" + print("\n=== Pipeline seg ===") + transform = [ + dict(type='ToTensor'), + dict(type='Collect', keys=('coord',), feat_keys=('coord', 'energy')), + ] + ds = make_ds(modalities=('seg',), transform=transform) + batch = collate_fn([ds[0], ds[1]]) + check(batch['coord'].shape[1] == 3, f"3D: {batch['coord'].shape}") + check(len(batch['offset']) == 2, "offset correct") + + +def test_dataloader(): + """DataLoader with workers.""" + print("\n=== DataLoader ===") + transform = [ + dict(type='ToTensor'), + dict(type='Collect', keys=('coord',), feat_keys=('coord', 'energy')), + ] + ds = make_ds(modalities=('sensor',), output_mode='response', + include_labels=False, transform=transform) + loader = torch.utils.data.DataLoader( + ds, batch_size=2, shuffle=False, num_workers=2, + collate_fn=collate_fn, persistent_workers=False) + for i, batch in enumerate(loader): + if i >= 1: + break + check(batch['coord'].shape[0] > 0, f"DataLoader batch: {batch['coord'].shape}") + + +if __name__ == '__main__': + print(f"Testing WCDataset\nData root: {DATA_ROOT}") + + test_sensor_response() + test_sensor_labels() + test_sensor_separate() + test_seg_only() + test_mixed_separate() + test_pipeline_response() + test_pipeline_labels() + test_pipeline_seg() + test_dataloader() + + print(f"\n{'='*50}") + print(f"PASSED: {PASSED}, FAILED: {FAILED}") + if FAILED > 0: + sys.exit(1) + print("ALL TESTS PASSED") From b1d5e0c49e9e9bbf5919a3c2f4975651dbb5ec29 Mon Sep 17 00:00:00 2001 From: Sam Young Date: Tue, 14 Apr 2026 10:37:49 -0700 Subject: [PATCH 2/4] Delete pimm/datasets/readers/.gitignore --- pimm/datasets/readers/.gitignore | 1 - 1 file changed, 1 deletion(-) delete mode 100644 pimm/datasets/readers/.gitignore diff --git a/pimm/datasets/readers/.gitignore b/pimm/datasets/readers/.gitignore deleted file mode 100644 index c18dd8d..0000000 --- a/pimm/datasets/readers/.gitignore +++ /dev/null @@ -1 +0,0 @@ -__pycache__/ From 995126f44246bcaeeb467ee0d38530d2cd70d45b Mon Sep 17 00:00:00 2001 From: Omar Alterkait Date: Fri, 17 Apr 2026 22:46:49 +0900 Subject: [PATCH 3/4] Remove HEPDataset base class; dataset classes inherit torch Dataset directly MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit HEPDataset was a 33-line fake-abstract class that added nothing: it claimed a dict-with-coord/energy contract that half the subclasses already violated, and inherited torch.utils.data.Dataset purely as a type decoration. The three dataset classes share ~15 LOC of trivial plumbing (__init__ storing transform/loop/max_len, __len__, __getitem__ dispatch). Extracting that into a base class costs more in abstraction overhead than it saves in duplication. The test-mode fragment_list logic is LArTPC-specific (WC's PMT arrays don't slide, PILArNet built but never used the result), so forcing it into a base imposed a contract that only one subclass actually uses. Changes: - Delete pimm/datasets/hepdataset.py - LArTPCDataset, WCDataset, PILArNetH5Dataset now inherit torch.utils.data.Dataset - Remove ignore_index kwarg from LArTPC/WC (dead code — never used by dataset, belongs in loss config). PILArNet unchanged (pre-existing code). - Each dataset class is now self-contained and readable top-to-bottom: open the file, see the whole data flow without jumping to a parent class. Real code reuse lives where it's justified: readers, transforms, utility functions. Dataset classes are ~300-450 LOC wrappers that orchestrate readers and apply transforms. No forced abstractions. 70 tests pass (38 LArTPC + 32 WC). --- docs/DETECTOR_DATASET.md | 41 +++++++++++++++++++++++++-------- pimm/datasets/hepdataset.py | 33 -------------------------- pimm/datasets/lartpc_dataset.py | 6 ++--- pimm/datasets/pilarnet.py | 3 +-- pimm/datasets/wc_dataset.py | 6 ++--- 5 files changed, 36 insertions(+), 53 deletions(-) delete mode 100644 pimm/datasets/hepdataset.py diff --git a/docs/DETECTOR_DATASET.md b/docs/DETECTOR_DATASET.md index 8b9545f..df87d7e 100644 --- a/docs/DETECTOR_DATASET.md +++ b/docs/DETECTOR_DATASET.md @@ -73,19 +73,30 @@ GridSample, ToTensor, Copy, Collect, RandomDropout, ShufflePoint, RandomJitter, For Water Cherenkov detectors (PMT-based). ### Data Layout + +Two HDF5 files per dataset; readers accept both naming conventions: + ``` dataset_root/ -├── seg/ wc_seg_0000.h5 — 3D track segments -└── sensor/ wc_sensor_0000.h5 — PMT response + per-particle labels +├── seg/ {dataset_name}_seg_NNNN.h5 or segment_events_NNNN.h5 +└── sensor/ {dataset_name}_sensor_NNNN.h5 or sensor_events_NNNN.h5 ``` +Format is flat CSR arrays (events indexed via `*_offsets` datasets), +matching the PhotonSim/LUCiD production output. + ### Task → Config +`coord` shape depends on whether PMT 3D positions are provided +(via `pmt_positions` / `pmt_positions_file` on WCSensorReader, or stored +in the file's `config/pmt_positions` dataset). Without positions, +`coord` falls back to `(N, 1)` with sensor indices. + | Task | `modalities` | `output_mode` | Output | |------|-------------|--------------|--------| -| Event classification | `('sensor',)` | `'response'` | `coord (N_pmt,3)`, `energy (N_pmt,1)`, `time (N_pmt,1)` | -| Per-sensor instance separation | `('sensor',)` | `'labels'` | `coord (E,3)`, `segment (E,)`, `instance (E,)` | -| 3D track reconstruction | `('seg',)` | any | `coord (N_seg,3)`, `energy (N_seg,1)`, `track_ids`, `pdg` | +| Event classification | `('sensor',)` | `'response'` | `coord (N_pmt, 3\|1)`, `energy (N_pmt,1)` [PE], `time (N_pmt,1)` [T] | +| Per-sensor instance separation | `('sensor',)` | `'labels'` | `coord (E, 3\|1)`, `energy (E,1)`, `segment (E,)`, `instance (E,)` | +| 3D track reconstruction | `('seg',)` | any | `coord (N_seg,3)`, `energy (N_seg,1)`, `time`, `track_ids`, `pdg`, `parent_ids` | | Joint 3D + sensor | `('seg', 'sensor')` | `'separate'` | `seg3d.*` + `pmt_*` + `pp_*` keys | ### Config Parameters @@ -97,6 +108,7 @@ data = dict(train=dict( modalities=("sensor",), output_mode="response", # 'response', 'labels', 'separate' include_labels=True, + pe_threshold=0.0, # optional: filter per-particle entries below this PE transform=[...], )) ``` @@ -105,15 +117,24 @@ data = dict(train=dict( ## Adding a New Detector -1. Write reader(s) in `pimm/datasets/readers/` — implement `__init__`, `_find_files`, `_build_index`, `h5py_worker_init`, `read_event(idx) → dict`, `__len__`, `close`. -2. Write a dataset class in `pimm/datasets/` — inherit `HEPDataset`, register via `@DATASETS.register_module()`. -3. The dataset's `get_data()` calls readers and maps output to `coord`/`energy`/`segment`. -4. Add imports in `pimm/datasets/__init__.py` and `pimm/datasets/readers/__init__.py`. +Each dataset class is self-contained — no base class. Copy the closest existing +dataset as a template and modify. + +1. **Write reader(s)** in `pimm/datasets/readers/`. Readers follow a lightweight + convention (not a forced ABC): `__init__` discovers files and builds an event + index; `h5py_worker_init` lazily opens HDF5 handles (DataLoader-fork safe); + `read_event(idx)` returns a `dict[str, np.ndarray]`; `__len__` returns the + event count; `close` releases handles. Copy an existing reader to start. +2. **Write a dataset class** in `pimm/datasets/`, inheriting + `torch.utils.data.Dataset` directly, registered via `@DATASETS.register_module()`. + Orchestrate readers in `get_data`; define `__init__`, `__len__`, `__getitem__`. +3. **Add imports** in `pimm/datasets/__init__.py` and + `pimm/datasets/readers/__init__.py`. No changes needed to transforms, collation, models, or training infrastructure. ## Running Tests ```bash /usr/bin/python3 tools/test_detector_dataset.py # LArTPC (38 tests) -/usr/bin/python3 tools/test_wc_dataset.py # Water Cherenkov (37 tests) +/usr/bin/python3 tools/test_wc_dataset.py # Water Cherenkov (32 tests) ``` diff --git a/pimm/datasets/hepdataset.py b/pimm/datasets/hepdataset.py deleted file mode 100644 index a0ec1d0..0000000 --- a/pimm/datasets/hepdataset.py +++ /dev/null @@ -1,33 +0,0 @@ -""" -Base HEP Dataset - -Minimal abstract class showing the required interface for HEP point cloud datasets. -""" - -from abc import ABC, abstractmethod -from torch.utils.data import Dataset - - -class HEPDataset(Dataset, ABC): - """ - Minimal interface for HEP point cloud datasets. - - Subclasses must implement: - - __getitem__(idx) -> dict with at least 'coord' (N,3) and 'energy' (N,1) - - __len__() -> int - """ - - @abstractmethod - def __len__(self) -> int: - raise NotImplementedError - - def __getitem__(self, idx: int) -> dict: - """ - Load a point cloud. - - Returns: - dict with at least: - - coord: (N, 3) float32 array of xyz coordinates - - energy: (N, 1) float32 array of energy deposits - """ - raise NotImplementedError \ No newline at end of file diff --git a/pimm/datasets/lartpc_dataset.py b/pimm/datasets/lartpc_dataset.py index 84a9083..f6b94a5 100644 --- a/pimm/datasets/lartpc_dataset.py +++ b/pimm/datasets/lartpc_dataset.py @@ -29,11 +29,11 @@ import os import numpy as np from copy import deepcopy +from torch.utils.data import Dataset from pimm.utils.logger import get_root_logger from .builder import DATASETS from .transform import Compose, TRANSFORMS -from .hepdataset import HEPDataset from .readers.lartpc_seg_reader import LArTPCSegReader from .readers.lartpc_resp_reader import LArTPCRespReader from .readers.lartpc_labl_reader import LArTPCLablReader @@ -41,7 +41,7 @@ @DATASETS.register_module() -class LArTPCDataset(HEPDataset): +class LArTPCDataset(Dataset): """Multimodal dataset for LArTPC detector simulation output. Parameters @@ -92,7 +92,6 @@ def __init__( label_keys=None, test_mode=False, test_cfg=None, - ignore_index=-1, ): super().__init__() self.data_root = data_root @@ -104,7 +103,6 @@ def __init__( self.min_deposits = min_deposits self.max_len = max_len self.loop = loop if not test_mode else 1 - self.ignore_index = ignore_index self.test_mode = test_mode self.test_cfg = test_cfg if test_mode else None diff --git a/pimm/datasets/pilarnet.py b/pimm/datasets/pilarnet.py index 5f471f7..e125b4c 100644 --- a/pimm/datasets/pilarnet.py +++ b/pimm/datasets/pilarnet.py @@ -15,14 +15,13 @@ from pimm.utils.logger import get_root_logger from .builder import DATASETS from .transform import Compose, TRANSFORMS -from .hepdataset import HEPDataset # priority for voxel deduplication: track (1) > shower (0) > michel (2) > delta (3) > led (4) DEFAULT_LABEL_PRIORITY = {1: 0, 0: 1, 2: 2, 3: 3, 4: 4} @DATASETS.register_module() -class PILArNetH5Dataset(HEPDataset): +class PILArNetH5Dataset(Dataset): """ PILArNet-M Dataset that loads directly from h5 files, avoiding the need for preprocessing to individual files. diff --git a/pimm/datasets/wc_dataset.py b/pimm/datasets/wc_dataset.py index bd258ef..2be6f0b 100644 --- a/pimm/datasets/wc_dataset.py +++ b/pimm/datasets/wc_dataset.py @@ -22,17 +22,17 @@ import os import numpy as np from copy import deepcopy +from torch.utils.data import Dataset from pimm.utils.logger import get_root_logger from .builder import DATASETS from .transform import Compose, TRANSFORMS -from .hepdataset import HEPDataset from .readers.wc_seg_reader import WCSegReader from .readers.wc_sensor_reader import WCSensorReader @DATASETS.register_module() -class WCDataset(HEPDataset): +class WCDataset(Dataset): """Water Cherenkov detector dataset. Parameters @@ -79,7 +79,6 @@ def __init__( loop=1, test_mode=False, test_cfg=None, - ignore_index=-1, ): super().__init__() self.data_root = data_root @@ -89,7 +88,6 @@ def __init__( self.output_mode = output_mode self.max_len = max_len self.loop = loop if not test_mode else 1 - self.ignore_index = ignore_index self.test_mode = test_mode self.test_cfg = test_cfg if test_mode else None From 715c84cb9eebfb7474837412ab674037caf5e149 Mon Sep 17 00:00:00 2001 From: Omar Alterkait Date: Sat, 18 Apr 2026 00:20:57 +0900 Subject: [PATCH 4/4] Rename dataset/reader classes to source-named (JAXTPC, LUCiD); move tests to /tests Datasets and readers are specific to the HDF5 schemas produced by their upstream production pipelines. Naming them by source rather than generic detector type is more honest and matches the existing PILArNetH5Dataset precedent. Renames: - LArTPCDataset -> JAXTPCDataset (lartpc_dataset.py -> jaxtpc_dataset.py) - WCDataset -> LUCiDDataset (wc_dataset.py -> lucid_dataset.py) - LArTPCSegReader -> JAXTPCSegReader (lartpc_seg_reader.py -> jaxtpc_seg_reader.py) - LArTPCRespReader -> JAXTPCRespReader (similar) - LArTPCLablReader -> JAXTPCLablReader (similar) - LArTPCCorrReader -> JAXTPCCorrReader (similar) - WCSegReader -> LUCiDSegReader (wc_seg_reader.py -> lucid_seg_reader.py) - WCSensorReader -> LUCiDSensorReader (wc_sensor_reader.py -> lucid_sensor_reader.py) File names use _dataset.py / _reader.py suffix so it's clear these are pimm integration modules, not the upstream projects themselves. Tests moved from tools/ to tests/ at repo root (standard Python layout): - tools/test_detector_dataset.py -> tests/test_jaxtpc_dataset.py - tools/test_wc_dataset.py -> tests/test_lucid_dataset.py Updates to all imports, configs, docs, and transform index_valid_keys comment. All 70 tests pass (38 JAXTPC + 32 LUCiD). --- configs/detector/_base_/jaxtpc_seg.py | 4 +-- docs/DETECTOR_DATASET.md | 14 +++++----- pimm/datasets/__init__.py | 4 +-- pimm/datasets/detector_transforms.py | 4 +-- .../{lartpc_dataset.py => jaxtpc_dataset.py} | 28 +++++++++---------- .../{wc_dataset.py => lucid_dataset.py} | 20 ++++++------- pimm/datasets/readers/__init__.py | 12 ++++---- ...c_corr_reader.py => jaxtpc_corr_reader.py} | 6 ++-- ...c_labl_reader.py => jaxtpc_labl_reader.py} | 6 ++-- ...c_resp_reader.py => jaxtpc_resp_reader.py} | 6 ++-- ...tpc_seg_reader.py => jaxtpc_seg_reader.py} | 8 +++--- .../{wc_seg_reader.py => lucid_seg_reader.py} | 6 ++-- ...ensor_reader.py => lucid_sensor_reader.py} | 6 ++-- pimm/datasets/transform.py | 2 +- .../test_jaxtpc_dataset.py | 10 +++---- .../test_lucid_dataset.py | 10 +++---- 16 files changed, 73 insertions(+), 73 deletions(-) rename pimm/datasets/{lartpc_dataset.py => jaxtpc_dataset.py} (96%) rename pimm/datasets/{wc_dataset.py => lucid_dataset.py} (93%) rename pimm/datasets/readers/{lartpc_corr_reader.py => jaxtpc_corr_reader.py} (97%) rename pimm/datasets/readers/{lartpc_labl_reader.py => jaxtpc_labl_reader.py} (96%) rename pimm/datasets/readers/{lartpc_resp_reader.py => jaxtpc_resp_reader.py} (97%) rename pimm/datasets/readers/{lartpc_seg_reader.py => jaxtpc_seg_reader.py} (97%) rename pimm/datasets/readers/{wc_seg_reader.py => lucid_seg_reader.py} (97%) rename pimm/datasets/readers/{wc_sensor_reader.py => lucid_sensor_reader.py} (97%) rename tools/test_detector_dataset.py => tests/test_jaxtpc_dataset.py (96%) rename tools/test_wc_dataset.py => tests/test_lucid_dataset.py (96%) diff --git a/configs/detector/_base_/jaxtpc_seg.py b/configs/detector/_base_/jaxtpc_seg.py index f8021a5..08fa9d4 100644 --- a/configs/detector/_base_/jaxtpc_seg.py +++ b/configs/detector/_base_/jaxtpc_seg.py @@ -65,7 +65,7 @@ ignore_index=-1, names=["shower", "track", "michel", "delta", "led"], train=dict( - type="LArTPCDataset", + type="JAXTPCDataset", data_root=_data_root, split="train", dataset_name="sim", @@ -75,7 +75,7 @@ max_len=-1, ), val=dict( - type="LArTPCDataset", + type="JAXTPCDataset", data_root=_data_root, split="val", dataset_name="sim", diff --git a/docs/DETECTOR_DATASET.md b/docs/DETECTOR_DATASET.md index df87d7e..c32af39 100644 --- a/docs/DETECTOR_DATASET.md +++ b/docs/DETECTOR_DATASET.md @@ -2,7 +2,7 @@ pimm supports multiple detector types through dedicated dataset classes. Each loads from co-indexed HDF5 files and produces flat dicts that flow through pimm's standard pipeline (transforms → Collect → collate → Point → model). -## LArTPCDataset +## JAXTPCDataset For Liquid Argon TPC detectors (JAXTPC production output). @@ -45,7 +45,7 @@ Raw per-plane keys (`plane.*`, `corr.*`) are always passed through for per-plane ### Config Parameters ```python data = dict(train=dict( - type="LArTPCDataset", + type="JAXTPCDataset", data_root="/path/to/dataset", split="", dataset_name="sim", @@ -68,7 +68,7 @@ GridSample, ToTensor, Copy, Collect, RandomDropout, ShufflePoint, RandomJitter, --- -## WCDataset +## LUCiDDataset For Water Cherenkov detectors (PMT-based). @@ -88,7 +88,7 @@ matching the PhotonSim/LUCiD production output. ### Task → Config `coord` shape depends on whether PMT 3D positions are provided -(via `pmt_positions` / `pmt_positions_file` on WCSensorReader, or stored +(via `pmt_positions` / `pmt_positions_file` on LUCiDSensorReader, or stored in the file's `config/pmt_positions` dataset). Without positions, `coord` falls back to `(N, 1)` with sensor indices. @@ -102,7 +102,7 @@ in the file's `config/pmt_positions` dataset). Without positions, ### Config Parameters ```python data = dict(train=dict( - type="WCDataset", + type="LUCiDDataset", data_root="/path/to/dataset_wc", dataset_name="wc", modalities=("sensor",), @@ -135,6 +135,6 @@ No changes needed to transforms, collation, models, or training infrastructure. ## Running Tests ```bash -/usr/bin/python3 tools/test_detector_dataset.py # LArTPC (38 tests) -/usr/bin/python3 tools/test_wc_dataset.py # Water Cherenkov (32 tests) +/usr/bin/python3 tests/test_jaxtpc_dataset.py # JAXTPC / LArTPC (38 tests) +/usr/bin/python3 tests/test_lucid_dataset.py # LUCiD / Water Cherenkov (32 tests) ``` diff --git a/pimm/datasets/__init__.py b/pimm/datasets/__init__.py index 78419e3..66fea89 100644 --- a/pimm/datasets/__init__.py +++ b/pimm/datasets/__init__.py @@ -5,8 +5,8 @@ # physics from .pilarnet import PILArNetH5Dataset -from .lartpc_dataset import LArTPCDataset -from .wc_dataset import WCDataset +from .jaxtpc_dataset import JAXTPCDataset +from .lucid_dataset import LUCiDDataset from . import detector_transforms # register PDGToSemantic # dataloader from .dataloader import MultiDatasetDataloader diff --git a/pimm/datasets/detector_transforms.py b/pimm/datasets/detector_transforms.py index d321b5b..4599c19 100644 --- a/pimm/datasets/detector_transforms.py +++ b/pimm/datasets/detector_transforms.py @@ -4,7 +4,7 @@ PDGToSemantic: derives semantic labels from PDG codes in seg data (3D tasks). Used when no label file is available (fallback). For production training, labels come from the labl file via - LArTPCDataset._apply_labl_to_3d() or _build_corr_pointcloud(). + JAXTPCDataset._apply_labl_to_3d() or _build_corr_pointcloud(). """ import numpy as np @@ -16,7 +16,7 @@ class PDGToSemantic: """Fallback: derive approximate semantic labels from PDG codes. Use this only when no labl file is available. For production training, - use modalities=('seg', 'labl') which applies labels via LArTPCLablReader. + use modalities=('seg', 'labl') which applies labels via JAXTPCLablReader. Schemes ------- diff --git a/pimm/datasets/lartpc_dataset.py b/pimm/datasets/jaxtpc_dataset.py similarity index 96% rename from pimm/datasets/lartpc_dataset.py rename to pimm/datasets/jaxtpc_dataset.py index f6b94a5..0473e33 100644 --- a/pimm/datasets/lartpc_dataset.py +++ b/pimm/datasets/jaxtpc_dataset.py @@ -1,5 +1,5 @@ """ -LArTPCDataset — multimodal dataset for LArTPC detector simulation output. +JAXTPCDataset — multimodal dataset for LArTPC detector simulation output. Loads from co-indexed HDF5 files produced by JAXTPC's production pipeline: seg (3D deposits), resp (2D wire signals), corr (3D→2D correspondence), @@ -14,15 +14,15 @@ Example configs:: # 3D segmentation - data = dict(train=dict(type="LArTPCDataset", + data = dict(train=dict(type="JAXTPCDataset", modalities=("seg", "labl"), label_key="particle", ...)) # 2D segmentation (all planes) - data = dict(train=dict(type="LArTPCDataset", + data = dict(train=dict(type="JAXTPCDataset", modalities=("resp", "corr", "labl"), label_key="particle", ...)) # Mixed 3D + 2D - data = dict(train=dict(type="LArTPCDataset", + data = dict(train=dict(type="JAXTPCDataset", modalities=("seg", "resp", "corr", "labl"), ...)) """ @@ -34,14 +34,14 @@ from pimm.utils.logger import get_root_logger from .builder import DATASETS from .transform import Compose, TRANSFORMS -from .readers.lartpc_seg_reader import LArTPCSegReader -from .readers.lartpc_resp_reader import LArTPCRespReader -from .readers.lartpc_labl_reader import LArTPCLablReader -from .readers.lartpc_corr_reader import LArTPCCorrReader +from .readers.jaxtpc_seg_reader import JAXTPCSegReader +from .readers.jaxtpc_resp_reader import JAXTPCRespReader +from .readers.jaxtpc_labl_reader import JAXTPCLablReader +from .readers.jaxtpc_corr_reader import JAXTPCCorrReader @DATASETS.register_module() -class LArTPCDataset(Dataset): +class JAXTPCDataset(Dataset): """Multimodal dataset for LArTPC detector simulation output. Parameters @@ -129,23 +129,23 @@ def __init__( f'volume_{volume}_Y'] if 'seg' in self.modalities: - self.seg_reader = LArTPCSegReader( + self.seg_reader = JAXTPCSegReader( data_root=self._modality_root('seg'), split=split, dataset_name=dataset_name, min_deposits=min_deposits, include_physics=include_physics, volume=volume) if 'resp' in self.modalities: - self.resp_reader = LArTPCRespReader( + self.resp_reader = JAXTPCRespReader( data_root=self._modality_root('resp'), split=split, dataset_name=dataset_name, planes=planes) if 'labl' in self.modalities: - self.labl_reader = LArTPCLablReader( + self.labl_reader = JAXTPCLablReader( data_root=self._modality_root('labl'), split=split, dataset_name=dataset_name, label_keys=label_keys) if 'corr' in self.modalities: - self.corr_reader = LArTPCCorrReader( + self.corr_reader = JAXTPCCorrReader( data_root=self._modality_root('corr'), split=split, dataset_name=dataset_name, planes=planes) @@ -171,7 +171,7 @@ def __init__( "Add 'corr' for 2D labels or 'seg' for 3D labels.") logger.info( - f"LArTPCDataset: {self._n_events} events, " + f"JAXTPCDataset: {self._n_events} events, " f"modalities={self.modalities}, " f"volume={volume}, split={split}") diff --git a/pimm/datasets/wc_dataset.py b/pimm/datasets/lucid_dataset.py similarity index 93% rename from pimm/datasets/wc_dataset.py rename to pimm/datasets/lucid_dataset.py index 2be6f0b..e92ea03 100644 --- a/pimm/datasets/wc_dataset.py +++ b/pimm/datasets/lucid_dataset.py @@ -1,5 +1,5 @@ """ -WCDataset — dataset for Water Cherenkov detector simulation output. +LUCiDDataset — dataset for Water Cherenkov detector simulation output. Loads PMT sensor data and/or 3D track segments from co-indexed HDF5 files. Produces flat dicts compatible with pimm's transform/collation pipeline. @@ -7,15 +7,15 @@ Example configs: # PMT event classification (sensor response as fixed-geometry point cloud) - data = dict(train=dict(type="WCDataset", data_root="dataset_wc", + data = dict(train=dict(type="LUCiDDataset", data_root="dataset_wc", modalities=("sensor",), dataset_name="wc", ...)) # Per-sensor instance separation (sparse per-particle entries) - data = dict(train=dict(type="WCDataset", data_root="dataset_wc", + data = dict(train=dict(type="LUCiDDataset", data_root="dataset_wc", modalities=("sensor",), include_labels=True, ...)) # 3D track reconstruction - data = dict(train=dict(type="WCDataset", data_root="dataset_wc", + data = dict(train=dict(type="LUCiDDataset", data_root="dataset_wc", modalities=("seg",), ...)) """ @@ -27,12 +27,12 @@ from pimm.utils.logger import get_root_logger from .builder import DATASETS from .transform import Compose, TRANSFORMS -from .readers.wc_seg_reader import WCSegReader -from .readers.wc_sensor_reader import WCSensorReader +from .readers.lucid_seg_reader import LUCiDSegReader +from .readers.lucid_sensor_reader import LUCiDSensorReader @DATASETS.register_module() -class WCDataset(Dataset): +class LUCiDDataset(Dataset): """Water Cherenkov detector dataset. Parameters @@ -107,13 +107,13 @@ def __init__( if 'seg' in self.modalities: seg_root = self._modality_root('seg') - self.seg_reader = WCSegReader( + self.seg_reader = LUCiDSegReader( data_root=seg_root, split=split, dataset_name=dataset_name, min_segments=min_segments) if 'sensor' in self.modalities: sensor_root = self._modality_root('sensor') - self.sensor_reader = WCSensorReader( + self.sensor_reader = LUCiDSensorReader( data_root=sensor_root, split=split, dataset_name=dataset_name, include_labels=include_labels, @@ -128,7 +128,7 @@ def __init__( self._n_events = min(len(r) for r in active_readers) logger = get_root_logger() - logger.info(f"WCDataset: {self._n_events} events, " + logger.info(f"LUCiDDataset: {self._n_events} events, " f"modalities={self.modalities}, output_mode={output_mode}") def _modality_root(self, modality): diff --git a/pimm/datasets/readers/__init__.py b/pimm/datasets/readers/__init__.py index 4227466..67d35a6 100644 --- a/pimm/datasets/readers/__init__.py +++ b/pimm/datasets/readers/__init__.py @@ -1,6 +1,6 @@ -from .lartpc_seg_reader import LArTPCSegReader -from .lartpc_resp_reader import LArTPCRespReader -from .lartpc_labl_reader import LArTPCLablReader -from .lartpc_corr_reader import LArTPCCorrReader -from .wc_seg_reader import WCSegReader -from .wc_sensor_reader import WCSensorReader +from .jaxtpc_seg_reader import JAXTPCSegReader +from .jaxtpc_resp_reader import JAXTPCRespReader +from .jaxtpc_labl_reader import JAXTPCLablReader +from .jaxtpc_corr_reader import JAXTPCCorrReader +from .lucid_seg_reader import LUCiDSegReader +from .lucid_sensor_reader import LUCiDSensorReader diff --git a/pimm/datasets/readers/lartpc_corr_reader.py b/pimm/datasets/readers/jaxtpc_corr_reader.py similarity index 97% rename from pimm/datasets/readers/lartpc_corr_reader.py rename to pimm/datasets/readers/jaxtpc_corr_reader.py index 1286189..6a282c6 100644 --- a/pimm/datasets/readers/lartpc_corr_reader.py +++ b/pimm/datasets/readers/jaxtpc_corr_reader.py @@ -1,5 +1,5 @@ """ -LArTPCCorrReader — reads 3D→2D correspondence from JAXTPC corr files. +JAXTPCCorrReader — reads 3D→2D correspondence from JAXTPC corr files. Decodes CSR-encoded per-plane correspondence into flat arrays: per pixel entry: (wire, time, group_id, charge). @@ -16,7 +16,7 @@ from pimm.utils.logger import get_root_logger -class LArTPCCorrReader: +class JAXTPCCorrReader: """Reads 3D→2D correspondence from JAXTPC corr HDF5 files. Parameters @@ -76,7 +76,7 @@ def _build_index(self): self.indices.append(index) self.cumulative_lengths = np.cumsum(self.cumulative_lengths) - log.info(f"LArTPCCorrReader: {self.cumulative_lengths[-1]} events " + log.info(f"JAXTPCCorrReader: {self.cumulative_lengths[-1]} events " f"from {len(self.h5_files)} files") def h5py_worker_init(self): diff --git a/pimm/datasets/readers/lartpc_labl_reader.py b/pimm/datasets/readers/jaxtpc_labl_reader.py similarity index 96% rename from pimm/datasets/readers/lartpc_labl_reader.py rename to pimm/datasets/readers/jaxtpc_labl_reader.py index 10ac08f..e16d5b7 100644 --- a/pimm/datasets/readers/lartpc_labl_reader.py +++ b/pimm/datasets/readers/jaxtpc_labl_reader.py @@ -1,5 +1,5 @@ """ -LArTPCLablReader — reads per-volume track_id → label lookup tables. +JAXTPCLablReader — reads per-volume track_id → label lookup tables. The labl file stores a mapping from track_id to labels (particle, cluster, interaction) per volume. This is used for: @@ -20,7 +20,7 @@ from pimm.utils.logger import get_root_logger -class LArTPCLablReader: +class JAXTPCLablReader: """Reads per-volume track_id → label lookup tables. Parameters @@ -80,7 +80,7 @@ def _build_index(self): self.indices.append(index) self.cumulative_lengths = np.cumsum(self.cumulative_lengths) - log.info(f"LArTPCLablReader: {self.cumulative_lengths[-1]} events " + log.info(f"JAXTPCLablReader: {self.cumulative_lengths[-1]} events " f"from {len(self.h5_files)} files") def h5py_worker_init(self): diff --git a/pimm/datasets/readers/lartpc_resp_reader.py b/pimm/datasets/readers/jaxtpc_resp_reader.py similarity index 97% rename from pimm/datasets/readers/lartpc_resp_reader.py rename to pimm/datasets/readers/jaxtpc_resp_reader.py index 33f6ea9..23afe7f 100644 --- a/pimm/datasets/readers/lartpc_resp_reader.py +++ b/pimm/datasets/readers/jaxtpc_resp_reader.py @@ -1,5 +1,5 @@ """ -LArTPCRespReader — reads sparse wire signals from JAXTPC resp files. +JAXTPCRespReader — reads sparse wire signals from JAXTPC resp files. Decodes delta-encoded (wire, time, value) triples per plane. Output keys are dot-namespaced: plane.{plane_label}.wire/time/value @@ -15,7 +15,7 @@ from pimm.utils.logger import get_root_logger -class LArTPCRespReader: +class JAXTPCRespReader: """Reads sparse wire signals from JAXTPC resp HDF5 files. Parameters @@ -80,7 +80,7 @@ def _build_index(self): self.indices.append(index) self.cumulative_lengths = np.cumsum(self.cumulative_lengths) - log.info(f"LArTPCRespReader: {self.cumulative_lengths[-1]} events " + log.info(f"JAXTPCRespReader: {self.cumulative_lengths[-1]} events " f"from {len(self.h5_files)} files") def h5py_worker_init(self): diff --git a/pimm/datasets/readers/lartpc_seg_reader.py b/pimm/datasets/readers/jaxtpc_seg_reader.py similarity index 97% rename from pimm/datasets/readers/lartpc_seg_reader.py rename to pimm/datasets/readers/jaxtpc_seg_reader.py index f5dbec4..f844696 100644 --- a/pimm/datasets/readers/lartpc_seg_reader.py +++ b/pimm/datasets/readers/jaxtpc_seg_reader.py @@ -1,8 +1,8 @@ """ -LArTPCSegReader — reads 3D truth deposits from JAXTPC seg files. +JAXTPCSegReader — reads 3D truth deposits from JAXTPC seg files. Produces raw geometry, physics, and IDs. Labels come from the labl file -via LArTPCLablReader + LArTPCDataset._apply_labl_to_3d(), or from +via JAXTPCLablReader + JAXTPCDataset._apply_labl_to_3d(), or from PDGToSemantic transform as a fallback. Output dict: @@ -19,7 +19,7 @@ from pimm.utils.logger import get_root_logger -class LArTPCSegReader: +class JAXTPCSegReader: """Reads 3D truth deposits from JAXTPC seg HDF5 files. Concatenates volumes into a single point cloud with a volume_id feature. @@ -109,7 +109,7 @@ def _build_index(self): self.indices.append(index) self.cumulative_lengths = np.cumsum(self.cumulative_lengths) - log.info(f"LArTPCSegReader: {self.cumulative_lengths[-1]} events " + log.info(f"JAXTPCSegReader: {self.cumulative_lengths[-1]} events " f"from {len(self.h5_files)} files " f"(min_deposits={self.min_deposits})") diff --git a/pimm/datasets/readers/wc_seg_reader.py b/pimm/datasets/readers/lucid_seg_reader.py similarity index 97% rename from pimm/datasets/readers/wc_seg_reader.py rename to pimm/datasets/readers/lucid_seg_reader.py index 0e762c3..aa526fc 100644 --- a/pimm/datasets/readers/wc_seg_reader.py +++ b/pimm/datasets/readers/lucid_seg_reader.py @@ -1,5 +1,5 @@ """ -WCSegReader — reads 3D track segments from Water Cherenkov segment files. +LUCiDSegReader — reads 3D track segments from Water Cherenkov segment files. Format: flat CSR arrays (no per-event groups). - track_event_offset (n_events+1,) — CSR into track arrays @@ -19,7 +19,7 @@ from pimm.utils.logger import get_root_logger -class WCSegReader: +class LUCiDSegReader: """Reads 3D track segments from WC segment HDF5 files. Parameters @@ -104,7 +104,7 @@ def _build_index(self): self.indices.append(index) self.cumulative_lengths = np.cumsum(self.cumulative_lengths) - log.info(f"WCSegReader: {self.cumulative_lengths[-1]} events " + log.info(f"LUCiDSegReader: {self.cumulative_lengths[-1]} events " f"from {len(self.h5_files)} files") def h5py_worker_init(self): diff --git a/pimm/datasets/readers/wc_sensor_reader.py b/pimm/datasets/readers/lucid_sensor_reader.py similarity index 97% rename from pimm/datasets/readers/wc_sensor_reader.py rename to pimm/datasets/readers/lucid_sensor_reader.py index aa308f0..18d767b 100644 --- a/pimm/datasets/readers/wc_sensor_reader.py +++ b/pimm/datasets/readers/lucid_sensor_reader.py @@ -1,5 +1,5 @@ """ -WCSensorReader — reads PMT sensor data from Water Cherenkov sensor files. +LUCiDSensorReader — reads PMT sensor data from Water Cherenkov sensor files. Format: flat CSR arrays (no per-event groups). - event_hit_offsets (n_events+1,) — CSR into hit arrays @@ -28,7 +28,7 @@ from pimm.utils.logger import get_root_logger -class WCSensorReader: +class LUCiDSensorReader: """Reads PMT sensor data from WC sensor HDF5 files. Parameters @@ -112,7 +112,7 @@ def _build_index(self): self.indices.append(index) self.cumulative_lengths = np.cumsum(self.cumulative_lengths) - log.info(f"WCSensorReader: {self.cumulative_lengths[-1]} events, " + log.info(f"LUCiDSensorReader: {self.cumulative_lengths[-1]} events, " f"{self._n_sensors} sensors from {len(self.h5_files)} files") def h5py_worker_init(self): diff --git a/pimm/datasets/transform.py b/pimm/datasets/transform.py index 3e63f42..08797d9 100644 --- a/pimm/datasets/transform.py +++ b/pimm/datasets/transform.py @@ -53,7 +53,7 @@ def index_operator(data_dict, index, duplicate=False): "instance_particle", "instance_interaction", "momentum", - # LArTPCDataset keys (JAXTPC) + # JAXTPCDataset keys (JAXTPC) "track_ids", "group_ids", "pdg", diff --git a/tools/test_detector_dataset.py b/tests/test_jaxtpc_dataset.py similarity index 96% rename from tools/test_detector_dataset.py rename to tests/test_jaxtpc_dataset.py index 45be779..72e3637 100644 --- a/tools/test_detector_dataset.py +++ b/tests/test_jaxtpc_dataset.py @@ -1,7 +1,7 @@ """ -Verification script for LArTPCDataset — all modality combinations. +Verification script for JAXTPCDataset — all modality combinations. -Run: /usr/bin/python3 tools/test_detector_dataset.py +Run: /usr/bin/python3 tests/test_jaxtpc_dataset.py """ import sys @@ -11,7 +11,7 @@ import numpy as np import torch import torch.nn as nn -from pimm.datasets.lartpc_dataset import LArTPCDataset +from pimm.datasets.jaxtpc_dataset import JAXTPCDataset from pimm.datasets.utils import collate_fn from pimm.datasets.transform import Compose @@ -37,7 +37,7 @@ def check(condition, msg): def make_ds(**kwargs): defaults = dict(data_root=DATA_ROOT, split='', dataset_name='sim', max_len=MAX_LEN) defaults.update(kwargs) - return LArTPCDataset(**defaults) + return JAXTPCDataset(**defaults) def test_seg_only(): @@ -210,7 +210,7 @@ def test_toy_model(): if __name__ == '__main__': - print(f"Testing LArTPCDataset\nData root: {DATA_ROOT}") + print(f"Testing JAXTPCDataset\nData root: {DATA_ROOT}") test_seg_only() test_seg_labl() diff --git a/tools/test_wc_dataset.py b/tests/test_lucid_dataset.py similarity index 96% rename from tools/test_wc_dataset.py rename to tests/test_lucid_dataset.py index b17092c..7c3c3bd 100644 --- a/tools/test_wc_dataset.py +++ b/tests/test_lucid_dataset.py @@ -1,7 +1,7 @@ """ -Verification script for WCDataset — all output modes. +Verification script for LUCiDDataset — all output modes. -Run: /usr/bin/python3 tools/test_wc_dataset.py +Run: /usr/bin/python3 tests/test_lucid_dataset.py """ import sys @@ -11,7 +11,7 @@ import numpy as np import torch import torch.nn as nn -from pimm.datasets.wc_dataset import WCDataset +from pimm.datasets.lucid_dataset import LUCiDDataset from pimm.datasets.utils import collate_fn from pimm.datasets.transform import Compose @@ -33,7 +33,7 @@ def check(condition, msg): def make_ds(**kwargs): defaults = dict(data_root=DATA_ROOT, split='', dataset_name='wc', max_len=4) defaults.update(kwargs) - return WCDataset(**defaults) + return LUCiDDataset(**defaults) def test_sensor_response(): @@ -178,7 +178,7 @@ def test_dataloader(): if __name__ == '__main__': - print(f"Testing WCDataset\nData root: {DATA_ROOT}") + print(f"Testing LUCiDDataset\nData root: {DATA_ROOT}") test_sensor_response() test_sensor_labels()