Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
132 changes: 132 additions & 0 deletions csrc/loader/stages/debug_chunk_source_generator.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,132 @@
#include "loader/stages/debug_chunk_source_generator.h"

#include <algorithm>
#include <numeric>
#include <utility>
#include <vector>

#include "absl/algorithm/container.h"
#include "absl/cleanup/cleanup.h"
#include "absl/log/log.h"
#include "absl/random/random.h"
#include "absl/random/seed_sequences.h"
#include "absl/time/clock.h"
#include "absl/time/time.h"
#include "loader/data_loader_metrics.h"

namespace lczero {
namespace training {

namespace {
constexpr uint64_t kDefaultQueueCapacity = 16;
constexpr uint64_t kInitialShuffleSeed = 0xC0FFEEull;
constexpr absl::Duration kStopPollInterval = absl::Milliseconds(10);
} // namespace

DebugChunkSourceGenerator::DebugChunkSourceGenerator(
const DebugChunkSourceGeneratorConfig& config,
const Stage::StageList& existing_stages)
: config_(config),
output_queue_(static_cast<size_t>(std::max<uint64_t>(
config.initial_chunk_sources(), kDefaultQueueCapacity))),
mean_chunk_count_(std::max(1.0, config.mean_chunks_per_chunk_source())) {
(void)existing_stages;
if (config.mean_chunks_per_chunk_source() <= 0.0) {
LOG(WARNING) << "DebugChunkSourceGenerator mean chunk count not positive."
<< " Using 1.";
}
}

DebugChunkSourceGenerator::~DebugChunkSourceGenerator() { Stop(); }

void DebugChunkSourceGenerator::Start() {
if (worker_.joinable()) {
return;
}
worker_ = std::jthread(
[this](std::stop_token stop_token) { Run(std::move(stop_token)); });
}

void DebugChunkSourceGenerator::Stop() {
if (!worker_.joinable()) return;
worker_.request_stop();
worker_.join();
Comment on lines +50 to +53

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P1 Badge Prevent Stop() from hanging on full output queue

Stop() only calls request_stop() and joins the thread, while the worker loop writes via blocking producer.Put (line 92) and never closes the queue. If a downstream stage has stopped and the queue fills, the worker blocks inside Put and can’t observe the stop token, causing Stop() to block indefinitely and preventing the loader from shutting down. Closing the queue or the producer before joining would let the worker unblock even when no consumer is draining.

Useful? React with 👍 / 👎.

}

Queue<DebugChunkSourceGenerator::OutputType>*
DebugChunkSourceGenerator::output() {
return &output_queue_;
}

QueueBase* DebugChunkSourceGenerator::GetOutput(std::string_view name) {
(void)name;
return &output_queue_;
}

StageMetricProto DebugChunkSourceGenerator::FlushMetrics() {
StageMetricProto metric;
metric.set_stage_type("debug_chunk_source_generator");
*metric.add_queue_metrics() = MetricsFromQueue("output", output_queue_);
auto* count_metric = metric.add_count_metrics();
count_metric->set_name("chunk_sources_generated");
count_metric->set_count(generated_sources_.load(std::memory_order_relaxed));
return metric;
}

void DebugChunkSourceGenerator::Run(std::stop_token stop_token) {
try {
auto producer = output_queue_.CreateProducer();
absl::Cleanup close_queue = [&] { output_queue_.Close(); };

std::vector<uint64_t> initial_ids(config_.initial_chunk_sources());
std::iota(initial_ids.begin(), initial_ids.end(), 0);
if (!initial_ids.empty()) {
absl::SeedSeq seed({static_cast<uint32_t>(kInitialShuffleSeed),
static_cast<uint32_t>(kInitialShuffleSeed >> 32)});
absl::BitGen bitgen(seed);
absl::c_shuffle(initial_ids, bitgen);
}

auto emit_source = [&](uint64_t id) {
auto source = std::make_unique<DebugChunkSource>(id, mean_chunk_count_);
producer.Put({.source = std::move(source),
.message_type = FilePathProvider::MessageType::kFile});
generated_sources_.fetch_add(1, std::memory_order_relaxed);
};

for (uint64_t id : initial_ids) {
if (stop_token.stop_requested()) return;
emit_source(id);
}

if (stop_token.stop_requested()) return;

producer.Put(
{.source = nullptr,
.message_type = FilePathProvider::MessageType::kInitialScanComplete});

const double per_minute = config_.chunk_sources_per_minute();
if (per_minute <= 0.0) return;

const absl::Duration cadence = absl::Seconds(60.0 / per_minute);
uint64_t next_id = config_.initial_chunk_sources();
absl::Time next_deadline = absl::Now();

while (!stop_token.stop_requested()) {
emit_source(next_id++);
next_deadline += cadence;
while (!stop_token.stop_requested()) {
const absl::Duration wait = next_deadline - absl::Now();
if (wait <= absl::ZeroDuration()) break;
const absl::Duration sleep =
wait < kStopPollInterval ? wait : kStopPollInterval;
absl::SleepFor(sleep);
}
}
} catch (const QueueClosedException&) {
LOG(INFO) << "DebugChunkSourceGenerator stopping due to closed queue.";
}
}

} // namespace training
} // namespace lczero
51 changes: 51 additions & 0 deletions csrc/loader/stages/debug_chunk_source_generator.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
#pragma once

