ops: native HIP Flash Attention kernels for AMD RDNA3 (gfx11)#2043
Open
T0nd3 wants to merge 14 commits into
Open
ops: native HIP Flash Attention kernels for AMD RDNA3 (gfx11)#2043T0nd3 wants to merge 14 commits into
T0nd3 wants to merge 14 commits into
Conversation
The build previously aborted with a hard error when both WITH_HIP=ON and WITH_FLASH_ATTN=ON were set. The CUDA Flash Attention kernels rely on CUTLASS/CuTe (sm80-specific) and cannot be compiled for HIP directly. Replace the FATAL_ERROR with a cmake WARNING so that the build succeeds. Using flash_attention=True at runtime on a ROCm build already raises a clear std::invalid_argument (see model.cc CT2_USE_HIP guard). The warning explains that a native HIP implementation via AMD Composable Kernel is planned.
Implements scaled dot-product attention for the ROCm/HIP backend using three plain HIP kernels (QK^T, softmax, PV) with FP32 accumulators. No CUTLASS, CuTe, or AMD Composable Kernel dependency required. Supports FP16 and BF16 inputs; tested on gfx1100 (RX 7900 XTX). KV-cache (offset > 0), rotary embeddings, and ALiBi are expected to be pre-applied by the caller in this initial implementation. Changes: - flash_attention_gpu.cu: add #ifndef CT2_USE_HIP / #else / #endif guard so CUDA and HIP share one translation unit with a single FlashAttention::compute<Device::CUDA> specialisation each. CUTLASS/CuTe headers are excluded for HIP builds. - CMakeLists.txt: enable -DCT2_WITH_FLASH_ATTN for HIP when WITH_FLASH_ATTN=ON (no sm80 CUTLASS sources added). - model.cc: honour CT2_WITH_FLASH_ATTN on HIP builds so that flash_attention=True is accepted at runtime instead of raising "FlashAttention not supported on ROCm."
Builds on the initial 3-pass HIP kernels with three changes:
1. KV-cache write path (offset > 0) — adds hip_kv_cache_write_kernel
that stages new K/V tokens into the cached buffer before the attention
pass. Mirrors the CUTLASS path's append-then-attend semantics so
autoregressive decoding works without the layer mutating the cache.
2. Softmax-reduction block-size fix — the tree reduction inside
hip_attn_softmax_kernel requires blockDim.x to be a power of two.
The dispatcher previously used min(seqlen_k, 256) directly, so for
seqlen_k = 3 (e.g. the Whisper prompt prefill of three tokens) the
reduction silently dropped the third element, corrupting the softmax
and forcing generate() to always emit the same token regardless of
input. Now rounded up to the next power of two with extra threads
contributing the identity (-1e9 for max, 0 for sum).
3. Two new fast paths in the dispatcher:
- hip_flash_attn_fwd_tiled (Flash Attention 2): one block per
(q_tile, head, batch), Q in registers, K/V tiles streamed through
LDS, online-softmax state (m_i, l_i, acc[D]). S = Q@K^T is never
materialised in HBM. Specialised for D in {64, 80, 128} with
BM = BN = 64. Used when seqlen_q >= BM.
- hip_flash_decode_kernel: one block per (head, batch) with threads
parallelising over K. Phase 1 computes scores in LDS, phase 2
reduces, phase 3 accumulates output channels with V-tiling
(BLOCK threads stage a V_TILE-wide slab of V into LDS once, then
D channel threads sum against the cached tile). Used for
seqlen_q == 1 with D in {64, 128} and seqlen_k bounded by the
64 KiB per-block LDS budget.
The original 3-pass kernels remain as the correctness-oracle fallback
for unsupported head dimensions and the 2..BM-1 query-length gap.
Verified on Whisper-medium / RX 7900 XTX: all five test seeds produce
identical token sequences to the standard MultiHeadAttention path
(generate up to max_length=200).
Two related changes that together let the encoder's self-attention go through FlashMultiHeadAttention when use_flash_attention=true: - FlashMultiHeadAttention previously instantiated the FlashAttention op with its default is_causal=true, which only matches the decoder's autoregressive self-attention. When the layer is used by the encoder every query would mask out its successors, producing wrong outputs. The op is now constructed with is_causal=_is_decoder, so encoder usage is non-causal and decoder usage stays causal. - TransformerEncoder and WhisperEncoder build their layers without passing use_flash_attention through to TransformerEncoderLayer, so they always picked MultiHeadAttention regardless of the model's flag. Both constructors now forward model.use_flash_attention(). (Whisper has its own encoder class — patching only the generic TransformerEncoder is not enough.) Whisper-medium / RX 7900 XTX numbers (Sq = Sk = 1500, B = 1): encoder-only: 64.2 ms → 59.3 ms (1.08x) generate(30): 221 ms → 213 ms (1.04x) generate(200): 1012 ms → 965 ms (1.05x) generate(448): 2233 ms → 2154 ms (1.04x) Correctness preserved (FP16 rounding gives ~0.3% relative diff in the encoder output; generated token sequences match exactly up to at least max_length=200 across five random seeds).
Adds hip_flash_attn_wmma_fp16: a Flash Attention forward kernel that
uses the wave32 16x16x16 fp16-input / fp32-accumulator WMMA built-in
(__builtin_amdgcn_wmma_f32_16x16x16_f16_w32) for both Q·K^T and P·V.
Block layout (one wave32 per block):
BM_W = 16 query rows
BN = 16 key tokens per K/V tile
Q tile loaded once (pre-scaled), K and V tiles streamed in
Inner reduction over the head dimension via D/16 WMMA calls per
output tile; the P·V phase fans out to D/16 output fragments held
entirely in registers. Online softmax with per-row (m_i, l_i)
staged through LDS so the WMMA accumulator fragment layout is
decoupled from the row-reduction code.
Activated in the dispatcher for FP16 inputs whenever head_dim is a
multiple of 16 and seqlen_q >= 16. Falls through to the scalar
tiled kernel for BF16 and other shapes; correctness oracle (3-pass)
remains as the final fallback.
Also adds src/ops/wmma_probe.cu — a small companion file with a
probe kernel + C-ABI driver used to empirically verify the wave32
accumulator fragment layout on this hardware before writing the
real kernel. The probe is reusable for future ports (BF16 path,
gfx12, other tile shapes). Layout verified on gfx1100:
lane l, slot s -> C[2*s + (l >> 4), l mod 16]
Whisper-medium / RX 7900 XTX numbers (vs. the scalar tiled kernel
that was the previous best):
encoder B=1: 59.3 -> 57.6 ms (1.03x)
encoder B=4: 202.8 -> 182.6 ms (1.11x) -- biggest win, scales
with compute load
generate 30: 227 -> 209 ms (1.09x)
generate 100: 549 -> 522 ms (1.05x)
generate 200: 1031 -> 968 ms (1.07x)
generate 448: 2203 -> 2099 ms (1.05x)
Correctness preserved (FP16 rounding gives ~0.3% relative diff in
the encoder output; generated token sequences match exactly up to
max_length=200 across five random seeds).
Two complementary additions to the WMMA Flash Attention path:
1. BF16 support. The wave32 16x16x16 fragment layout is identical
between the FP16 and BF16 WMMA built-ins on RDNA3 — the only
difference is which intrinsic to call. The existing BM_W=16
kernel is now templated over the half-precision element type
(HalfT in {_Float16, __bf16}) and selects the right built-in
(__builtin_amdgcn_wmma_f32_16x16x16_{f16,bf16}_w32) via
if constexpr. The dispatcher matches scalar_t against
float16_t / bfloat16_t and reinterpret_casts the data pointers
to the underlying HalfT (same memory layout, the wrapper is just
a host-side ABI thing).
Whisper-medium with compute_type=bfloat16, RX 7900 XTX:
encoder B=1: 59.8 -> 55.0 ms (1.09x), matches the FP16 win.
2. New kernel hip_flash_attn_wmma_fp16_bm64<D>: 4 wave32 wavefronts
per block, BM_W = BN = 64, each K/V tile loaded once into LDS
and shared across all 64 query rows. Built but currently NOT
dispatched (kept for future tuning). On gfx1100 it is ~5-10%
slower than the BM_W=16 variant because:
- 48 KiB LDS allows only 1 block per CU vs. BM_W=16's 4 blocks
per CU -> halved occupancy
- Per-row softmax loop iterates BN=64 cols sequentially (4x more
work per softmax thread)
- The expected K/V HBM-reuse benefit doesn't materialise: Whisper
attention is compute/LDS-bound, not HBM-bound (~580 MB total
attention HBM traffic / 960 GB/s = 0.6 ms, vs. ~60 ms encoder)
The kernel structure (per-wave S/P scratch, cooperative tile load)
is the foundation for a future wave-shuffle-softmax variant that
could drop LDS enough to make BM_W=64 actually a win.
Correctness preserved across all five test seeds for FP16; BF16
matches 4/5 (the fifth differs in the last token, expected from
BF16's 8-bit mantissa — the relative encoder diff is 1.85%).
…Flash Attention
python/tests/test_flash_attention.py (CI-friendly pytest, 9 tests):
- test_flash_attention_encoder_matches_standard[float16|bfloat16]
Verifies Flash=ON encoder output matches Flash=OFF within
rounding tolerances (FP16: max_abs_diff <= 0.5, rel <= 0.5%;
BF16: <= 2.0, rel <= 3%). Also prints the per-layer score-
buffer size that the online-softmax path avoids materialising.
- test_flash_attention_generate_fp16_token_match[42|123|777|999|1234]
Five seeds, generates 20 tokens each, asserts byte-identical
token sequences between Flash=ON and Flash=OFF.
- test_softmax_block_size_regression
Specifically guards the bug from commit d9016a5: a 2^k
block-size requirement in the softmax tree reduction that
caused all five seeds to emit token 50411 as the first
generated token regardless of audio input. Asserts that the
set of first tokens across five distinct seeds has size > 1.
- test_flash_attention_score_buffer_savings
Pure documentation test that prints, for a range of common
transformer shapes (Whisper-medium encoder/decoder, hypothetical
4k and 32k LLM contexts), how much HBM the standard attention
path would allocate for the [B, H, Sq, Sk] FP32 score matrix.
Tests are skipped automatically if the faster-whisper-medium snapshot
isn't already in the local HuggingFace cache, so they don't pull
network in CI environments that haven't pre-populated models.
python/tests/benchmark_flash_attention.py (standalone, not pytest):
- Direct HBM measurement via ctypes -> hipMemGetInfo on the active
HIP device. Reports both the persistent model footprint and the
working-set growth after a representative generate() call,
Flash=ON vs Flash=OFF.
- Performance measurement of encoder (B=1, B=4) and generate
(max_length=30/100/200/448), GPU-synced before timing.
Measured on Whisper-medium, RX 7900 XTX, gfx1100:
HBM: model 2.87 GiB OFF -> 2.67 GiB ON (-200 MiB persistent);
working set +359 MiB OFF -> +342 MiB ON (-17 MiB peak).
Speed: encoder B=1 1.08x, B=4 1.11x; generate 1.03-1.07x across
max_length 30..448. Stable across runs.
Both files include a Windows local-build dev-loop helper that
attempts to add the ROCm SDK wheel's bin directories to the DLL
search path before importing ctranslate2. No-op in normal CI.
The wmma_probe kernel is a one-shot reverse-engineering tool for the RDNA3 wave32 WMMA fragment layout — it's how the layout used by hip_flash_attn_wmma_fp16 was discovered. Useful when porting to a new gfx target / precision / tile shape, but no business living in the production DLL where it just exports an unused C symbol (ct2_wmma_probe_run) and bloats the binary. Comment it out of CMakeLists.txt's CUDA_SOURCES and document how to re-enable it at the top of wmma_probe.cu. Source file stays in the tree so future ports of the WMMA kernels can rerun it.
The hip_flash_attn_wmma_fp16_bm64 kernel was added speculatively as a
"larger Q-tile so K/V loads from HBM amortise across more query rows"
experiment, but never dispatched because empirically on gfx1100 it is
~5-10% slower than the BM_W=16 variant:
- 48 KiB LDS allows only 1 block per CU vs. BM_W=16's 4 blocks per CU
-> roughly halved wave occupancy
- The per-row softmax loop sequentially iterates BN=64 cols, 4x the
work per softmax thread vs. BM_W=16
- The expected HBM-bandwidth-reduction benefit doesn't materialise
because Whisper attention is compute/LDS-bound, not HBM-bound:
total attention HBM traffic is ~580 MB which at 960 GB/s is 0.6 ms,
vs. ~60 ms of encoder time
The lesson is in the git history; carrying disabled 240-line code around
is just a maintenance liability. If a future wave-shuffle-softmax
refactor shrinks the LDS footprint below ~24 KiB, the kernel can be
brought back from this commit cleanly.
The dispatcher's BM_W=16 path remains the WMMA fast path for both
FP16 and BF16 inputs with head_dim in {64, 128} and seqlen_q >= 16.
Verified: all 9 tests in python/tests/test_flash_attention.py still
pass after the removal; Whisper-medium encoder/generate speedups
unchanged.
… cases
Two new parameterised tests:
- test_flash_attention_encoder_batched[2|4]
Runs the encoder with batch_size in {2, 4} and asserts Flash=ON
output matches Flash=OFF on every batch element. The previous
encoder test only covered B=1 — this catches any per-batch
indexing bugs in the WMMA tile loads / dispatch (grid.z == batch).
- test_flash_attention_variable_prompt_length[1|3|5|8]
Runs generate() with prompt lengths 1, 3, 5, and 8. This is the
case where the decoder's first forward pass has seqlen_q != 1 and
also seqlen_q < BM_W=16, so the dispatcher routes through the
3-pass fallback kernel rather than the WMMA or decode paths.
Specifically covers the dispatcher branch we never explicitly
exercised before, and re-checks the n_prompt=3 case that
originally triggered the softmax-block-size bug.
All 15 tests pass on Whisper-medium / RX 7900 XTX.
The HIP section's introductory comment was written when there were only two code paths (tiled + 3-pass). Update it to reflect the current four paths (WMMA, decode, tiled, 3-pass) with a small table of which inputs each path handles. Also expand the Kernel 2 (softmax) doc-comment to call out the power-of-two block-size invariant and reference commit d9016a5 — the bug that caught it produced the same first generated token (50411) for every input because the 3-prompt-token prefill ran the tree reduction with blockDim.x = 3 and silently dropped the third element. Worth guarding against a future maintainer "simplifying" the launch code.
Lists the four dispatched kernels (WMMA, scalar tiled, decode, 3-pass), the KV-cache write path, the build flag pair, and the measured speedup on Whisper-medium / RX 7900 XTX.
…ests
CI's check-python-style step caught:
- black formatting differences (long argument lists wanted on separate lines,
a couple of generator-expression layouts)
- isort import-order: test_utils + ctranslate2 needed a blank-line group
separation
- flake8 F541: one f-string with no placeholders in benchmark_flash_attention.py
line 231 (the savings-banner print spans two lines and only the second one
actually interpolates -- both got prefixed with f"" by mistake).
All three checks now pass locally with the same versions CI pins
(black==22.*, isort==5.*, flake8==3.8.*); all 15 pytest tests still pass.
The wheel CI builds the ROCm wheel with
--offload-arch=gfx1030 --offload-arch=gfx1100 ...
i.e. one device pass per architecture. In the device pass for gfx1030
the WMMA kernel template's body is gated out (the wave32 WMMA built-in
only exists on gfx11+/gfx12+), so its name is not in scope. The
dispatcher's `hipLaunchKernelGGL((hip_flash_attn_wmma_fp16<HalfT, D>), …)`
call is parsed by the compiler in every pass (including device passes,
even though the dispatcher itself is host-only), so on gfx1030 it fails:
flash_attention_gpu.cu:1296:31: error: use of undeclared identifier
'hip_flash_attn_wmma_fp16'
1 error generated when compiling for gfx1030.
This affected all three ROCm CI jobs:
build-python-wheels-rocm (ubuntu-24.04)
build-python-wheels-rocm (windows-2025)
build-and-push-docker-images (rocm)
Two-level fix:
1. Compile-time gate: wrap the dispatcher's WMMA launch (lambda + the
two `if (launch_wmma_bm16(...)) return;` calls) with the *same*
`#if defined(__gfx11..) || !defined(__HIP_DEVICE_COMPILE__)` guard
that already wraps the kernel definition. In a non-gfx11 device
pass the entire WMMA call site is preprocessor-skipped, so the
undeclared-identifier error is gone.
2. Runtime gate: even when the file was compiled with some gfx11
target in the arch list, the resulting multi-arch wheel may run
on a non-gfx11 GPU at execution time, in which case the WMMA
kernel isn't in the loaded device binary. Add a one-time
`gcnArchName` inspection (via cuda::get_device_properties) that
only enters the WMMA path on a `gfx1*` / `gfx2*` device. On any
other arch the dispatcher falls through to the scalar tiled /
3-pass kernels.
Verified locally:
- Single-arch build (gfx1100 only): WMMA path active, 15/15 tests pass.
- Multi-arch build (gfx1030;gfx1100;gfx1101;gfx1102): compiles
cleanly, 15/15 tests pass on the gfx1100 host.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Summary
Adds a native HIP implementation of Flash Attention 2 for AMD RDNA3+ GPUs.
Builds on the ROCm/HIP backend introduced in v4.7.0 (#1989), filling in the
-DWITH_FLASH_ATTN=ONpath that previously raisedFATAL_ERRORwhen combinedwith
-DWITH_HIP=ON.No external dependency (no CUTLASS, no rocWMMA, no Composable Kernel). All
kernels are plain HIP / Clang built-ins, FP32 accumulators throughout.
Tested on AMD Radeon RX 7900 XTX (gfx1100, RDNA3) on Windows 11 with the
ROCm SDK pip wheels.
Architecture: four dispatched kernels
The dispatcher (
flash_attention_hip_impl) picks the most efficient kernelthat matches the input shape and dtype:
seqlen_q ≥ 16,D ∈ {64,128}, gfx11+seqlen_q == 1,D ∈ {64,128}seqlen_q ≥ 64,D ∈ {64,80,128}WMMA path uses
__builtin_amdgcn_wmma_f32_16x16x16_{f16,bf16}_w32for bothQ·K^T and P·V. The wave32 accumulator fragment layout was reverse-engineered
empirically — see
src/ops/wmma_probe.cufor the probe (kept in the treeas a dev tool, excluded from the production DLL).
Decode kernel uses V-tiling in phase 3 so output-channel threads read V
from LDS instead of looping through HBM
seqlen_ktimes each.Performance (Whisper-medium, RX 7900 XTX, FP16)
BF16 path gives the same ~1.09× encoder speedup (separate WMMA built-in,
same wave32 layout).
HBM footprint (measured via
hipMemGetInfo)The 137 MiB-per-layer score matrix isn't a persistent saving — the standard
path frees it after each layer — but each layer reads/writes that buffer
~3× through HBM, so the actual benefit is roughly 3 GiB of HBM bandwidth
saved per encoder pass, which is what shows up as the speedup.
Key correctness fix worth flagging
hip_attn_softmax_kerneldoes its per-row max/sum via a binary treereduction (
for (s = blockDim.x >> 1; s > 0; s >>= 1)). The dispatcheroriginally launched it with
blockDim.x = min(seqlen_k, 256). Forseqlen_k = 3(the Whisper prompt prefill of three tokens) that meansblockDim.x = 3, which is not a power of two: the very first reductionstep skips the third element, the softmax denominator is silently wrong,
and
generate()ends up producing token 50411 as the first generatedtoken regardless of audio input.
Now the dispatcher rounds up to the next power of two; extra threads
contribute identity values (−1e9 for max, 0 for sum). Five-seed regression
test in the pytest suite specifically guards against this re-surfacing.
Tests
New
python/tests/test_flash_attention.py(15 tests, all green):Tests skip cleanly if the faster-whisper-medium snapshot isn't already in
the local HuggingFace cache, so they don't pull network on CI runners that
haven't pre-populated models.
python/tests/benchmark_flash_attention.pyis a standalone benchmark(not pytest) that measures both speed and HBM via ctypes ->
hipMemGetInfo.The numbers in this PR description come straight from running it.
Changes
src/ops/flash_attention_gpu.cu— adds the four-path HIP implementation in the existing#else // CT2_USE_HIPblock; CUDA path untouched.src/ops/wmma_probe.cu— dev tool for layout reverse-engineering. Excluded from the default build via a commented-out entry inCMakeLists.txt; document at the top explains how to re-enable.src/layers/flash_attention.cc— passis_causal = _is_decoderto theFlashAttentionop so encoder self-attention runs non-causal.src/layers/transformer.cc,src/layers/whisper.cc— passmodel.use_flash_attention()through toTransformerEncoderLayer. Whisper has its own encoder class, so both needed the fix.src/models/model.cc— acceptflash_attention=Trueon HIP builds whenCT2_WITH_FLASH_ATTNis defined.CMakeLists.txt— adds the HIP-side-DCT2_WITH_FLASH_ATTNdefinition.python/tests/test_flash_attention.py,python/tests/benchmark_flash_attention.py— see above.CHANGELOG.md— entry under Unreleased.Test plan
-DWITH_HIP=ON -DWITH_FLASH_ATTN=ONon gfx1100 (Windows 11)-DWITH_HIP=ON -DWITH_FLASH_ATTN=OFF— verified thatflash_attention=Trueraises the documented errorpython/tests/test_flash_attention.pypassBuilds on PR #2041 (Windows ROCm build guide) and #2042 (HIP test enablement)
in spirit, but doesn't depend on either.