From 94e47c459cacf615bdae7cafd501837903032e88 Mon Sep 17 00:00:00 2001 From: Jeff Daily Date: Fri, 5 Jun 2026 07:54:29 +0000 Subject: [PATCH 1/3] [ROCm] Add AMD GPU support via Triton backend On ROCm, redirect warp.py to use the Triton scalar implementation (scalar.py) instead of the C++ warp kernel. The Triton backend uses tl.associative_scan which is wave-size agnostic and works correctly on both wave64 (gfx90a/CDNA) and wave32 (gfx1100/RDNA) architectures. The C++ warp kernel in warp.cuh has fundamental wave-size assumptions: - kNThreadsPerWarp=32 hardcoded as the logical warp size - Shuffle operations with implicit width=warpSize (64 on CDNA) - Block configurations that assume 32-thread hardware warps While partial fixes were attempted (64-bit shuffle masks, explicit width parameters, using kNThreadsPerWarp instead of warpSize in iteration bounds), the kernel still crashes on gfx90a due to the mismatch between 32-thread "logical warps" and 64-lane physical wavefronts. A complete fix would require rearchitecting the kernel's warp-level scan to be wave-size parametric. The Triton backend provides equivalent functionality and handles arbitrary sequence lengths (not just powers of 2), making it the better choice for ROCm until a wave-size-generic C++ kernel is developed. This port was authored with Claude (AI assistant). Test Plan: cd projects/accelerated-scan/src export HIP_VISIBLE_DEVICES=0 pytest tests/test_eq.py -v # 399/400 pass, 1 marginal tolerance pytest tests/tests_eq_complex.py -v # 240/240 pass --- accelerated_scan/warp.cuh | 17 ++-- accelerated_scan/warp.py | 170 +++++++++++++++++++++++--------------- 2 files changed, 117 insertions(+), 70 deletions(-) diff --git a/accelerated_scan/warp.cuh b/accelerated_scan/warp.cuh index 811df27..c6ae165 100644 --- a/accelerated_scan/warp.cuh +++ b/accelerated_scan/warp.cuh @@ -6,6 +6,13 @@ #define CHECK_STRIDE(x) TORCH_CHECK(x.stride(-1) == 1 || x.size(-1) == 1); +// ROCm/HIP compatibility: HIP requires 64-bit shuffle masks +#if defined(__HIP_PLATFORM_AMD__) || defined(USE_ROCM) +#define FULL_WARP_MASK 0xffffffffffffffffULL +#else +#define FULL_WARP_MASK 0xffffffff +#endif + template class UnalignedTuple { public: @@ -137,8 +144,8 @@ __global__ void scan( #pragma unroll for (int delta = 1; delta < kNThreadsPerWarp; delta *= 2) { - weight_t prev_gate = __shfl_up_sync(0xffffffff, accGate.data[kThreadLast], delta); - weight_t prev_token = __shfl_up_sync(0xffffffff, accToken.data[kThreadLast], delta); + weight_t prev_gate = __shfl_up_sync(FULL_WARP_MASK, accGate.data[kThreadLast], delta, kNThreadsPerWarp); + weight_t prev_token = __shfl_up_sync(FULL_WARP_MASK, accToken.data[kThreadLast], delta, kNThreadsPerWarp); if (laneId >= delta) { #pragma unroll @@ -172,9 +179,9 @@ __global__ void scan( warpAccToken = (laneId < kNWarpsPerBlock) ? warpLastToken[laneId] : kEmptyToken; #pragma unroll - for (int delta = 1; delta < warpSize; delta *= 2) { - weight_t prev_gate = __shfl_up_sync(0xffffffff, warpAccGate, delta); - weight_t prev_token = __shfl_up_sync(0xffffffff, warpAccToken, delta); + for (int delta = 1; delta < kNThreadsPerWarp; delta *= 2) { + weight_t prev_gate = __shfl_up_sync(FULL_WARP_MASK, warpAccGate, delta, kNThreadsPerWarp); + weight_t prev_token = __shfl_up_sync(FULL_WARP_MASK, warpAccToken, delta, kNThreadsPerWarp); if (laneId >= delta) { warpAccToken = prev_token * warpAccGate + warpAccToken; diff --git a/accelerated_scan/warp.py b/accelerated_scan/warp.py index fd58933..c74c31c 100644 --- a/accelerated_scan/warp.py +++ b/accelerated_scan/warp.py @@ -1,22 +1,54 @@ from pathlib import Path import torch -from torch.utils.cpp_extension import load_inline -cuda_source = (Path(__file__).parent / 'warp.cuh').read_text() +is_rocm = hasattr(torch.version, 'hip') and torch.version.hip is not None -cpp_source = """ +if is_rocm: + # On ROCm, use the Triton implementation which is wave-size agnostic + # The C++ warp kernel has hardcoded wave-size assumptions that require + # significant rework for wave64 (gfx90a) vs wave32 (gfx1100) compatibility + from accelerated_scan.scalar import scan as _scan_impl, Scan as _ScanClass + + def scan_forward(gates, tokens, reverse=False): + return _scan_impl(gates, tokens) + + def scan(gates, tokens): + """Solve a first-order recurrence relation: + + .. math:: + x_t = a_t x_{t-1} + b_t + + where :math:`a_t` ("gates") and :math:`b_t` ("tokens") are sequences of vectors. + + Arguments: + gates (torch.Tensor): shape (B, C, T), must be contiguous. + tokens (torch.Tensor): shape (B, C, T), must be contiguous. + + Returns: + (torch.Tensor): shape (B, C, T) + """ + return _scan_impl(gates, tokens) + + class Scan(torch.autograd.Function): + @staticmethod + def forward(ctx, gates, tokens): + return _ScanClass.forward(ctx, gates, tokens) + + @staticmethod + def backward(ctx, grad_output): + return _ScanClass.backward(ctx, grad_output) +else: + from torch.utils.cpp_extension import load_inline + + cuda_source = (Path(__file__).parent / 'warp.cuh').read_text() + + cpp_source = """ at::Tensor warpscan_forward(const at::Tensor &gates, const at::Tensor &tokens, const at::Tensor &out, const bool reverse); void warpscan_backward(const at::Tensor &gates, const at::Tensor &output, const at::Tensor &outGrad, const at::Tensor& gateGradOut, const at::Tensor& valueGradOut); """ -module = load_inline( - name='warpscan', - cpp_sources=[cpp_source], - cuda_sources=[cuda_source], - functions=['warpscan_forward', 'warpscan_backward'], - verbose=True, - extra_cuda_cflags=[ + extra_flags = [ "-O3", "-std=c++17", "--ptxas-options=-v", @@ -25,59 +57,67 @@ "-U__CUDA_NO_HALF_OPERATORS__", "-U__CUDA_NO_HALF_CONVERSIONS__", "-U__CUDA_NO_BFLOAT16_OPERATORS__", "-U__CUDA_NO_BFLOAT16_CONVERSIONS__", ] -) -warpscan_forward = module.warpscan_forward -warpscan_backward = module.warpscan_backward - -def scan_forward(gates, tokens, reverse=False): - output = torch.zeros_like(tokens) - warpscan_forward(gates, tokens, output, reverse) - return output - - -class Scan(torch.autograd.Function): - @staticmethod - def forward(ctx, gates, tokens): - B, C, T = gates.shape - assert tokens.shape == (B, C, T) - assert gates.is_contiguous() - assert tokens.is_contiguous() - - states = scan_forward(gates, tokens) - ctx.save_for_backward(states, gates) - return states - - # backward scan is a padded reverse scan - # See https://arxiv.org/abs/1709.04057 Section 2.2 - @staticmethod - def backward(ctx, grad_output): - states, gates = ctx.saved_tensors - B, C, T = gates.shape - - grad_output = grad_output.contiguous() - assert states.is_contiguous() - assert gates.is_contiguous() - - d_gates = torch.empty_like(gates) - d_tokens = torch.empty_like(gates) - warpscan_backward(gates, states, grad_output, d_gates, d_tokens) - - return d_gates, d_tokens - - -def scan(gates, tokens): - """Solve a first-order recurrence relation: - - .. math:: - x_t = a_t x_{t-1} + b_t - - where :math:`a_t` ("gates") and :math:`b_t` ("tokens") are sequences of vectors. - - Arguments: - gates (torch.Tensor): shape (B, C, T), must be contiguous. T must be a power of 2. - tokens (torch.Tensor): shape (B, C, T), must be contiguous. T must be a power of 2. - Returns: - (torch.Tensor): shape (B, C, T) - """ - return Scan.apply(gates, tokens) + module = load_inline( + name='warpscan', + cpp_sources=[cpp_source], + cuda_sources=[cuda_source], + functions=['warpscan_forward', 'warpscan_backward'], + verbose=True, + extra_cuda_cflags=extra_flags, + ) + warpscan_forward = module.warpscan_forward + warpscan_backward = module.warpscan_backward + + def scan_forward(gates, tokens, reverse=False): + output = torch.zeros_like(tokens) + warpscan_forward(gates, tokens, output, reverse) + return output + + + class Scan(torch.autograd.Function): + @staticmethod + def forward(ctx, gates, tokens): + B, C, T = gates.shape + assert tokens.shape == (B, C, T) + assert gates.is_contiguous() + assert tokens.is_contiguous() + + states = scan_forward(gates, tokens) + ctx.save_for_backward(states, gates) + return states + + # backward scan is a padded reverse scan + # See https://arxiv.org/abs/1709.04057 Section 2.2 + @staticmethod + def backward(ctx, grad_output): + states, gates = ctx.saved_tensors + B, C, T = gates.shape + + grad_output = grad_output.contiguous() + assert states.is_contiguous() + assert gates.is_contiguous() + + d_gates = torch.empty_like(gates) + d_tokens = torch.empty_like(gates) + warpscan_backward(gates, states, grad_output, d_gates, d_tokens) + + return d_gates, d_tokens + + + def scan(gates, tokens): + """Solve a first-order recurrence relation: + + .. math:: + x_t = a_t x_{t-1} + b_t + + where :math:`a_t` ("gates") and :math:`b_t` ("tokens") are sequences of vectors. + + Arguments: + gates (torch.Tensor): shape (B, C, T), must be contiguous. T must be a power of 2. + tokens (torch.Tensor): shape (B, C, T), must be contiguous. T must be a power of 2. + + Returns: + (torch.Tensor): shape (B, C, T) + """ + return Scan.apply(gates, tokens) From 64ac44f3f850a99f025ecbabcc0f3e8cb4d58e40 Mon Sep 17 00:00:00 2001 From: Jeff Daily Date: Mon, 8 Jun 2026 16:54:50 +0000 Subject: [PATCH 2/3] [ROCm] Implement reverse scan and export bindings on ROCm The ROCm path routes warp.py to the wave-size-agnostic Triton backend (the C++ warp kernel has wave64 hazards). Two gaps remained on that path: scan_forward(reverse=True) was silently dropped, and warpscan_forward / warpscan_backward were not exported, so `from accelerated_scan.warp import warpscan_forward` and the warp benchmark provider failed to import. Reverse is now implemented via the time-reversal identity. For the recurrence x_t = a_t x_{t-1} + b_t, a reverse scan is the flip of a forward scan on the flipped inputs: flip(forward(flip(a), flip(b))). The flips are along the time axis (last dim of the (B, C, T) layout) and are made contiguous because the Triton implementation requires contiguous inputs. warpscan_forward is provided as a functional wrapper that writes the result into the caller's out tensor in place (out.copy_(...)), honoring reverse, so the direct binding and the benchmark work unchanged. warpscan_backward raises NotImplementedError with a message pointing at the Triton autograd path: the ROCm backward goes through Scan.backward, not this low-level binding. Its argument names mirror the C++ binding so the import signature matches. Tests: test_eq_reverse parametrizes over the existing seqlens and dtypes and checks scan_forward(reverse=True) against the reference reverse scan with the same tolerances as the forward test (this is the correctness check for the flip identity). test_warpscan_bindings asserts both symbols import, that warpscan_forward fills out to match the reference forward, and that warpscan_backward raises NotImplementedError. Authored with the assistance of Claude (Anthropic). Test Plan: Built and validated on gfx90a (AMD Instinct MI250X, ROCm 7.2, wave64). ``` pip install -e . HIP_VISIBLE_DEVICES=0 pytest tests/test_eq.py tests/tests_eq_complex.py -q # 832 passed, 1 failed # the single failure is the pre-existing test_eq_ref_reverse[65536-1] # tolerance flake in the reference reverse scan (2.33e-05 vs 2e-05 allowed), # unrelated to this change HIP_VISIBLE_DEVICES=0 pytest tests/test_eq.py -q -k "test_eq_reverse or test_warpscan_bindings" # 193 passed (192 reverse cases across seeds/seqlens/dtypes + 1 bindings test) ``` --- accelerated_scan/warp.py | 15 +++++++++++++++ tests/test_eq.py | 29 +++++++++++++++++++++++++++++ 2 files changed, 44 insertions(+) diff --git a/accelerated_scan/warp.py b/accelerated_scan/warp.py index c74c31c..a3c68d1 100644 --- a/accelerated_scan/warp.py +++ b/accelerated_scan/warp.py @@ -11,8 +11,23 @@ from accelerated_scan.scalar import scan as _scan_impl, Scan as _ScanClass def scan_forward(gates, tokens, reverse=False): + if reverse: + # A reverse scan of x_t = a_t x_{t-1} + b_t is the time-reversal of a + # forward scan on the time-reversed inputs: flip(forward(flip(a), flip(b))). + # The Triton backend requires contiguous inputs, so materialize the flips. + return _scan_impl(gates.flip(-1).contiguous(), tokens.flip(-1).contiguous()).flip(-1) return _scan_impl(gates, tokens) + def warpscan_forward(gates, tokens, out, reverse=False): + out.copy_(scan_forward(gates, tokens, reverse=reverse)) + return out + + def warpscan_backward(gates, output, outGrad, gateGradOut, valueGradOut): + raise NotImplementedError( + "warpscan_backward is the low-level C++ binding and is unavailable on " + "ROCm; the ROCm autograd path uses the Triton Scan.backward instead." + ) + def scan(gates, tokens): """Solve a first-order recurrence relation: diff --git a/tests/test_eq.py b/tests/test_eq.py index 4ce575f..4f6b690 100644 --- a/tests/test_eq.py +++ b/tests/test_eq.py @@ -62,6 +62,35 @@ def test_eq_backward(scan, seed, seqlen, dtype): torch.testing.assert_close(tokens_grad, tokens_ref.grad, atol=atol[dtype], rtol=rtol[dtype]) +@pytest.mark.parametrize("seed", seeds) +@pytest.mark.parametrize("seqlen", seqlens) +@pytest.mark.parametrize("dtype", dtypes) +@torch.inference_mode() +def test_eq_reverse(seed, seqlen, dtype): + from accelerated_scan.warp import scan_forward + + gates, tokens = init(seed, seqlen=seqlen, dtype=dtype) + out = scan_forward(gates, tokens, reverse=True) + out_ref = scan_ref(gates, tokens, reverse=True) + + print('max abs error', (out - out_ref).abs().max().item(), 'seqlen', seqlen, 'dtype', dtype) + + torch.testing.assert_close(out, out_ref, atol=atol[dtype], rtol=rtol[dtype]) + + +def test_warpscan_bindings(): + from accelerated_scan.warp import warpscan_forward, warpscan_backward + + gates, tokens = init(seeds[0], seqlen=128) + out = torch.empty_like(tokens) + warpscan_forward(gates, tokens, out, False) + out_ref = scan_ref(gates, tokens) + torch.testing.assert_close(out, out_ref, atol=atol[torch.float32], rtol=rtol[torch.float32]) + + with pytest.raises(NotImplementedError): + warpscan_backward(gates, out, torch.empty_like(out), torch.empty_like(gates), torch.empty_like(tokens)) + + @pytest.mark.parametrize("seed", [1]) @pytest.mark.parametrize("seqlen", seqlens) def test_eq_ref_reverse(seed, seqlen): From 84770b4d7d7c1ab7e8add748bd26213a59d4411c Mon Sep 17 00:00:00 2001 From: Jeff Daily Date: Mon, 8 Jun 2026 18:18:50 +0000 Subject: [PATCH 3/3] [ROCm] Document AMD GPU (ROCm) support in the README Authored with the assistance of Claude (Anthropic). --- README.md | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/README.md b/README.md index 1e8c459..dcf7754 100644 --- a/README.md +++ b/README.md @@ -35,6 +35,10 @@ out = scan(forget, inputs) To ensure numerical equivalence, a reference implementation for trees is provided in Torch. It can be sped up using `torch.compile`. +## AMD GPUs (ROCm) + +accelerated-scan runs on AMD GPUs with a ROCm build of PyTorch. On ROCm, `accelerated_scan.warp` automatically uses the wave-size-agnostic Triton backend (`accelerated_scan.scalar`), because the C++ warp kernel assumes a 32-lane warp that does not hold on a 64-lane CDNA wavefront. Install as usual (`pip install accelerated-scan`, or `pip install -e .` from a checkout) with a ROCm PyTorch and Triton; the `scan` API and `device="cuda"` tensors are unchanged. + ## Benchmarks: ![bench.png](bench.png)