Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
36 changes: 34 additions & 2 deletions tsl/profiler/lib/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -148,7 +148,8 @@ cc_library(
]),
deps = [
"//tsl/profiler/protobuf:xplane_proto_cc",
"@xla//xla/tsl/platform:status",
"@com_google_absl//absl/status",
"@com_google_absl//absl/status:statusor",
],
)

Expand Down Expand Up @@ -183,6 +184,20 @@ tsl_cc_test(
],
)

tsl_cc_test(
name = "continuous_profiler_orchestrator_test",
srcs = ["continuous_profiler_orchestrator_test.cc"],
deps = [
":continuous_profiler_orchestrator",
":profiler_interface",
"//tsl/profiler/protobuf:xplane_proto_cc",
"@com_google_absl//absl/status",
"@com_google_absl//absl/status:statusor",
"@com_google_absl//absl/time",
"@com_google_googletest//:gtest_main",
],
)

cc_library(
name = "profiler_session",
hdrs = ["profiler_session.h"],
Expand Down Expand Up @@ -235,6 +250,7 @@ cc_library(
"@xla//xla/tsl/profiler/utils:xplane_schema",
"@xla//xla/tsl/profiler/utils:xplane_utils",
] + if_not_android([
":continuous_profiler_orchestrator",
":profiler_collection",
":profiler_factory",
":profiler_interface",
Expand Down Expand Up @@ -389,7 +405,7 @@ cc_library(
":profiler_interface",
"//tsl/profiler/protobuf:xplane_proto_cc",
"@com_google_absl//absl/status",
"@xla//xla/tsl/platform:status",
"@com_google_absl//absl/status:statusor",
],
)

Expand All @@ -402,3 +418,19 @@ cc_library(
"@com_google_absl//absl/strings:string_view",
],
)

cc_library(
name = "continuous_profiler_orchestrator",
hdrs = ["continuous_profiler_orchestrator.h"],
visibility = ["//visibility:public"],
deps = [
":profiler_interface",
"//tsl/profiler/protobuf:xplane_proto_cc",
"@com_google_absl//absl/base:core_headers",
"@com_google_absl//absl/status",
"@com_google_absl//absl/synchronization",
"@com_google_absl//absl/time",
"@com_google_absl//absl/types:any",
"@xla//xla/tsl/platform:env",
],
)
168 changes: 168 additions & 0 deletions tsl/profiler/lib/continuous_profiler_orchestrator.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,168 @@
/* Copyright 2026 The TensorFlow Authors. All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_TSL_PROFILER_LIB_CONTINUOUS_PROFILER_ORCHESTRATOR_H_
#define TENSORFLOW_TSL_PROFILER_LIB_CONTINUOUS_PROFILER_ORCHESTRATOR_H_

#include <algorithm>
#include <any>
#include <atomic>
#include <cstddef>
#include <memory>
#include <utility>
#include <vector>

#include "absl/base/thread_annotations.h"
#include "absl/status/status.h"
#include "absl/synchronization/mutex.h"
#include "absl/time/time.h"
#include "absl/types/any.h"
#include "xla/tsl/platform/env.h"
#include "tsl/profiler/lib/profiler_interface.h"
#include "tsl/profiler/protobuf/xplane.pb.h"

namespace tsl {
namespace profiler {

template <typename ProfilerType>
class ContinuousProfilerOrchestrator : public ProfilerInterface {
public:
explicit ContinuousProfilerOrchestrator(
std::unique_ptr<ProfilerType> profiler)
: profiler_(std::move(profiler)),
is_running_(false),
polling_interval_(absl::Seconds(1)) {}

~ContinuousProfilerOrchestrator() override { StopIngestionThread(); }

// Starts profiling and spawns background thread.
absl::Status Start() override {
absl::Status status = profiler_->Start();
if (!status.ok()) return status;

{
absl::MutexLock lock(&mutex_);
is_running_ = true;
}
ingestion_thread_ =
std::unique_ptr<tsl::Thread>(tsl::Env::Default()->StartThread(
tsl::ThreadOptions{}, "ContinuousProfilerIngestion",
[this]() { IngestionLoop(); }));
return absl::OkStatus();
}

// Stops background thread and profiling.
absl::Status Stop() override {
StopIngestionThread();
return profiler_->Stop();
}

absl::Status CollectData(tensorflow::profiler::XSpace* space) override {
absl::Status status = Serialize({}, space);
status.Update(profiler_->CollectData(space));
return status;
}

absl::Status Serialize(std::any data,
tensorflow::profiler::XSpace* space) override {
if (data.has_value()) {
return profiler_->Serialize(std::move(data), space);
}
auto chunks = PopBuffer();
absl::Status status;
for (auto& chunk : chunks) {
status.Update(profiler_->Serialize(std::move(chunk), space));
}
return status;
}

// Returns the current polling interval (primarily for testing).
absl::Duration polling_interval() const {
absl::MutexLock lock(&mutex_);
return polling_interval_;
}

std::vector<std::any> PopBuffer() {
absl::MutexLock lock(&mutex_);
std::vector<std::any> chunks = std::move(circular_buffer_);
circular_buffer_.clear();
return chunks;
}

ProfilerType* profiler() { return profiler_.get(); }
const ProfilerType* profiler() const { return profiler_.get(); }

private:
void IngestionLoop() {
while (true) {
auto result = profiler_->Consume();

absl::MutexLock lock(&mutex_);
if (!is_running_) break;

if (result.ok()) {
circular_buffer_.push_back(std::move(result->data));

// Cap circular buffer to prevent infinite memory growth.
if (circular_buffer_.size() > 100) {
circular_buffer_.erase(circular_buffer_.begin());
}

AdjustIntervalLocked(result->estimated_size_bytes);
}

// Wait using absl::CondVar on absl::Mutex
cv_.WaitWithTimeout(&mutex_, polling_interval_);
if (!is_running_) break;
}
}

void StopIngestionThread() {
{
absl::MutexLock lock(&mutex_);
if (!is_running_) return;
is_running_ = false;
cv_.SignalAll();
}
ingestion_thread_.reset();
}

void AdjustIntervalLocked(size_t chunk_size_bytes)
ABSL_EXCLUSIVE_LOCKS_REQUIRED(mutex_) {
constexpr size_t kHighWatermark = 512 * 1024 * 1024; // 512MB
constexpr size_t kLowWatermark = 5 * 1024 * 1024; // 5MB

if (chunk_size_bytes > kHighWatermark) {
polling_interval_ =
std::max(polling_interval_ / 2, absl::Milliseconds(100));
} else if (chunk_size_bytes < kLowWatermark) {
polling_interval_ = std::min(polling_interval_ * 2, absl::Seconds(5));
}
}

std::unique_ptr<ProfilerType> profiler_;

mutable absl::Mutex mutex_;
absl::CondVar cv_;
std::unique_ptr<tsl::Thread> ingestion_thread_;
std::atomic<bool> is_running_;

absl::Duration polling_interval_ ABSL_GUARDED_BY(mutex_);
std::vector<std::any> circular_buffer_ ABSL_GUARDED_BY(mutex_);
};

} // namespace profiler
} // namespace tsl

#endif // TENSORFLOW_TSL_PROFILER_LIB_CONTINUOUS_PROFILER_ORCHESTRATOR_H_
157 changes: 157 additions & 0 deletions tsl/profiler/lib/continuous_profiler_orchestrator_test.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,157 @@
/* Copyright 2026 The TensorFlow Authors. All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tsl/profiler/lib/continuous_profiler_orchestrator.h"

#include <any>
#include <atomic>
#include <memory>
#include <utility>

#include "testing/base/public/gmock.h"
#include "testing/base/public/gunit.h"
#include "absl/status/status.h"
#include "absl/status/statusor.h"
#include "absl/time/clock.h"
#include "absl/time/time.h"
#include "tsl/profiler/lib/profiler_interface.h"
#include "tsl/profiler/protobuf/xplane.pb.h"

namespace tsl {
namespace profiler {
namespace {

using ::testing::Invoke;
using ::testing::Return;

class MockProfiler : public ProfilerInterface {
public:
MOCK_METHOD(absl::Status, Start, (), (override));
MOCK_METHOD(absl::Status, Stop, (), (override));
MOCK_METHOD(absl::Status, CollectData, (tensorflow::profiler::XSpace * space),
(override));
MOCK_METHOD(absl::StatusOr<ConsumeResult>, Consume, (), (override));
MOCK_METHOD(absl::Status, Serialize,
(std::any data, tensorflow::profiler::XSpace* space), (override));
};

// Custom Google Mock matcher to compare integers wrapped in std::any.
MATCHER_P(AnyEqInt, expected_value, "") {
auto* val = std::any_cast<int>(&arg);
return val != nullptr && *val == expected_value;
}

TEST(ContinuousProfilerOrchestratorTest,
CircularBufferingAndCorrectSerialization) {
auto mock_profiler = std::make_unique<MockProfiler>();
MockProfiler* mock_ptr = mock_profiler.get();

EXPECT_CALL(*mock_ptr, Start()).WillOnce(Return(absl::OkStatus()));
EXPECT_CALL(*mock_ptr, Stop()).WillOnce(Return(absl::OkStatus()));

// Setup Consume mock to return sequential chunks with high sizes to shrink
// the interval
std::atomic<int> consume_count(0);
EXPECT_CALL(*mock_ptr, Consume())
.WillRepeatedly(
Invoke([&]() -> absl::StatusOr<ProfilerInterface::ConsumeResult> {
int count = ++consume_count;
return ProfilerInterface::ConsumeResult{
.data = std::any(count),
.estimated_size_bytes = 1000 * 1024 * 1024 // 1000MB (>512MB)
};
}));

ContinuousProfilerOrchestrator<ProfilerInterface> orchestrator(
std::move(mock_profiler));

// Start orchestrator (spawns background loop)
ASSERT_OK(orchestrator.Start());

// Wait until we have consumed at least 4 chunks.
// Due to high watermark scaling, interval shrinks from 1s -> 500ms -> 250ms
// -> 125ms -> 100ms. This will happen in less than 1.0 second of real-time
// sleep.
while (consume_count < 4) {
absl::SleepFor(absl::Milliseconds(50));
}

// Stop orchestrator (terminates background loop)
ASSERT_OK(orchestrator.Stop());

// Verify polling interval shrank from 1s due to high watermark
EXPECT_LE(orchestrator.polling_interval(), absl::Milliseconds(500));

// Verify that CollectData serializes the chunks in correct chronological
// order!
tensorflow::profiler::XSpace space;

// We expect Serialize to be called for each chunk in order: 1, 2, 3...
::testing::InSequence seq;
for (int i = 1; i <= consume_count; ++i) {
EXPECT_CALL(*mock_ptr, Serialize(AnyEqInt(i), &space))
.WillOnce(Return(absl::OkStatus()));
}
// And finally mock CollectData is called to collect remainder.
EXPECT_CALL(*mock_ptr, CollectData(&space))
.WillOnce(Return(absl::OkStatus()));

// Call CollectData manually by popping buffer and serializing.
auto chunks = orchestrator.PopBuffer();
absl::Status status;
for (auto& chunk : chunks) {
status.Update(orchestrator.profiler()->Serialize(std::move(chunk), &space));
}
status.Update(orchestrator.profiler()->CollectData(&space));
EXPECT_OK(status);
}

TEST(ContinuousProfilerOrchestratorTest, DynamicIntervalLowWatermarkScaling) {
auto mock_profiler = std::make_unique<MockProfiler>();
MockProfiler* mock_ptr = mock_profiler.get();

EXPECT_CALL(*mock_ptr, Start()).WillOnce(Return(absl::OkStatus()));
EXPECT_CALL(*mock_ptr, Stop()).WillOnce(Return(absl::OkStatus()));

// Consume returns a very small chunk (1MB < 5MB low watermark)
EXPECT_CALL(*mock_ptr, Consume())
.WillRepeatedly(
Invoke([]() -> absl::StatusOr<ProfilerInterface::ConsumeResult> {
return ProfilerInterface::ConsumeResult{
.data = std::any(1),
.estimated_size_bytes = 1 * 1024 * 1024 // 1MB (<5MB)
};
}));

ContinuousProfilerOrchestrator<ProfilerInterface> orchestrator(
std::move(mock_profiler));
EXPECT_EQ(orchestrator.polling_interval(), absl::Seconds(1)); // initial

// Start
ASSERT_OK(orchestrator.Start());

// Wait a very short time for the first immediate consume to run and adjust
// the interval
absl::SleepFor(absl::Milliseconds(50));

// Stop immediately
ASSERT_OK(orchestrator.Stop());

// Verify interval scaled up from 1s to 2s due to low watermark!
EXPECT_EQ(orchestrator.polling_interval(), absl::Seconds(2));
}

} // namespace
} // namespace profiler
} // namespace tsl
Loading
Loading