#include <atomic>
#include <cstdint>
#include <memory>
#include <stop_token>
#include <string_view>
#include <thread>

#include "loader/chunk_source/chunk_source.h"
#include "loader/chunk_source/debug_chunk_source.h"
#include "loader/stages/chunk_source_loader.h"
#include "loader/stages/stage.h"
#include "proto/data_loader_config.pb.h"
#include "proto/training_metrics.pb.h"
#include "utils/queue.h"

namespace lczero {
namespace training {

// DebugChunkSourceGenerator emits deterministic DebugChunkSource instances.
// It is intended for loader bring-up and testing without filesystem input.
class DebugChunkSourceGenerator : public Stage {
public:
using OutputType = ChunkSourceWithPhase;

explicit DebugChunkSourceGenerator(
const DebugChunkSourceGeneratorConfig& config,
const Stage::StageList& existing_stages = {});
~DebugChunkSourceGenerator() override;

void Start() override;
void Stop() override;

StageMetricProto FlushMetrics() override;

QueueBase* GetOutput(std::string_view name = "") override;
Queue<OutputType>* output();

private:
void Run(std::stop_token stop_token);

const DebugChunkSourceGeneratorConfig config_;
Queue<OutputType> output_queue_;
std::jthread worker_;
std::atomic<uint64_t> generated_sources_{0};
const double mean_chunk_count_;
};

} // namespace training
} // namespace lczero
51 changes: 51 additions & 0 deletions csrc/loader/stages/debug_chunk_source_generator_test.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
#include "loader/stages/debug_chunk_source_generator.h"

#include <gtest/gtest.h>

#include <algorithm>
#include <cstdint>
#include <string>
#include <vector>

namespace lczero {
namespace training {

TEST(DebugChunkSourceGeneratorTest, EmitsInitialSourcesAndMarker) {
DebugChunkSourceGeneratorConfig config;
config.set_mean_chunks_per_chunk_source(10.0);
config.set_initial_chunk_sources(3);
config.set_chunk_sources_per_minute(6000.0);

DebugChunkSourceGenerator generator(config);
generator.Start();

auto* queue = generator.output();
std::vector<uint64_t> initial_ids;
for (int i = 0; i < 3; ++i) {
auto item = queue->Get();
ASSERT_NE(item.source, nullptr);
EXPECT_EQ(item.message_type, FilePathProvider::MessageType::kFile);
uint64_t id = 0;
ASSERT_NO_THROW(id = std::stoull(item.source->GetChunkSortKey()));
initial_ids.push_back(id);
}
std::sort(initial_ids.begin(), initial_ids.end());
EXPECT_EQ(initial_ids, (std::vector<uint64_t>{0, 1, 2}));

auto marker = queue->Get();
EXPECT_EQ(marker.source, nullptr);
EXPECT_EQ(marker.message_type,
FilePathProvider::MessageType::kInitialScanComplete);

auto next = queue->Get();
ASSERT_NE(next.source, nullptr);
EXPECT_EQ(next.message_type, FilePathProvider::MessageType::kFile);
uint64_t next_id = 0;
ASSERT_NO_THROW(next_id = std::stoull(next.source->GetChunkSortKey()));
EXPECT_EQ(next_id, 3);

generator.Stop();
}

} // namespace training
} // namespace lczero
8 changes: 7 additions & 1 deletion csrc/loader/stages/stage_factory.cc
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
#include "loader/stages/chunk_rescorer.h"
#include "loader/stages/chunk_source_loader.h"
#include "loader/stages/chunk_unpacker.h"
#include "loader/stages/debug_chunk_source_generator.h"
#include "loader/stages/file_path_provider.h"
#include "loader/stages/shuffling_chunk_pool.h"
#include "loader/stages/shuffling_frame_sampler.h"
Expand All @@ -23,7 +24,8 @@ int CountStageConfigs(const StageConfig& config) {
static_cast<int>(config.has_chunk_rescorer()) +
static_cast<int>(config.has_chunk_unpacker()) +
static_cast<int>(config.has_shuffling_frame_sampler()) +
static_cast<int>(config.has_tensor_generator());
static_cast<int>(config.has_tensor_generator()) +
static_cast<int>(config.has_debug_chunk_source_generator());
}

} // namespace
Expand Down Expand Up @@ -63,6 +65,10 @@ std::unique_ptr<Stage> CreateStage(const StageConfig& config,
return std::make_unique<TensorGenerator>(config.tensor_generator(),
existing_stages);
}
if (config.has_debug_chunk_source_generator()) {
return std::make_unique<DebugChunkSourceGenerator>(
config.debug_chunk_source_generator(), existing_stages);
}

