Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions build_variables.bzl
Original file line number Diff line number Diff line change
Expand Up @@ -111,6 +111,7 @@ libtorch_profiler_sources = [
"torch/csrc/profiler/standalone/execution_trace_observer.cpp",
"torch/csrc/profiler/standalone/itt_observer.cpp",
"torch/csrc/profiler/standalone/nvtx_observer.cpp",
"torch/csrc/profiler/standalone/roctx_observer.cpp",
"torch/csrc/profiler/standalone/privateuse1_observer.cpp",
"torch/csrc/profiler/stubs/base.cpp",
"torch/csrc/profiler/orchestration/vulkan.cpp",
Expand Down Expand Up @@ -870,6 +871,7 @@ libtorch_python_cuda_core_sources = [
"torch/csrc/cuda/GreenContext.cpp",
"torch/csrc/cuda/shared/cudart.cpp",
"torch/csrc/cuda/shared/nvtx.cpp",
"torch/csrc/cuda/shared/roctx.cpp",
"torch/csrc/cuda/utils.cpp",
"torch/csrc/cuda/GdsFile.cpp",
]
Expand Down
4 changes: 4 additions & 0 deletions caffe2/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -1438,6 +1438,10 @@ if(USE_ROCM)
USE_ROCM
__HIP_PLATFORM_AMD__
)
# torch_cpu contains profiler (e.g. profiler_legacy.cpp) that calls roctx*; link libroctx
if(ROCM_ROCTX_LIB)
target_link_libraries(torch_cpu PRIVATE ${ROCM_ROCTX_LIB})
endif()

if(NOT ROCM_SOURCE_DIR)
set(ROCM_SOURCE_DIR "$ENV{ROCM_SOURCE_DIR}")
Expand Down
6 changes: 5 additions & 1 deletion cmake/public/LoadHIP.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -206,8 +206,12 @@ if(HIP_FOUND)
list(REMOVE_DUPLICATES ROCM_INCLUDE_DIRS)

if(UNIX)
# roctx is part of roctracer
# roctx is part of roctracer (header needed for profiler_legacy.cpp and kineto)
find_library(ROCM_ROCTX_LIB roctx64 HINTS ${ROCM_PATH}/lib)
set(ROCTRACER_INCLUDE_DIR "${ROCM_PATH}/include/roctracer")
if(EXISTS "${ROCTRACER_INCLUDE_DIR}/roctx.h")
list(APPEND ROCM_INCLUDE_DIRS ${ROCTRACER_INCLUDE_DIR})
endif()

set(PROJECT_RANDOM_BINARY_DIR "${PROJECT_BINARY_DIR}")

Expand Down
8 changes: 7 additions & 1 deletion docs/source/autograd.md
Original file line number Diff line number Diff line change
Expand Up @@ -289,7 +289,9 @@ Autograd includes a profiler that lets you inspect the cost of different
operators inside your model - both on the CPU and GPU. There are three modes
implemented at the moment - CPU-only using {class}`~torch.autograd.profiler.profile`.
nvprof based (registers both CPU and GPU activity) using
{class}`~torch.autograd.profiler.emit_nvtx`.
{class}`~torch.autograd.profiler.emit_nvtx`,
ROCm tools (rocprof, rocprofv3) using
{class}`~torch.autograd.profiler.emit_roctx`,
and vtune profiler based using
{class}`~torch.autograd.profiler.emit_itt`.

Expand Down Expand Up @@ -320,6 +322,10 @@ and vtune profiler based using
.. autoclass:: torch.autograd.profiler.emit_nvtx
```

```{eval-rst}
.. autoclass:: torch.autograd.profiler.emit_roctx
```

```{eval-rst}
.. autoclass:: torch.autograd.profiler.emit_itt

Expand Down
21 changes: 21 additions & 0 deletions docs/source/cuda.md
Original file line number Diff line number Diff line change
Expand Up @@ -189,6 +189,23 @@
nvtx.range
```

## ROCm Tools Extension (ROCTX)

ROCTX is the ROCm analogue of NVTX. Use it when profiling with ROCm tools
(e.g. rocprof, rocprofv3) so that markers and ranges appear in the trace.
Only available when PyTorch is built with ROCm (USE_ROCM).

```{eval-rst}
.. autosummary::
:toctree: generated
:nosignatures:

