Skip to content

[ROCm] Add AMD GPU support via the Triton backend#17

Open
jeffdaily wants to merge 3 commits into
proger:mainfrom
jeffdaily:moat-port
Open

[ROCm] Add AMD GPU support via the Triton backend#17
jeffdaily wants to merge 3 commits into
proger:mainfrom
jeffdaily:moat-port

Conversation

@jeffdaily

Copy link
Copy Markdown

This adds AMD GPU support (ROCm/HIP) to accelerated-scan. The custom C++ warp-scan kernel hardcodes a 32-lane warp (shuffle widths, block geometry, kNThreadsPerWarp), which does not hold on a 64-lane CDNA wavefront and crashed on gfx90a even after partial fixes. Rather than rewrite that kernel for both wave widths, on ROCm warp.py routes to the existing Triton scalar backend, whose tl.associative_scan is wave-size agnostic and handles wave32 and wave64. The NVIDIA path (the C++ kernel) is unchanged -- routing is keyed on torch.version.hip.

What changed:

  • accelerated_scan/warp.py: on ROCm, scan/Scan use the Triton scalar implementation. scan_forward(reverse=True) is implemented via the identity reverse(a, b) = flip(forward(flip(a), flip(b))) (exact for the recurrence x_t = a_t x_{t-1} + b_t), so reverse scans work on ROCm too. warpscan_forward/warpscan_backward are exported on ROCm so from accelerated_scan.warp import ... and bench.py --provider=warp keep working (warpscan_forward is a functional Triton wrapper writing into out; warpscan_backward raises a clear NotImplementedError, since the ROCm autograd uses the Triton Scan.backward).
  • accelerated_scan/warp.cuh: HIP needs 64-bit shuffle masks, guarded by __HIP_PLATFORM_AMD__ so the CUDA path is unchanged.
  • tests/test_eq.py: adds test_eq_reverse (the reverse path now matches scan_ref(reverse=True) -- previously untested on the warp/triton backends on any platform) and test_warpscan_bindings.

Test Plan:

Validated on AMD Instinct MI250X (gfx90a, CDNA2 wave64), Radeon Pro W7800 (gfx1100, RDNA3 wave32), and Radeon RX 9070 XT (gfx1201, RDNA4 wave32).

pip install -e .
pytest tests/test_eq.py tests/tests_eq_complex.py -q
  • 832/833 pass on gfx90a; the one failure is the pre-existing test_eq_ref_reverse[65536-1] reference tolerance flake (2.33e-5 vs 2e-5), unrelated to this change.
  • test_eq_reverse (192 cases: seeds x seqlens x {float32, bfloat16, float16}) confirms the flip-identity reverse matches the reference exactly.
  • Forward and gradient equality vs the reference hold across sequence lengths 32..131072 and all three dtypes.

Authored with the assistance of Claude (Anthropic).

jeffdaily added 3 commits June 5, 2026 07:54
On ROCm, redirect warp.py to use the Triton scalar implementation
(scalar.py) instead of the C++ warp kernel. The Triton backend uses
tl.associative_scan which is wave-size agnostic and works correctly
on both wave64 (gfx90a/CDNA) and wave32 (gfx1100/RDNA) architectures.

The C++ warp kernel in warp.cuh has fundamental wave-size assumptions:
- kNThreadsPerWarp=32 hardcoded as the logical warp size
- Shuffle operations with implicit width=warpSize (64 on CDNA)
- Block configurations that assume 32-thread hardware warps

While partial fixes were attempted (64-bit shuffle masks, explicit
width parameters, using kNThreadsPerWarp instead of warpSize in
iteration bounds), the kernel still crashes on gfx90a due to the
mismatch between 32-thread "logical warps" and 64-lane physical
wavefronts. A complete fix would require rearchitecting the kernel's
warp-level scan to be wave-size parametric.

The Triton backend provides equivalent functionality and handles
arbitrary sequence lengths (not just powers of 2), making it the
better choice for ROCm until a wave-size-generic C++ kernel is
developed. This port was authored with Claude (AI assistant).

Test Plan:

    cd projects/accelerated-scan/src
    export HIP_VISIBLE_DEVICES=0
    pytest tests/test_eq.py -v  # 399/400 pass, 1 marginal tolerance
    pytest tests/tests_eq_complex.py -v  # 240/240 pass
The ROCm path routes warp.py to the wave-size-agnostic Triton backend
(the C++ warp kernel has wave64 hazards). Two gaps remained on that path:
scan_forward(reverse=True) was silently dropped, and warpscan_forward /
warpscan_backward were not exported, so `from accelerated_scan.warp import
warpscan_forward` and the warp benchmark provider failed to import.

Reverse is now implemented via the time-reversal identity. For the
recurrence x_t = a_t x_{t-1} + b_t, a reverse scan is the flip of a forward
scan on the flipped inputs: flip(forward(flip(a), flip(b))). The flips are
along the time axis (last dim of the (B, C, T) layout) and are made
contiguous because the Triton implementation requires contiguous inputs.

warpscan_forward is provided as a functional wrapper that writes the result
into the caller's out tensor in place (out.copy_(...)), honoring reverse, so
the direct binding and the benchmark work unchanged. warpscan_backward
raises NotImplementedError with a message pointing at the Triton autograd
path: the ROCm backward goes through Scan.backward, not this low-level
binding. Its argument names mirror the C++ binding so the import signature
matches.

Tests: test_eq_reverse parametrizes over the existing seqlens and dtypes and
checks scan_forward(reverse=True) against the reference reverse scan with the
same tolerances as the forward test (this is the correctness check for the
flip identity). test_warpscan_bindings asserts both symbols import, that
warpscan_forward fills out to match the reference forward, and that
warpscan_backward raises NotImplementedError.

Authored with the assistance of Claude (Anthropic).

Test Plan:

Built and validated on gfx90a (AMD Instinct MI250X, ROCm 7.2, wave64).

```
pip install -e .
HIP_VISIBLE_DEVICES=0 pytest tests/test_eq.py tests/tests_eq_complex.py -q
# 832 passed, 1 failed
# the single failure is the pre-existing test_eq_ref_reverse[65536-1]
# tolerance flake in the reference reverse scan (2.33e-05 vs 2e-05 allowed),
# unrelated to this change

HIP_VISIBLE_DEVICES=0 pytest tests/test_eq.py -q -k "test_eq_reverse or test_warpscan_bindings"
# 193 passed (192 reverse cases across seeds/seqlens/dtypes + 1 bindings test)
```
Authored with the assistance of Claude (Anthropic).
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant