Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions src/lczero_training/training/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -201,6 +201,11 @@ def configure_parser(parser: argparse.ArgumentParser) -> None:
default=10,
help="Number of batches to fetch from the data loader.",
)
dataloader_parser.add_argument(
"--debug-chunk-file",
type=str,
help="Optional path to write debug chunk metadata for each batch.",
)
dataloader_parser.set_defaults(func=run)


Expand Down Expand Up @@ -249,6 +254,7 @@ def run(args: argparse.Namespace) -> None:
probe_dataloader(
config_filename=args.config,
num_batches=args.num_batches,
debug_chunk_file=getattr(args, "debug_chunk_file", None),
)


Expand Down
59 changes: 56 additions & 3 deletions src/lczero_training/training/dataloader_probe.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,11 @@

import logging
import time
from collections.abc import Sequence
from contextlib import suppress
from typing import List, Optional, TextIO, Tuple

import numpy as np
from google.protobuf import text_format

from lczero_training.dataloader import DataLoader, make_dataloader
Expand All @@ -12,17 +15,55 @@
logger = logging.getLogger(__name__)


_BIT_ORDER = (np.arange(64, dtype=np.uint64) ^ 7).reshape(1, 1, 64)


def _stop_loader(loader: DataLoader) -> None:
with suppress(Exception):
loader.stop()


def probe_dataloader(config_filename: str, num_batches: int) -> None:
def _extract_debug_chunk_info(
batch: Sequence[np.ndarray],
) -> List[Tuple[int, int, int]]:
"""Decode debug chunk metadata from the first three planes of a batch."""

if not batch:
return []

inputs = batch[0]
if inputs.ndim < 4 or inputs.shape[1] < 3:
return []

planes = np.asarray(inputs[:, :3, :, :], dtype=np.float32)
plane_bits = planes.reshape(planes.shape[0], 3, 64)
bits = np.rint(plane_bits).astype(np.uint64) & 1
weighted = np.left_shift(bits, _BIT_ORDER)
decoded = np.asarray(np.sum(weighted, axis=-1, dtype=np.uint64))
return [
(int(sample[0]), int(sample[1]), int(sample[2])) for sample in decoded
]


def _maybe_write_debug_info(
batch: Sequence[np.ndarray], debug_file: Optional[TextIO]
) -> None:
if debug_file is None:
return

debug_info = _extract_debug_chunk_info(batch)
debug_file.write(f"{debug_info}\n")


def probe_dataloader(
config_filename: str, num_batches: int, debug_chunk_file: str | None = None
) -> None:
"""Measure latency and throughput for the configured data loader.

Args:
config_filename: Path to the root configuration proto file.
num_batches: Total number of batches to fetch from the loader.
debug_chunk_file: Optional path to write chunk metadata for each batch.
"""

if num_batches < 1:
Expand All @@ -36,14 +77,23 @@ def probe_dataloader(config_filename: str, num_batches: int) -> None:
logger.info("Creating data loader")
loader = make_dataloader(config.data_loader)

debug_handle: Optional[TextIO] = None
if debug_chunk_file:
try:
debug_handle = open(debug_chunk_file, "w")
except Exception:
_stop_loader(loader)
raise

first_batch_time = 0.0
remaining_batches = num_batches - 1
try:
logger.info("Fetching first batch")
start_time = time.perf_counter()
loader.get_next()
first_batch = loader.get_next()
first_batch_time = time.perf_counter() - start_time
logger.info("Time to first batch: %.3f seconds", first_batch_time)
_maybe_write_debug_info(first_batch, debug_handle)

if remaining_batches <= 0:
logger.info("Only fetched first batch; skipping throughput")
Expand All @@ -55,7 +105,8 @@ def probe_dataloader(config_filename: str, num_batches: int) -> None:
)
throughput_start = time.perf_counter()
for _ in range(remaining_batches):
loader.get_next()
batch = loader.get_next()
_maybe_write_debug_info(batch, debug_handle)
throughput_duration = time.perf_counter() - throughput_start

if throughput_duration <= 0:
Expand All @@ -72,4 +123,6 @@ def probe_dataloader(config_filename: str, num_batches: int) -> None:
throughput_duration,
)
finally:
if debug_handle is not None:
debug_handle.close()
_stop_loader(loader)