throw std::runtime_error(
"StageConfig did not contain a recognized stage configuration.");
Expand Down
13 changes: 13 additions & 0 deletions csrc/loader/stages/stage_factory_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -31,5 +31,18 @@ TEST(StageFactoryTest, ThrowsWhenMultipleStageConfigsSet) {
EXPECT_THROW(CreateStage(config, {}), std::runtime_error);
}

TEST(StageFactoryTest, CreatesDebugChunkSourceGeneratorStage) {
StageConfig config;
auto* debug_config = config.mutable_debug_chunk_source_generator();
debug_config->set_mean_chunks_per_chunk_source(10.0);
debug_config->set_initial_chunk_sources(2);
debug_config->set_chunk_sources_per_minute(60.0);

auto stage = CreateStage(config, {});

ASSERT_NE(stage, nullptr);
EXPECT_NE(stage->GetOutput(), nullptr);
}

} // namespace training
} // namespace lczero
10 changes: 10 additions & 0 deletions meson.build
Original file line number Diff line number Diff line change
Expand Up @@ -117,6 +117,7 @@ files = [
'csrc/loader/stages/chunk_source_loader.cc',
'csrc/loader/stages/chunk_rescorer.cc',
'csrc/loader/stages/chunk_unpacker.cc',
'csrc/loader/stages/debug_chunk_source_generator.cc',
'csrc/loader/stages/file_path_provider.cc',
'csrc/loader/stages/stage_factory.cc',
'csrc/loader/chunk_source/debug_chunk_source.cc',
Expand Down Expand Up @@ -204,6 +205,14 @@ chunk_source_loader_test = executable(
link_with : loader_lib,
)

debug_chunk_source_generator_test = executable(
'debug_chunk_source_generator_test',
'csrc/loader/stages/debug_chunk_source_generator_test.cc',
include_directories : includes,
dependencies : test_deps + [absl_deps['synchronization'], absl_deps['log']],
link_with : loader_lib,
)

shuffling_chunk_pool_test = executable(
'shuffling_chunk_pool_test',
'csrc/loader/stages/shuffling_chunk_pool_test.cc',
Expand Down Expand Up @@ -286,6 +295,7 @@ test('stream_shuffler_test', stream_shuffler_test)
test('queue_test', queue_test)
test('file_path_provider_test', file_path_provider_test)
test('chunk_source_loader_test', chunk_source_loader_test)
test('debug_chunk_source_generator_test', debug_chunk_source_generator_test)
test('shuffling_chunk_pool_test', shuffling_chunk_pool_test)
test('chunk_rescorer_test', chunk_rescorer_test)
test('chunk_unpacker_test', chunk_unpacker_test)
Expand Down
13 changes: 13 additions & 0 deletions proto/data_loader_config.proto
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,18 @@ message ChunkSourceLoaderConfig {
optional uint64 queue_capacity = 3 [default = 16];
}

// Configuration for debug chunk source generator producing synthetic
// ChunkSource instances. Maps to DebugChunkSourceGenerator in
// csrc/loader/stages/debug_chunk_source_generator.h
message DebugChunkSourceGeneratorConfig {
// Mean number of chunks per generated source.
optional double mean_chunks_per_chunk_source = 1 [default = 100.0];
// Number of chunk sources generated before initial scan completion.
optional uint64 initial_chunk_sources = 2 [default = 0];
// Rate of new chunk sources produced after the initial scan (per minute).
optional double chunk_sources_per_minute = 3 [default = 0.0];
}

// Configuration for shuffling chunk pool that manages chunk shuffling and
// loading. Maps to ShufflingChunkPoolOptions in
// csrc/loader/chunk_feed/shuffling_chunk_pool.h
Expand Down Expand Up @@ -118,6 +130,7 @@ message StageConfig {
optional ShufflingFrameSamplerConfig shuffling_frame_sampler = 6;
optional TensorGeneratorConfig tensor_generator = 7;
optional ChunkRescorerConfig chunk_rescorer = 8;
optional DebugChunkSourceGeneratorConfig debug_chunk_source_generator = 9;
}

// Main configuration class for the DataLoader containing all component
Expand Down