diff --git a/tsl/profiler/lib/BUILD b/tsl/profiler/lib/BUILD index 9293a2988..d2c981d49 100644 --- a/tsl/profiler/lib/BUILD +++ b/tsl/profiler/lib/BUILD @@ -148,7 +148,8 @@ cc_library( ]), deps = [ "//tsl/profiler/protobuf:xplane_proto_cc", - "@xla//xla/tsl/platform:status", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", ], ) @@ -183,6 +184,20 @@ tsl_cc_test( ], ) +tsl_cc_test( + name = "continuous_profiler_orchestrator_test", + srcs = ["continuous_profiler_orchestrator_test.cc"], + deps = [ + ":continuous_profiler_orchestrator", + ":profiler_interface", + "//tsl/profiler/protobuf:xplane_proto_cc", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/time", + "@com_google_googletest//:gtest_main", + ], +) + cc_library( name = "profiler_session", hdrs = ["profiler_session.h"], @@ -235,6 +250,7 @@ cc_library( "@xla//xla/tsl/profiler/utils:xplane_schema", "@xla//xla/tsl/profiler/utils:xplane_utils", ] + if_not_android([ + ":continuous_profiler_orchestrator", ":profiler_collection", ":profiler_factory", ":profiler_interface", @@ -389,7 +405,7 @@ cc_library( ":profiler_interface", "//tsl/profiler/protobuf:xplane_proto_cc", "@com_google_absl//absl/status", - "@xla//xla/tsl/platform:status", + "@com_google_absl//absl/status:statusor", ], ) @@ -402,3 +418,19 @@ cc_library( "@com_google_absl//absl/strings:string_view", ], ) + +cc_library( + name = "continuous_profiler_orchestrator", + hdrs = ["continuous_profiler_orchestrator.h"], + visibility = ["//visibility:public"], + deps = [ + ":profiler_interface", + "//tsl/profiler/protobuf:xplane_proto_cc", + "@com_google_absl//absl/base:core_headers", + "@com_google_absl//absl/status", + "@com_google_absl//absl/synchronization", + "@com_google_absl//absl/time", + "@com_google_absl//absl/types:any", + "@xla//xla/tsl/platform:env", + ], +) diff --git a/tsl/profiler/lib/continuous_profiler_orchestrator.h b/tsl/profiler/lib/continuous_profiler_orchestrator.h new file mode 100644 index 000000000..269564298 --- /dev/null +++ b/tsl/profiler/lib/continuous_profiler_orchestrator.h @@ -0,0 +1,168 @@ +/* Copyright 2026 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_TSL_PROFILER_LIB_CONTINUOUS_PROFILER_ORCHESTRATOR_H_ +#define TENSORFLOW_TSL_PROFILER_LIB_CONTINUOUS_PROFILER_ORCHESTRATOR_H_ + +#include +#include +#include +#include +#include +#include +#include + +#include "absl/base/thread_annotations.h" +#include "absl/status/status.h" +#include "absl/synchronization/mutex.h" +#include "absl/time/time.h" +#include "absl/types/any.h" +#include "xla/tsl/platform/env.h" +#include "tsl/profiler/lib/profiler_interface.h" +#include "tsl/profiler/protobuf/xplane.pb.h" + +namespace tsl { +namespace profiler { + +template +class ContinuousProfilerOrchestrator : public ProfilerInterface { + public: + explicit ContinuousProfilerOrchestrator( + std::unique_ptr profiler) + : profiler_(std::move(profiler)), + is_running_(false), + polling_interval_(absl::Seconds(1)) {} + + ~ContinuousProfilerOrchestrator() override { StopIngestionThread(); } + + // Starts profiling and spawns background thread. + absl::Status Start() override { + absl::Status status = profiler_->Start(); + if (!status.ok()) return status; + + { + absl::MutexLock lock(&mutex_); + is_running_ = true; + } + ingestion_thread_ = + std::unique_ptr(tsl::Env::Default()->StartThread( + tsl::ThreadOptions{}, "ContinuousProfilerIngestion", + [this]() { IngestionLoop(); })); + return absl::OkStatus(); + } + + // Stops background thread and profiling. + absl::Status Stop() override { + StopIngestionThread(); + return profiler_->Stop(); + } + + absl::Status CollectData(tensorflow::profiler::XSpace* space) override { + absl::Status status = Serialize({}, space); + status.Update(profiler_->CollectData(space)); + return status; + } + + absl::Status Serialize(std::any data, + tensorflow::profiler::XSpace* space) override { + if (data.has_value()) { + return profiler_->Serialize(std::move(data), space); + } + auto chunks = PopBuffer(); + absl::Status status; + for (auto& chunk : chunks) { + status.Update(profiler_->Serialize(std::move(chunk), space)); + } + return status; + } + + // Returns the current polling interval (primarily for testing). + absl::Duration polling_interval() const { + absl::MutexLock lock(&mutex_); + return polling_interval_; + } + + std::vector PopBuffer() { + absl::MutexLock lock(&mutex_); + std::vector chunks = std::move(circular_buffer_); + circular_buffer_.clear(); + return chunks; + } + + ProfilerType* profiler() { return profiler_.get(); } + const ProfilerType* profiler() const { return profiler_.get(); } + + private: + void IngestionLoop() { + while (true) { + auto result = profiler_->Consume(); + + absl::MutexLock lock(&mutex_); + if (!is_running_) break; + + if (result.ok()) { + circular_buffer_.push_back(std::move(result->data)); + + // Cap circular buffer to prevent infinite memory growth. + if (circular_buffer_.size() > 100) { + circular_buffer_.erase(circular_buffer_.begin()); + } + + AdjustIntervalLocked(result->estimated_size_bytes); + } + + // Wait using absl::CondVar on absl::Mutex + cv_.WaitWithTimeout(&mutex_, polling_interval_); + if (!is_running_) break; + } + } + + void StopIngestionThread() { + { + absl::MutexLock lock(&mutex_); + if (!is_running_) return; + is_running_ = false; + cv_.SignalAll(); + } + ingestion_thread_.reset(); + } + + void AdjustIntervalLocked(size_t chunk_size_bytes) + ABSL_EXCLUSIVE_LOCKS_REQUIRED(mutex_) { + constexpr size_t kHighWatermark = 512 * 1024 * 1024; // 512MB + constexpr size_t kLowWatermark = 5 * 1024 * 1024; // 5MB + + if (chunk_size_bytes > kHighWatermark) { + polling_interval_ = + std::max(polling_interval_ / 2, absl::Milliseconds(100)); + } else if (chunk_size_bytes < kLowWatermark) { + polling_interval_ = std::min(polling_interval_ * 2, absl::Seconds(5)); + } + } + + std::unique_ptr profiler_; + + mutable absl::Mutex mutex_; + absl::CondVar cv_; + std::unique_ptr ingestion_thread_; + std::atomic is_running_; + + absl::Duration polling_interval_ ABSL_GUARDED_BY(mutex_); + std::vector circular_buffer_ ABSL_GUARDED_BY(mutex_); +}; + +} // namespace profiler +} // namespace tsl + +#endif // TENSORFLOW_TSL_PROFILER_LIB_CONTINUOUS_PROFILER_ORCHESTRATOR_H_ diff --git a/tsl/profiler/lib/continuous_profiler_orchestrator_test.cc b/tsl/profiler/lib/continuous_profiler_orchestrator_test.cc new file mode 100644 index 000000000..bb3a6c5d9 --- /dev/null +++ b/tsl/profiler/lib/continuous_profiler_orchestrator_test.cc @@ -0,0 +1,157 @@ +/* Copyright 2026 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include "tsl/profiler/lib/continuous_profiler_orchestrator.h" + +#include +#include +#include +#include + +#include "testing/base/public/gmock.h" +#include "testing/base/public/gunit.h" +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "absl/time/clock.h" +#include "absl/time/time.h" +#include "tsl/profiler/lib/profiler_interface.h" +#include "tsl/profiler/protobuf/xplane.pb.h" + +namespace tsl { +namespace profiler { +namespace { + +using ::testing::Invoke; +using ::testing::Return; + +class MockProfiler : public ProfilerInterface { + public: + MOCK_METHOD(absl::Status, Start, (), (override)); + MOCK_METHOD(absl::Status, Stop, (), (override)); + MOCK_METHOD(absl::Status, CollectData, (tensorflow::profiler::XSpace * space), + (override)); + MOCK_METHOD(absl::StatusOr, Consume, (), (override)); + MOCK_METHOD(absl::Status, Serialize, + (std::any data, tensorflow::profiler::XSpace* space), (override)); +}; + +// Custom Google Mock matcher to compare integers wrapped in std::any. +MATCHER_P(AnyEqInt, expected_value, "") { + auto* val = std::any_cast(&arg); + return val != nullptr && *val == expected_value; +} + +TEST(ContinuousProfilerOrchestratorTest, + CircularBufferingAndCorrectSerialization) { + auto mock_profiler = std::make_unique(); + MockProfiler* mock_ptr = mock_profiler.get(); + + EXPECT_CALL(*mock_ptr, Start()).WillOnce(Return(absl::OkStatus())); + EXPECT_CALL(*mock_ptr, Stop()).WillOnce(Return(absl::OkStatus())); + + // Setup Consume mock to return sequential chunks with high sizes to shrink + // the interval + std::atomic consume_count(0); + EXPECT_CALL(*mock_ptr, Consume()) + .WillRepeatedly( + Invoke([&]() -> absl::StatusOr { + int count = ++consume_count; + return ProfilerInterface::ConsumeResult{ + .data = std::any(count), + .estimated_size_bytes = 1000 * 1024 * 1024 // 1000MB (>512MB) + }; + })); + + ContinuousProfilerOrchestrator orchestrator( + std::move(mock_profiler)); + + // Start orchestrator (spawns background loop) + ASSERT_OK(orchestrator.Start()); + + // Wait until we have consumed at least 4 chunks. + // Due to high watermark scaling, interval shrinks from 1s -> 500ms -> 250ms + // -> 125ms -> 100ms. This will happen in less than 1.0 second of real-time + // sleep. + while (consume_count < 4) { + absl::SleepFor(absl::Milliseconds(50)); + } + + // Stop orchestrator (terminates background loop) + ASSERT_OK(orchestrator.Stop()); + + // Verify polling interval shrank from 1s due to high watermark + EXPECT_LE(orchestrator.polling_interval(), absl::Milliseconds(500)); + + // Verify that CollectData serializes the chunks in correct chronological + // order! + tensorflow::profiler::XSpace space; + + // We expect Serialize to be called for each chunk in order: 1, 2, 3... + ::testing::InSequence seq; + for (int i = 1; i <= consume_count; ++i) { + EXPECT_CALL(*mock_ptr, Serialize(AnyEqInt(i), &space)) + .WillOnce(Return(absl::OkStatus())); + } + // And finally mock CollectData is called to collect remainder. + EXPECT_CALL(*mock_ptr, CollectData(&space)) + .WillOnce(Return(absl::OkStatus())); + + // Call CollectData manually by popping buffer and serializing. + auto chunks = orchestrator.PopBuffer(); + absl::Status status; + for (auto& chunk : chunks) { + status.Update(orchestrator.profiler()->Serialize(std::move(chunk), &space)); + } + status.Update(orchestrator.profiler()->CollectData(&space)); + EXPECT_OK(status); +} + +TEST(ContinuousProfilerOrchestratorTest, DynamicIntervalLowWatermarkScaling) { + auto mock_profiler = std::make_unique(); + MockProfiler* mock_ptr = mock_profiler.get(); + + EXPECT_CALL(*mock_ptr, Start()).WillOnce(Return(absl::OkStatus())); + EXPECT_CALL(*mock_ptr, Stop()).WillOnce(Return(absl::OkStatus())); + + // Consume returns a very small chunk (1MB < 5MB low watermark) + EXPECT_CALL(*mock_ptr, Consume()) + .WillRepeatedly( + Invoke([]() -> absl::StatusOr { + return ProfilerInterface::ConsumeResult{ + .data = std::any(1), + .estimated_size_bytes = 1 * 1024 * 1024 // 1MB (<5MB) + }; + })); + + ContinuousProfilerOrchestrator orchestrator( + std::move(mock_profiler)); + EXPECT_EQ(orchestrator.polling_interval(), absl::Seconds(1)); // initial + + // Start + ASSERT_OK(orchestrator.Start()); + + // Wait a very short time for the first immediate consume to run and adjust + // the interval + absl::SleepFor(absl::Milliseconds(50)); + + // Stop immediately + ASSERT_OK(orchestrator.Stop()); + + // Verify interval scaled up from 1s to 2s due to low watermark! + EXPECT_EQ(orchestrator.polling_interval(), absl::Seconds(2)); +} + +} // namespace +} // namespace profiler +} // namespace tsl diff --git a/tsl/profiler/lib/profiler_collection.cc b/tsl/profiler/lib/profiler_collection.cc index f3ffec62b..851535ffc 100644 --- a/tsl/profiler/lib/profiler_collection.cc +++ b/tsl/profiler/lib/profiler_collection.cc @@ -14,11 +14,14 @@ limitations under the License. ==============================================================================*/ #include "tsl/profiler/lib/profiler_collection.h" +#include +#include #include #include #include #include "absl/status/status.h" +#include "absl/status/statusor.h" #include "tsl/profiler/lib/profiler_interface.h" #include "tsl/profiler/protobuf/xplane.pb.h" @@ -55,5 +58,47 @@ absl::Status ProfilerCollection::CollectData( return status; } +absl::StatusOr ProfilerCollection::Consume() { + std::vector data_vector; + data_vector.reserve(profilers_.size()); + size_t total_estimated_size_bytes = 0; + + for (auto& profiler : profilers_) { + auto result = profiler->Consume(); + if (result.ok()) { + data_vector.push_back(std::move(result->data)); + total_estimated_size_bytes += result->estimated_size_bytes; + } else { + data_vector.push_back(std::any()); + } + } + + return ConsumeResult{std::any(std::move(data_vector)), + total_estimated_size_bytes}; +} + +absl::Status ProfilerCollection::Serialize( + std::any data, tensorflow::profiler::XSpace* space) { + auto* data_vector_ptr = std::any_cast>(&data); + if (data_vector_ptr == nullptr) { + return absl::InvalidArgumentError( + "Invalid data type for ProfilerCollection::Serialize"); + } + + if (data_vector_ptr->size() != profilers_.size()) { + return absl::InternalError( + "Data vector size mismatch in ProfilerCollection::Serialize"); + } + + absl::Status status; + for (size_t i = 0; i < profilers_.size(); ++i) { + if ((*data_vector_ptr)[i].has_value()) { + status.Update( + profilers_[i]->Serialize(std::move((*data_vector_ptr)[i]), space)); + } + } + return status; +} + } // namespace profiler } // namespace tsl diff --git a/tsl/profiler/lib/profiler_collection.h b/tsl/profiler/lib/profiler_collection.h index e2b9fd3ef..f18ec237c 100644 --- a/tsl/profiler/lib/profiler_collection.h +++ b/tsl/profiler/lib/profiler_collection.h @@ -15,11 +15,12 @@ limitations under the License. #ifndef TENSORFLOW_TSL_PROFILER_LIB_PROFILER_COLLECTION_H_ #define TENSORFLOW_TSL_PROFILER_LIB_PROFILER_COLLECTION_H_ +#include #include #include #include "absl/status/status.h" -#include "xla/tsl/platform/status.h" +#include "absl/status/statusor.h" #include "tsl/profiler/lib/profiler_interface.h" #include "tsl/profiler/protobuf/xplane.pb.h" @@ -39,6 +40,10 @@ class ProfilerCollection : public ProfilerInterface { absl::Status CollectData(tensorflow::profiler::XSpace* space) override; + absl::StatusOr Consume() override; + absl::Status Serialize(std::any data, + tensorflow::profiler::XSpace* space) override; + private: std::vector> profilers_; }; diff --git a/tsl/profiler/lib/profiler_interface.h b/tsl/profiler/lib/profiler_interface.h index 2b0b71242..d43c6add1 100644 --- a/tsl/profiler/lib/profiler_interface.h +++ b/tsl/profiler/lib/profiler_interface.h @@ -15,12 +15,21 @@ limitations under the License. #ifndef TENSORFLOW_TSL_PROFILER_LIB_PROFILER_INTERFACE_H_ #define TENSORFLOW_TSL_PROFILER_LIB_PROFILER_INTERFACE_H_ -#include "xla/tsl/platform/status.h" +#include +#include + +#include "absl/status/status.h" +#include "absl/status/statusor.h" #include "tsl/profiler/protobuf/xplane.pb.h" namespace tsl { namespace profiler { +struct ConsumeResult { + std::any data; + size_t estimated_size_bytes = 0; +}; + // Interface for tensorflow profiler plugins. // // ProfileSession calls each of these methods at most once per instance, and @@ -41,6 +50,22 @@ class ProfilerInterface { // Saves collected profile data into XSpace. virtual absl::Status CollectData(tensorflow::profiler::XSpace* space) = 0; + + struct ConsumeResult { + std::any data; + size_t estimated_size_bytes = 0; + }; + + // Consumes collected profile data without stopping the profiler. + virtual absl::StatusOr Consume() { + return absl::UnimplementedError("Consume not implemented"); + } + + // Serializes consumed profile data into XSpace. + virtual absl::Status Serialize(std::any data, + tensorflow::profiler::XSpace* space) { + return absl::UnimplementedError("Serialize not implemented"); + } }; } // namespace profiler diff --git a/tsl/profiler/lib/profiler_session.cc b/tsl/profiler/lib/profiler_session.cc index ce2333f87..024436ec2 100644 --- a/tsl/profiler/lib/profiler_session.cc +++ b/tsl/profiler/lib/profiler_session.cc @@ -15,6 +15,7 @@ limitations under the License. #include "tsl/profiler/lib/profiler_session.h" +#include #include #include @@ -34,6 +35,7 @@ limitations under the License. #include "xla/tsl/profiler/convert/post_process_single_host_xplane.h" #include "xla/tsl/profiler/utils/time_utils.h" #include "tsl/platform/host_info.h" +#include "tsl/profiler/lib/continuous_profiler_orchestrator.h" #include "tsl/profiler/lib/profiler_collection.h" #include "tsl/profiler/lib/profiler_factory.h" #include "tsl/profiler/lib/profiler_interface.h" @@ -135,8 +137,23 @@ ProfilerSession::ProfilerSession(const ProfileOptions& options) start_time_ns_ = profiler::GetCurrentTimeNanos(); DCHECK(profiler_lock_.Active()); - profilers_ = std::make_unique( - profiler::CreateProfilers(options_)); + std::unique_ptr collection = + std::make_unique( + profiler::CreateProfilers(options_)); + + bool enable_circular_buffer_tracing = false; + const auto& advanced_config = options_.advanced_configuration(); + if (auto it = advanced_config.find("tpu_circular_buffer_tracing"); + it != advanced_config.end()) { + enable_circular_buffer_tracing = it->second.bool_value(); + } + + if (enable_circular_buffer_tracing) { + profilers_ = std::make_unique >(std::move(collection)); + } else { + profilers_ = std::move(collection); + } absl::Status status = profilers_->Start(); if (options_.raise_error_on_start_failure()) {