Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,10 @@ out = scan(forget, inputs)

To ensure numerical equivalence, a reference implementation for trees is provided in Torch. It can be sped up using `torch.compile`.

## AMD GPUs (ROCm)

accelerated-scan runs on AMD GPUs with a ROCm build of PyTorch. On ROCm, `accelerated_scan.warp` automatically uses the wave-size-agnostic Triton backend (`accelerated_scan.scalar`), because the C++ warp kernel assumes a 32-lane warp that does not hold on a 64-lane CDNA wavefront. Install as usual (`pip install accelerated-scan`, or `pip install -e .` from a checkout) with a ROCm PyTorch and Triton; the `scan` API and `device="cuda"` tensors are unchanged.

## Benchmarks:

![bench.png](bench.png)
Expand Down
17 changes: 12 additions & 5 deletions accelerated_scan/warp.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,13 @@

#define CHECK_STRIDE(x) TORCH_CHECK(x.stride(-1) == 1 || x.size(-1) == 1);

// ROCm/HIP compatibility: HIP requires 64-bit shuffle masks
#if defined(__HIP_PLATFORM_AMD__) || defined(USE_ROCM)
#define FULL_WARP_MASK 0xffffffffffffffffULL
#else
#define FULL_WARP_MASK 0xffffffff
#endif