roctx.mark
roctx.range_push
roctx.range_pop
roctx.range
```

## Jiterator (beta)

```{eval-rst}
Expand Down Expand Up @@ -308,6 +325,10 @@ See the docs for {class}`~torch.cuda.green_contexts.GreenContext` for an example
.. py:module:: torch.cuda.nvtx
```

```{eval-rst}
.. py:module:: torch.cuda.roctx
```

```{eval-rst}
.. py:module:: torch.cuda.profiler
```
Expand Down
11 changes: 10 additions & 1 deletion test/test_autograd.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@
)
from torch.autograd.function import InplaceFunction, once_differentiable
from torch.autograd.graph import GradientEdge
from torch.autograd.profiler import emit_itt, emit_nvtx, profile, record_function
from torch.autograd.profiler import emit_itt, emit_nvtx, emit_roctx, profile, record_function
from torch.autograd.profiler_util import (
_format_time,
EventList,
Expand Down Expand Up @@ -12346,6 +12346,15 @@ def test_profiler_emit_nvtx(self, device):
with emit_nvtx():
a.add(1.0)

@onlyCUDA
@unittest.skipIf(not torch.version.hip, "emit_roctx requires ROCm build")
def test_profiler_emit_roctx(self, device):
# Mirror of test_profiler_emit_nvtx: catch if emit_roctx breaks on construction (ROCm only).
a = torch.tensor([1, 2, 3], dtype=torch.float32, device=device)
with torch.cuda.profiler.profile():
with emit_roctx():
a.add(1.0)

@onlyCUDA
def test_rnn_backward_to_input_but_not_parameters(self, device):
# this checks whether it is possible to not require
Expand Down
14 changes: 14 additions & 0 deletions test/test_cuda.py
Original file line number Diff line number Diff line change
Expand Up @@ -1607,6 +1607,20 @@ def test_nvtx(self):
range_handle = torch.cuda.nvtx.range_start("range_start")
torch.cuda.nvtx.range_end(range_handle)

@unittest.skipIf(
not torch.cuda.is_available() or not torch.version.hip,
"ROCTX tests require CUDA (HIP/ROCm) build",
)
def test_roctx(self):
# Mirror of test_nvtx: ensure roctx symbols work (ROCm build only).
torch.cuda.roctx.range_push("foo")
torch.cuda.roctx.mark("bar")
torch.cuda.roctx.range_pop()
range_handle = torch.cuda.roctx.range_start("range_start")
torch.cuda.roctx.range_end(range_handle)
with torch.cuda.roctx.range("context_region"):
pass

def test_bincount_ext(self):
# ensure CUDA code coverage
input_size = (100000,)
Expand Down
96 changes: 96 additions & 0 deletions test/test_roctx_standalone.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,96 @@
#!/usr/bin/env python3
"""
Standalone ROCTX smoke test. Run with:
python test/test_roctx_standalone.py

On a ROCm build this exercises torch.cuda.roctx and emit_roctx.
On a non-ROCm build, ROCTX API is skipped (or stub raises); NVTX test still runs if CUDA.
"""
import sys

import torch


def test_roctx_api():
"""Test manual ROCTX markers (ROCm build only)."""
if not torch.cuda.is_available():
print("SKIP: CUDA not available")
return True
if not getattr(torch.version, "hip", None):
print("SKIP: Not a ROCm build (torch.version.hip missing); ROCTX is stub-only")
return True
try:
torch.cuda.roctx.range_push("roctx_foo")
torch.cuda.roctx.mark("roctx_bar")
torch.cuda.roctx.range_pop()
rid = torch.cuda.roctx.range_start("roctx_range_start")
torch.cuda.roctx.range_end(rid)
with torch.cuda.roctx.range("roctx_context"):
_ = torch.tensor([1.0], device="cuda")
print("PASS: torch.cuda.roctx API")
return True
except Exception as e:
print(f"FAIL: torch.cuda.roctx: {e}")
return False


def test_emit_roctx():
"""Test emit_roctx context manager (ROCm build only)."""
if not torch.cuda.is_available():
print("SKIP: CUDA not available")
return True
if not getattr(torch.version, "hip", None):
print("SKIP: Not a ROCm build; emit_roctx not exercised")
return True
try:
from torch.autograd.profiler import emit_roctx
a = torch.tensor([1.0, 2.0, 3.0], device="cuda")
with torch.cuda.profiler.profile():
with emit_roctx():
a.add_(1.0)
print("PASS: emit_roctx")
return True
except Exception as e:
print(f"FAIL: emit_roctx: {e}")
return False


def test_nvtx_api():
"""Test NVTX API (CUDA build) for comparison."""
if not torch.cuda.is_available():
print("SKIP: CUDA not available")
return True
try:
torch.cuda.nvtx.range_push("nvtx_foo")
torch.cuda.nvtx.mark("nvtx_bar")
torch.cuda.nvtx.range_pop()
rid = torch.cuda.nvtx.range_start("nvtx_range_start")
torch.cuda.nvtx.range_end(rid)
print("PASS: torch.cuda.nvtx API")
return True
except Exception as e:
print(f"SKIP/FAIL: torch.cuda.nvtx: {e}")
return True # skip is ok on ROCm-only build


def main():
from pathlib import Path
# Allow importing torch from repo root
repo_root = Path(__file__).resolve().parent.parent
if str(repo_root) not in sys.path:
sys.path.insert(0, str(repo_root))

print("ROCTX standalone smoke test")
print(f" PyTorch: {torch.__version__}")
print(f" HIP: {getattr(torch.version, 'hip', 'N/A')}")
print(f" CUDA available: {torch.cuda.is_available()}")
ok = True
ok &= test_nvtx_api()
ok &= test_roctx_api()
ok &= test_emit_roctx()
print("Done.")
sys.exit(0 if ok else 1)


if __name__ == "__main__":
main()
9 changes: 9 additions & 0 deletions torch/_C/_roctx.pyi
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
# mypy: allow-untyped-defs
# Defined in torch/csrc/cuda/shared/roctx.cpp (ROCm builds only)
def rangePushA(message: str) -> int: ...
def rangePop() -> int: ...
def rangeStartA(message: str) -> int: ...
def rangeEnd(range_id: int) -> None: ...
def markA(message: str) -> None: ...
def deviceRangeStart(message: str, stream: int = 0) -> object: ...
def deviceRangeEnd(range_handle: object, stream: int = 0) -> None: ...
6 changes: 6 additions & 0 deletions torch/_dynamo/trace_rules.py
Original file line number Diff line number Diff line change
Expand Up @@ -2679,6 +2679,12 @@
"torch.cuda.nvtx.range_push",
"torch.cuda.nvtx.range_start",
"torch.cuda.nvtx.range",
"torch.cuda.roctx.mark",
"torch.cuda.roctx.range_end",
"torch.cuda.roctx.range_pop",
"torch.cuda.roctx.range_push",
"torch.cuda.roctx.range_start",
"torch.cuda.roctx.range",
"torch.cuda.power_draw",
"torch.cuda.profiler.init",
"torch.cuda.profiler.profile",
Expand Down
62 changes: 62 additions & 0 deletions torch/autograd/profiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@
"record_function",
"emit_itt",
"emit_nvtx",
"emit_roctx",
"load_nvprof",
"EnforceUnique",
"parse_nvprof_trace",
Expand Down Expand Up @@ -1112,6 +1113,67 @@ def __exit__(self, exc_type, exc_val, exc_tb):
return False


class emit_roctx:
"""Context manager that makes every autograd operation emit a ROCTX range.

This is the ROCm analogue of :class:`emit_nvtx`. Use it when profiling
with ROCm tools (e.g. rocprof, rocprofv3) so that every autograd op is
wrapped in a ROCTX range visible in the trace.

Only available when PyTorch is built with ROCm (USE_ROCM). Otherwise
enabling this context manager will raise an error.

Args:
enabled (bool, optional): Setting ``enabled=False`` makes this context manager a no-op.
Default: ``True``.
record_shapes (bool, optional): If ``record_shapes=True``, the roctx range wrapping
each autograd op will append information about the sizes of Tensor arguments.
Default: ``False``.

Example:
>>> # On a ROCm build:
>>> with torch.cuda.profiler.profile():
... model(x) # Warmup
... with torch.autograd.profiler.emit_roctx():
... model(x)
"""

def __init__(self, enabled=True, record_shapes=False):
self.enabled = enabled
self.entered = False
self.record_shapes = record_shapes

def __enter__(self):
if not self.enabled:
return
if self.entered:
raise RuntimeError("ROCTX annotation context manager is not reentrant")
self.entered = True
torch.cuda.synchronize()
_run_on_profiler_start()
_enable_profiler(
ProfilerConfig(
ProfilerState.ROCTX,
self.record_shapes,
False,
False,
False,
False,
_ExperimentalConfig(),
),
set(),
)
return self

def __exit__(self, exc_type, exc_val, exc_tb):
if not self.enabled:
return
torch.cuda.synchronize()
_disable_profiler()
_run_on_profiler_stop()
return False


def load_nvprof(path):
"""Open an nvprof trace file and parses autograd annotations.

Expand Down
Loading