[ROCm] Add AMD GPU support via the Triton backend#17
Open
jeffdaily wants to merge 3 commits into
Open
Conversation
On ROCm, redirect warp.py to use the Triton scalar implementation
(scalar.py) instead of the C++ warp kernel. The Triton backend uses
tl.associative_scan which is wave-size agnostic and works correctly
on both wave64 (gfx90a/CDNA) and wave32 (gfx1100/RDNA) architectures.
The C++ warp kernel in warp.cuh has fundamental wave-size assumptions:
- kNThreadsPerWarp=32 hardcoded as the logical warp size
- Shuffle operations with implicit width=warpSize (64 on CDNA)
- Block configurations that assume 32-thread hardware warps
While partial fixes were attempted (64-bit shuffle masks, explicit
width parameters, using kNThreadsPerWarp instead of warpSize in
iteration bounds), the kernel still crashes on gfx90a due to the
mismatch between 32-thread "logical warps" and 64-lane physical
wavefronts. A complete fix would require rearchitecting the kernel's
warp-level scan to be wave-size parametric.
The Triton backend provides equivalent functionality and handles
arbitrary sequence lengths (not just powers of 2), making it the
better choice for ROCm until a wave-size-generic C++ kernel is
developed. This port was authored with Claude (AI assistant).
Test Plan:
cd projects/accelerated-scan/src
export HIP_VISIBLE_DEVICES=0
pytest tests/test_eq.py -v # 399/400 pass, 1 marginal tolerance
pytest tests/tests_eq_complex.py -v # 240/240 pass
The ROCm path routes warp.py to the wave-size-agnostic Triton backend
(the C++ warp kernel has wave64 hazards). Two gaps remained on that path:
scan_forward(reverse=True) was silently dropped, and warpscan_forward /
warpscan_backward were not exported, so `from accelerated_scan.warp import
warpscan_forward` and the warp benchmark provider failed to import.
Reverse is now implemented via the time-reversal identity. For the
recurrence x_t = a_t x_{t-1} + b_t, a reverse scan is the flip of a forward
scan on the flipped inputs: flip(forward(flip(a), flip(b))). The flips are
along the time axis (last dim of the (B, C, T) layout) and are made
contiguous because the Triton implementation requires contiguous inputs.
warpscan_forward is provided as a functional wrapper that writes the result
into the caller's out tensor in place (out.copy_(...)), honoring reverse, so
the direct binding and the benchmark work unchanged. warpscan_backward
raises NotImplementedError with a message pointing at the Triton autograd
path: the ROCm backward goes through Scan.backward, not this low-level
binding. Its argument names mirror the C++ binding so the import signature
matches.
Tests: test_eq_reverse parametrizes over the existing seqlens and dtypes and
checks scan_forward(reverse=True) against the reference reverse scan with the
same tolerances as the forward test (this is the correctness check for the
flip identity). test_warpscan_bindings asserts both symbols import, that
warpscan_forward fills out to match the reference forward, and that
warpscan_backward raises NotImplementedError.
Authored with the assistance of Claude (Anthropic).
Test Plan:
Built and validated on gfx90a (AMD Instinct MI250X, ROCm 7.2, wave64).
```
pip install -e .
HIP_VISIBLE_DEVICES=0 pytest tests/test_eq.py tests/tests_eq_complex.py -q
# 832 passed, 1 failed
# the single failure is the pre-existing test_eq_ref_reverse[65536-1]
# tolerance flake in the reference reverse scan (2.33e-05 vs 2e-05 allowed),
# unrelated to this change
HIP_VISIBLE_DEVICES=0 pytest tests/test_eq.py -q -k "test_eq_reverse or test_warpscan_bindings"
# 193 passed (192 reverse cases across seeds/seqlens/dtypes + 1 bindings test)
```
Authored with the assistance of Claude (Anthropic).
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
This adds AMD GPU support (ROCm/HIP) to accelerated-scan. The custom C++ warp-scan kernel hardcodes a 32-lane warp (shuffle widths, block geometry,
kNThreadsPerWarp), which does not hold on a 64-lane CDNA wavefront and crashed on gfx90a even after partial fixes. Rather than rewrite that kernel for both wave widths, on ROCmwarp.pyroutes to the existing Tritonscalarbackend, whosetl.associative_scanis wave-size agnostic and handles wave32 and wave64. The NVIDIA path (the C++ kernel) is unchanged -- routing is keyed ontorch.version.hip.What changed:
accelerated_scan/warp.py: on ROCm,scan/Scanuse the Triton scalar implementation.scan_forward(reverse=True)is implemented via the identityreverse(a, b) = flip(forward(flip(a), flip(b)))(exact for the recurrencex_t = a_t x_{t-1} + b_t), so reverse scans work on ROCm too.warpscan_forward/warpscan_backwardare exported on ROCm sofrom accelerated_scan.warp import ...andbench.py --provider=warpkeep working (warpscan_forwardis a functional Triton wrapper writing intoout;warpscan_backwardraises a clearNotImplementedError, since the ROCm autograd uses the TritonScan.backward).accelerated_scan/warp.cuh: HIP needs 64-bit shuffle masks, guarded by__HIP_PLATFORM_AMD__so the CUDA path is unchanged.tests/test_eq.py: addstest_eq_reverse(the reverse path now matchesscan_ref(reverse=True)-- previously untested on the warp/triton backends on any platform) andtest_warpscan_bindings.Test Plan:
Validated on AMD Instinct MI250X (gfx90a, CDNA2 wave64), Radeon Pro W7800 (gfx1100, RDNA3 wave32), and Radeon RX 9070 XT (gfx1201, RDNA4 wave32).
test_eq_ref_reverse[65536-1]reference tolerance flake (2.33e-5 vs 2e-5), unrelated to this change.test_eq_reverse(192 cases: seeds x seqlens x {float32, bfloat16, float16}) confirms the flip-identity reverse matches the reference exactly.Authored with the assistance of Claude (Anthropic).