diff --git a/build_variables.bzl b/build_variables.bzl index 3579828f2cdd3..b3aa800a7ec9d 100644 --- a/build_variables.bzl +++ b/build_variables.bzl @@ -111,6 +111,7 @@ libtorch_profiler_sources = [ "torch/csrc/profiler/standalone/execution_trace_observer.cpp", "torch/csrc/profiler/standalone/itt_observer.cpp", "torch/csrc/profiler/standalone/nvtx_observer.cpp", + "torch/csrc/profiler/standalone/roctx_observer.cpp", "torch/csrc/profiler/standalone/privateuse1_observer.cpp", "torch/csrc/profiler/stubs/base.cpp", "torch/csrc/profiler/orchestration/vulkan.cpp", @@ -870,6 +871,7 @@ libtorch_python_cuda_core_sources = [ "torch/csrc/cuda/GreenContext.cpp", "torch/csrc/cuda/shared/cudart.cpp", "torch/csrc/cuda/shared/nvtx.cpp", + "torch/csrc/cuda/shared/roctx.cpp", "torch/csrc/cuda/utils.cpp", "torch/csrc/cuda/GdsFile.cpp", ] diff --git a/caffe2/CMakeLists.txt b/caffe2/CMakeLists.txt index 0a419e46a5bce..71688fd798be0 100644 --- a/caffe2/CMakeLists.txt +++ b/caffe2/CMakeLists.txt @@ -1438,6 +1438,10 @@ if(USE_ROCM) USE_ROCM __HIP_PLATFORM_AMD__ ) + # torch_cpu contains profiler (e.g. profiler_legacy.cpp) that calls roctx*; link libroctx + if(ROCM_ROCTX_LIB) + target_link_libraries(torch_cpu PRIVATE ${ROCM_ROCTX_LIB}) + endif() if(NOT ROCM_SOURCE_DIR) set(ROCM_SOURCE_DIR "$ENV{ROCM_SOURCE_DIR}") diff --git a/cmake/public/LoadHIP.cmake b/cmake/public/LoadHIP.cmake index 3ff7b3d2c1b36..66d64ebdb927f 100644 --- a/cmake/public/LoadHIP.cmake +++ b/cmake/public/LoadHIP.cmake @@ -206,8 +206,12 @@ if(HIP_FOUND) list(REMOVE_DUPLICATES ROCM_INCLUDE_DIRS) if(UNIX) - # roctx is part of roctracer + # roctx is part of roctracer (header needed for profiler_legacy.cpp and kineto) find_library(ROCM_ROCTX_LIB roctx64 HINTS ${ROCM_PATH}/lib) + set(ROCTRACER_INCLUDE_DIR "${ROCM_PATH}/include/roctracer") + if(EXISTS "${ROCTRACER_INCLUDE_DIR}/roctx.h") + list(APPEND ROCM_INCLUDE_DIRS ${ROCTRACER_INCLUDE_DIR}) + endif() set(PROJECT_RANDOM_BINARY_DIR "${PROJECT_BINARY_DIR}") diff --git a/docs/source/autograd.md b/docs/source/autograd.md index e78b77e4eb45c..6e7f9db42d46e 100644 --- a/docs/source/autograd.md +++ b/docs/source/autograd.md @@ -289,7 +289,9 @@ Autograd includes a profiler that lets you inspect the cost of different operators inside your model - both on the CPU and GPU. There are three modes implemented at the moment - CPU-only using {class}`~torch.autograd.profiler.profile`. nvprof based (registers both CPU and GPU activity) using -{class}`~torch.autograd.profiler.emit_nvtx`. +{class}`~torch.autograd.profiler.emit_nvtx`, +ROCm tools (rocprof, rocprofv3) using +{class}`~torch.autograd.profiler.emit_roctx`, and vtune profiler based using {class}`~torch.autograd.profiler.emit_itt`. @@ -320,6 +322,10 @@ and vtune profiler based using .. autoclass:: torch.autograd.profiler.emit_nvtx ``` +```{eval-rst} +.. autoclass:: torch.autograd.profiler.emit_roctx +``` + ```{eval-rst} .. autoclass:: torch.autograd.profiler.emit_itt diff --git a/docs/source/cuda.md b/docs/source/cuda.md index 94894942b74e5..abb7b0ca62df5 100644 --- a/docs/source/cuda.md +++ b/docs/source/cuda.md @@ -189,6 +189,23 @@ nvtx.range ``` +## ROCm Tools Extension (ROCTX) + +ROCTX is the ROCm analogue of NVTX. Use it when profiling with ROCm tools +(e.g. rocprof, rocprofv3) so that markers and ranges appear in the trace. +Only available when PyTorch is built with ROCm (USE_ROCM). + +```{eval-rst} +.. autosummary:: + :toctree: generated + :nosignatures: + + roctx.mark + roctx.range_push + roctx.range_pop + roctx.range +``` + ## Jiterator (beta) ```{eval-rst} @@ -308,6 +325,10 @@ See the docs for {class}`~torch.cuda.green_contexts.GreenContext` for an example .. py:module:: torch.cuda.nvtx ``` +```{eval-rst} +.. py:module:: torch.cuda.roctx +``` + ```{eval-rst} .. py:module:: torch.cuda.profiler ``` diff --git a/test/test_autograd.py b/test/test_autograd.py index eb2ad5cb91cb5..6b9ff53675712 100644 --- a/test/test_autograd.py +++ b/test/test_autograd.py @@ -40,7 +40,7 @@ ) from torch.autograd.function import InplaceFunction, once_differentiable from torch.autograd.graph import GradientEdge -from torch.autograd.profiler import emit_itt, emit_nvtx, profile, record_function +from torch.autograd.profiler import emit_itt, emit_nvtx, emit_roctx, profile, record_function from torch.autograd.profiler_util import ( _format_time, EventList, @@ -12346,6 +12346,15 @@ def test_profiler_emit_nvtx(self, device): with emit_nvtx(): a.add(1.0) + @onlyCUDA + @unittest.skipIf(not torch.version.hip, "emit_roctx requires ROCm build") + def test_profiler_emit_roctx(self, device): + # Mirror of test_profiler_emit_nvtx: catch if emit_roctx breaks on construction (ROCm only). + a = torch.tensor([1, 2, 3], dtype=torch.float32, device=device) + with torch.cuda.profiler.profile(): + with emit_roctx(): + a.add(1.0) + @onlyCUDA def test_rnn_backward_to_input_but_not_parameters(self, device): # this checks whether it is possible to not require diff --git a/test/test_cuda.py b/test/test_cuda.py index 72a4e5e1296a6..f214d482317f6 100644 --- a/test/test_cuda.py +++ b/test/test_cuda.py @@ -1607,6 +1607,20 @@ def test_nvtx(self): range_handle = torch.cuda.nvtx.range_start("range_start") torch.cuda.nvtx.range_end(range_handle) + @unittest.skipIf( + not torch.cuda.is_available() or not torch.version.hip, + "ROCTX tests require CUDA (HIP/ROCm) build", + ) + def test_roctx(self): + # Mirror of test_nvtx: ensure roctx symbols work (ROCm build only). + torch.cuda.roctx.range_push("foo") + torch.cuda.roctx.mark("bar") + torch.cuda.roctx.range_pop() + range_handle = torch.cuda.roctx.range_start("range_start") + torch.cuda.roctx.range_end(range_handle) + with torch.cuda.roctx.range("context_region"): + pass + def test_bincount_ext(self): # ensure CUDA code coverage input_size = (100000,) diff --git a/test/test_roctx_standalone.py b/test/test_roctx_standalone.py new file mode 100644 index 0000000000000..9399d7205f90b --- /dev/null +++ b/test/test_roctx_standalone.py @@ -0,0 +1,96 @@ +#!/usr/bin/env python3 +""" +Standalone ROCTX smoke test. Run with: + python test/test_roctx_standalone.py + +On a ROCm build this exercises torch.cuda.roctx and emit_roctx. +On a non-ROCm build, ROCTX API is skipped (or stub raises); NVTX test still runs if CUDA. +""" +import sys + +import torch + + +def test_roctx_api(): + """Test manual ROCTX markers (ROCm build only).""" + if not torch.cuda.is_available(): + print("SKIP: CUDA not available") + return True + if not getattr(torch.version, "hip", None): + print("SKIP: Not a ROCm build (torch.version.hip missing); ROCTX is stub-only") + return True + try: + torch.cuda.roctx.range_push("roctx_foo") + torch.cuda.roctx.mark("roctx_bar") + torch.cuda.roctx.range_pop() + rid = torch.cuda.roctx.range_start("roctx_range_start") + torch.cuda.roctx.range_end(rid) + with torch.cuda.roctx.range("roctx_context"): + _ = torch.tensor([1.0], device="cuda") + print("PASS: torch.cuda.roctx API") + return True + except Exception as e: + print(f"FAIL: torch.cuda.roctx: {e}") + return False + + +def test_emit_roctx(): + """Test emit_roctx context manager (ROCm build only).""" + if not torch.cuda.is_available(): + print("SKIP: CUDA not available") + return True + if not getattr(torch.version, "hip", None): + print("SKIP: Not a ROCm build; emit_roctx not exercised") + return True + try: + from torch.autograd.profiler import emit_roctx + a = torch.tensor([1.0, 2.0, 3.0], device="cuda") + with torch.cuda.profiler.profile(): + with emit_roctx(): + a.add_(1.0) + print("PASS: emit_roctx") + return True + except Exception as e: + print(f"FAIL: emit_roctx: {e}") + return False + + +def test_nvtx_api(): + """Test NVTX API (CUDA build) for comparison.""" + if not torch.cuda.is_available(): + print("SKIP: CUDA not available") + return True + try: + torch.cuda.nvtx.range_push("nvtx_foo") + torch.cuda.nvtx.mark("nvtx_bar") + torch.cuda.nvtx.range_pop() + rid = torch.cuda.nvtx.range_start("nvtx_range_start") + torch.cuda.nvtx.range_end(rid) + print("PASS: torch.cuda.nvtx API") + return True + except Exception as e: + print(f"SKIP/FAIL: torch.cuda.nvtx: {e}") + return True # skip is ok on ROCm-only build + + +def main(): + from pathlib import Path + # Allow importing torch from repo root + repo_root = Path(__file__).resolve().parent.parent + if str(repo_root) not in sys.path: + sys.path.insert(0, str(repo_root)) + + print("ROCTX standalone smoke test") + print(f" PyTorch: {torch.__version__}") + print(f" HIP: {getattr(torch.version, 'hip', 'N/A')}") + print(f" CUDA available: {torch.cuda.is_available()}") + ok = True + ok &= test_nvtx_api() + ok &= test_roctx_api() + ok &= test_emit_roctx() + print("Done.") + sys.exit(0 if ok else 1) + + +if __name__ == "__main__": + main() diff --git a/torch/_C/_roctx.pyi b/torch/_C/_roctx.pyi new file mode 100644 index 0000000000000..c3044aa9f3a1a --- /dev/null +++ b/torch/_C/_roctx.pyi @@ -0,0 +1,9 @@ +# mypy: allow-untyped-defs +# Defined in torch/csrc/cuda/shared/roctx.cpp (ROCm builds only) +def rangePushA(message: str) -> int: ... +def rangePop() -> int: ... +def rangeStartA(message: str) -> int: ... +def rangeEnd(range_id: int) -> None: ... +def markA(message: str) -> None: ... +def deviceRangeStart(message: str, stream: int = 0) -> object: ... +def deviceRangeEnd(range_handle: object, stream: int = 0) -> None: ... diff --git a/torch/_dynamo/trace_rules.py b/torch/_dynamo/trace_rules.py index fa1511ef3d334..f5684267e2ad6 100644 --- a/torch/_dynamo/trace_rules.py +++ b/torch/_dynamo/trace_rules.py @@ -2679,6 +2679,12 @@ "torch.cuda.nvtx.range_push", "torch.cuda.nvtx.range_start", "torch.cuda.nvtx.range", + "torch.cuda.roctx.mark", + "torch.cuda.roctx.range_end", + "torch.cuda.roctx.range_pop", + "torch.cuda.roctx.range_push", + "torch.cuda.roctx.range_start", + "torch.cuda.roctx.range", "torch.cuda.power_draw", "torch.cuda.profiler.init", "torch.cuda.profiler.profile", diff --git a/torch/autograd/profiler.py b/torch/autograd/profiler.py index a43249fa566b6..8f410ff32133d 100644 --- a/torch/autograd/profiler.py +++ b/torch/autograd/profiler.py @@ -47,6 +47,7 @@ "record_function", "emit_itt", "emit_nvtx", + "emit_roctx", "load_nvprof", "EnforceUnique", "parse_nvprof_trace", @@ -1112,6 +1113,67 @@ def __exit__(self, exc_type, exc_val, exc_tb): return False +class emit_roctx: + """Context manager that makes every autograd operation emit a ROCTX range. + + This is the ROCm analogue of :class:`emit_nvtx`. Use it when profiling + with ROCm tools (e.g. rocprof, rocprofv3) so that every autograd op is + wrapped in a ROCTX range visible in the trace. + + Only available when PyTorch is built with ROCm (USE_ROCM). Otherwise + enabling this context manager will raise an error. + + Args: + enabled (bool, optional): Setting ``enabled=False`` makes this context manager a no-op. + Default: ``True``. + record_shapes (bool, optional): If ``record_shapes=True``, the roctx range wrapping + each autograd op will append information about the sizes of Tensor arguments. + Default: ``False``. + + Example: + >>> # On a ROCm build: + >>> with torch.cuda.profiler.profile(): + ... model(x) # Warmup + ... with torch.autograd.profiler.emit_roctx(): + ... model(x) + """ + + def __init__(self, enabled=True, record_shapes=False): + self.enabled = enabled + self.entered = False + self.record_shapes = record_shapes + + def __enter__(self): + if not self.enabled: + return + if self.entered: + raise RuntimeError("ROCTX annotation context manager is not reentrant") + self.entered = True + torch.cuda.synchronize() + _run_on_profiler_start() + _enable_profiler( + ProfilerConfig( + ProfilerState.ROCTX, + self.record_shapes, + False, + False, + False, + False, + _ExperimentalConfig(), + ), + set(), + ) + return self + + def __exit__(self, exc_type, exc_val, exc_tb): + if not self.enabled: + return + torch.cuda.synchronize() + _disable_profiler() + _run_on_profiler_stop() + return False + + def load_nvprof(path): """Open an nvprof trace file and parses autograd annotations. diff --git a/torch/csrc/autograd/profiler_kineto.cpp b/torch/csrc/autograd/profiler_kineto.cpp index 0ce905fff7e71..0876473846d73 100644 --- a/torch/csrc/autograd/profiler_kineto.cpp +++ b/torch/csrc/autograd/profiler_kineto.cpp @@ -18,6 +18,7 @@ #include #include #include +#include #include #include @@ -85,14 +86,15 @@ inline bool isKinetoCompatibleState(ProfilerState state) { inline bool isValidDisableState(ProfilerState state) { return isKinetoCompatibleState(state) || state == ProfilerState::KINETO_ONDEMAND || state == ProfilerState::NVTX || - state == ProfilerState::ITT || state == ProfilerState::PRIVATEUSE1; + state == ProfilerState::ROCTX || state == ProfilerState::ITT || + state == ProfilerState::PRIVATEUSE1; } // Helper function to check if ProfilerState uses an external tracer -// (NVTX/ITT/PRIVATEUSE1 - these use their own tracing callbacks, not Kineto) +// (NVTX/ROCTX/ITT/PRIVATEUSE1 - these use their own tracing callbacks, not Kineto) inline bool isExternalTracerState(ProfilerState state) { - return state == ProfilerState::NVTX || state == ProfilerState::ITT || - state == ProfilerState::PRIVATEUSE1; + return state == ProfilerState::NVTX || state == ProfilerState::ROCTX || + state == ProfilerState::ITT || state == ProfilerState::PRIVATEUSE1; } struct OpArgData { @@ -642,6 +644,7 @@ void prepareProfiler( const torch::profiler::impl::ProfilerConfig& config, const std::set& activities) { if (config.state == ProfilerState::NVTX || + config.state == ProfilerState::ROCTX || config.state == ProfilerState::ITT) { return; } @@ -766,6 +769,9 @@ void enableProfilerWithEventPostProcess( TORCH_CHECK( config.state != ProfilerState::NVTX, "NVTX does not support post processing callback."); + TORCH_CHECK( + config.state != ProfilerState::ROCTX, + "ROCTX does not support post processing callback."); TORCH_CHECK( config.state != ProfilerState::ITT, "ITT does not support post processing callback."); @@ -794,6 +800,9 @@ void enableProfiler( case ProfilerState::NVTX: torch::profiler::impl::pushNVTXCallbacks(config, scopes); break; + case ProfilerState::ROCTX: + torch::profiler::impl::pushROCTXCallbacks(config, scopes); + break; case ProfilerState::ITT: torch::profiler::impl::pushITTCallbacks(config, scopes); break; @@ -875,10 +884,11 @@ std::unique_ptr disableProfiler() { return std::make_unique(); } - // Shared among NVTX, PRIVATEUSE1, KINETO, KINETO_GPU_FALLBACK, + // Shared among NVTX, ROCTX, PRIVATEUSE1, KINETO, KINETO_GPU_FALLBACK, // KINETO_PRIVATEUSE1_FALLBACK std::unique_ptr result; if (config.state == ProfilerState::NVTX || + config.state == ProfilerState::ROCTX || config.state == ProfilerState::PRIVATEUSE1) { result = std::make_unique(); } diff --git a/torch/csrc/autograd/profiler_legacy.cpp b/torch/csrc/autograd/profiler_legacy.cpp index 9a4816d5e212b..89751d7ea09b1 100644 --- a/torch/csrc/autograd/profiler_legacy.cpp +++ b/torch/csrc/autograd/profiler_legacy.cpp @@ -22,6 +22,10 @@ #include +#ifdef USE_ROCM +#include +#endif + namespace torch::autograd::profiler { // We decompose the profiler logic into the following components: @@ -195,6 +199,10 @@ void ProfilerLegacyThreadLocalState::mark(std::string name, bool include_cuda) { } if (config_.state == torch::profiler::impl::ProfilerState::NVTX) { torch::profiler::impl::cudaStubs()->mark(name.c_str()); + } else if (config_.state == torch::profiler::impl::ProfilerState::ROCTX) { +#ifdef USE_ROCM + roctxMarkA(name.c_str()); +#endif } else { LegacyEvent evt( EventKind::Mark, @@ -229,6 +237,12 @@ void ProfilerLegacyThreadLocalState::pushRange( torch::profiler::impl::cudaStubs()->rangePush( torch::profiler::impl::getNvtxStr(fn.name(), fn.seqNr(), shapes) .c_str()); + } else if (config_.state == torch::profiler::impl::ProfilerState::ROCTX) { +#ifdef USE_ROCM + roctxRangePushA( + torch::profiler::impl::getNvtxStr(fn.name(), fn.seqNr(), shapes) + .c_str()); +#endif } else { LegacyEvent evt( EventKind::PushRange, @@ -275,6 +289,10 @@ void ProfilerLegacyThreadLocalState::popRange( } if (config_.state == torch::profiler::impl::ProfilerState::NVTX) { torch::profiler::impl::cudaStubs()->rangePop(); + } else if (config_.state == torch::profiler::impl::ProfilerState::ROCTX) { +#ifdef USE_ROCM + roctxRangePop(); +#endif } else { // In some cases RecordFunction (and popRange) may be // called on a different thread than pushRange @@ -415,6 +433,11 @@ void enableProfilerLegacy( new_config.state != torch::profiler::impl::ProfilerState::NVTX || torch::profiler::impl::cudaStubs()->enabled(), "Can't use NVTX profiler - PyTorch was compiled without CUDA"); +#ifndef USE_ROCM + TORCH_CHECK( + new_config.state != torch::profiler::impl::ProfilerState::ROCTX, + "Can't use ROCTX profiler - PyTorch was not compiled with ROCm (USE_ROCM)"); +#endif TORCH_CHECK(new_config.state != torch::profiler::impl::ProfilerState::KINETO); @@ -451,7 +474,8 @@ thread_event_lists disableProfilerLegacy( cleanupTLSState ? state_ptr->removeCallback() : state_ptr->leakHandle(); if (!consolidate || - state_ptr->config().state == torch::profiler::impl::ProfilerState::NVTX) { + state_ptr->config().state == torch::profiler::impl::ProfilerState::NVTX || + state_ptr->config().state == torch::profiler::impl::ProfilerState::ROCTX) { return thread_event_lists(); } diff --git a/torch/csrc/cuda/Module.cpp b/torch/csrc/cuda/Module.cpp index ac982dc5d33d0..6b97c73fe33cf 100644 --- a/torch/csrc/cuda/Module.cpp +++ b/torch/csrc/cuda/Module.cpp @@ -2249,6 +2249,7 @@ namespace shared { void initCudartBindings(PyObject* module); void initNvtxBindings(PyObject* module); +void initRoctxBindings(PyObject* module); #if defined(USE_CUDNN) || defined(USE_ROCM) void initCudnnBindings(PyObject* module); #endif @@ -2264,6 +2265,9 @@ void initModule(PyObject* module) { // so this condition might not always be true... shared::initCudartBindings(module); shared::initNvtxBindings(module); +#if defined(USE_ROCM) + shared::initRoctxBindings(module); +#endif #if defined(USE_CUDNN) || defined(USE_ROCM) shared::initCudnnBindings(module); #endif diff --git a/torch/csrc/cuda/shared/roctx.cpp b/torch/csrc/cuda/shared/roctx.cpp new file mode 100644 index 0000000000000..807f08460919d --- /dev/null +++ b/torch/csrc/cuda/shared/roctx.cpp @@ -0,0 +1,63 @@ +#include + +#include + +#ifdef USE_ROCM +#include +#endif + +namespace torch::cuda::shared { + +#ifdef USE_ROCM + +void initRoctxBindings(PyObject* module) { + auto m = py::handle(module).cast(); + auto roctx = m.def_submodule("_roctx", "ROCTX bindings for ROCm profiling"); + + roctx.def( + "rangePushA", + [](const std::string& msg) { + return roctxRangePushA(msg.c_str()); + }, + py::arg("msg")); + roctx.def("rangePop", []() { return roctxRangePop(); }); + roctx.def( + "rangeStartA", + [](const std::string& msg) { + return static_cast(roctxRangeStartA(msg.c_str())); + }, + py::arg("msg")); + roctx.def( + "rangeEnd", + [](int64_t range_id) { + roctxRangeStop(static_cast(range_id)); + }, + py::arg("range_id")); + roctx.def("markA", [](const std::string& msg) { roctxMarkA(msg.c_str()); }, py::arg("msg")); + + // ROCTX has no stream-callback API; stub to match NVTX API surface + roctx.def( + "deviceRangeStart", + [](const std::string& /* msg */, std::intptr_t /* stream */) { + return py::none(); + }, + py::arg("msg"), + py::arg("stream") = 0); + roctx.def( + "deviceRangeEnd", + [](py::object /* handle */, std::intptr_t /* stream */) {}, + py::arg("range_handle"), + py::arg("stream") = 0); +} + +#else + +void initRoctxBindings(PyObject* module) { + (void)module; + // No-op when not ROCm: _roctx submodule is not registered, + // so torch.cuda.roctx will get ImportError and use a stub. +} + +#endif + +} // namespace torch::cuda::shared diff --git a/torch/csrc/profiler/orchestration/observer.h b/torch/csrc/profiler/orchestration/observer.h index 3b59466e6060f..c58971d8bf345 100644 --- a/torch/csrc/profiler/orchestration/observer.h +++ b/torch/csrc/profiler/orchestration/observer.h @@ -33,6 +33,7 @@ enum class C10_API_ENUM ProfilerState { CPU, // CPU-only profiling CUDA, // CPU + CUDA events NVTX, // only emit NVTX markers + ROCTX, // only emit ROCTX markers (ROCm) ITT, // only emit ITT markers PRIVATEUSE1, // only emit PRIVATEUSE1 markers KINETO, // use libkineto @@ -47,6 +48,7 @@ enum class C10_API_ENUM ActiveProfilerType { LEGACY, KINETO, NVTX, + ROCTX, ITT, PRIVATEUSE1 }; diff --git a/torch/csrc/profiler/python/init.cpp b/torch/csrc/profiler/python/init.cpp index d42af26c6b6a0..055d073eb4169 100644 --- a/torch/csrc/profiler/python/init.cpp +++ b/torch/csrc/profiler/python/init.cpp @@ -341,6 +341,7 @@ void initPythonBindings(PyObject* module) { .value("CPU", ProfilerState::CPU) .value("CUDA", ProfilerState::CUDA) .value("NVTX", ProfilerState::NVTX) + .value("ROCTX", ProfilerState::ROCTX) .value("ITT", ProfilerState::ITT) .value("PRIVATEUSE1", ProfilerState::PRIVATEUSE1) .value("KINETO", ProfilerState::KINETO) @@ -354,6 +355,7 @@ void initPythonBindings(PyObject* module) { .value("LEGACY", ActiveProfilerType::LEGACY) .value("KINETO", ActiveProfilerType::KINETO) .value("NVTX", ActiveProfilerType::NVTX) + .value("ROCTX", ActiveProfilerType::ROCTX) .value("ITT", ActiveProfilerType::ITT) .value("PRIVATEUSE1", ActiveProfilerType::PRIVATEUSE1); diff --git a/torch/csrc/profiler/standalone/roctx_observer.cpp b/torch/csrc/profiler/standalone/roctx_observer.cpp new file mode 100644 index 0000000000000..4c17610612c32 --- /dev/null +++ b/torch/csrc/profiler/standalone/roctx_observer.cpp @@ -0,0 +1,190 @@ +#include + +#include + +#ifdef USE_ROCM +#include +#endif + +namespace torch::profiler::impl { + +#ifdef USE_ROCM + +struct ROCTXThreadLocalState : ProfilerStateBase { + explicit ROCTXThreadLocalState(const ProfilerConfig& config) + : ProfilerStateBase(config) { + TORCH_CHECK(!config.profile_memory); + TORCH_CHECK(!config.with_stack); + TORCH_CHECK(!config.with_flops); + TORCH_CHECK(!config.with_modules); + } + ~ROCTXThreadLocalState() override = default; + + ActiveProfilerType profilerType() override { + return ActiveProfilerType::ROCTX; + } + + void reportMemoryUsage( + void* /*ptr*/, + int64_t /*alloc_size*/, + size_t /*total_allocated*/, + size_t /*total_reserved*/, + c10::Device /*device*/) override {} + + static ROCTXThreadLocalState* getTLS() { + auto tls = ProfilerStateBase::get(/*global=*/false); + TORCH_INTERNAL_ASSERT_DEBUG_ONLY( + tls == nullptr || tls->profilerType() == ActiveProfilerType::ROCTX); + return static_cast(tls); + } + std::pair getOpIdFromInput( + const at::Tensor& tensor); + + void setProducerTensorMap( + at::TensorImpl* tensor, + at::RecordFunctionHandle op_id, + int output_nr) { + producer_tensor_map_[(void*)tensor] = + std::pair{op_id, output_nr}; + } + + protected: + std::unordered_map> + producer_tensor_map_; +}; + +std::pair ROCTXThreadLocalState::getOpIdFromInput( + const at::Tensor& tensor) { + std::pair producer_op_pair(0, -1); + if (tensor.defined()) { + at::TensorImpl* ten_addr = tensor.unsafeGetTensorImpl(); + if (producer_tensor_map_.count((void*)ten_addr) > 0) { + producer_op_pair = producer_tensor_map_[(void*)ten_addr]; + } + } + return producer_op_pair; +} + +static std::list> flattenOpIdListROCTX( + const c10::List& list) { + std::list> input_op_id_list; + auto state_ptr = ROCTXThreadLocalState::getTLS(); + TORCH_INTERNAL_ASSERT(state_ptr, "Expected profiler state set"); + for (const c10::IValue& input : list) { + if (input.isTensor()) { + const at::Tensor& tensor = input.toTensor(); + auto producer_op_pair = state_ptr->getOpIdFromInput(tensor); + input_op_id_list.push_back(producer_op_pair); + } + } + return input_op_id_list; +} + +static std::list> +getInputTensorOpIdsROCTX(const at::RecordFunction& fn) { + std::pair undefined_op_pair(0, -1); + std::list> input_producer_ops_; + auto state_ptr = ROCTXThreadLocalState::getTLS(); + TORCH_INTERNAL_ASSERT(state_ptr, "Expected profiler state set"); + for (const c10::IValue& input_item : fn.inputs()) { + if (input_item.isTensor()) { + const at::Tensor& tensor = input_item.toTensor(); + auto producer_pair = state_ptr->getOpIdFromInput(tensor); + input_producer_ops_.push_back(producer_pair); + } else { + if (input_item.isList()) { + std::list> tmp_op_ids = + flattenOpIdListROCTX(input_item.toList()); + if (!tmp_op_ids.empty()) { + input_producer_ops_.splice(input_producer_ops_.end(), tmp_op_ids); + } else { + input_producer_ops_.emplace_back(undefined_op_pair); + } + } else { + input_producer_ops_.emplace_back(undefined_op_pair); + } + } + } + return input_producer_ops_; +} + +static void updateOutputTensorTrackerROCTX(const at::RecordFunction& fn) { + int output_nr = 0; + auto state_ptr = ROCTXThreadLocalState::getTLS(); + TORCH_INTERNAL_ASSERT(state_ptr, "Expected profiler state set"); + for (const c10::IValue& s_tensor : fn.outputs()) { + if (s_tensor.isTensor()) { + const at::Tensor& tensor = s_tensor.toTensor(); + if (tensor.defined()) { + auto ten_addr = tensor.unsafeGetTensorImpl(); + state_ptr->setProducerTensorMap(ten_addr, fn.handle(), output_nr); + } + } + output_nr++; + } +} + +template +static std::unique_ptr enterROCTX( + const at::RecordFunction& fn) { + if (ROCTXThreadLocalState::getTLS() != nullptr) { + auto input_op_ids = getInputTensorOpIdsROCTX(fn); + std::string name = torch::profiler::impl::getNvtxStr( + fn.name(), + fn.seqNr(), + report_input_shapes ? torch::profiler::impl::inputSizes(fn, true) + : std::vector>(), + fn.handle(), + report_input_shapes + ? input_op_ids + : std::list>()); + roctxRangePushA(name.c_str()); + } + return nullptr; +} + +void pushROCTXCallbacks( + const ProfilerConfig& config, + const std::unordered_set& scopes) { + // Marker visible in rocprof/rocprofv3 --marker-trace: confirms new (standalone) observer is active + roctxMarkA("PyTorch_ROCTX_observer_v2_active"); + c10::ThreadLocalDebugInfo::_push( + c10::DebugInfoKind::PROFILER_STATE, + std::make_shared(config)); + + auto state_ptr = ROCTXThreadLocalState::getTLS(); + TORCH_INTERNAL_ASSERT(state_ptr, "Expected profiler state set"); + + auto handle = at::addThreadLocalCallback( + at::RecordFunctionCallback( + state_ptr->config().report_input_shapes + ? &enterROCTX + : &enterROCTX, + [](const at::RecordFunction& fn, at::ObserverContext* ctx) { + (void)ctx; + roctxRangePop(); + updateOutputTensorTrackerROCTX(fn); + }) + .needsInputs(config.report_input_shapes) + .needsOutputs(config.report_input_shapes) + .needsIds(true) + .scopes(scopes)); + state_ptr->setCallbackHandle(handle); +} + +#else // !USE_ROCM + +void pushROCTXCallbacks( + const ProfilerConfig& config, + const std::unordered_set& scopes) { + (void)config; + (void)scopes; + TORCH_CHECK( + false, + "ROCTX profiler is only available in ROCm builds. " + "Rebuild PyTorch with USE_ROCM=ON."); +} + +#endif // USE_ROCM + +} // namespace torch::profiler::impl diff --git a/torch/csrc/profiler/standalone/roctx_observer.h b/torch/csrc/profiler/standalone/roctx_observer.h new file mode 100644 index 0000000000000..4a325d25eb22c --- /dev/null +++ b/torch/csrc/profiler/standalone/roctx_observer.h @@ -0,0 +1,9 @@ +#include + +namespace torch::profiler::impl { + +void pushROCTXCallbacks( + const ProfilerConfig& config, + const std::unordered_set& scopes); + +} // namespace torch::profiler::impl diff --git a/torch/cuda/__init__.py b/torch/cuda/__init__.py index 894760ebd67b5..b43eed951c368 100644 --- a/torch/cuda/__init__.py +++ b/torch/cuda/__init__.py @@ -1876,7 +1876,7 @@ def _compile_kernel( return getattr(result, mangled_name) -from . import amp, jiterator, nvtx, profiler, sparse, tunable +from . import amp, jiterator, nvtx, roctx, profiler, sparse, tunable _POOL_HANDLE = NewType("_POOL_HANDLE", tuple[int, int]) @@ -1995,6 +1995,7 @@ def _compile_kernel( "clock_rate", "nccl", "nvtx", + "roctx", "profiler", "random", "reset_accumulated_host_memory_stats", diff --git a/torch/cuda/roctx.py b/torch/cuda/roctx.py new file mode 100644 index 0000000000000..7b8a9dd7b5a04 --- /dev/null +++ b/torch/cuda/roctx.py @@ -0,0 +1,129 @@ +# mypy: allow-untyped-defs +r"""This package adds support for AMD ROCTX (ROCm Tools Extension) used in profiling. + +Mirrors the API of :mod:`torch.cuda.nvtx` so that the same code can be used +on ROCm with ROCTX markers visible to rocprof, rocprofv3, etc. +""" + +from contextlib import contextmanager + + +try: + from torch._C import _roctx +except ImportError: + + class _ROCTXStub: + @staticmethod + def _fail(*args, **kwargs): + raise RuntimeError( + "ROCTX functions not installed. Are you sure you have a ROCm build?" + ) + + rangePushA = _fail + rangePop = _fail + markA = _fail + rangeStartA = _fail + rangeEnd = _fail + deviceRangeStart = _fail + deviceRangeEnd = _fail + + _roctx = _ROCTXStub() # type: ignore[assignment] + + +__all__ = ["range_push", "range_pop", "range_start", "range_end", "mark", "range"] + + +def range_push(msg): + """ + Push a range onto a stack of nested range span. Returns zero-based depth of the range that is started. + + Args: + msg (str): ASCII message to associate with range + """ + return _roctx.rangePushA(msg) + + +def range_pop(): + """Pop a range off of a stack of nested range spans. Returns the zero-based depth of the range that is ended.""" + return _roctx.rangePop() + + +def range_start(msg) -> int: + """ + Mark the start of a range with string message. It returns an unique handle + for this range to pass to the corresponding call to range_end(). + + A key difference between this and range_push/range_pop is that the + range_start/range_end version supports range across threads (start on one + thread and end on another thread). + + Returns: A range handle (uint64_t) that can be passed to range_end(). + + Args: + msg (str): ASCII message to associate with the range. + """ + return _roctx.rangeStartA(msg) + + +def range_end(range_id) -> None: + """ + Mark the end of a range for a given range_id. + + Args: + range_id (int): an unique handle for the start range. + """ + _roctx.rangeEnd(range_id) + + +def _device_range_start(msg: str, stream: int = 0) -> object: + """ + Marks the start of a range with string message. + It returns an opaque heap-allocated handle for this range + to pass to the corresponding call to _device_range_end(). + + On ROCm, ROCTX has no stream-callback API; this is a no-op and returns None. + + Args: + msg (str): ASCII message to associate with the range. + stream (int): HIP stream id. + """ + return _roctx.deviceRangeStart(msg, stream) + + +def _device_range_end(range_handle: object, stream: int = 0) -> None: + """ + Mark the end of a range for a given range_handle. + On ROCm, ROCTX has no stream-callback API; this is a no-op. + + Args: + range_handle: an unique handle for the start range. + stream (int): HIP stream id. + """ + _roctx.deviceRangeEnd(range_handle, stream) + + +def mark(msg): + """ + Describe an instantaneous event that occurred at some point. + + Args: + msg (str): ASCII message to associate with the event. + """ + return _roctx.markA(msg) + + +@contextmanager +def range(msg, *args, **kwargs): + """ + Context manager / decorator that pushes a ROCTX range at the beginning + of its scope, and pops it at the end. If extra arguments are given, + they are passed as arguments to msg.format(). + + Args: + msg (str): message to associate with the range + """ + range_push(msg.format(*args, **kwargs)) + try: + yield + finally: + range_pop()