diff --git a/README.md b/README.md index 1e8c459..dcf7754 100644 --- a/README.md +++ b/README.md @@ -35,6 +35,10 @@ out = scan(forget, inputs) To ensure numerical equivalence, a reference implementation for trees is provided in Torch. It can be sped up using `torch.compile`. +## AMD GPUs (ROCm) + +accelerated-scan runs on AMD GPUs with a ROCm build of PyTorch. On ROCm, `accelerated_scan.warp` automatically uses the wave-size-agnostic Triton backend (`accelerated_scan.scalar`), because the C++ warp kernel assumes a 32-lane warp that does not hold on a 64-lane CDNA wavefront. Install as usual (`pip install accelerated-scan`, or `pip install -e .` from a checkout) with a ROCm PyTorch and Triton; the `scan` API and `device="cuda"` tensors are unchanged. + ## Benchmarks: ![bench.png](bench.png) diff --git a/accelerated_scan/warp.cuh b/accelerated_scan/warp.cuh index 811df27..c6ae165 100644 --- a/accelerated_scan/warp.cuh +++ b/accelerated_scan/warp.cuh @@ -6,6 +6,13 @@ #define CHECK_STRIDE(x) TORCH_CHECK(x.stride(-1) == 1 || x.size(-1) == 1); +// ROCm/HIP compatibility: HIP requires 64-bit shuffle masks +#if defined(__HIP_PLATFORM_AMD__) || defined(USE_ROCM) +#define FULL_WARP_MASK 0xffffffffffffffffULL +#else +#define FULL_WARP_MASK 0xffffffff +#endif + template class UnalignedTuple { public: @@ -137,8 +144,8 @@ __global__ void scan( #pragma unroll for (int delta = 1; delta < kNThreadsPerWarp; delta *= 2) { - weight_t prev_gate = __shfl_up_sync(0xffffffff, accGate.data[kThreadLast], delta); - weight_t prev_token = __shfl_up_sync(0xffffffff, accToken.data[kThreadLast], delta); + weight_t prev_gate = __shfl_up_sync(FULL_WARP_MASK, accGate.data[kThreadLast], delta, kNThreadsPerWarp); + weight_t prev_token = __shfl_up_sync(FULL_WARP_MASK, accToken.data[kThreadLast], delta, kNThreadsPerWarp); if (laneId >= delta) { #pragma unroll @@ -172,9 +179,9 @@ __global__ void scan( warpAccToken = (laneId < kNWarpsPerBlock) ? warpLastToken[laneId] : kEmptyToken; #pragma unroll - for (int delta = 1; delta < warpSize; delta *= 2) { - weight_t prev_gate = __shfl_up_sync(0xffffffff, warpAccGate, delta); - weight_t prev_token = __shfl_up_sync(0xffffffff, warpAccToken, delta); + for (int delta = 1; delta < kNThreadsPerWarp; delta *= 2) { + weight_t prev_gate = __shfl_up_sync(FULL_WARP_MASK, warpAccGate, delta, kNThreadsPerWarp); + weight_t prev_token = __shfl_up_sync(FULL_WARP_MASK, warpAccToken, delta, kNThreadsPerWarp); if (laneId >= delta) { warpAccToken = prev_token * warpAccGate + warpAccToken; diff --git a/accelerated_scan/warp.py b/accelerated_scan/warp.py index fd58933..a3c68d1 100644 --- a/accelerated_scan/warp.py +++ b/accelerated_scan/warp.py @@ -1,83 +1,138 @@ from pathlib import Path import torch -from torch.utils.cpp_extension import load_inline -cuda_source = (Path(__file__).parent / 'warp.cuh').read_text() +is_rocm = hasattr(torch.version, 'hip') and torch.version.hip is not None -cpp_source = """ -at::Tensor warpscan_forward(const at::Tensor &gates, const at::Tensor &tokens, const at::Tensor &out, const bool reverse); -void warpscan_backward(const at::Tensor &gates, const at::Tensor &output, const at::Tensor &outGrad, const at::Tensor& gateGradOut, const at::Tensor& valueGradOut); -""" - -module = load_inline( - name='warpscan', - cpp_sources=[cpp_source], - cuda_sources=[cuda_source], - functions=['warpscan_forward', 'warpscan_backward'], - verbose=True, - extra_cuda_cflags=[ - "-O3", - "-std=c++17", - "--ptxas-options=-v", - "-lineinfo", - "--fmad", "false", - "-U__CUDA_NO_HALF_OPERATORS__", "-U__CUDA_NO_HALF_CONVERSIONS__", - "-U__CUDA_NO_BFLOAT16_OPERATORS__", "-U__CUDA_NO_BFLOAT16_CONVERSIONS__", - ] -) -warpscan_forward = module.warpscan_forward -warpscan_backward = module.warpscan_backward +if is_rocm: + # On ROCm, use the Triton implementation which is wave-size agnostic + # The C++ warp kernel has hardcoded wave-size assumptions that require + # significant rework for wave64 (gfx90a) vs wave32 (gfx1100) compatibility + from accelerated_scan.scalar import scan as _scan_impl, Scan as _ScanClass -def scan_forward(gates, tokens, reverse=False): - output = torch.zeros_like(tokens) - warpscan_forward(gates, tokens, output, reverse) - return output + def scan_forward(gates, tokens, reverse=False): + if reverse: + # A reverse scan of x_t = a_t x_{t-1} + b_t is the time-reversal of a + # forward scan on the time-reversed inputs: flip(forward(flip(a), flip(b))). + # The Triton backend requires contiguous inputs, so materialize the flips. + return _scan_impl(gates.flip(-1).contiguous(), tokens.flip(-1).contiguous()).flip(-1) + return _scan_impl(gates, tokens) + def warpscan_forward(gates, tokens, out, reverse=False): + out.copy_(scan_forward(gates, tokens, reverse=reverse)) + return out -class Scan(torch.autograd.Function): - @staticmethod - def forward(ctx, gates, tokens): - B, C, T = gates.shape - assert tokens.shape == (B, C, T) - assert gates.is_contiguous() - assert tokens.is_contiguous() + def warpscan_backward(gates, output, outGrad, gateGradOut, valueGradOut): + raise NotImplementedError( + "warpscan_backward is the low-level C++ binding and is unavailable on " + "ROCm; the ROCm autograd path uses the Triton Scan.backward instead." + ) - states = scan_forward(gates, tokens) - ctx.save_for_backward(states, gates) - return states + def scan(gates, tokens): + """Solve a first-order recurrence relation: - # backward scan is a padded reverse scan - # See https://arxiv.org/abs/1709.04057 Section 2.2 - @staticmethod - def backward(ctx, grad_output): - states, gates = ctx.saved_tensors - B, C, T = gates.shape + .. math:: + x_t = a_t x_{t-1} + b_t - grad_output = grad_output.contiguous() - assert states.is_contiguous() - assert gates.is_contiguous() + where :math:`a_t` ("gates") and :math:`b_t` ("tokens") are sequences of vectors. - d_gates = torch.empty_like(gates) - d_tokens = torch.empty_like(gates) - warpscan_backward(gates, states, grad_output, d_gates, d_tokens) + Arguments: + gates (torch.Tensor): shape (B, C, T), must be contiguous. + tokens (torch.Tensor): shape (B, C, T), must be contiguous. - return d_gates, d_tokens + Returns: + (torch.Tensor): shape (B, C, T) + """ + return _scan_impl(gates, tokens) + class Scan(torch.autograd.Function): + @staticmethod + def forward(ctx, gates, tokens): + return _ScanClass.forward(ctx, gates, tokens) -def scan(gates, tokens): - """Solve a first-order recurrence relation: + @staticmethod + def backward(ctx, grad_output): + return _ScanClass.backward(ctx, grad_output) +else: + from torch.utils.cpp_extension import load_inline - .. math:: - x_t = a_t x_{t-1} + b_t + cuda_source = (Path(__file__).parent / 'warp.cuh').read_text() - where :math:`a_t` ("gates") and :math:`b_t` ("tokens") are sequences of vectors. + cpp_source = """ +at::Tensor warpscan_forward(const at::Tensor &gates, const at::Tensor &tokens, const at::Tensor &out, const bool reverse); +void warpscan_backward(const at::Tensor &gates, const at::Tensor &output, const at::Tensor &outGrad, const at::Tensor& gateGradOut, const at::Tensor& valueGradOut); +""" - Arguments: - gates (torch.Tensor): shape (B, C, T), must be contiguous. T must be a power of 2. - tokens (torch.Tensor): shape (B, C, T), must be contiguous. T must be a power of 2. + extra_flags = [ + "-O3", + "-std=c++17", + "--ptxas-options=-v", + "-lineinfo", + "--fmad", "false", + "-U__CUDA_NO_HALF_OPERATORS__", "-U__CUDA_NO_HALF_CONVERSIONS__", + "-U__CUDA_NO_BFLOAT16_OPERATORS__", "-U__CUDA_NO_BFLOAT16_CONVERSIONS__", + ] - Returns: - (torch.Tensor): shape (B, C, T) - """ - return Scan.apply(gates, tokens) + module = load_inline( + name='warpscan', + cpp_sources=[cpp_source], + cuda_sources=[cuda_source], + functions=['warpscan_forward', 'warpscan_backward'], + verbose=True, + extra_cuda_cflags=extra_flags, + ) + warpscan_forward = module.warpscan_forward + warpscan_backward = module.warpscan_backward + + def scan_forward(gates, tokens, reverse=False): + output = torch.zeros_like(tokens) + warpscan_forward(gates, tokens, output, reverse) + return output + + + class Scan(torch.autograd.Function): + @staticmethod + def forward(ctx, gates, tokens): + B, C, T = gates.shape + assert tokens.shape == (B, C, T) + assert gates.is_contiguous() + assert tokens.is_contiguous() + + states = scan_forward(gates, tokens) + ctx.save_for_backward(states, gates) + return states + + # backward scan is a padded reverse scan + # See https://arxiv.org/abs/1709.04057 Section 2.2 + @staticmethod + def backward(ctx, grad_output): + states, gates = ctx.saved_tensors + B, C, T = gates.shape + + grad_output = grad_output.contiguous() + assert states.is_contiguous() + assert gates.is_contiguous() + + d_gates = torch.empty_like(gates) + d_tokens = torch.empty_like(gates) + warpscan_backward(gates, states, grad_output, d_gates, d_tokens) + + return d_gates, d_tokens + + + def scan(gates, tokens): + """Solve a first-order recurrence relation: + + .. math:: + x_t = a_t x_{t-1} + b_t + + where :math:`a_t` ("gates") and :math:`b_t` ("tokens") are sequences of vectors. + + Arguments: + gates (torch.Tensor): shape (B, C, T), must be contiguous. T must be a power of 2. + tokens (torch.Tensor): shape (B, C, T), must be contiguous. T must be a power of 2. + + Returns: + (torch.Tensor): shape (B, C, T) + """ + return Scan.apply(gates, tokens) diff --git a/tests/test_eq.py b/tests/test_eq.py index 4ce575f..4f6b690 100644 --- a/tests/test_eq.py +++ b/tests/test_eq.py @@ -62,6 +62,35 @@ def test_eq_backward(scan, seed, seqlen, dtype): torch.testing.assert_close(tokens_grad, tokens_ref.grad, atol=atol[dtype], rtol=rtol[dtype]) +@pytest.mark.parametrize("seed", seeds) +@pytest.mark.parametrize("seqlen", seqlens) +@pytest.mark.parametrize("dtype", dtypes) +@torch.inference_mode() +def test_eq_reverse(seed, seqlen, dtype): + from accelerated_scan.warp import scan_forward + + gates, tokens = init(seed, seqlen=seqlen, dtype=dtype) + out = scan_forward(gates, tokens, reverse=True) + out_ref = scan_ref(gates, tokens, reverse=True) + + print('max abs error', (out - out_ref).abs().max().item(), 'seqlen', seqlen, 'dtype', dtype) + + torch.testing.assert_close(out, out_ref, atol=atol[dtype], rtol=rtol[dtype]) + + +def test_warpscan_bindings(): + from accelerated_scan.warp import warpscan_forward, warpscan_backward + + gates, tokens = init(seeds[0], seqlen=128) + out = torch.empty_like(tokens) + warpscan_forward(gates, tokens, out, False) + out_ref = scan_ref(gates, tokens) + torch.testing.assert_close(out, out_ref, atol=atol[torch.float32], rtol=rtol[torch.float32]) + + with pytest.raises(NotImplementedError): + warpscan_backward(gates, out, torch.empty_like(out), torch.empty_like(gates), torch.empty_like(tokens)) + + @pytest.mark.parametrize("seed", [1]) @pytest.mark.parametrize("seqlen", seqlens) def test_eq_ref_reverse(seed, seqlen):