template<typename weight_t, int N>
class UnalignedTuple {
public:
Expand Down Expand Up @@ -137,8 +144,8 @@ __global__ void scan(

#pragma unroll
for (int delta = 1; delta < kNThreadsPerWarp; delta *= 2) {
weight_t prev_gate = __shfl_up_sync(0xffffffff, accGate.data[kThreadLast], delta);
weight_t prev_token = __shfl_up_sync(0xffffffff, accToken.data[kThreadLast], delta);
weight_t prev_gate = __shfl_up_sync(FULL_WARP_MASK, accGate.data[kThreadLast], delta, kNThreadsPerWarp);
weight_t prev_token = __shfl_up_sync(FULL_WARP_MASK, accToken.data[kThreadLast], delta, kNThreadsPerWarp);

if (laneId >= delta) {
#pragma unroll
Expand Down Expand Up @@ -172,9 +179,9 @@ __global__ void scan(
warpAccToken = (laneId < kNWarpsPerBlock) ? warpLastToken[laneId] : kEmptyToken;

#pragma unroll
for (int delta = 1; delta < warpSize; delta *= 2) {
weight_t prev_gate = __shfl_up_sync(0xffffffff, warpAccGate, delta);
weight_t prev_token = __shfl_up_sync(0xffffffff, warpAccToken, delta);
for (int delta = 1; delta < kNThreadsPerWarp; delta *= 2) {
weight_t prev_gate = __shfl_up_sync(FULL_WARP_MASK, warpAccGate, delta, kNThreadsPerWarp);
weight_t prev_token = __shfl_up_sync(FULL_WARP_MASK, warpAccToken, delta, kNThreadsPerWarp);

if (laneId >= delta) {
warpAccToken = prev_token * warpAccGate + warpAccToken;
Expand Down
183 changes: 119 additions & 64 deletions accelerated_scan/warp.py
Original file line number Diff line number Diff line change
@@ -1,83 +1,138 @@
from pathlib import Path

import torch
from torch.utils.cpp_extension import load_inline

cuda_source = (Path(__file__).parent / 'warp.cuh').read_text()
is_rocm = hasattr(torch.version, 'hip') and torch.version.hip is not None

cpp_source = """
at::Tensor warpscan_forward(const at::Tensor &gates, const at::Tensor &tokens, const at::Tensor &out, const bool reverse);
void warpscan_backward(const at::Tensor &gates, const at::Tensor &output, const at::Tensor &outGrad, const at::Tensor& gateGradOut, const at::Tensor& valueGradOut);
"""

module = load_inline(
name='warpscan',
cpp_sources=[cpp_source],
cuda_sources=[cuda_source],
functions=['warpscan_forward', 'warpscan_backward'],
verbose=True,
extra_cuda_cflags=[
"-O3",
"-std=c++17",
"--ptxas-options=-v",
"-lineinfo",
"--fmad", "false",
"-U__CUDA_NO_HALF_OPERATORS__", "-U__CUDA_NO_HALF_CONVERSIONS__",
"-U__CUDA_NO_BFLOAT16_OPERATORS__", "-U__CUDA_NO_BFLOAT16_CONVERSIONS__",
]
)
warpscan_forward = module.warpscan_forward
warpscan_backward = module.warpscan_backward
if is_rocm:
# On ROCm, use the Triton implementation which is wave-size agnostic
# The C++ warp kernel has hardcoded wave-size assumptions that require
# significant rework for wave64 (gfx90a) vs wave32 (gfx1100) compatibility
from accelerated_scan.scalar import scan as _scan_impl, Scan as _ScanClass

def scan_forward(gates, tokens, reverse=False):
output = torch.zeros_like(tokens)
warpscan_forward(gates, tokens, output, reverse)
return output
def scan_forward(gates, tokens, reverse=False):
if reverse:
# A reverse scan of x_t = a_t x_{t-1} + b_t is the time-reversal of a
# forward scan on the time-reversed inputs: flip(forward(flip(a), flip(b))).
# The Triton backend requires contiguous inputs, so materialize the flips.
return _scan_impl(gates.flip(-1).contiguous(), tokens.flip(-1).contiguous()).flip(-1)
return _scan_impl(gates, tokens)

def warpscan_forward(gates, tokens, out, reverse=False):
out.copy_(scan_forward(gates, tokens, reverse=reverse))
return out

class Scan(torch.autograd.Function):
@staticmethod
def forward(ctx, gates, tokens):
B, C, T = gates.shape
assert tokens.shape == (B, C, T)
assert gates.is_contiguous()
assert tokens.is_contiguous()
def warpscan_backward(gates, output, outGrad, gateGradOut, valueGradOut):
raise NotImplementedError(
"warpscan_backward is the low-level C++ binding and is unavailable on "
"ROCm; the ROCm autograd path uses the Triton Scan.backward instead."
)

states = scan_forward(gates, tokens)
ctx.save_for_backward(states, gates)
return states
def scan(gates, tokens):
"""Solve a first-order recurrence relation:

# backward scan is a padded reverse scan
# See https://arxiv.org/abs/1709.04057 Section 2.2
@staticmethod
def backward(ctx, grad_output):
states, gates = ctx.saved_tensors
B, C, T = gates.shape
.. math::
x_t = a_t x_{t-1} + b_t

grad_output = grad_output.contiguous()
assert states.is_contiguous()
assert gates.is_contiguous()
where :math:`a_t` ("gates") and :math:`b_t` ("tokens") are sequences of vectors.

d_gates = torch.empty_like(gates)
d_tokens = torch.empty_like(gates)
warpscan_backward(gates, states, grad_output, d_gates, d_tokens)
Arguments:
gates (torch.Tensor): shape (B, C, T), must be contiguous.
tokens (torch.Tensor): shape (B, C, T), must be contiguous.

return d_gates, d_tokens
Returns:
(torch.Tensor): shape (B, C, T)
"""
return _scan_impl(gates, tokens)

class Scan(torch.autograd.Function):
@staticmethod
def forward(ctx, gates, tokens):
return _ScanClass.forward(ctx, gates, tokens)

def scan(gates, tokens):
"""Solve a first-order recurrence relation:
@staticmethod
def backward(ctx, grad_output):
return _ScanClass.backward(ctx, grad_output)
else:
from torch.utils.cpp_extension import load_inline

.. math::
x_t = a_t x_{t-1} + b_t
cuda_source = (Path(__file__).parent / 'warp.cuh').read_text()

where :math:`a_t` ("gates") and :math:`b_t` ("tokens") are sequences of vectors.
cpp_source = """
at::Tensor warpscan_forward(const at::Tensor &gates, const at::Tensor &tokens, const at::Tensor &out, const bool reverse);
void warpscan_backward(const at::Tensor &gates, const at::Tensor &output, const at::Tensor &outGrad, const at::Tensor& gateGradOut, const at::Tensor& valueGradOut);
"""

Arguments:
gates (torch.Tensor): shape (B, C, T), must be contiguous. T must be a power of 2.
tokens (torch.Tensor): shape (B, C, T), must be contiguous. T must be a power of 2.
extra_flags = [
"-O3",
"-std=c++17",
"--ptxas-options=-v",
"-lineinfo",
"--fmad", "false",
"-U__CUDA_NO_HALF_OPERATORS__", "-U__CUDA_NO_HALF_CONVERSIONS__",
"-U__CUDA_NO_BFLOAT16_OPERATORS__", "-U__CUDA_NO_BFLOAT16_CONVERSIONS__",
]

Returns:
(torch.Tensor): shape (B, C, T)
"""
return Scan.apply(gates, tokens)
module = load_inline(
name='warpscan',
cpp_sources=[cpp_source],
cuda_sources=[cuda_source],
functions=['warpscan_forward', 'warpscan_backward'],
verbose=True,
extra_cuda_cflags=extra_flags,
)
warpscan_forward = module.warpscan_forward
warpscan_backward = module.warpscan_backward

def scan_forward(gates, tokens, reverse=False):
output = torch.zeros_like(tokens)
warpscan_forward(gates, tokens, output, reverse)
return output


class Scan(torch.autograd.Function):
@staticmethod
def forward(ctx, gates, tokens):
B, C, T = gates.shape
assert tokens.shape == (B, C, T)
assert gates.is_contiguous()
assert tokens.is_contiguous()

states = scan_forward(gates, tokens)
ctx.save_for_backward(states, gates)
return states

# backward scan is a padded reverse scan
# See https://arxiv.org/abs/1709.04057 Section 2.2
@staticmethod
def backward(ctx, grad_output):
states, gates = ctx.saved_tensors
B, C, T = gates.shape

grad_output = grad_output.contiguous()
assert states.is_contiguous()
assert gates.is_contiguous()

d_gates = torch.empty_like(gates)
d_tokens = torch.empty_like(gates)
warpscan_backward(gates, states, grad_output, d_gates, d_tokens)

return d_gates, d_tokens


def scan(gates, tokens):
"""Solve a first-order recurrence relation:

.. math::
x_t = a_t x_{t-1} + b_t

where :math:`a_t` ("gates") and :math:`b_t` ("tokens") are sequences of vectors.

Arguments:
gates (torch.Tensor): shape (B, C, T), must be contiguous. T must be a power of 2.
tokens (torch.Tensor): shape (B, C, T), must be contiguous. T must be a power of 2.

Returns:
(torch.Tensor): shape (B, C, T)
"""
return Scan.apply(gates, tokens)
29 changes: 29 additions & 0 deletions tests/test_eq.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,35 @@ def test_eq_backward(scan, seed, seqlen, dtype):
torch.testing.assert_close(tokens_grad, tokens_ref.grad, atol=atol[dtype], rtol=rtol[dtype])


@pytest.mark.parametrize("seed", seeds)
@pytest.mark.parametrize("seqlen", seqlens)
@pytest.mark.parametrize("dtype", dtypes)
@torch.inference_mode()
def test_eq_reverse(seed, seqlen, dtype):
from accelerated_scan.warp import scan_forward

gates, tokens = init(seed, seqlen=seqlen, dtype=dtype)
out = scan_forward(gates, tokens, reverse=True)
out_ref = scan_ref(gates, tokens, reverse=True)

print('max abs error', (out - out_ref).abs().max().item(), 'seqlen', seqlen, 'dtype', dtype)

torch.testing.assert_close(out, out_ref, atol=atol[dtype], rtol=rtol[dtype])


def test_warpscan_bindings():
from accelerated_scan.warp import warpscan_forward, warpscan_backward

gates, tokens = init(seeds[0], seqlen=128)
out = torch.empty_like(tokens)
warpscan_forward(gates, tokens, out, False)
out_ref = scan_ref(gates, tokens)
torch.testing.assert_close(out, out_ref, atol=atol[torch.float32], rtol=rtol[torch.float32])

with pytest.raises(NotImplementedError):
warpscan_backward(gates, out, torch.empty_like(out), torch.empty_like(gates), torch.empty_like(tokens))


@pytest.mark.parametrize("seed", [1])
@pytest.mark.parametrize("seqlen", seqlens)
def test_eq_ref_reverse(seed, seqlen):
Expand Down