From 148859712cda4ac46f27a4b0b0249e7ef2ead1f3 Mon Sep 17 00:00:00 2001 From: Alexander Lyashuk Date: Mon, 6 Oct 2025 22:00:55 +0200 Subject: [PATCH] Add debug chunk dump to dataloader probe --- src/lczero_training/training/__main__.py | 6 ++ .../training/dataloader_probe.py | 59 ++++++++++++++++++- 2 files changed, 62 insertions(+), 3 deletions(-) diff --git a/src/lczero_training/training/__main__.py b/src/lczero_training/training/__main__.py index b188bb57..cf305024 100644 --- a/src/lczero_training/training/__main__.py +++ b/src/lczero_training/training/__main__.py @@ -201,6 +201,11 @@ def configure_parser(parser: argparse.ArgumentParser) -> None: default=10, help="Number of batches to fetch from the data loader.", ) + dataloader_parser.add_argument( + "--debug-chunk-file", + type=str, + help="Optional path to write debug chunk metadata for each batch.", + ) dataloader_parser.set_defaults(func=run) @@ -249,6 +254,7 @@ def run(args: argparse.Namespace) -> None: probe_dataloader( config_filename=args.config, num_batches=args.num_batches, + debug_chunk_file=getattr(args, "debug_chunk_file", None), ) diff --git a/src/lczero_training/training/dataloader_probe.py b/src/lczero_training/training/dataloader_probe.py index 4f51160a..d253cde0 100644 --- a/src/lczero_training/training/dataloader_probe.py +++ b/src/lczero_training/training/dataloader_probe.py @@ -2,8 +2,11 @@ import logging import time +from collections.abc import Sequence from contextlib import suppress +from typing import List, Optional, TextIO, Tuple +import numpy as np from google.protobuf import text_format from lczero_training.dataloader import DataLoader, make_dataloader @@ -12,17 +15,55 @@ logger = logging.getLogger(__name__) +_BIT_ORDER = (np.arange(64, dtype=np.uint64) ^ 7).reshape(1, 1, 64) + + def _stop_loader(loader: DataLoader) -> None: with suppress(Exception): loader.stop() -def probe_dataloader(config_filename: str, num_batches: int) -> None: +def _extract_debug_chunk_info( + batch: Sequence[np.ndarray], +) -> List[Tuple[int, int, int]]: + """Decode debug chunk metadata from the first three planes of a batch.""" + + if not batch: + return [] + + inputs = batch[0] + if inputs.ndim < 4 or inputs.shape[1] < 3: + return [] + + planes = np.asarray(inputs[:, :3, :, :], dtype=np.float32) + plane_bits = planes.reshape(planes.shape[0], 3, 64) + bits = np.rint(plane_bits).astype(np.uint64) & 1 + weighted = np.left_shift(bits, _BIT_ORDER) + decoded = np.asarray(np.sum(weighted, axis=-1, dtype=np.uint64)) + return [ + (int(sample[0]), int(sample[1]), int(sample[2])) for sample in decoded + ] + + +def _maybe_write_debug_info( + batch: Sequence[np.ndarray], debug_file: Optional[TextIO] +) -> None: + if debug_file is None: + return + + debug_info = _extract_debug_chunk_info(batch) + debug_file.write(f"{debug_info}\n") + + +def probe_dataloader( + config_filename: str, num_batches: int, debug_chunk_file: str | None = None +) -> None: """Measure latency and throughput for the configured data loader. Args: config_filename: Path to the root configuration proto file. num_batches: Total number of batches to fetch from the loader. + debug_chunk_file: Optional path to write chunk metadata for each batch. """ if num_batches < 1: @@ -36,14 +77,23 @@ def probe_dataloader(config_filename: str, num_batches: int) -> None: logger.info("Creating data loader") loader = make_dataloader(config.data_loader) + debug_handle: Optional[TextIO] = None + if debug_chunk_file: + try: + debug_handle = open(debug_chunk_file, "w") + except Exception: + _stop_loader(loader) + raise + first_batch_time = 0.0 remaining_batches = num_batches - 1 try: logger.info("Fetching first batch") start_time = time.perf_counter() - loader.get_next() + first_batch = loader.get_next() first_batch_time = time.perf_counter() - start_time logger.info("Time to first batch: %.3f seconds", first_batch_time) + _maybe_write_debug_info(first_batch, debug_handle) if remaining_batches <= 0: logger.info("Only fetched first batch; skipping throughput") @@ -55,7 +105,8 @@ def probe_dataloader(config_filename: str, num_batches: int) -> None: ) throughput_start = time.perf_counter() for _ in range(remaining_batches): - loader.get_next() + batch = loader.get_next() + _maybe_write_debug_info(batch, debug_handle) throughput_duration = time.perf_counter() - throughput_start if throughput_duration <= 0: @@ -72,4 +123,6 @@ def probe_dataloader(config_filename: str, num_batches: int) -> None: throughput_duration, ) finally: + if debug_handle is not None: + debug_handle.close() _stop_loader(loader)