diff --git a/configs/detector/_base_/jaxtpc_seg.py b/configs/detector/_base_/jaxtpc_seg.py new file mode 100644 index 0000000..08fa9d4 --- /dev/null +++ b/configs/detector/_base_/jaxtpc_seg.py @@ -0,0 +1,87 @@ +# Base dataset config for JAXTPC 3D seg data. +# +# Set JAXTPC_DATA_ROOT environment variable or override data_root in child config. +# Expected directory layout: +# {data_root}/seg/{split}/{dataset_name}_seg_NNNN.h5 +# or: {data_root}/seg/{dataset_name}_seg_NNNN.h5 (flat, split ignored) + +import os + +_data_root = os.environ.get("JAXTPC_DATA_ROOT", "/path/to/jaxtpc/production") + +# Coordinate normalization center and scale. +# Default is for SBND-scale dual-TPC: x in [-2160, 2160], y/z in [-2160, 2160] mm. +_center = [0.0, 0.0, 0.0] +_scale = 2160.0 * 3 ** 0.5 # ~3741 mm — normalizes to roughly [-1, 1] + +grid_size = 0.001 # after normalization + +transform = [ + dict(type="PDGToSemantic", scheme="motif_5cls"), + dict(type="NormalizeCoord", center=_center, scale=_scale), + dict(type="LogTransform", min_val=0.01, max_val=20.0), + dict( + type="GridSample", + grid_size=grid_size, + hash_type="fnv", + mode="train", + return_grid_coord=True, + ), + dict(type="RandomRotate", angle=[-1, 1], axis="z", center=[0, 0, 0], p=0.8), + dict(type="RandomRotate", angle=[-1, 1], axis="x", center=[0, 0, 0], p=0.8), + dict(type="RandomRotate", angle=[-1, 1], axis="y", center=[0, 0, 0], p=0.8), + dict(type="RandomFlip", p=0.5), + dict(type="Copy", keys_dict={"segment_motif": "segment"}), + dict(type="ToTensor"), + dict( + type="Collect", + keys=("coord", "grid_coord", "segment"), + feat_keys=("coord", "energy"), + ), +] + +test_transform = [ + dict(type="PDGToSemantic", scheme="motif_5cls"), + dict(type="NormalizeCoord", center=_center, scale=_scale), + dict(type="LogTransform", min_val=0.01, max_val=20.0), + dict( + type="GridSample", + grid_size=grid_size, + hash_type="fnv", + mode="train", + return_grid_coord=True, + ), + dict(type="Copy", keys_dict={"segment_motif": "segment"}), + dict(type="ToTensor"), + dict( + type="Collect", + keys=("coord", "grid_coord", "segment"), + feat_keys=("coord", "energy"), + ), +] + +data = dict( + num_classes=5, + ignore_index=-1, + names=["shower", "track", "michel", "delta", "led"], + train=dict( + type="JAXTPCDataset", + data_root=_data_root, + split="train", + dataset_name="sim", + modalities=("seg",), + transform=transform, + min_deposits=1024, + max_len=-1, + ), + val=dict( + type="JAXTPCDataset", + data_root=_data_root, + split="val", + dataset_name="sim", + modalities=("seg",), + transform=test_transform, + min_deposits=1024, + max_len=1000, + ), +) diff --git a/configs/detector/semseg/semseg-pt-v3m2-jaxtpc-5cls.py b/configs/detector/semseg/semseg-pt-v3m2-jaxtpc-5cls.py new file mode 100644 index 0000000..6f1be17 --- /dev/null +++ b/configs/detector/semseg/semseg-pt-v3m2-jaxtpc-5cls.py @@ -0,0 +1,106 @@ +""" +PTv3 semantic segmentation on JAXTPC 3D data. + +Drop-in replacement for PILArNet semseg — same model, different data source. +""" + +_base_ = [ + "../../../configs/_base_/default_runtime.py", + "../_base_/jaxtpc_seg.py", +] + +# --- training --- +batch_size = 48 +num_worker = 24 +mix_prob = 0.0 +clip_grad = None +empty_cache = False +enable_amp = True +amp_dtype = "bfloat16" +matmul_precision = "high" +seed = 0 +evaluate = True + +use_wandb = True +wandb_project = "SemSeg-JAXTPC" + +class_weights = None # set from data statistics if needed + +# --- model --- +model = dict( + type="DefaultSegmentorV2", + num_classes=5, + backbone_out_channels=64, + backbone=dict( + type="PT-v3m2", + in_channels=4, # [xyz, energy] + order=("hilbert", "hilbert-trans", "z", "z-trans"), + stride=(2, 2, 2, 2), + enc_depths=(3, 3, 3, 9, 3), + enc_channels=(48, 96, 192, 384, 512), + enc_num_head=(3, 6, 12, 24, 32), + enc_patch_size=(256, 256, 256, 256, 256), + dec_depths=(2, 2, 2, 2), + dec_channels=(64, 96, 192, 384), + dec_num_head=(4, 6, 12, 24), + dec_patch_size=(256, 256, 256, 256), + mlp_ratio=4, + qkv_bias=True, + qk_scale=None, + layer_scale=0.0, + attn_drop=0.0, + proj_drop=0.0, + drop_path=0.3, + shuffle_orders=True, + pre_norm=True, + enable_rpe=False, + enable_flash=True, + upcast_attention=False, + upcast_softmax=False, + traceable=False, + mask_token=False, + enc_mode=False, + freeze_encoder=False, + ), + criteria=[ + dict(type="CrossEntropyLoss", loss_weight=1.0, ignore_index=-1), + dict( + type="LovaszLoss", + mode="multiclass", + loss_weight=1.0 / 20.0, + ignore_index=-1, + ), + ], + freeze_backbone=False, +) + +# --- scheduler --- +epoch = 20 +eval_epoch = 20 +base_lr = 0.0026 +optimizer = dict(type="AdamW", lr=base_lr, weight_decay=0.04) +param_dicts = None + +scheduler = dict( + type="OneCycleLR", + max_lr=[base_lr], + pct_start=0.05, + anneal_strategy="cos", + div_factor=10.0, + final_div_factor=1000.0, +) + +# --- hooks --- +hooks = [ + dict(type="CheckpointLoader"), + dict(type="WeightDecayExclusion", + exclude_bias_from_wd=True, exclude_norm_from_wd=True, + exclude_gamma_from_wd=True, exclude_token_from_wd=True, + exclude_ndim_1_from_wd=True), + dict(type="GradientNormLogger", log_frequency=10), + dict(type="IterationTimer", warmup_iter=2), + dict(type="InformationWriter"), + dict(type="SemSegEvaluator", every_n_steps=1000, write_cls_iou=True), + dict(type="CheckpointSaver", save_freq=None, evaluator_every_n_steps=1000), + dict(type="PreciseEvaluator", test_last=False), +] diff --git a/docs/DETECTOR_DATASET.md b/docs/DETECTOR_DATASET.md new file mode 100644 index 0000000..c32af39 --- /dev/null +++ b/docs/DETECTOR_DATASET.md @@ -0,0 +1,140 @@ +# Detector Datasets + +pimm supports multiple detector types through dedicated dataset classes. Each loads from co-indexed HDF5 files and produces flat dicts that flow through pimm's standard pipeline (transforms → Collect → collate → Point → model). + +## JAXTPCDataset + +For Liquid Argon TPC detectors (JAXTPC production output). + +### Data Layout +``` +dataset_root/ +├── seg/ sim_seg_0000.h5 — 3D truth deposits +├── resp/ sim_resp_0000.h5 — sparse wire signals per plane +├── corr/ sim_corr_0000.h5 — 3D→2D correspondence +└── labl/ sim_labl_0000.h5 — per-volume track_id→label lookup +``` + +### How coord is assigned + +Who owns `coord`/`energy` depends on which modalities are loaded: + +- **seg present** → coord is 3D (N,3) from deposits. Resp/corr available as `resp_coord`, `corr_coord`. +- **seg absent, corr+labl present** → coord is 2D (E,2) from corr entries with labels. +- **seg absent, resp present** → coord is 2D (M,2) from all planes merged. + +When both resp and corr are loaded, each gets its own point cloud: +- `resp_coord`, `resp_energy`, `resp_plane_id` — from resp signal +- `corr_coord`, `corr_energy`, `corr_segment`, `corr_instance`, `corr_plane_id` — from correspondence + +Raw per-plane keys (`plane.*`, `corr.*`) are always passed through for per-plane access. + +### Task → Config + +| Task | `modalities` | What you get | +|------|-------------|-------------| +| 3D segmentation | `('seg', 'labl')` | `coord (N,3)`, `segment (N,)` | +| 3D seg (PDG fallback) | `('seg',)` | `coord (N,3)`, `pdg (N,)` — use PDGToSemantic | +| 2D segmentation | `('resp', 'corr', 'labl')` | `coord (E,2)` from corr + labels. `resp_coord` also available. | +| 2D self-supervised | `('resp',)` | `coord (M,2)` merged planes | +| Resp→corr denoising | `('resp', 'corr')` | `coord (M,2)` from resp. `corr.*` namespaced keys. | +| Everything | `('seg', 'resp', 'corr', 'labl')` | 3D `coord` + `resp_coord` + `corr_coord` + raw keys | + +**Note:** `modalities=('resp', 'labl')` without `'corr'` will NOT produce labels — labl provides track_id→label tables but resp pixels can't be mapped to track_ids without corr. + +### Config Parameters +```python +data = dict(train=dict( + type="JAXTPCDataset", + data_root="/path/to/dataset", + split="", + dataset_name="sim", + modalities=("seg", "labl"), + volume=None, # None=all, 0=volume_0 only + label_key="particle", # 'particle', 'cluster', 'interaction' + min_deposits=1024, + transform=[...], +)) +``` + +### Label Chain +- **3D**: `deposit.track_id → labl[track_id] → label` +- **2D**: `corr.group_id → g2t → track_id → labl[track_id] → label` + +### Transforms Safe for 2D +GridSample, ToTensor, Copy, Collect, RandomDropout, ShufflePoint, RandomJitter, RandomScale, RandomFlip, PositiveShift. + +**3D-only** (crash on 2D coords): RandomRotate, NormalizeCoord, SphereCrop. + +--- + +## LUCiDDataset + +For Water Cherenkov detectors (PMT-based). + +### Data Layout + +Two HDF5 files per dataset; readers accept both naming conventions: + +``` +dataset_root/ +├── seg/ {dataset_name}_seg_NNNN.h5 or segment_events_NNNN.h5 +└── sensor/ {dataset_name}_sensor_NNNN.h5 or sensor_events_NNNN.h5 +``` + +Format is flat CSR arrays (events indexed via `*_offsets` datasets), +matching the PhotonSim/LUCiD production output. + +### Task → Config + +`coord` shape depends on whether PMT 3D positions are provided +(via `pmt_positions` / `pmt_positions_file` on LUCiDSensorReader, or stored +in the file's `config/pmt_positions` dataset). Without positions, +`coord` falls back to `(N, 1)` with sensor indices. + +| Task | `modalities` | `output_mode` | Output | +|------|-------------|--------------|--------| +| Event classification | `('sensor',)` | `'response'` | `coord (N_pmt, 3\|1)`, `energy (N_pmt,1)` [PE], `time (N_pmt,1)` [T] | +| Per-sensor instance separation | `('sensor',)` | `'labels'` | `coord (E, 3\|1)`, `energy (E,1)`, `segment (E,)`, `instance (E,)` | +| 3D track reconstruction | `('seg',)` | any | `coord (N_seg,3)`, `energy (N_seg,1)`, `time`, `track_ids`, `pdg`, `parent_ids` | +| Joint 3D + sensor | `('seg', 'sensor')` | `'separate'` | `seg3d.*` + `pmt_*` + `pp_*` keys | + +### Config Parameters +```python +data = dict(train=dict( + type="LUCiDDataset", + data_root="/path/to/dataset_wc", + dataset_name="wc", + modalities=("sensor",), + output_mode="response", # 'response', 'labels', 'separate' + include_labels=True, + pe_threshold=0.0, # optional: filter per-particle entries below this PE + transform=[...], +)) +``` + +--- + +## Adding a New Detector + +Each dataset class is self-contained — no base class. Copy the closest existing +dataset as a template and modify. + +1. **Write reader(s)** in `pimm/datasets/readers/`. Readers follow a lightweight + convention (not a forced ABC): `__init__` discovers files and builds an event + index; `h5py_worker_init` lazily opens HDF5 handles (DataLoader-fork safe); + `read_event(idx)` returns a `dict[str, np.ndarray]`; `__len__` returns the + event count; `close` releases handles. Copy an existing reader to start. +2. **Write a dataset class** in `pimm/datasets/`, inheriting + `torch.utils.data.Dataset` directly, registered via `@DATASETS.register_module()`. + Orchestrate readers in `get_data`; define `__init__`, `__len__`, `__getitem__`. +3. **Add imports** in `pimm/datasets/__init__.py` and + `pimm/datasets/readers/__init__.py`. + +No changes needed to transforms, collation, models, or training infrastructure. + +## Running Tests +```bash +/usr/bin/python3 tests/test_jaxtpc_dataset.py # JAXTPC / LArTPC (38 tests) +/usr/bin/python3 tests/test_lucid_dataset.py # LUCiD / Water Cherenkov (32 tests) +``` diff --git a/pimm/datasets/__init__.py b/pimm/datasets/__init__.py index f6e2807..66fea89 100644 --- a/pimm/datasets/__init__.py +++ b/pimm/datasets/__init__.py @@ -5,5 +5,8 @@ # physics from .pilarnet import PILArNetH5Dataset +from .jaxtpc_dataset import JAXTPCDataset +from .lucid_dataset import LUCiDDataset +from . import detector_transforms # register PDGToSemantic # dataloader from .dataloader import MultiDatasetDataloader diff --git a/pimm/datasets/detector_transforms.py b/pimm/datasets/detector_transforms.py new file mode 100644 index 0000000..4599c19 --- /dev/null +++ b/pimm/datasets/detector_transforms.py @@ -0,0 +1,112 @@ +""" +Detector-specific transforms for multimodal detector data. + +PDGToSemantic: derives semantic labels from PDG codes in seg data (3D tasks). + Used when no label file is available (fallback). + For production training, labels come from the labl file via + JAXTPCDataset._apply_labl_to_3d() or _build_corr_pointcloud(). +""" + +import numpy as np +from .transform import TRANSFORMS + + +@TRANSFORMS.register_module() +class PDGToSemantic: + """Fallback: derive approximate semantic labels from PDG codes. + + Use this only when no labl file is available. For production training, + use modalities=('seg', 'labl') which applies labels via JAXTPCLablReader. + + Schemes + ------- + motif_5cls : shower(0), track(1), michel(2), delta(3), led(4) + pid_6cls : photon(0), electron(1), muon(2), pion(3), proton(4), other(5) + custom : user-provided {pdg_code: class_index} dict + """ + + _MOTIF = { + 22: 0, 11: 0, -11: 0, # shower + 13: 1, -13: 1, # track (muon) + 211: 1, -211: 1, # track (pion) + 2212: 1, # track (proton) + 321: 1, -321: 1, # track (kaon) + } + + _PID = { + 22: 0, # photon + 11: 1, -11: 1, # electron + 13: 2, -13: 2, # muon + 211: 3, -211: 3, # pion + 2212: 4, # proton + } + + def __init__(self, scheme='motif_5cls', custom_map=None): + self.scheme = scheme + if scheme == 'motif_5cls': + self.mapping = self._MOTIF + self.default = 4 + elif scheme == 'pid_6cls': + self.mapping = self._PID + self.default = 5 + elif scheme == 'custom': + assert custom_map is not None + self.mapping = custom_map + self.default = max(custom_map.values()) + 1 + elif scheme == 'none': + self.mapping = None + self.default = -1 + else: + raise ValueError(f"Unknown label scheme: {scheme}") + + def __call__(self, data_dict): + if self.mapping is None or 'pdg' not in data_dict: + return data_dict + + # Skip if labels already loaded from label file + if 'segment' in data_dict or 'segment_motif' in data_dict: + return data_dict + + pdg = data_dict['pdg'] + n = len(pdg) + labels = np.full(n, self.default, dtype=np.int32) + for code, cls in self.mapping.items(): + labels[pdg == code] = cls + + data_dict['segment_motif'] = labels[:, None] + + # Also produce PID labels + if self.scheme == 'motif_5cls': + pid = np.full(n, 5, dtype=np.int32) + for code, cls in self._PID.items(): + pid[pdg == code] = cls + data_dict['segment_pid'] = pid[:, None] + elif self.scheme == 'pid_6cls': + data_dict['segment_pid'] = labels[:, None] + + # Derive instance from track_ids (simple contiguous remapping) + if 'instance_particle' not in data_dict and 'track_ids' in data_dict: + track_ids = data_dict['track_ids'] + mask = track_ids >= 0 + if mask.any(): + _, inverse = np.unique(track_ids[mask], return_inverse=True) + out = np.full(n, -1, dtype=np.int32) + out[mask] = inverse + data_dict['instance_particle'] = out[:, None] + else: + data_dict['instance_particle'] = np.full((n, 1), -1, dtype=np.int32) + + if 'instance_interaction' not in data_dict and 'interaction_ids' in data_dict: + iids = data_dict['interaction_ids'] + mask = iids >= 0 + if mask.any(): + _, inverse = np.unique(iids[mask], return_inverse=True) + out = np.full(n, -1, dtype=np.int32) + out[mask] = inverse + data_dict['instance_interaction'] = out[:, None] + else: + data_dict['instance_interaction'] = np.full((n, 1), -1, dtype=np.int32) + + data_dict['segment_interaction'] = (iids[:, None] != -1).astype(np.int32) + + return data_dict diff --git a/pimm/datasets/hepdataset.py b/pimm/datasets/hepdataset.py deleted file mode 100644 index a0ec1d0..0000000 --- a/pimm/datasets/hepdataset.py +++ /dev/null @@ -1,33 +0,0 @@ -""" -Base HEP Dataset - -Minimal abstract class showing the required interface for HEP point cloud datasets. -""" - -from abc import ABC, abstractmethod -from torch.utils.data import Dataset - - -class HEPDataset(Dataset, ABC): - """ - Minimal interface for HEP point cloud datasets. - - Subclasses must implement: - - __getitem__(idx) -> dict with at least 'coord' (N,3) and 'energy' (N,1) - - __len__() -> int - """ - - @abstractmethod - def __len__(self) -> int: - raise NotImplementedError - - def __getitem__(self, idx: int) -> dict: - """ - Load a point cloud. - - Returns: - dict with at least: - - coord: (N, 3) float32 array of xyz coordinates - - energy: (N, 1) float32 array of energy deposits - """ - raise NotImplementedError \ No newline at end of file diff --git a/pimm/datasets/jaxtpc_dataset.py b/pimm/datasets/jaxtpc_dataset.py new file mode 100644 index 0000000..0473e33 --- /dev/null +++ b/pimm/datasets/jaxtpc_dataset.py @@ -0,0 +1,453 @@ +""" +JAXTPCDataset — multimodal dataset for LArTPC detector simulation output. + +Loads from co-indexed HDF5 files produced by JAXTPC's production pipeline: +seg (3D deposits), resp (2D wire signals), corr (3D→2D correspondence), +labl (track_id→label lookup tables). + +Who owns ``coord``/``energy`` is determined by which modalities are loaded: + +- seg present → coord is 3D (N,3) from deposits. Resp/corr stay namespaced. +- seg absent, resp present → all planes merged into coord (M,2) with plane_id. +- seg absent, corr+labl present → corr entries become coord (E,2) with labels. + +Example configs:: + + # 3D segmentation + data = dict(train=dict(type="JAXTPCDataset", + modalities=("seg", "labl"), label_key="particle", ...)) + + # 2D segmentation (all planes) + data = dict(train=dict(type="JAXTPCDataset", + modalities=("resp", "corr", "labl"), label_key="particle", ...)) + + # Mixed 3D + 2D + data = dict(train=dict(type="JAXTPCDataset", + modalities=("seg", "resp", "corr", "labl"), ...)) +""" + +import os +import numpy as np +from copy import deepcopy +from torch.utils.data import Dataset + +from pimm.utils.logger import get_root_logger +from .builder import DATASETS +from .transform import Compose, TRANSFORMS +from .readers.jaxtpc_seg_reader import JAXTPCSegReader +from .readers.jaxtpc_resp_reader import JAXTPCRespReader +from .readers.jaxtpc_labl_reader import JAXTPCLablReader +from .readers.jaxtpc_corr_reader import JAXTPCCorrReader + + +@DATASETS.register_module() +class JAXTPCDataset(Dataset): + """Multimodal dataset for LArTPC detector simulation output. + + Parameters + ---------- + data_root : str + Root directory with seg/, resp/, corr/, labl/ subdirectories. + split : str + Split name for file discovery. + transform : list[dict] + Transform pipeline config. + modalities : tuple[str] + Which to load: 'seg', 'resp', 'labl', 'corr'. + dataset_name : str + File prefix (e.g., 'sim' for 'sim_seg_0000.h5'). + volume : int or None + Load only this volume index. None = all volumes. + label_key : str + Which label to use as 'segment': 'particle', 'cluster', 'interaction'. + min_deposits : int + Minimum 3D deposits per event (seg reader filter). + max_len : int + Cap on dataset length (-1 = unlimited). + loop : int + Dataset repetition per epoch. + include_physics : bool + Whether seg reader loads dx, theta, phi, charge, photons, etc. + label_keys : list or None + Which label datasets to load from labl files. + test_mode : bool + Enable test-time transforms. + test_cfg : object + Test config (voxelize, crop, post_transform, aug_transform). + """ + + def __init__( + self, + data_root, + split='train', + transform=None, + modalities=('seg',), + dataset_name='sim', + volume=None, + label_key='particle', + min_deposits=0, + max_len=-1, + loop=1, + include_physics=True, + label_keys=None, + test_mode=False, + test_cfg=None, + ): + super().__init__() + self.data_root = data_root + self.split = split + self.modalities = tuple(modalities) + self.dataset_name = dataset_name + self.volume = volume + self.label_key = label_key + self.min_deposits = min_deposits + self.max_len = max_len + self.loop = loop if not test_mode else 1 + self.test_mode = test_mode + self.test_cfg = test_cfg if test_mode else None + + self.transform = Compose(transform) + if test_mode and test_cfg is not None: + self.test_voxelize = TRANSFORMS.build(self.test_cfg.voxelize) + self.test_crop = ( + TRANSFORMS.build(self.test_cfg.crop) + if self.test_cfg.crop else None) + self.post_transform = Compose(self.test_cfg.post_transform) + self.aug_transform = [ + Compose(aug) for aug in self.test_cfg.aug_transform] + + # Build readers + self.seg_reader = None + self.resp_reader = None + self.labl_reader = None + self.corr_reader = None + + # Plane filter: if volume is set, only load that volume's planes + planes = 'all' + if volume is not None: + planes = [f'volume_{volume}_U', f'volume_{volume}_V', + f'volume_{volume}_Y'] + + if 'seg' in self.modalities: + self.seg_reader = JAXTPCSegReader( + data_root=self._modality_root('seg'), split=split, + dataset_name=dataset_name, min_deposits=min_deposits, + include_physics=include_physics, volume=volume) + + if 'resp' in self.modalities: + self.resp_reader = JAXTPCRespReader( + data_root=self._modality_root('resp'), split=split, + dataset_name=dataset_name, planes=planes) + + if 'labl' in self.modalities: + self.labl_reader = JAXTPCLablReader( + data_root=self._modality_root('labl'), split=split, + dataset_name=dataset_name, label_keys=label_keys) + + if 'corr' in self.modalities: + self.corr_reader = JAXTPCCorrReader( + data_root=self._modality_root('corr'), split=split, + dataset_name=dataset_name, planes=planes) + + # Canonical reader and length + active_readers = [r for r in (self.seg_reader, self.resp_reader, + self.labl_reader, self.corr_reader) + if r is not None] + if not active_readers: + raise ValueError(f"Need at least one modality, got {self.modalities}") + self._canonical_reader = (self.seg_reader or self.resp_reader + or self.corr_reader or self.labl_reader) + self._n_events = min(len(r) for r in active_readers) + + logger = get_root_logger() + + # Warn about modality combinations that won't produce labels + if (self.resp_reader and self.labl_reader + and not self.corr_reader and not self.seg_reader): + logger.warning( + "modalities=('resp','labl') without 'corr': labl provides " + "track_id→label tables but resp pixels can't be mapped to " + "track_ids without corr. No 'segment' will be produced. " + "Add 'corr' for 2D labels or 'seg' for 3D labels.") + + logger.info( + f"JAXTPCDataset: {self._n_events} events, " + f"modalities={self.modalities}, " + f"volume={volume}, split={split}") + + def _modality_root(self, modality): + """Resolve root directory for a modality.""" + mod_dir = os.path.join(self.data_root, modality) + if os.path.isdir(mod_dir): + return mod_dir + split_dir = os.path.join(self.data_root, self.split) + if os.path.isdir(split_dir): + return self.data_root + return self.data_root + + def get_data(self, idx): + """Load one event. Who owns coord depends on modalities: + + - seg present: coord = 3D deposits. Resp/corr as namespaced keys. + - seg absent, corr+labl present: coord = 2D corr entries with labels. + - seg absent, resp present (no corr): coord = 2D resp merged. + """ + data_dict = {} + + # --- Seg (3D point cloud) → owns coord if present --- + if self.seg_reader is not None: + data_dict.update(self.seg_reader.read_event(idx)) + + # --- Labl (track_id → label lookup) --- + labl_data = {} + if self.labl_reader is not None: + labl_data = self.labl_reader.read_event(idx) + + # Apply labels to 3D seg data + if self.seg_reader is not None and labl_data: + self._apply_labl_to_3d(data_dict, labl_data) + + # --- Resp (2D wire planes) --- + resp_data = {} + if self.resp_reader is not None: + resp_data = self.resp_reader.read_event(idx) + + # --- Corr (correspondence) --- + corr_data = {} + if self.corr_reader is not None: + corr_data = self.corr_reader.read_event(idx) + + # --- Build point clouds for each spatial modality --- + # Each gets its own prefixed keys. When only one spatial source + # exists, its keys are also copied to the standard coord/energy + # so the default pipeline (GridSample, Collect, etc.) works. + + has_seg = self.seg_reader is not None + has_resp = bool(resp_data) + has_corr = bool(corr_data) + + # Resp → resp_coord/resp_energy/resp_plane_id + if has_resp: + self._merge_resp_planes(data_dict, resp_data, prefix='resp_') + # Also keep raw namespaced keys for per-plane access + data_dict.update(resp_data) + + # Corr+labl → corr_coord/corr_energy/corr_segment/corr_instance + if has_corr and labl_data: + self._build_corr_pointcloud(data_dict, corr_data, labl_data, prefix='corr_') + elif has_corr: + # corr without labl — keep as namespaced keys only + data_dict.update(corr_data) + + # --- Set standard coord/energy from the primary spatial source --- + if has_seg: + # seg already set coord/energy + pass + elif has_corr and labl_data: + # corr is primary (has labels) + data_dict['coord'] = data_dict['corr_coord'] + data_dict['energy'] = data_dict['corr_energy'] + data_dict['segment'] = data_dict['corr_segment'] + data_dict['instance'] = data_dict['corr_instance'] + data_dict['plane_id'] = data_dict['corr_plane_id'] + elif has_resp: + # resp is primary + data_dict['coord'] = data_dict['resp_coord'] + data_dict['energy'] = data_dict['resp_energy'] + data_dict['plane_id'] = data_dict['resp_plane_id'] + + # Pass through labl lookup tables (for downstream use) + if labl_data: + for k, v in labl_data.items(): + if k not in data_dict: + data_dict[k] = v + + # Metadata + data_dict['name'] = self.get_data_name(idx) + data_dict['split'] = self.split if isinstance(self.split, str) else 'custom' + return data_dict + + def _apply_labl_to_3d(self, data_dict, labl_data): + """Map 3D deposits' track_ids to labels via labl lookup. Vectorized.""" + track_ids = data_dict.get('track_ids') + volume_ids = data_dict.get('volume_id') + if track_ids is None: + return + + n = len(track_ids) + labels = np.full(n, -1, dtype=np.int32) + + vol_indices = sorted(set( + k.split('_')[1] for k in labl_data + if k.startswith('labl_v') and k.endswith('_track_ids') + )) + + for vi in vol_indices: + tids_key = f'labl_{vi}_track_ids' + label_key = f'labl_{vi}_{self.label_key}' + if tids_key not in labl_data or label_key not in labl_data: + continue + + vol_tids = labl_data[tids_key] + vol_labels = labl_data[label_key] + vol_num = int(vi[1:]) + + if volume_ids is not None: + vol_mask = volume_ids.ravel() == vol_num + else: + vol_mask = np.ones(n, dtype=bool) + + sort_idx = np.argsort(vol_tids) + sorted_tids = vol_tids[sort_idx] + sorted_labels = vol_labels[sort_idx] + + deposit_tids = track_ids[vol_mask] + insert_pos = np.searchsorted(sorted_tids, deposit_tids) + insert_pos = np.clip(insert_pos, 0, len(sorted_tids) - 1) + matched = sorted_tids[insert_pos] == deposit_tids + labels[vol_mask] = np.where(matched, sorted_labels[insert_pos], -1) + + data_dict['segment'] = labels + + def _merge_resp_planes(self, data_dict, resp_data, prefix=''): + """Merge all planes into {prefix}coord (M,2), {prefix}energy (M,1), {prefix}plane_id (M,1).""" + planes = sorted(set( + k.split('.')[1] for k in resp_data if k.endswith('.wire') + )) + + all_coord, all_energy, all_plane_id = [], [], [] + for i, plane in enumerate(planes): + wire = resp_data[f'plane.{plane}.wire'] + time = resp_data[f'plane.{plane}.time'] + value = resp_data[f'plane.{plane}.value'] + n = len(wire) + all_coord.append(np.stack([wire, time], axis=1).astype(np.float32)) + all_energy.append(value[:, None].astype(np.float32)) + all_plane_id.append(np.full((n, 1), i, dtype=np.int32)) + + data_dict[f'{prefix}coord'] = np.concatenate(all_coord, axis=0) + data_dict[f'{prefix}energy'] = np.concatenate(all_energy, axis=0) + data_dict[f'{prefix}plane_id'] = np.concatenate(all_plane_id, axis=0) + + def _build_corr_pointcloud(self, data_dict, corr_data, labl_data, prefix=''): + """Build 2D labeled point cloud from corr + labl. + + Each corr entry is a point: coord=(wire,time), feature=charge, + instance=group_id, segment from g2t+labl chain. + Overlapping instances at the same pixel are separate points. + """ + planes = sorted(set( + k.split('.')[1] for k in corr_data if k.endswith('.wire') + )) + + all_coord, all_charge, all_gid, all_segment, all_plane_id = [], [], [], [], [] + + for pi, plane in enumerate(planes): + wire_key = f'corr.{plane}.wire' + if wire_key not in corr_data: + continue + + wire = corr_data[f'corr.{plane}.wire'] + time = corr_data[f'corr.{plane}.time'] + gid = corr_data[f'corr.{plane}.group_id'] + charge = corr_data[f'corr.{plane}.charge'] + n = len(wire) + + all_coord.append(np.stack([wire, time], axis=1).astype(np.float32)) + all_charge.append(charge[:, None].astype(np.float32)) + all_gid.append(gid.astype(np.int32)) + all_plane_id.append(np.full((n, 1), pi, dtype=np.int32)) + + # group_id → g2t → track_id → labl → label + vol_idx = plane.split('_')[1] # 'volume_0_U' → '0' + g2t = corr_data.get(f'g2t_v{vol_idx}') + + labels = np.full(n, -1, dtype=np.int32) + if g2t is not None: + valid_gid = (gid >= 0) & (gid < len(g2t)) + track_ids = np.where(valid_gid, g2t[gid], -1) + + tids_key = f'labl_v{vol_idx}_track_ids' + lbl_key = f'labl_v{vol_idx}_{self.label_key}' + if tids_key in labl_data and lbl_key in labl_data: + labl_tids = labl_data[tids_key] + labl_vals = labl_data[lbl_key] + sort_idx = np.argsort(labl_tids) + sorted_tids = labl_tids[sort_idx] + sorted_vals = labl_vals[sort_idx] + insert_pos = np.searchsorted(sorted_tids, track_ids) + insert_pos = np.clip(insert_pos, 0, len(sorted_tids) - 1) + matched = sorted_tids[insert_pos] == track_ids + labels[matched] = sorted_vals[insert_pos[matched]] + + all_segment.append(labels) + + if not all_coord: + return + + data_dict[f'{prefix}coord'] = np.concatenate(all_coord, axis=0) + data_dict[f'{prefix}energy'] = np.concatenate(all_charge, axis=0) + data_dict[f'{prefix}instance'] = np.concatenate(all_gid, axis=0) + data_dict[f'{prefix}segment'] = np.concatenate(all_segment, axis=0) + data_dict[f'{prefix}plane_id'] = np.concatenate(all_plane_id, axis=0) + + def get_data_name(self, idx): + reader = self._canonical_reader + file_idx = int(np.searchsorted(reader.cumulative_lengths, idx, side='right')) + local = idx - (int(reader.cumulative_lengths[file_idx - 1]) + if file_idx > 0 else 0) + event_num = reader.indices[file_idx][local] + fname = os.path.basename(reader.h5_files[file_idx]) + return f"{fname}_evt{event_num:03d}" + + def prepare_train_data(self, idx): + return self.transform(self.get_data(idx % len(self))) + + def prepare_test_data(self, idx): + data_dict = self.get_data(idx % len(self)) + if self.transform is not None: + data_dict = self.transform(data_dict) + result_dict = dict(name=data_dict.pop("name")) + if "segment" in data_dict: + result_dict["segment"] = data_dict.pop("segment") + if "origin_segment" in data_dict: + assert "inverse" in data_dict + result_dict["origin_segment"] = data_dict.pop("origin_segment") + result_dict["inverse"] = data_dict.pop("inverse") + data_dict_list = [] + for aug in self.aug_transform: + data_dict_list.append(aug(deepcopy(data_dict))) + fragment_list = [] + for data in data_dict_list: + if self.test_voxelize is not None: + data_part_list = self.test_voxelize(data) + else: + data_part_list = [data] + for data_part in data_part_list: + if self.test_crop is not None: + data_part = self.test_crop(data_part) + else: + data_part = [data_part] + fragment_list += data_part + for i in range(len(fragment_list)): + fragment_list[i] = self.post_transform(fragment_list[i]) + result_dict["fragment_list"] = fragment_list + return result_dict + + def __getitem__(self, idx): + real_idx = idx % len(self) + if self.test_mode: + return self.prepare_test_data(real_idx) + return self.prepare_train_data(real_idx) + + def __len__(self): + n = self._n_events + if self.max_len > 0: + n = min(n, self.max_len) + return n * self.loop + + def __del__(self): + for attr in ('seg_reader', 'resp_reader', 'labl_reader', 'corr_reader'): + reader = getattr(self, attr, None) + if reader is not None: + reader.close() diff --git a/pimm/datasets/lucid_dataset.py b/pimm/datasets/lucid_dataset.py new file mode 100644 index 0000000..e92ea03 --- /dev/null +++ b/pimm/datasets/lucid_dataset.py @@ -0,0 +1,243 @@ +""" +LUCiDDataset — dataset for Water Cherenkov detector simulation output. + +Loads PMT sensor data and/or 3D track segments from co-indexed HDF5 files. +Produces flat dicts compatible with pimm's transform/collation pipeline. + +Example configs: + + # PMT event classification (sensor response as fixed-geometry point cloud) + data = dict(train=dict(type="LUCiDDataset", data_root="dataset_wc", + modalities=("sensor",), dataset_name="wc", ...)) + + # Per-sensor instance separation (sparse per-particle entries) + data = dict(train=dict(type="LUCiDDataset", data_root="dataset_wc", + modalities=("sensor",), include_labels=True, ...)) + + # 3D track reconstruction + data = dict(train=dict(type="LUCiDDataset", data_root="dataset_wc", + modalities=("seg",), ...)) +""" + +import os +import numpy as np +from copy import deepcopy +from torch.utils.data import Dataset + +from pimm.utils.logger import get_root_logger +from .builder import DATASETS +from .transform import Compose, TRANSFORMS +from .readers.lucid_seg_reader import LUCiDSegReader +from .readers.lucid_sensor_reader import LUCiDSensorReader + + +@DATASETS.register_module() +class LUCiDDataset(Dataset): + """Water Cherenkov detector dataset. + + Parameters + ---------- + data_root : str + Root directory with seg/ and/or sensor/ subdirectories. + split : str + Split name for file discovery. + transform : list[dict] + Transform pipeline config. + modalities : tuple[str] + Which to load: 'seg', 'sensor'. + dataset_name : str + File prefix (e.g., 'wc' for 'wc_seg_0000.h5'). + output_mode : str + How to format sensor data for the model: + - 'response': PMT point cloud with total PE/T features + - 'labels': sparse per-particle entries with instance/semantic labels + - 'separate': keep raw reader keys (pmt_coord, pmt_pe, pp_* keys) + include_labels : bool + Whether sensor reader loads per-particle decomposition. + pe_threshold : float + Minimum PE for sparsifying PE_per_particle. + min_segments : int + Minimum segments per event (seg reader filter). + max_len : int + Cap on dataset length. + loop : int + Dataset repetition per epoch. + """ + + def __init__( + self, + data_root, + split='', + transform=None, + modalities=('sensor',), + dataset_name='wc', + output_mode='response', + include_labels=True, + pe_threshold=0.0, + min_segments=0, + max_len=-1, + loop=1, + test_mode=False, + test_cfg=None, + ): + super().__init__() + self.data_root = data_root + self.split = split + self.modalities = tuple(modalities) + self.dataset_name = dataset_name + self.output_mode = output_mode + self.max_len = max_len + self.loop = loop if not test_mode else 1 + self.test_mode = test_mode + self.test_cfg = test_cfg if test_mode else None + + self.transform = Compose(transform) + if test_mode and test_cfg is not None: + self.test_voxelize = TRANSFORMS.build(self.test_cfg.voxelize) + self.test_crop = ( + TRANSFORMS.build(self.test_cfg.crop) + if self.test_cfg.crop else None) + self.post_transform = Compose(self.test_cfg.post_transform) + self.aug_transform = [ + Compose(aug) for aug in self.test_cfg.aug_transform] + + # Build readers + self.seg_reader = None + self.sensor_reader = None + + if 'seg' in self.modalities: + seg_root = self._modality_root('seg') + self.seg_reader = LUCiDSegReader( + data_root=seg_root, split=split, + dataset_name=dataset_name, min_segments=min_segments) + + if 'sensor' in self.modalities: + sensor_root = self._modality_root('sensor') + self.sensor_reader = LUCiDSensorReader( + data_root=sensor_root, split=split, + dataset_name=dataset_name, + include_labels=include_labels, + pe_threshold=pe_threshold) + + # Canonical reader and length + active_readers = [r for r in (self.seg_reader, self.sensor_reader) + if r is not None] + if not active_readers: + raise ValueError(f"Need 'seg' or 'sensor' in modalities, got {self.modalities}") + self._canonical_reader = active_readers[0] + self._n_events = min(len(r) for r in active_readers) + + logger = get_root_logger() + logger.info(f"LUCiDDataset: {self._n_events} events, " + f"modalities={self.modalities}, output_mode={output_mode}") + + def _modality_root(self, modality): + mod_dir = os.path.join(self.data_root, modality) + if os.path.isdir(mod_dir): + return mod_dir + return self.data_root + + def get_data(self, idx): + data_dict = {} + + # --- Seg (3D track segments) --- + if self.seg_reader is not None: + seg_data = self.seg_reader.read_event(idx) + if self.sensor_reader is not None and self.output_mode == 'separate': + # Prefix 3D keys to avoid collision with sensor coord + for k, v in seg_data.items(): + data_dict[f'seg3d.{k}'] = v + else: + data_dict.update(seg_data) + + # --- Sensor (PMT response + optional per-particle labels) --- + if self.sensor_reader is not None: + sensor_data = self.sensor_reader.read_event(idx) + + if self.output_mode == 'response': + # PMT response — one entry per sensor + n = len(sensor_data['pmt_pe']) + if 'pmt_coord' in sensor_data: + data_dict['coord'] = sensor_data['pmt_coord'] + else: + # No 3D positions — use sensor index as 1D coord + data_dict['coord'] = np.arange(n, dtype=np.float32)[:, None] + data_dict['energy'] = sensor_data['pmt_pe'][:, None] + data_dict['time'] = sensor_data['pmt_t'][:, None] + + elif self.output_mode == 'labels': + # Sparse per-particle entries + if 'pp_sensor_idx' in sensor_data: + sidx = sensor_data['pp_sensor_idx'] + if 'pmt_coord' in sensor_data: + data_dict['coord'] = sensor_data['pmt_coord'][sidx] + else: + data_dict['coord'] = sidx.astype(np.float32)[:, None] + data_dict['energy'] = sensor_data['pp_pe'][:, None] + data_dict['segment'] = sensor_data['pp_category'] + data_dict['instance'] = sensor_data['pp_particle_idx'] + if 'pp_t' in sensor_data: + data_dict['time'] = sensor_data['pp_t'][:, None] + + elif self.output_mode == 'separate': + data_dict.update(sensor_data) + + # Metadata + data_dict['name'] = self.get_data_name(idx) + data_dict['split'] = self.split if isinstance(self.split, str) else 'custom' + return data_dict + + def get_data_name(self, idx): + reader = self._canonical_reader + file_idx = int(np.searchsorted(reader.cumulative_lengths, idx, side='right')) + local = idx - (int(reader.cumulative_lengths[file_idx - 1]) if file_idx > 0 else 0) + event_num = reader.indices[file_idx][local] + fname = os.path.basename(reader.h5_files[file_idx]) + return f"{fname}_evt{event_num:03d}" + + def prepare_train_data(self, idx): + return self.transform(self.get_data(idx % len(self))) + + def prepare_test_data(self, idx): + data_dict = self.get_data(idx % len(self)) + if self.transform is not None: + data_dict = self.transform(data_dict) + result_dict = dict(name=data_dict.pop("name")) + if "segment" in data_dict: + result_dict["segment"] = data_dict.pop("segment") + data_dict_list = [] + for aug in self.aug_transform: + data_dict_list.append(aug(deepcopy(data_dict))) + fragment_list = [] + for data in data_dict_list: + if self.test_voxelize is not None: + data_part_list = self.test_voxelize(data) + else: + data_part_list = [data] + for data_part in data_part_list: + if self.test_crop is not None: + data_part = self.test_crop(data_part) + else: + data_part = [data_part] + fragment_list += data_part + for i in range(len(fragment_list)): + fragment_list[i] = self.post_transform(fragment_list[i]) + result_dict["fragment_list"] = fragment_list + return result_dict + + def __getitem__(self, idx): + real_idx = idx % len(self) + if self.test_mode: + return self.prepare_test_data(real_idx) + return self.prepare_train_data(real_idx) + + def __len__(self): + n = self._n_events + if self.max_len > 0: + n = min(n, self.max_len) + return n * self.loop + + def __del__(self): + for reader in (self.seg_reader, self.sensor_reader): + if reader is not None: + reader.close() diff --git a/pimm/datasets/pilarnet.py b/pimm/datasets/pilarnet.py index 5f471f7..e125b4c 100644 --- a/pimm/datasets/pilarnet.py +++ b/pimm/datasets/pilarnet.py @@ -15,14 +15,13 @@ from pimm.utils.logger import get_root_logger from .builder import DATASETS from .transform import Compose, TRANSFORMS -from .hepdataset import HEPDataset # priority for voxel deduplication: track (1) > shower (0) > michel (2) > delta (3) > led (4) DEFAULT_LABEL_PRIORITY = {1: 0, 0: 1, 2: 2, 3: 3, 4: 4} @DATASETS.register_module() -class PILArNetH5Dataset(HEPDataset): +class PILArNetH5Dataset(Dataset): """ PILArNet-M Dataset that loads directly from h5 files, avoiding the need for preprocessing to individual files. diff --git a/pimm/datasets/readers/__init__.py b/pimm/datasets/readers/__init__.py new file mode 100644 index 0000000..67d35a6 --- /dev/null +++ b/pimm/datasets/readers/__init__.py @@ -0,0 +1,6 @@ +from .jaxtpc_seg_reader import JAXTPCSegReader +from .jaxtpc_resp_reader import JAXTPCRespReader +from .jaxtpc_labl_reader import JAXTPCLablReader +from .jaxtpc_corr_reader import JAXTPCCorrReader +from .lucid_seg_reader import LUCiDSegReader +from .lucid_sensor_reader import LUCiDSensorReader diff --git a/pimm/datasets/readers/jaxtpc_corr_reader.py b/pimm/datasets/readers/jaxtpc_corr_reader.py new file mode 100644 index 0000000..6a282c6 --- /dev/null +++ b/pimm/datasets/readers/jaxtpc_corr_reader.py @@ -0,0 +1,190 @@ +""" +JAXTPCCorrReader — reads 3D→2D correspondence from JAXTPC corr files. + +Decodes CSR-encoded per-plane correspondence into flat arrays: +per pixel entry: (wire, time, group_id, charge). + +Also loads per-volume group_to_track lookup tables. + +All decoding is fully vectorized (no Python loops over groups). +""" + +import os +import glob +import numpy as np +import h5py +from pimm.utils.logger import get_root_logger + + +class JAXTPCCorrReader: + """Reads 3D→2D correspondence from JAXTPC corr HDF5 files. + + Parameters + ---------- + data_root : str + Directory containing corr shard files. + split : str + Split name. + dataset_name : str + File prefix (e.g., 'sim' matches 'sim_corr_0000.h5'). + planes : str or list + Which planes to load: 'all' or list like ['volume_0_U']. + """ + + def __init__(self, data_root, split='train', dataset_name='sim', + planes='all', **kwargs): + self.data_root = data_root + self.split = split + self.dataset_name = dataset_name + self.planes = planes + + self.h5_files = self._find_files() + assert len(self.h5_files) > 0, ( + f"No corr files found for '{dataset_name}' in {data_root}/{split}") + + self._initted = False + self._h5data = [] + + self._build_index() + + def _find_files(self): + pattern = os.path.join( + self.data_root, self.split, + f'{self.dataset_name}_corr_*.h5') + files = sorted(glob.glob(pattern)) + if not files: + pattern = os.path.join( + self.data_root, f'{self.dataset_name}_corr_*.h5') + files = sorted(glob.glob(pattern)) + return files + + def _build_index(self): + log = get_root_logger() + self.cumulative_lengths = [] + self.indices = [] + + for h5_path in self.h5_files: + try: + with h5py.File(h5_path, 'r', libver='latest', swmr=True) as f: + n_events = int(f['config'].attrs['n_events']) + index = np.arange(n_events, dtype=np.int64) + except Exception as e: + log.warning(f"Error processing {h5_path}: {e}") + index = np.array([], dtype=np.int64) + + self.cumulative_lengths.append(len(index)) + self.indices.append(index) + + self.cumulative_lengths = np.cumsum(self.cumulative_lengths) + log.info(f"JAXTPCCorrReader: {self.cumulative_lengths[-1]} events " + f"from {len(self.h5_files)} files") + + def h5py_worker_init(self): + self._h5data = [ + h5py.File(p, 'r', libver='latest', swmr=True) + for p in self.h5_files + ] + self._initted = True + + def _locate_event(self, idx): + file_idx = int(np.searchsorted(self.cumulative_lengths, idx, side='right')) + local_idx = idx - (int(self.cumulative_lengths[file_idx - 1]) if file_idx > 0 else 0) + event_num = self.indices[file_idx][local_idx] + event_key = f'event_{event_num:03d}' + f = self._h5data[file_idx] + return f, event_key + + @staticmethod + def _decode_plane_vectorized(g): + """Decode one plane's CSR correspondence — fully vectorized. + + Returns (wire, time, group_id, charge) arrays, all shape (E,). + """ + group_ids = g['group_ids'][:] + group_sizes = g['group_sizes'][:].astype(np.int32) + center_wires = g['center_wires'][:] + center_times = g['center_times'][:] + peak_charges = g['peak_charges'][:] + delta_wires = g['delta_wires'][:] + delta_times = g['delta_times'][:] + charges_u16 = g['charges_u16'][:] + + G = len(group_ids) + if G == 0: + empty = np.array([], dtype=np.int32) + return empty, empty, empty, np.array([], dtype=np.float32) + + # Broadcast group-level arrays to per-entry arrays + wires = (np.repeat(center_wires, group_sizes).astype(np.int32) + + delta_wires.astype(np.int32)) + times = (np.repeat(center_times, group_sizes).astype(np.int32) + + delta_times.astype(np.int32)) + gids = np.repeat(group_ids, group_sizes) + charges = (np.repeat(peak_charges, group_sizes) + * charges_u16.astype(np.float32) / 65535.0) + + return wires, times, gids, charges + + def read_event(self, idx): + """Read one event's correspondence data. + + Returns dict with: + corr.{vol_plane}.wire: (E,) int32 — wire index per entry + corr.{vol_plane}.time: (E,) int32 — time index per entry + corr.{vol_plane}.group_id: (E,) int32 — group ID per entry + corr.{vol_plane}.charge: (E,) float32 — charge per entry + g2t_v{N}: (G,) int32 — group_to_track per volume + """ + if not self._initted: + self.h5py_worker_init() + + f, event_key = self._locate_event(idx) + evt = f[event_key] + + data_dict = {} + + for vol_key in evt: + vol = evt[vol_key] + if not isinstance(vol, h5py.Group): + continue + if not vol_key.startswith('volume_'): + continue + + vol_idx = vol_key.replace('volume_', '') + + # group_to_track lookup for this volume + if 'group_to_track' in vol: + data_dict[f'g2t_v{vol_idx}'] = vol['group_to_track'][:].astype(np.int32) + + # Per-plane correspondence + for plane_key in vol: + pg = vol[plane_key] + if not isinstance(pg, h5py.Group) or 'group_ids' not in pg: + continue + + plane_label = f'volume_{vol_idx}_{plane_key}' + if self.planes != 'all' and plane_label not in self.planes: + continue + + wires, times, gids, charges = self._decode_plane_vectorized(pg) + + prefix = f'corr.{plane_label}' + data_dict[f'{prefix}.wire'] = wires + data_dict[f'{prefix}.time'] = times + data_dict[f'{prefix}.group_id'] = gids + data_dict[f'{prefix}.charge'] = charges + + return data_dict + + def __len__(self): + return int(self.cumulative_lengths[-1]) if len(self.cumulative_lengths) > 0 else 0 + + def close(self): + if self._initted: + for f in self._h5data: + try: + f.close() + except Exception: + pass + self._h5data = [] + self._initted = False diff --git a/pimm/datasets/readers/jaxtpc_labl_reader.py b/pimm/datasets/readers/jaxtpc_labl_reader.py new file mode 100644 index 0000000..e16d5b7 --- /dev/null +++ b/pimm/datasets/readers/jaxtpc_labl_reader.py @@ -0,0 +1,150 @@ +""" +JAXTPCLablReader — reads per-volume track_id → label lookup tables. + +The labl file stores a mapping from track_id to labels (particle, cluster, +interaction) per volume. This is used for: + - 3D tasks: deposit's track_id → look up label directly + - 2D tasks: pixel → group_id → group_to_track → track_id → look up label + +Output: a dict of numpy arrays per volume, keyed by label name. + labl_v0_track_ids: (T,) int32 — unique track IDs + labl_v0_particle: (T,) int32 — particle type per track + labl_v0_cluster: (T,) int32 — cluster ID per track + labl_v0_interaction: (T,) int32 — interaction ID per track +""" + +import os +import glob +import numpy as np +import h5py +from pimm.utils.logger import get_root_logger + + +class JAXTPCLablReader: + """Reads per-volume track_id → label lookup tables. + + Parameters + ---------- + data_root : str + Directory containing labl shard files. + split : str + Split name. + dataset_name : str + File prefix (e.g., 'sim' matches 'sim_labl_0000.h5'). + label_keys : list of str or None + Which label datasets to load (default: all available). + """ + + def __init__(self, data_root, split='train', dataset_name='sim', + label_keys=None): + self.data_root = data_root + self.split = split + self.dataset_name = dataset_name + self.label_keys = label_keys # None = load all + + self.h5_files = self._find_files() + assert len(self.h5_files) > 0, ( + f"No labl files found for '{dataset_name}' in {data_root}/{split}") + + self._initted = False + self._h5data = [] + + self._build_index() + + def _find_files(self): + pattern = os.path.join( + self.data_root, self.split, + f'{self.dataset_name}_labl_*.h5') + files = sorted(glob.glob(pattern)) + if not files: + pattern = os.path.join( + self.data_root, f'{self.dataset_name}_labl_*.h5') + files = sorted(glob.glob(pattern)) + return files + + def _build_index(self): + log = get_root_logger() + self.cumulative_lengths = [] + self.indices = [] + + for h5_path in self.h5_files: + try: + with h5py.File(h5_path, 'r', libver='latest', swmr=True) as f: + n_events = int(f['config'].attrs['n_events']) + index = np.arange(n_events, dtype=np.int64) + except Exception as e: + log.warning(f"Error processing {h5_path}: {e}") + index = np.array([], dtype=np.int64) + + self.cumulative_lengths.append(len(index)) + self.indices.append(index) + + self.cumulative_lengths = np.cumsum(self.cumulative_lengths) + log.info(f"JAXTPCLablReader: {self.cumulative_lengths[-1]} events " + f"from {len(self.h5_files)} files") + + def h5py_worker_init(self): + self._h5data = [ + h5py.File(p, 'r', libver='latest', swmr=True) + for p in self.h5_files + ] + self._initted = True + + def _locate_event(self, idx): + file_idx = int(np.searchsorted(self.cumulative_lengths, idx, side='right')) + local_idx = idx - (int(self.cumulative_lengths[file_idx - 1]) if file_idx > 0 else 0) + event_num = self.indices[file_idx][local_idx] + event_key = f'event_{event_num:03d}' + f = self._h5data[file_idx] + return f, event_key + + def read_event(self, idx): + """Read one event's per-volume label lookup tables. + + Returns dict with keys like: + labl_v0_track_ids: (T,) int32 + labl_v0_particle: (T,) int32 + labl_v0_cluster: (T,) int32 + labl_v0_interaction: (T,) int32 + """ + if not self._initted: + self.h5py_worker_init() + + f, event_key = self._locate_event(idx) + evt = f[event_key] + + data_dict = {} + for vol_key in evt: + vol = evt[vol_key] + if not isinstance(vol, h5py.Group): + continue + if 'track_ids' not in vol: + continue + + # Volume index from key (e.g., 'volume_0' → 0) + vol_idx = vol_key.replace('volume_', '') + prefix = f'labl_v{vol_idx}' + + data_dict[f'{prefix}_track_ids'] = vol['track_ids'][:].astype(np.int32) + + for lk in vol: + if lk == 'track_ids': + continue + if self.label_keys is not None and lk not in self.label_keys: + continue + data_dict[f'{prefix}_{lk}'] = vol[lk][:].astype(np.int32) + + return data_dict + + def __len__(self): + return int(self.cumulative_lengths[-1]) if len(self.cumulative_lengths) > 0 else 0 + + def close(self): + if self._initted: + for f in self._h5data: + try: + f.close() + except Exception: + pass + self._h5data = [] + self._initted = False diff --git a/pimm/datasets/readers/jaxtpc_resp_reader.py b/pimm/datasets/readers/jaxtpc_resp_reader.py new file mode 100644 index 0000000..23afe7f --- /dev/null +++ b/pimm/datasets/readers/jaxtpc_resp_reader.py @@ -0,0 +1,180 @@ +""" +JAXTPCRespReader — reads sparse wire signals from JAXTPC resp files. + +Decodes delta-encoded (wire, time, value) triples per plane. +Output keys are dot-namespaced: plane.{plane_label}.wire/time/value + +Handles both old format (planes directly under event) and new format +(planes under volume_N/ subgroups). +""" + +import os +import glob +import numpy as np +import h5py +from pimm.utils.logger import get_root_logger + + +class JAXTPCRespReader: + """Reads sparse wire signals from JAXTPC resp HDF5 files. + + Parameters + ---------- + data_root : str + Directory containing resp shard files. + split : str + Split name — used as subdirectory or glob pattern. + dataset_name : str + File prefix (e.g., 'sim' matches 'sim_resp_0000.h5'). + planes : str or list + Which planes to load: 'all' or list like ['east_U', 'east_V']. + decode_digitization : bool + If True, subtract pedestal from uint16 values. + """ + + def __init__(self, data_root, split='train', dataset_name='sim', + planes='all', decode_digitization=True): + self.data_root = data_root + self.split = split + self.dataset_name = dataset_name + self.planes = planes + self.decode_digitization = decode_digitization + + self.h5_files = self._find_files() + assert len(self.h5_files) > 0, ( + f"No resp files found for '{dataset_name}' in {data_root}/{split}") + + self._initted = False + self._h5data = [] + + self._build_index() + + def _find_files(self): + """Discover resp shard files.""" + pattern = os.path.join( + self.data_root, self.split, + f'{self.dataset_name}_resp_*.h5') + files = sorted(glob.glob(pattern)) + if not files: + pattern = os.path.join( + self.data_root, f'{self.dataset_name}_resp_*.h5') + files = sorted(glob.glob(pattern)) + return files + + def _build_index(self): + """Scan files, count events, build cumulative index.""" + log = get_root_logger() + self.cumulative_lengths = [] + self.indices = [] + + for h5_path in self.h5_files: + try: + with h5py.File(h5_path, 'r', libver='latest', swmr=True) as f: + n_events = int(f['config'].attrs['n_events']) + index = np.arange(n_events, dtype=np.int64) + except Exception as e: + log.warning(f"Error processing {h5_path}: {e}") + index = np.array([], dtype=np.int64) + + self.cumulative_lengths.append(len(index)) + self.indices.append(index) + + self.cumulative_lengths = np.cumsum(self.cumulative_lengths) + log.info(f"JAXTPCRespReader: {self.cumulative_lengths[-1]} events " + f"from {len(self.h5_files)} files") + + def h5py_worker_init(self): + """Lazily open file handles (called after DataLoader fork).""" + self._h5data = [ + h5py.File(p, 'r', libver='latest', swmr=True) + for p in self.h5_files + ] + self._initted = True + + def _locate_event(self, idx): + """Map global index -> (file_handle, event_key).""" + file_idx = int(np.searchsorted(self.cumulative_lengths, idx, side='right')) + local_idx = idx - (int(self.cumulative_lengths[file_idx - 1]) if file_idx > 0 else 0) + event_num = self.indices[file_idx][local_idx] + event_key = f'event_{event_num:03d}' + f = self._h5data[file_idx] + return f, event_key + + def _iter_planes(self, evt): + """Yield (plane_label, h5py.Group) for each plane in an event. + + Handles both formats: + - Old: planes directly under event (east_U, east_V, ...) + - New: planes under volume_N/ subgroups + """ + for key in evt: + obj = evt[key] + if not isinstance(obj, h5py.Group): + continue + if key.startswith('volume_'): + # New format: volume_0/U, volume_0/V, ... + vol_label = key # e.g., 'volume_0' + for plane_key in obj: + pg = obj[plane_key] + if isinstance(pg, h5py.Group) and 'delta_wire' in pg: + yield f'{vol_label}_{plane_key}', pg + elif 'delta_wire' in obj: + # Old format: east_U, east_V, ... + yield key, obj + + def _decode_plane(self, g): + """Decode one plane's delta-encoded sparse data.""" + wire_start = int(g.attrs['wire_start']) + time_start = int(g.attrs['time_start']) + + wire = wire_start + np.cumsum(g['delta_wire'][:]).astype(np.int32) + time = time_start + np.cumsum(g['delta_time'][:]).astype(np.int32) + + raw_values = g['values'][:] + if self.decode_digitization and raw_values.dtype == np.uint16: + ped = int(g.attrs.get('pedestal', 0)) + values = raw_values.astype(np.float32) - ped + else: + values = raw_values.astype(np.float32) + + return wire, time, values + + def read_event(self, idx): + """Read one event, return dict with plane-namespaced sparse arrays. + + Returns keys like: + plane.east_U.wire: (M,) int32 + plane.east_U.time: (M,) int32 + plane.east_U.value: (M,) float32 + """ + if not self._initted: + self.h5py_worker_init() + + f, event_key = self._locate_event(idx) + evt = f[event_key] + + data_dict = {} + for plane_label, pg in self._iter_planes(evt): + if self.planes != 'all' and plane_label not in self.planes: + continue + + wire, time, values = self._decode_plane(pg) + prefix = f'plane.{plane_label}' + data_dict[f'{prefix}.wire'] = wire + data_dict[f'{prefix}.time'] = time + data_dict[f'{prefix}.value'] = values + + return data_dict + + def __len__(self): + return int(self.cumulative_lengths[-1]) if len(self.cumulative_lengths) > 0 else 0 + + def close(self): + if self._initted: + for f in self._h5data: + try: + f.close() + except Exception: + pass + self._h5data = [] + self._initted = False diff --git a/pimm/datasets/readers/jaxtpc_seg_reader.py b/pimm/datasets/readers/jaxtpc_seg_reader.py new file mode 100644 index 0000000..f844696 --- /dev/null +++ b/pimm/datasets/readers/jaxtpc_seg_reader.py @@ -0,0 +1,281 @@ +""" +JAXTPCSegReader — reads 3D truth deposits from JAXTPC seg files. + +Produces raw geometry, physics, and IDs. Labels come from the labl file +via JAXTPCLablReader + JAXTPCDataset._apply_labl_to_3d(), or from +PDGToSemantic transform as a fallback. + +Output dict: + coord (N,3), energy (N,1), volume_id (N,1), + track_ids (N,), group_ids (N,), pdg (N,), interaction_ids (N,), + ancestor_track_ids (N,), + and optionally: dx, theta, phi, t0_us, charge, photons, qs_fractions +""" + +import os +import glob +import numpy as np +import h5py +from pimm.utils.logger import get_root_logger + + +class JAXTPCSegReader: + """Reads 3D truth deposits from JAXTPC seg HDF5 files. + + Concatenates volumes into a single point cloud with a volume_id feature. + No label computation — just raw data. + + Parameters + ---------- + data_root : str + Directory containing seg shard files. + split : str + Split name — used as subdirectory or glob pattern. + dataset_name : str + File prefix (e.g., 'sim' matches 'sim_seg_0000.h5'). + min_deposits : int + Minimum deposits per event to include in index. + include_physics : bool + Whether to load dx, theta, phi, charge, photons, etc. + """ + + def __init__(self, data_root, split='train', dataset_name='sim', + min_deposits=0, include_physics=True, volume=None): + self.data_root = data_root + self.split = split + self.dataset_name = dataset_name + self.min_deposits = min_deposits + self.include_physics = include_physics + self.volume = volume # None = all volumes, int = single volume + + self.h5_files = self._find_files() + assert len(self.h5_files) > 0, ( + f"No seg files found for '{dataset_name}' in {data_root}/{split}") + + self._initted = False + self._h5data = [] + + self._build_index() + + def _find_files(self): + """Discover seg shard files.""" + pattern = os.path.join( + self.data_root, self.split, + f'{self.dataset_name}_seg_*.h5') + files = sorted(glob.glob(pattern)) + if not files: + pattern = os.path.join( + self.data_root, f'{self.dataset_name}_seg_*.h5') + files = sorted(glob.glob(pattern)) + return files + + def _build_index(self): + """Scan files, count events, build cumulative index.""" + log = get_root_logger() + self.cumulative_lengths = [] + self.indices = [] + + for h5_path in self.h5_files: + try: + with h5py.File(h5_path, 'r', libver='latest', swmr=True) as f: + n_events = int(f['config'].attrs['n_events']) + n_volumes = int(f['config'].attrs.get('n_volumes', 1)) + + if self.min_deposits > 0: + valid = [] + for i in range(n_events): + evt_key = f'event_{i:03d}' + if evt_key not in f: + continue + evt = f[evt_key] + total = sum( + int(evt[f'volume_{v}'].attrs.get('n_actual', 0)) + for v in range(n_volumes) + if f'volume_{v}' in evt + ) if n_volumes > 1 else ( + evt['positions'].shape[0] if 'positions' in evt else 0 + ) + if total >= self.min_deposits: + valid.append(i) + index = np.array(valid, dtype=np.int64) + else: + index = np.arange(n_events, dtype=np.int64) + + except Exception as e: + log.warning(f"Error processing {h5_path}: {e}") + index = np.array([], dtype=np.int64) + + self.cumulative_lengths.append(len(index)) + self.indices.append(index) + + self.cumulative_lengths = np.cumsum(self.cumulative_lengths) + log.info(f"JAXTPCSegReader: {self.cumulative_lengths[-1]} events " + f"from {len(self.h5_files)} files " + f"(min_deposits={self.min_deposits})") + + def h5py_worker_init(self): + """Lazily open file handles (called after DataLoader fork).""" + self._h5data = [ + h5py.File(p, 'r', libver='latest', swmr=True) + for p in self.h5_files + ] + self._initted = True + + def _locate_event(self, idx): + """Map global index → (file_handle, event_key, n_volumes).""" + file_idx = int(np.searchsorted(self.cumulative_lengths, idx, side='right')) + local_idx = idx - (int(self.cumulative_lengths[file_idx - 1]) if file_idx > 0 else 0) + event_num = self.indices[file_idx][local_idx] + event_key = f'event_{event_num:03d}' + f = self._h5data[file_idx] + n_volumes = int(f['config'].attrs.get('n_volumes', 1)) + return f, event_key, n_volumes + + def read_event(self, idx): + """Read one event, return flat dict of numpy arrays. + + No label computation — just raw geometry, physics, and IDs. + """ + if not self._initted: + self.h5py_worker_init() + + f, event_key, n_volumes = self._locate_event(idx) + evt = f[event_key] + + vol_arrays = [] + + if n_volumes > 1: + for v in range(n_volumes): + if self.volume is not None and v != self.volume: + continue + vk = f'volume_{v}' + if vk not in evt: + continue + vg = evt[vk] + n = int(vg.attrs.get('n_actual', 0)) + if n == 0: + continue + vol_arrays.append(self._read_volume(vg, n, v)) + else: + # Legacy flat format + if 'positions' in evt: + n = evt['positions'].shape[0] + vol_arrays.append(self._read_volume_flat(evt, n, 0)) + + if not vol_arrays: + return self._empty_dict() + + return self._concat_volumes(vol_arrays) + + def _read_volume(self, vg, n, vol_idx): + """Read arrays from a volume group.""" + step = float(vg.attrs['pos_step_mm']) + origin = np.array([vg.attrs['pos_origin_x'], + vg.attrs['pos_origin_y'], + vg.attrs['pos_origin_z']], dtype=np.float32) + + d = { + 'coord': vg['positions'][:].astype(np.float32) * step + origin, + 'energy': vg['de'][:].astype(np.float32), + 'volume_id': np.full(n, vol_idx, dtype=np.int32), + 'track_ids': vg['track_ids'][:].astype(np.int32), + 'group_ids': vg['group_ids'][:].astype(np.int32), + } + + # Optional ID fields + for key, dtype in [('pdg', np.int32), ('interaction_ids', np.int32), + ('ancestor_track_ids', np.int32)]: + if key in vg: + d[key] = vg[key][:].astype(dtype) + else: + d[key] = np.full(n, -1, dtype=dtype) + + # Optional physics + if self.include_physics: + for key in ('dx', 'theta', 'phi', 't0_us'): + if key in vg: + d[key] = vg[key][:].astype(np.float32) + for key in ('charge', 'photons', 'qs_fractions'): + if key in vg: + d[key] = vg[key][:].astype(np.float32) + + return d + + def _read_volume_flat(self, evt, n, vol_idx): + """Read from legacy flat event format (no volume subgroups).""" + step = float(evt.attrs['pos_step_mm']) + origin = np.array([evt.attrs['pos_origin_x'], + evt.attrs['pos_origin_y'], + evt.attrs['pos_origin_z']], dtype=np.float32) + + d = { + 'coord': evt['positions'][:].astype(np.float32) * step + origin, + 'energy': evt['de'][:].astype(np.float32), + 'volume_id': np.full(n, vol_idx, dtype=np.int32), + 'track_ids': evt['track_ids'][:].astype(np.int32), + 'group_ids': evt['group_ids'][:].astype(np.int32), + } + + for key, dtype in [('pdg', np.int32), ('interaction_ids', np.int32), + ('ancestor_track_ids', np.int32)]: + if key in evt: + d[key] = evt[key][:].astype(dtype) + else: + d[key] = np.full(n, -1, dtype=dtype) + + if self.include_physics: + for key in ('dx', 'theta', 'phi', 't0_us'): + if key in evt: + d[key] = evt[key][:].astype(np.float32) + for key in ('charge', 'photons', 'qs_fractions'): + if key in evt: + d[key] = evt[key][:].astype(np.float32) + + return d + + def _concat_volumes(self, vol_arrays): + """Concatenate per-volume dicts into a single flat dict.""" + keys = vol_arrays[0].keys() + data_dict = {} + for k in keys: + arrays = [v[k] for v in vol_arrays if k in v] + combined = np.concatenate(arrays, axis=0) + if k == 'coord': + data_dict[k] = combined # already (N,3) float32 + elif k in ('energy', 'dx', 'theta', 'phi', 't0_us', + 'charge', 'photons', 'qs_fractions'): + data_dict[k] = combined[:, None] # (N,1) + elif k == 'volume_id': + data_dict[k] = combined[:, None] # (N,1) + else: + data_dict[k] = combined # (N,) for IDs + + return data_dict + + def _empty_dict(self): + """Minimal valid dict for empty events.""" + d = { + 'coord': np.zeros((0, 3), dtype=np.float32), + 'energy': np.zeros((0, 1), dtype=np.float32), + 'volume_id': np.zeros((0, 1), dtype=np.int32), + 'track_ids': np.zeros((0,), dtype=np.int32), + 'group_ids': np.zeros((0,), dtype=np.int32), + 'pdg': np.zeros((0,), dtype=np.int32), + 'interaction_ids': np.zeros((0,), dtype=np.int32), + 'ancestor_track_ids': np.zeros((0,), dtype=np.int32), + } + return d + + def __len__(self): + return int(self.cumulative_lengths[-1]) if len(self.cumulative_lengths) > 0 else 0 + + def close(self): + """Close open file handles.""" + if self._initted: + for f in self._h5data: + try: + f.close() + except Exception: + pass + self._h5data = [] + self._initted = False diff --git a/pimm/datasets/readers/lucid_seg_reader.py b/pimm/datasets/readers/lucid_seg_reader.py new file mode 100644 index 0000000..aa526fc --- /dev/null +++ b/pimm/datasets/readers/lucid_seg_reader.py @@ -0,0 +1,201 @@ +""" +LUCiDSegReader — reads 3D track segments from Water Cherenkov segment files. + +Format: flat CSR arrays (no per-event groups). + - track_event_offset (n_events+1,) — CSR into track arrays + - segment_offset (n_tracks+1,) — CSR into segment arrays + - start_x/y/z, end_x/y/z, edep, time (total_segments,) — per-segment + - track_id, pdg, parent_id, initial_energy (n_tracks,) — per-track + +Output dict: + coord (N,3), energy (N,1), time (N,1), + track_ids (N,), pdg (N,), parent_ids (N,) +""" + +import os +import glob +import numpy as np +import h5py +from pimm.utils.logger import get_root_logger + + +class LUCiDSegReader: + """Reads 3D track segments from WC segment HDF5 files. + + Parameters + ---------- + data_root : str + Directory containing segment shard files. + split : str + Split name. + dataset_name : str + File prefix — matches files like '{dataset_name}_*segment*_*.h5' + or 'segment_events_*.h5'. + min_segments : int + Minimum segments per event to include. + """ + + def __init__(self, data_root, split='', dataset_name='wc', + min_segments=0, **kwargs): + self.data_root = data_root + self.split = split + self.dataset_name = dataset_name + self.min_segments = min_segments + + self.h5_files = self._find_files() + assert len(self.h5_files) > 0, ( + f"No WC seg files found in {data_root}/{split}") + + self._initted = False + self._h5data = [] + self._build_index() + + def _find_files(self): + """Find segment HDF5 files. Tries multiple naming patterns.""" + for pattern in [ + os.path.join(self.data_root, self.split, f'{self.dataset_name}_seg_*.h5'), + os.path.join(self.data_root, f'{self.dataset_name}_seg_*.h5'), + os.path.join(self.data_root, self.split, 'segment_events_*.h5'), + os.path.join(self.data_root, 'segment_events_*.h5'), + os.path.join(self.data_root, self.split, '*segment*.h5'), + os.path.join(self.data_root, '*segment*.h5'), + ]: + files = sorted(glob.glob(pattern)) + if files: + return files + return [] + + def _build_index(self): + log = get_root_logger() + self.cumulative_lengths = [] + self.indices = [] + self._file_n_events = [] + + for h5_path in self.h5_files: + try: + with h5py.File(h5_path, 'r', libver='latest', swmr=True) as f: + n_events = int(f.attrs.get('n_events', 0)) + if n_events == 0 and 'event_number' in f: + n_events = f['event_number'].shape[0] + + if self.min_segments > 0 and 'segment_offset' in f and 'track_event_offset' in f: + track_offsets = f['track_event_offset'][:] + seg_offsets = f['segment_offset'][:] + valid = [] + for i in range(n_events): + t0 = track_offsets[i] + t1 = track_offsets[i + 1] + if t1 > t0: + n_seg = int(seg_offsets[t1] - seg_offsets[t0]) + else: + n_seg = 0 + if n_seg >= self.min_segments: + valid.append(i) + index = np.array(valid, dtype=np.int64) + else: + index = np.arange(n_events, dtype=np.int64) + self._file_n_events.append(n_events) + except Exception as e: + log.warning(f"Error processing {h5_path}: {e}") + index = np.array([], dtype=np.int64) + self._file_n_events.append(0) + + self.cumulative_lengths.append(len(index)) + self.indices.append(index) + + self.cumulative_lengths = np.cumsum(self.cumulative_lengths) + log.info(f"LUCiDSegReader: {self.cumulative_lengths[-1]} events " + f"from {len(self.h5_files)} files") + + def h5py_worker_init(self): + self._h5data = [ + h5py.File(p, 'r', libver='latest', swmr=True) + for p in self.h5_files + ] + self._initted = True + + def _locate_event(self, idx): + file_idx = int(np.searchsorted(self.cumulative_lengths, idx, side='right')) + local_idx = idx - (int(self.cumulative_lengths[file_idx - 1]) if file_idx > 0 else 0) + event_num = self.indices[file_idx][local_idx] + return self._h5data[file_idx], event_num + + def read_event(self, idx): + if not self._initted: + self.h5py_worker_init() + + f, event_num = self._locate_event(idx) + + # Track range for this event + track_offsets = f['track_event_offset'] + t0 = int(track_offsets[event_num]) + t1 = int(track_offsets[event_num + 1]) + n_tracks = t1 - t0 + + if n_tracks == 0: + return self._empty_dict() + + # Segment range for these tracks + seg_offsets = f['segment_offset'] + s0 = int(seg_offsets[t0]) + s1 = int(seg_offsets[t1]) + n_seg = s1 - s0 + + if n_seg == 0: + return self._empty_dict() + + # Segment data + sx = f['start_x'][s0:s1] + sy = f['start_y'][s0:s1] + sz = f['start_z'][s0:s1] + ex = f['end_x'][s0:s1] + ey = f['end_y'][s0:s1] + ez = f['end_z'][s0:s1] + + mid_x = (sx + ex) / 2 + mid_y = (sy + ey) / 2 + mid_z = (sz + ez) / 2 + + # Per-track data, expanded to per-segment + track_ids = f['track_id'][t0:t1].astype(np.int32) + pdg = f['pdg'][t0:t1].astype(np.int32) + parent_ids = f['parent_id'][t0:t1].astype(np.int32) + + # Number of segments per track (from segment_offset) + n_segs_per_track = np.diff(seg_offsets[t0:t1 + 1]).astype(np.int32) + + seg_track_ids = np.repeat(track_ids, n_segs_per_track) + seg_pdg = np.repeat(pdg, n_segs_per_track) + seg_parent_ids = np.repeat(parent_ids, n_segs_per_track) + + return { + 'coord': np.stack([mid_x, mid_y, mid_z], axis=1).astype(np.float32), + 'energy': f['edep'][s0:s1].astype(np.float32)[:, None], + 'time': f['time'][s0:s1].astype(np.float32)[:, None], + 'track_ids': seg_track_ids, + 'pdg': seg_pdg, + 'parent_ids': seg_parent_ids, + } + + def _empty_dict(self): + return { + 'coord': np.zeros((0, 3), dtype=np.float32), + 'energy': np.zeros((0, 1), dtype=np.float32), + 'time': np.zeros((0, 1), dtype=np.float32), + 'track_ids': np.zeros((0,), dtype=np.int32), + 'pdg': np.zeros((0,), dtype=np.int32), + 'parent_ids': np.zeros((0,), dtype=np.int32), + } + + def __len__(self): + return int(self.cumulative_lengths[-1]) if len(self.cumulative_lengths) > 0 else 0 + + def close(self): + if self._initted: + for fh in self._h5data: + try: + fh.close() + except Exception: + pass + self._h5data = [] + self._initted = False diff --git a/pimm/datasets/readers/lucid_sensor_reader.py b/pimm/datasets/readers/lucid_sensor_reader.py new file mode 100644 index 0000000..18d767b --- /dev/null +++ b/pimm/datasets/readers/lucid_sensor_reader.py @@ -0,0 +1,223 @@ +""" +LUCiDSensorReader — reads PMT sensor data from Water Cherenkov sensor files. + +Format: flat CSR arrays (no per-event groups). + - event_hit_offsets (n_events+1,) — CSR into hit arrays + - event_hit_sensor_idx, event_hit_PE, event_hit_T (total_hits,) — per-hit + - particle_event_offset (n_events+1,) — CSR into particle arrays + - particle_hit_offsets (n_particles+1,) — CSR into per-particle hits + - particle_hit_sensor_idx, particle_hit_PE, particle_hit_T (total_pp_hits,) + - particle_category (n_particles,) — semantic labels + +PMT 3D positions are NOT stored in the file — must be provided separately +(via pmt_positions_file or pmt_positions array). + +Output (sensor response): + pmt_pe (N_sensors,), pmt_t (N_sensors,) + If pmt_positions provided: pmt_coord (N_sensors, 3) + +Output (per-particle sparse, when include_labels=True): + pp_sensor_idx (E,), pp_particle_idx (E,), pp_pe (E,), + pp_t (E,), pp_category (E,) +""" + +import os +import glob +import numpy as np +import h5py +from pimm.utils.logger import get_root_logger + + +class LUCiDSensorReader: + """Reads PMT sensor data from WC sensor HDF5 files. + + Parameters + ---------- + data_root : str + Directory containing sensor shard files. + split : str + Split name. + dataset_name : str + File prefix. + include_labels : bool + Whether to load per-particle hit decomposition. + pe_threshold : float + Minimum PE for per-particle hits (already sparse in file, this + applies additional filtering). + pmt_positions : ndarray or None + (N_sensors, 3) PMT positions. If None, coord won't be produced + (sensor_idx is still available). + pmt_positions_file : str or None + Path to .npy file with PMT positions. Alternative to pmt_positions. + """ + + def __init__(self, data_root, split='', dataset_name='wc', + include_labels=True, pe_threshold=0.0, + pmt_positions=None, pmt_positions_file=None, **kwargs): + self.data_root = data_root + self.split = split + self.dataset_name = dataset_name + self.include_labels = include_labels + self.pe_threshold = pe_threshold + + # PMT positions (optional) + if pmt_positions is not None: + self._pmt_positions = np.asarray(pmt_positions, dtype=np.float32) + elif pmt_positions_file is not None: + self._pmt_positions = np.load(pmt_positions_file).astype(np.float32) + else: + self._pmt_positions = None + + self.h5_files = self._find_files() + assert len(self.h5_files) > 0, ( + f"No WC sensor files found in {data_root}/{split}") + + self._initted = False + self._h5data = [] + self._n_sensors = None + self._build_index() + + def _find_files(self): + for pattern in [ + os.path.join(self.data_root, self.split, f'{self.dataset_name}_sensor_*.h5'), + os.path.join(self.data_root, f'{self.dataset_name}_sensor_*.h5'), + os.path.join(self.data_root, self.split, 'sensor_events_*.h5'), + os.path.join(self.data_root, 'sensor_events_*.h5'), + os.path.join(self.data_root, self.split, '*sensor*.h5'), + os.path.join(self.data_root, '*sensor*.h5'), + ]: + files = sorted(glob.glob(pattern)) + if files: + return files + return [] + + def _build_index(self): + log = get_root_logger() + self.cumulative_lengths = [] + self.indices = [] + + for h5_path in self.h5_files: + try: + with h5py.File(h5_path, 'r', libver='latest', swmr=True) as f: + n_events = int(f.attrs.get('n_events', 0)) + if n_events == 0 and 'event_number' in f: + n_events = f['event_number'].shape[0] + self._n_sensors = int(f.attrs.get('n_sensors', 0)) + index = np.arange(n_events, dtype=np.int64) + except Exception as e: + log.warning(f"Error processing {h5_path}: {e}") + index = np.array([], dtype=np.int64) + + self.cumulative_lengths.append(len(index)) + self.indices.append(index) + + self.cumulative_lengths = np.cumsum(self.cumulative_lengths) + log.info(f"LUCiDSensorReader: {self.cumulative_lengths[-1]} events, " + f"{self._n_sensors} sensors from {len(self.h5_files)} files") + + def h5py_worker_init(self): + self._h5data = [ + h5py.File(p, 'r', libver='latest', swmr=True) + for p in self.h5_files + ] + # Try to load PMT positions from file config if not provided + if self._pmt_positions is None: + f = self._h5data[0] + if 'config' in f and 'pmt_positions' in f['config']: + self._pmt_positions = f['config']['pmt_positions'][:].astype(np.float32) + self._initted = True + + def _locate_event(self, idx): + file_idx = int(np.searchsorted(self.cumulative_lengths, idx, side='right')) + local_idx = idx - (int(self.cumulative_lengths[file_idx - 1]) if file_idx > 0 else 0) + event_num = self.indices[file_idx][local_idx] + return self._h5data[file_idx], event_num + + def read_event(self, idx): + if not self._initted: + self.h5py_worker_init() + + f, event_num = self._locate_event(idx) + data_dict = {} + + # --- Event-level sensor hits --- + hit_offsets = f['event_hit_offsets'] + h0 = int(hit_offsets[event_num]) + h1 = int(hit_offsets[event_num + 1]) + + sensor_idx = f['event_hit_sensor_idx'][h0:h1].astype(np.int32) + pe = f['event_hit_PE'][h0:h1].astype(np.float32) + t = f['event_hit_T'][h0:h1].astype(np.float32) + + # Build dense per-sensor arrays (sum PE, min T for sensors with hits) + n_sensors = self._n_sensors + pmt_pe = np.zeros(n_sensors, dtype=np.float32) + pmt_t = np.full(n_sensors, -1.0, dtype=np.float32) + np.add.at(pmt_pe, sensor_idx, pe) + # First-hit time: use scatter-min + for i in range(len(sensor_idx)): + s = sensor_idx[i] + if pmt_t[s] < 0 or t[i] < pmt_t[s]: + pmt_t[s] = t[i] + + data_dict['pmt_pe'] = pmt_pe + data_dict['pmt_t'] = pmt_t + + if self._pmt_positions is not None: + data_dict['pmt_coord'] = self._pmt_positions.copy() + + # --- Per-particle hits (sparse) --- + if self.include_labels and 'particle_hit_offsets' in f: + # Which particles belong to this event + p_offsets = f['particle_event_offset'] + p0 = int(p_offsets[event_num]) + p1 = int(p_offsets[event_num + 1]) + n_particles = p1 - p0 + + if n_particles > 0: + categories = f['particle_category'][p0:p1].astype(np.int32) + + # Per-particle hit ranges + pp_hit_offsets = f['particle_hit_offsets'] + pp_h0 = int(pp_hit_offsets[p0]) + pp_h1 = int(pp_hit_offsets[p1]) + + pp_sensor = f['particle_hit_sensor_idx'][pp_h0:pp_h1].astype(np.int32) + pp_pe = f['particle_hit_PE'][pp_h0:pp_h1].astype(np.float32) + pp_t = f['particle_hit_T'][pp_h0:pp_h1].astype(np.float32) + + # Build particle_idx for each hit + hits_per_particle = np.diff(pp_hit_offsets[p0:p1 + 1]).astype(np.int32) + pp_particle_idx = np.repeat(np.arange(n_particles, dtype=np.int32), + hits_per_particle) + pp_category = np.repeat(categories, hits_per_particle) + + # Optional threshold filter + if self.pe_threshold > 0: + mask = pp_pe > self.pe_threshold + pp_sensor = pp_sensor[mask] + pp_pe = pp_pe[mask] + pp_t = pp_t[mask] + pp_particle_idx = pp_particle_idx[mask] + pp_category = pp_category[mask] + + data_dict['pp_sensor_idx'] = pp_sensor + data_dict['pp_particle_idx'] = pp_particle_idx + data_dict['pp_pe'] = pp_pe + data_dict['pp_t'] = pp_t + data_dict['pp_category'] = pp_category + + return data_dict + + def __len__(self): + return int(self.cumulative_lengths[-1]) if len(self.cumulative_lengths) > 0 else 0 + + def close(self): + if self._initted: + for fh in self._h5data: + try: + fh.close() + except Exception: + pass + self._h5data = [] + self._initted = False diff --git a/pimm/datasets/transform.py b/pimm/datasets/transform.py index 7b5ebcf..08797d9 100644 --- a/pimm/datasets/transform.py +++ b/pimm/datasets/transform.py @@ -53,6 +53,21 @@ def index_operator(data_dict, index, duplicate=False): "instance_particle", "instance_interaction", "momentum", + # JAXTPCDataset keys (JAXTPC) + "track_ids", + "group_ids", + "pdg", + "volume_id", + "interaction_ids", + "ancestor_track_ids", + "charge", + "photons", + "qs_fractions", + "t0_us", + "dx", + "theta", + "phi", + "segment_interaction", ] if not duplicate: for key in data_dict["index_valid_keys"]: diff --git a/pimm/datasets/utils.py b/pimm/datasets/utils.py index b717eb4..3f7ea45 100644 --- a/pimm/datasets/utils.py +++ b/pimm/datasets/utils.py @@ -43,6 +43,7 @@ def collate_fn(batch, mix_prob=0): ) ) for key in batch[0] + if not key.startswith("_") # skip non-tensor metadata } return batch else: diff --git a/tests/test_jaxtpc_dataset.py b/tests/test_jaxtpc_dataset.py new file mode 100644 index 0000000..72e3637 --- /dev/null +++ b/tests/test_jaxtpc_dataset.py @@ -0,0 +1,231 @@ +""" +Verification script for JAXTPCDataset — all modality combinations. + +Run: /usr/bin/python3 tests/test_jaxtpc_dataset.py +""" + +import sys +import os +sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) + +import numpy as np +import torch +import torch.nn as nn +from pimm.datasets.jaxtpc_dataset import JAXTPCDataset +from pimm.datasets.utils import collate_fn +from pimm.datasets.transform import Compose + +DATA_ROOT = os.environ.get( + 'JAXTPC_DATA_ROOT', + '/home/oalterka/desktop_linux/JAXTPC/dataset_1') + +MAX_LEN = 4 +PASSED = 0 +FAILED = 0 + + +def check(condition, msg): + global PASSED, FAILED + if condition: + print(f" OK: {msg}") + PASSED += 1 + else: + print(f" FAIL: {msg}") + FAILED += 1 + + +def make_ds(**kwargs): + defaults = dict(data_root=DATA_ROOT, split='', dataset_name='sim', max_len=MAX_LEN) + defaults.update(kwargs) + return JAXTPCDataset(**defaults) + + +def test_seg_only(): + """seg only — 3D point cloud, no labels.""" + print("\n=== seg only ===") + ds = make_ds(modalities=('seg',)) + d = ds.get_data(0) + check(d['coord'].shape[1] == 3, f"coord 3D: {d['coord'].shape}") + check(d['energy'].shape[1] == 1, f"energy: {d['energy'].shape}") + check('segment' not in d, "no segment without labl") + + +def test_seg_labl(): + """seg + labl — 3D with labels from lookup.""" + print("\n=== seg + labl ===") + ds = make_ds(modalities=('seg', 'labl'), label_key='particle') + d = ds.get_data(0) + check(d['coord'].shape[1] == 3, f"coord 3D: {d['coord'].shape}") + check('segment' in d, "segment present") + check(d['segment'].shape[0] == d['coord'].shape[0], "segment matches coord") + + +def test_resp_only(): + """resp only — all planes merged into 2D point cloud, no labels.""" + print("\n=== resp only ===") + ds = make_ds(modalities=('resp',)) + d = ds.get_data(0) + check(d['coord'].shape[1] == 2, f"coord 2D: {d['coord'].shape}") + check('plane_id' in d, "plane_id present") + check('segment' not in d, "no segment") + n_planes = len(np.unique(d['plane_id'])) + check(n_planes > 1, f"multiple planes: {n_planes}") + + +def test_resp_corr_labl(): + """resp + corr + labl — 2D labeled point cloud from corr chain.""" + print("\n=== resp + corr + labl ===") + ds = make_ds(modalities=('resp', 'corr', 'labl'), label_key='particle') + d = ds.get_data(0) + check(d['coord'].shape[1] == 2, f"coord 2D: {d['coord'].shape}") + check('segment' in d, "segment present") + check('instance' in d, "instance present") + check('plane_id' in d, "plane_id present") + # Resp signal also available as namespaced keys + resp_keys = [k for k in d if k.startswith('plane.')] + check(len(resp_keys) > 0, f"resp namespaced keys: {len(resp_keys)}") + # Overlapping instances + _, counts = np.unique(d['coord'], axis=0, return_counts=True) + check(np.sum(counts > 1) > 0, f"overlapping pixels: {np.sum(counts > 1)}") + + +def test_seg_resp_corr_labl(): + """All modalities — seg owns coord, resp/corr as separate point clouds.""" + print("\n=== seg + resp + corr + labl ===") + ds = make_ds(modalities=('seg', 'resp', 'corr', 'labl'), label_key='particle') + d = ds.get_data(0) + check(d['coord'].shape[1] == 3, f"3D coord: {d['coord'].shape}") + check('segment' in d, "3D segment from labl") + # Resp as separate point cloud + check('resp_coord' in d, f"resp_coord present: {d.get('resp_coord', 'MISSING')}") + check(d['resp_coord'].shape[1] == 2, f"resp_coord 2D: {d['resp_coord'].shape}") + # Corr as separate point cloud + check('corr_coord' in d, "corr_coord present") + check('corr_segment' in d, "corr_segment present") + check('corr_instance' in d, "corr_instance present") + # Raw plane keys also available + plane_keys = [k for k in d if k.startswith('plane.')] + check(len(plane_keys) > 0, f"raw plane keys: {len(plane_keys)}") + + +def test_resp_corr(): + """resp + corr (no labl) — resp merged, corr namespaced.""" + print("\n=== resp + corr (no labl) ===") + ds = make_ds(modalities=('resp', 'corr')) + d = ds.get_data(0) + check(d['coord'].shape[1] == 2, f"coord 2D from resp: {d['coord'].shape}") + check('segment' not in d, "no segment without labl") + corr_keys = [k for k in d if k.startswith('corr.')] + check(len(corr_keys) > 0, f"corr namespaced: {len(corr_keys)}") + + +def test_volume_filter(): + """volume=0 — only volume 0 data (fewer points than all volumes).""" + print("\n=== volume filter ===") + ds_all = make_ds(modalities=('resp',)) + ds_v0 = make_ds(modalities=('resp',), volume=0) + d_all = ds_all.get_data(0) + d_v0 = ds_v0.get_data(0) + check(d_v0['coord'].shape[0] < d_all['coord'].shape[0], + f"volume_0 ({d_v0['coord'].shape[0]}) < all ({d_all['coord'].shape[0]})") + + +def test_different_label_keys(): + """All label_key options.""" + print("\n=== different label_keys ===") + for lk in ['particle', 'cluster', 'interaction']: + ds = make_ds(modalities=('seg', 'labl'), label_key=lk) + d = ds.get_data(0) + n = len(np.unique(d['segment'])) + check(n > 1, f"label_key={lk}: {n} classes") + + +def test_pipeline_3d(): + """Full 3D pipeline: transforms → collate.""" + print("\n=== 3D pipeline ===") + transform = [ + dict(type='NormalizeCoord', center=[0, 0, 0], scale=4000.0), + dict(type='GridSample', grid_size=0.001, hash_type='fnv', + mode='train', return_grid_coord=True), + dict(type='ToTensor'), + dict(type='Collect', keys=('coord', 'grid_coord', 'segment'), + feat_keys=('coord', 'energy')), + ] + ds = make_ds(modalities=('seg', 'labl'), label_key='particle', + min_deposits=1024, transform=transform) + batch = collate_fn([ds[0], ds[1]]) + check(batch['coord'].shape[1] == 3, f"3D: {batch['coord'].shape}") + check(len(batch['offset']) == 2, "offset correct") + + +def test_pipeline_2d(): + """Full 2D pipeline: transforms → collate → DataLoader.""" + print("\n=== 2D pipeline ===") + transform = [ + dict(type='GridSample', grid_size=1.0, hash_type='fnv', + mode='train', return_grid_coord=True), + dict(type='ToTensor'), + dict(type='Collect', keys=('coord', 'grid_coord', 'segment', 'instance'), + feat_keys=('coord', 'energy')), + ] + ds = make_ds(modalities=('resp', 'corr', 'labl'), + label_key='particle', transform=transform) + batch = collate_fn([ds[0], ds[1]]) + check(batch['coord'].shape[1] == 2, f"2D: {batch['coord'].shape}") + check(len(batch['offset']) == 2, "offset correct") + + # DataLoader + loader = torch.utils.data.DataLoader( + ds, batch_size=2, shuffle=False, num_workers=2, + collate_fn=collate_fn, persistent_workers=False) + for i, b in enumerate(loader): + if i >= 1: + break + check(b['coord'].shape[1] == 2, f"DataLoader: {b['coord'].shape}") + + +def test_toy_model(): + """Toy model forward+backward.""" + print("\n=== toy model ===") + transform = [ + dict(type='GridSample', grid_size=1.0, hash_type='fnv', + mode='train', return_grid_coord=True), + dict(type='ToTensor'), + dict(type='Collect', keys=('coord', 'grid_coord', 'segment'), + feat_keys=('coord', 'energy')), + ] + ds = make_ds(modalities=('resp', 'corr', 'labl'), + label_key='particle', transform=transform) + device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') + batch = collate_fn([ds[0], ds[1]]) + for k in batch: + if isinstance(batch[k], torch.Tensor): + batch[k] = batch[k].to(device) + model = nn.Linear(batch['feat'].shape[1], 5).to(device) + logits = model(batch['feat']) + loss = nn.CrossEntropyLoss(ignore_index=-1)(logits, batch['segment'].long()) + loss.backward() + check(logits.shape[1] == 5, f"logits: {logits.shape}") + check(model.weight.grad is not None, "gradients computed") + + +if __name__ == '__main__': + print(f"Testing JAXTPCDataset\nData root: {DATA_ROOT}") + + test_seg_only() + test_seg_labl() + test_resp_only() + test_resp_corr_labl() + test_seg_resp_corr_labl() + test_resp_corr() + test_volume_filter() + test_different_label_keys() + test_pipeline_3d() + test_pipeline_2d() + test_toy_model() + + print(f"\n{'='*50}") + print(f"PASSED: {PASSED}, FAILED: {FAILED}") + if FAILED > 0: + sys.exit(1) + print("ALL TESTS PASSED") diff --git a/tests/test_lucid_dataset.py b/tests/test_lucid_dataset.py new file mode 100644 index 0000000..7c3c3bd --- /dev/null +++ b/tests/test_lucid_dataset.py @@ -0,0 +1,197 @@ +""" +Verification script for LUCiDDataset — all output modes. + +Run: /usr/bin/python3 tests/test_lucid_dataset.py +""" + +import sys +import os +sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) + +import numpy as np +import torch +import torch.nn as nn +from pimm.datasets.lucid_dataset import LUCiDDataset +from pimm.datasets.utils import collate_fn +from pimm.datasets.transform import Compose + +DATA_ROOT = '/home/oalterka/desktop_linux/JAXTPC/dataset_wc' +PASSED = 0 +FAILED = 0 + + +def check(condition, msg): + global PASSED, FAILED + if condition: + print(f" OK: {msg}") + PASSED += 1 + else: + print(f" FAIL: {msg}") + FAILED += 1 + + +def make_ds(**kwargs): + defaults = dict(data_root=DATA_ROOT, split='', dataset_name='wc', max_len=4) + defaults.update(kwargs) + return LUCiDDataset(**defaults) + + +def test_sensor_response(): + """Sensor response — one entry per sensor.""" + print("\n=== Sensor response ===") + ds = make_ds(modalities=('sensor',), output_mode='response', include_labels=False) + d = ds.get_data(0) + check('coord' in d, "coord present") + check('energy' in d, f"energy: {d['energy'].shape}") + check('time' in d, "time present") + check('segment' not in d, "no segment (SSL)") + n = d['coord'].shape[0] + check(n > 1000, f"many sensors: {n}") + + +def test_sensor_labels(): + """Sensor with labels — sparse per-particle entries.""" + print("\n=== Sensor labels ===") + ds = make_ds(modalities=('sensor',), output_mode='labels', include_labels=True) + d = ds.get_data(0) + check('coord' in d, "coord present") + check('segment' in d, "segment present") + check('instance' in d, "instance present") + n = d['coord'].shape[0] + check(n > 0, f"sparse entries: {n}") + n_inst = len(np.unique(d['instance'])) + check(n_inst > 1, f"multiple instances: {n_inst}") + cats = np.unique(d['segment']) + check(len(cats) >= 1, f"categories: {cats}") + + +def test_sensor_separate(): + """Sensor separate — raw reader keys.""" + print("\n=== Sensor separate ===") + ds = make_ds(modalities=('sensor',), output_mode='separate') + d = ds.get_data(0) + check('pmt_pe' in d, "pmt_pe present") + check('pmt_t' in d, "pmt_t present") + check('pp_sensor_idx' in d, "pp_sensor_idx present") + check('pp_category' in d, "pp_category present") + check('coord' not in d, "no top-level coord") + + +def test_seg_only(): + """3D track segments.""" + print("\n=== Seg only ===") + ds = make_ds(modalities=('seg',)) + d = ds.get_data(0) + check(d['coord'].shape[1] == 3, f"coord 3D: {d['coord'].shape}") + check(d['energy'].shape[1] == 1, f"energy: {d['energy'].shape}") + check('track_ids' in d, "track_ids present") + check('pdg' in d, "pdg present") + + +def test_mixed_separate(): + """Seg + sensor separate.""" + print("\n=== Mixed separate ===") + ds = make_ds(modalities=('seg', 'sensor'), output_mode='separate') + d = ds.get_data(0) + seg_keys = [k for k in d if k.startswith('seg3d.')] + check(len(seg_keys) > 0, f"seg3d keys: {len(seg_keys)}") + check('pmt_pe' in d, "pmt_pe present") + + +def test_pipeline_response(): + """Pipeline: response → transforms → collate.""" + print("\n=== Pipeline response ===") + transform = [ + dict(type='ToTensor'), + dict(type='Collect', keys=('coord',), feat_keys=('coord', 'energy', 'time')), + ] + ds = make_ds(modalities=('sensor',), output_mode='response', + include_labels=False, transform=transform) + s0 = ds[0] + coord_dim = s0['coord'].shape[1] + feat_dim = s0['feat'].shape[1] + check(feat_dim == coord_dim + 2, f"feat={s0['feat'].shape} (coord_dim+2)") + check('offset' in s0, "offset present") + + batch = collate_fn([ds[0], ds[1]]) + n0 = ds.get_data(0)['coord'].shape[0] + check(batch['coord'].shape[0] > n0, f"batch: {batch['coord'].shape}") + check(len(batch['offset']) == 2, f"offset: {batch['offset'].tolist()}") + + +def test_pipeline_labels(): + """Pipeline: labels → transforms → collate → toy model.""" + print("\n=== Pipeline labels ===") + transform = [ + dict(type='ToTensor'), + dict(type='Collect', keys=('coord', 'segment', 'instance'), + feat_keys=('coord', 'energy')), + ] + ds = make_ds(modalities=('sensor',), output_mode='labels', + include_labels=True, transform=transform) + + batch = collate_fn([ds[0], ds[1]]) + check('segment' in batch, "segment in batch") + check(len(batch['offset']) == 2, "offset correct") + + device = torch.device('cpu') + for k in batch: + if isinstance(batch[k], torch.Tensor): + batch[k] = batch[k].to(device) + n_classes = len(torch.unique(batch['segment'])) + model = nn.Linear(batch['feat'].shape[1], max(4, n_classes)).to(device) + logits = model(batch['feat']) + loss = nn.CrossEntropyLoss(ignore_index=-1)(logits, batch['segment'].long()) + loss.backward() + check(model.weight.grad is not None, "gradients computed") + + +def test_pipeline_seg(): + """Pipeline: 3D segments → transforms → collate.""" + print("\n=== Pipeline seg ===") + transform = [ + dict(type='ToTensor'), + dict(type='Collect', keys=('coord',), feat_keys=('coord', 'energy')), + ] + ds = make_ds(modalities=('seg',), transform=transform) + batch = collate_fn([ds[0], ds[1]]) + check(batch['coord'].shape[1] == 3, f"3D: {batch['coord'].shape}") + check(len(batch['offset']) == 2, "offset correct") + + +def test_dataloader(): + """DataLoader with workers.""" + print("\n=== DataLoader ===") + transform = [ + dict(type='ToTensor'), + dict(type='Collect', keys=('coord',), feat_keys=('coord', 'energy')), + ] + ds = make_ds(modalities=('sensor',), output_mode='response', + include_labels=False, transform=transform) + loader = torch.utils.data.DataLoader( + ds, batch_size=2, shuffle=False, num_workers=2, + collate_fn=collate_fn, persistent_workers=False) + for i, batch in enumerate(loader): + if i >= 1: + break + check(batch['coord'].shape[0] > 0, f"DataLoader batch: {batch['coord'].shape}") + + +if __name__ == '__main__': + print(f"Testing LUCiDDataset\nData root: {DATA_ROOT}") + + test_sensor_response() + test_sensor_labels() + test_sensor_separate() + test_seg_only() + test_mixed_separate() + test_pipeline_response() + test_pipeline_labels() + test_pipeline_seg() + test_dataloader() + + print(f"\n{'='*50}") + print(f"PASSED: {PASSED}, FAILED: {FAILED}") + if FAILED > 0: + sys.exit(1) + print("ALL TESTS PASSED")