diff --git a/CHANGELOG.md b/CHANGELOG.md index 86d8a12f3..62d1b7d11 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -2,6 +2,15 @@ ### New features +* Native HIP Flash Attention for AMD RDNA3+ GPUs. Adds four dispatched + kernels (WMMA via the wave32 16x16x16 fp16/bf16 built-in, scalar tiled, + decode-optimised for `seqlen_q == 1`, and a 3-pass correctness oracle) + plus the KV-cache write path for autoregressive decoding. Enabled with + `-DWITH_HIP=ON -DWITH_FLASH_ATTN=ON`. Encoder and decoder self-attention + use this path when `flash_attention=True` is passed at model load time. + ~1.08-1.11x encoder speedup and ~1.05-1.09x `generate()` speedup on + Whisper-medium / RX 7900 XTX vs. the standard MultiHeadAttention path. + ### Fixes and improvements ## [v4.7.1](https://github.com/OpenNMT/CTranslate2/releases/tag/v4.7.1) (2026-02-04) diff --git a/CMakeLists.txt b/CMakeLists.txt index cf80e37b5..3b8059370 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -238,6 +238,13 @@ set(CUDA_SOURCES src/ops/conv1d_gpu.cu src/ops/dequantize_gpu.cu src/ops/flash_attention_gpu.cu + # src/ops/wmma_probe.cu — RDNA3 WMMA fragment-layout probe (dev tool). + # Builds a tiny kernel + C-ABI driver used to + # reverse-engineer the wave32 accumulator layout. + # Off by default to keep it out of the production + # DLL; uncomment to rebuild and rerun the layout + # probe (see python/tests/benchmark_flash_attention.py + # and the kernel's own header comment). src/ops/gather_gpu.cu src/ops/gumbel_max_gpu.cu src/ops/layer_norm_gpu.cu @@ -719,7 +726,11 @@ elseif(WITH_HIP) src/ops/awq/dequantize_gpu.cu ) if(WITH_FLASH_ATTN) - message(FATAL_ERROR "WITH_HIP=ON incompatible with WITH_FLASH_ATTN=ON") + # Native HIP Flash Attention implementation (three-pass: QK^T, softmax, PV). + # Does not require CUTLASS/CuTe; uses plain HIP kernels with FP32 accumulators. + # Supports FP16 and BF16 inputs; tested on gfx1100 (RDNA3 / RX 7900 XTX). + add_definitions(-DCT2_WITH_FLASH_ATTN) + message(STATUS "Building with HIP Flash Attention support") endif() set_source_files_properties(${CUDA_SOURCES} PROPERTIES LANGUAGE HIP) diff --git a/python/tests/benchmark_flash_attention.py b/python/tests/benchmark_flash_attention.py new file mode 100644 index 000000000..bf5922408 --- /dev/null +++ b/python/tests/benchmark_flash_attention.py @@ -0,0 +1,299 @@ +"""Standalone benchmark script for the native HIP Flash Attention path. + +Measures both speed and HBM footprint of `flash_attention=True` vs the +standard MultiHeadAttention oracle, on the encoder pass and on `generate` +at a few different max_length values. + +This is *not* a pytest — it's run manually (`python benchmark_flash_attention.py`) +because the numbers are timing-sensitive and we don't want regression +flapping in CI. pytest tests for correctness live in test_flash_attention.py. + +HBM is measured via the HIP runtime's `hipMemGetInfo`, called through +ctypes. On Windows the runtime DLL is `amdhip64_*.dll`; on Linux it's +`libamdhip64.so`. + +Usage: + python benchmark_flash_attention.py [--model PATH] [--runs N] +""" + +import argparse +import ctypes +import os +import sys +import time + +import numpy as np + +# ---------------------------------------------------------------------------- +# Local-build DLL loader (mirrors test_flash_attention.py). +# ---------------------------------------------------------------------------- +if sys.platform == "win32": + import site + + for site_dir in site.getsitepackages() + [site.getusersitepackages()]: + for sub in ("_rocm_sdk_core/bin", "_rocm_sdk_libraries_custom/bin"): + cand = os.path.join(site_dir, *sub.split("/")) + if os.path.isdir(cand): + try: + os.add_dll_directory(cand) + except (FileNotFoundError, OSError): + pass + +import ctranslate2 # noqa: E402 + + +# ---------------------------------------------------------------------------- +# HIP runtime memory query via ctypes. +# Returns (free_bytes, total_bytes) for the currently-active device. +# ---------------------------------------------------------------------------- +def _load_hip_runtime(): + if sys.platform == "win32": + # The DLL name on Windows has a major-version suffix that varies. + site_dir = next( + d + for d in ( + os.path.join(s, "_rocm_sdk_core", "bin") + for s in ( + __import__("site").getsitepackages() + + [__import__("site").getusersitepackages()] + ) + ) + if os.path.isdir(d) + ) + candidates = sorted( + f + for f in os.listdir(site_dir) + if f.startswith("amdhip64") and f.endswith(".dll") + ) + if not candidates: + raise FileNotFoundError("amdhip64_*.dll not found in ROCm SDK bin") + return ctypes.CDLL(os.path.join(site_dir, candidates[-1])) + else: + return ctypes.CDLL("libamdhip64.so") + + +_hip = _load_hip_runtime() +_hip.hipMemGetInfo.argtypes = [ + ctypes.POINTER(ctypes.c_size_t), + ctypes.POINTER(ctypes.c_size_t), +] +_hip.hipMemGetInfo.restype = ctypes.c_int +_hip.hipDeviceSynchronize.argtypes = [] +_hip.hipDeviceSynchronize.restype = ctypes.c_int + + +def hbm_free_total(): + """(free_bytes, total_bytes) on the active HIP device.""" + free = ctypes.c_size_t() + total = ctypes.c_size_t() + rc = _hip.hipMemGetInfo(ctypes.byref(free), ctypes.byref(total)) + if rc != 0: + raise RuntimeError(f"hipMemGetInfo returned {rc}") + return free.value, total.value + + +def gpu_sync(): + _hip.hipDeviceSynchronize() + + +# ---------------------------------------------------------------------------- +# Bench helpers. +# ---------------------------------------------------------------------------- +def _mel(seed=0): + rng = np.random.default_rng(seed) + return rng.standard_normal((1, 80, 3000)).astype(np.float32) + + +def _bench(fn, runs=20, warmup=3): + for _ in range(warmup): + fn() + gpu_sync() + t0 = time.perf_counter() + for _ in range(runs): + fn() + gpu_sync() + return (time.perf_counter() - t0) / runs * 1000 # ms per call + + +def measure_hbm_delta(loader_fn, work_fn): + """Build a model with `loader_fn`, baseline its persistent HBM, run + `work_fn(model)` once (so the allocator pool grows), then measure the + additional HBM held. Returns (model_bytes, work_bytes_added).""" + gpu_sync() + free_empty, _ = hbm_free_total() + + model = loader_fn() + gpu_sync() + free_after_load, _ = hbm_free_total() + model_bytes = free_empty - free_after_load + + # Warm up enough to make the allocator pool reach its working-set peak. + for _ in range(3): + work_fn(model) + gpu_sync() + free_after_work, _ = hbm_free_total() + work_bytes = free_after_load - free_after_work + return model, model_bytes, work_bytes + + +def fmt_bytes(n): + n = abs(n) + if n < 1024 * 1024: + return f"{n/1024:.1f} KiB" + if n < 1024 * 1024 * 1024: + return f"{n/1024/1024:.1f} MiB" + return f"{n/1024/1024/1024:.2f} GiB" + + +# ---------------------------------------------------------------------------- +# Main benchmark. +# ---------------------------------------------------------------------------- +PROMPTS = [[50258, 50259, 50360]] + + +def main(): + parser = argparse.ArgumentParser() + parser.add_argument( + "--model", + default=None, + help="Path to a converted faster-whisper-medium snapshot. " + "If omitted, the HuggingFace cache is searched.", + ) + parser.add_argument("--runs", type=int, default=20) + parser.add_argument( + "--compute-type", + default="float16", + choices=["float16", "bfloat16"], + ) + args = parser.parse_args() + + if args.model is None: + snaps = os.path.expanduser( + "~/.cache/huggingface/hub/" + "models--Systran--faster-whisper-medium/snapshots" + ) + if not os.path.isdir(snaps): + sys.exit("No --model given and faster-whisper-medium not cached.") + shas = [ + d + for d in os.listdir(snaps) + if os.path.isfile(os.path.join(snaps, d, "model.bin")) + ] + if not shas: + sys.exit("No converted snapshot found in HF cache.") + args.model = os.path.join(snaps, shas[0]) + + print(f"Model: {args.model}") + print(f"Compute type: {args.compute_type}, runs/case: {args.runs}\n") + + mel = _mel(seed=0) + + # ------------------------------------------------------------------ + # HBM measurement: load both variants, measure persistent + working set. + # ------------------------------------------------------------------ + print("=" * 70) + print("HBM footprint") + print("=" * 70) + for flash in (False, True): + label = "Flash=ON " if flash else "Flash=OFF" + + def load(): + return ctranslate2.models.Whisper( + args.model, + device="cuda", + compute_type=args.compute_type, + flash_attention=flash, + ) + + def work(m): + r = m.generate( + ctranslate2.StorageView.from_array(mel), + PROMPTS, + beam_size=1, + max_length=100, + ) + return r + + model, model_b, work_b = measure_hbm_delta(load, work) + print( + f" {label}: model = {fmt_bytes(model_b)}, " + f"working set (generate max_length=100) = +{fmt_bytes(work_b)}" + ) + del model + gpu_sync() + + # Theoretical Flash Attention savings on the encoder score buffer. + sq = sk = 1500 + h = 16 + saved_per_layer = sq * sk * h * 4 + print() + print( + " Score-matrix that the standard path materialises and Flash " + "avoids, per encoder layer:" + ) + print(f" {sq}x{sk}x{h} heads x FP32 = {fmt_bytes(saved_per_layer)} per layer") + print( + f" 24 layers => up to {fmt_bytes(24 * saved_per_layer)} of HBM " + f"traffic per encoder pass" + ) + print() + + # ------------------------------------------------------------------ + # Performance: encoder-only, then generate() at several lengths. + # ------------------------------------------------------------------ + print("=" * 70) + print("Performance") + print("=" * 70) + + m_off = ctranslate2.models.Whisper( + args.model, + device="cuda", + compute_type=args.compute_type, + flash_attention=False, + ) + m_on = ctranslate2.models.Whisper( + args.model, + device="cuda", + compute_type=args.compute_type, + flash_attention=True, + ) + + def enc_fn(m): + return lambda: m.encode(ctranslate2.StorageView.from_array(mel), to_cpu=False) + + for batch in (1, 4): + mel_b = np.stack([mel[0]] * batch, axis=0) + + def enc_b(m): + return lambda: m.encode( + ctranslate2.StorageView.from_array(mel_b), to_cpu=False + ) + + off_ms = _bench(enc_b(m_off), runs=args.runs) + on_ms = _bench(enc_b(m_on), runs=args.runs) + print( + f" encoder (B={batch}, Sq=Sk=1500): " + f"OFF={off_ms:6.2f} ms ON={on_ms:6.2f} ms speedup={off_ms/on_ms:.2f}x" + ) + + print() + for max_len in (30, 100, 200, 448): + + def gen_b(m, ml=max_len): + return lambda: m.generate( + ctranslate2.StorageView.from_array(mel), + PROMPTS, + beam_size=1, + max_length=ml, + ) + + off_ms = _bench(gen_b(m_off), runs=max(args.runs // 4, 3)) + on_ms = _bench(gen_b(m_on), runs=max(args.runs // 4, 3)) + print( + f" generate (max_length={max_len:3d}): " + f"OFF={off_ms:7.1f} ms ON={on_ms:7.1f} ms speedup={off_ms/on_ms:.2f}x" + ) + + +if __name__ == "__main__": + main() diff --git a/python/tests/test_flash_attention.py b/python/tests/test_flash_attention.py new file mode 100644 index 000000000..ba62ded0f --- /dev/null +++ b/python/tests/test_flash_attention.py @@ -0,0 +1,376 @@ +"""Correctness tests for the native HIP Flash Attention path. + +These tests exercise the FP16 and BF16 WMMA kernels, the decode kernel, and +the KV-cache write path against the standard MultiHeadAttention oracle. +A model snapshot (Systran/faster-whisper-medium) is required; the test is +skipped if it isn't already on disk so it can run in environments without +network access. + +The tests are written to be CI-friendly: only correctness is asserted; the +memory footprint of the materialised attention score buffer that the Flash +path avoids is printed for context, not asserted. +""" + +import os +import sys + +import numpy as np +import pytest + +# Local-build dev-loop helper: when the test is run against a non-installed +# ctranslate2 source tree (i.e. python/ctranslate2 imports a freshly-built +# DLL but the ROCm SDK is only available via the pip-installed +# _rocm_sdk_core/_rocm_sdk_libraries_custom wheels), Python 3.8+ no longer +# honours PATH for DLL search and the load fails before we get to import. +# Find the SDK directories among site-packages and add them up-front. In a +# normal CI setup this is a no-op (the SDK is already discoverable). +if sys.platform == "win32": + import site + + for site_dir in site.getsitepackages() + [site.getusersitepackages()]: + for sub in ("_rocm_sdk_core/bin", "_rocm_sdk_libraries_custom/bin"): + cand = os.path.join(site_dir, *sub.split("/")) + if os.path.isdir(cand): + try: + os.add_dll_directory(cand) + except (FileNotFoundError, OSError): + pass + +import test_utils + +import ctranslate2 + + +# ---------------------------------------------------------------------------- +# Model discovery — uses an existing CT2 Whisper snapshot if available. +# ---------------------------------------------------------------------------- +def _find_whisper_medium(): + """Return the path to a converted faster-whisper-medium snapshot, or None.""" + hf_cache = os.path.expanduser("~/.cache/huggingface/hub") + model_dir = os.path.join( + hf_cache, + "models--Systran--faster-whisper-medium", + "snapshots", + ) + if not os.path.isdir(model_dir): + return None + # snapshots//{model.bin, config.json, tokenizer.json, ...} + for sha in os.listdir(model_dir): + full = os.path.join(model_dir, sha) + if os.path.isfile(os.path.join(full, "model.bin")): + return full + return None + + +_MODEL_PATH = _find_whisper_medium() + +require_model = pytest.mark.skipif( + _MODEL_PATH is None, + reason="faster-whisper-medium snapshot not cached locally", +) + + +# ---------------------------------------------------------------------------- +# Helpers +# ---------------------------------------------------------------------------- +def _mel(seed=0): + """Synthetic 30-second mel spectrogram, deterministic for a given seed.""" + rng = np.random.default_rng(seed) + return rng.standard_normal((1, 80, 3000)).astype(np.float32) + + +def _encoder_output(model, mel): + """Run the encoder and return the result as a CPU FP32 numpy array.""" + features = ctranslate2.StorageView.from_array(mel) + out = model.encode(features, to_cpu=True) + # BF16 outputs aren't directly numpy-castable. + if out.dtype != ctranslate2.DataType.float32: + out = out.to(ctranslate2.DataType.float32) + return np.array(out) + + +def _score_buffer_bytes(seqlen_q, seqlen_k, num_heads=16, batch=1): + """Bytes the standard attention path would allocate for the + [B, H, Sq, Sk] FP32 score matrix. This is what Flash Attention's + online-softmax design avoids materialising in HBM.""" + return batch * num_heads * seqlen_q * seqlen_k * 4 + + +# ---------------------------------------------------------------------------- +# Encoder correctness — Flash=ON must match Flash=OFF within a small tolerance. +# ---------------------------------------------------------------------------- +@test_utils.require_cuda +@require_model +@pytest.mark.parametrize( + "compute_type,abs_tol,rel_tol", + [ + # FP16 path: 11-bit mantissa, accumulators are FP32. The diff is + # dominated by re-ordering of summations in WMMA vs. rocBLAS-GEMM. + ("float16", 0.5, 5e-3), + # BF16 path: only 8-bit mantissa, larger accumulated rounding. + ("bfloat16", 2.0, 3e-2), + ], +) +def test_flash_attention_encoder_matches_standard(compute_type, abs_tol, rel_tol): + """The Flash Attention encoder output must match the standard path + within float16/bfloat16 rounding noise.""" + mel = _mel(seed=0) + + m_off = ctranslate2.models.Whisper( + _MODEL_PATH, + device="cuda", + compute_type=compute_type, + flash_attention=False, + ) + m_on = ctranslate2.models.Whisper( + _MODEL_PATH, + device="cuda", + compute_type=compute_type, + flash_attention=True, + ) + + out_off = _encoder_output(m_off, mel) + out_on = _encoder_output(m_on, mel) + + assert out_off.shape == out_on.shape + diff = np.abs(out_off - out_on) + max_diff = float(diff.max()) + max_abs = float(np.abs(out_off).max()) + 1e-8 + rel_diff = max_diff / max_abs + + # Informative — what Flash Attention's online softmax saves us in HBM + # per layer (Whisper-medium encoder: Sq = Sk = 1500, 16 heads). + saved_bytes = _score_buffer_bytes(1500, 1500, num_heads=16, batch=1) + print( + f"\n[{compute_type}] encoder max_abs_diff={max_diff:.4f}, " + f"rel_diff={rel_diff*100:.3f}%; " + f"per-layer score-buffer avoided = {saved_bytes/1024/1024:.1f} MiB" + ) + + assert ( + max_diff <= abs_tol + ), f"{compute_type} encoder max diff {max_diff:.4f} exceeds {abs_tol}" + assert rel_diff <= rel_tol, ( + f"{compute_type} encoder rel diff {rel_diff*100:.3f}% exceeds " + f"{rel_tol*100:.3f}%" + ) + + +# ---------------------------------------------------------------------------- +# generate() correctness — exercises decode-kernel + KV-cache write path. +# FP16 should produce token-identical output to the standard path; BF16 may +# differ in the last few tokens due to the smaller mantissa. +# ---------------------------------------------------------------------------- +PROMPTS = [[50258, 50259, 50360]] + + +@test_utils.require_cuda +@require_model +@pytest.mark.parametrize("seed", [42, 123, 777, 999, 1234]) +def test_flash_attention_generate_fp16_token_match(seed): + """FP16 Flash=ON must produce identical token sequences to Flash=OFF + across multiple random inputs. + + This is also the regression test for the softmax-reduction block-size + bug: that bug caused the very first generated token to always be 50411 + regardless of the input. Five distinct seeds with distinct expected + tokens makes silent regression effectively impossible. + """ + mel = _mel(seed=seed) + + m_off = ctranslate2.models.Whisper( + _MODEL_PATH, + device="cuda", + compute_type="float16", + flash_attention=False, + ) + m_on = ctranslate2.models.Whisper( + _MODEL_PATH, + device="cuda", + compute_type="float16", + flash_attention=True, + ) + + feat_off = ctranslate2.StorageView.from_array(mel) + feat_on = ctranslate2.StorageView.from_array(mel) + + r_off = m_off.generate(feat_off, PROMPTS, beam_size=1, max_length=20) + r_on = m_on.generate(feat_on, PROMPTS, beam_size=1, max_length=20) + + tok_off = r_off[0].sequences_ids[0] + tok_on = r_on[0].sequences_ids[0] + assert ( + tok_off == tok_on + ), f"seed={seed}: Flash=ON produced {tok_on} but oracle is {tok_off}" + + +# ---------------------------------------------------------------------------- +# Regression test for the softmax-reduction-block-size bug. +# +# Symptom (before the fix in commit d9016a58): generate() with Flash=ON would +# emit token 50411 as the first generated token regardless of the audio input, +# because the per-row tree reduction in hip_attn_softmax_kernel ran with +# block = min(seqlen_k, 256), and for seqlen_k = 3 (the Whisper prompt +# prefill with three tokens) the reduction silently dropped the third score. +# +# This test reproduces the exact pre-fix configuration and asserts that the +# first generated token actually depends on the audio input. +# ---------------------------------------------------------------------------- +@test_utils.require_cuda +@require_model +def test_softmax_block_size_regression(): + """Different audio inputs must produce different first generated tokens + when Flash Attention is enabled. Guards against a re-emergence of the + softmax reduction block-size bug.""" + m_on = ctranslate2.models.Whisper( + _MODEL_PATH, + device="cuda", + compute_type="float16", + flash_attention=True, + ) + + first_tokens = set() + for seed in [42, 123, 777, 999, 1234]: + mel = _mel(seed=seed) + feat = ctranslate2.StorageView.from_array(mel) + # max_length must exceed the 3 prompt tokens for the generation step + # to actually emit a token; we only inspect index 0. + r = m_on.generate(feat, PROMPTS, beam_size=1, max_length=5) + seq = r[0].sequences_ids[0] + assert len(seq) >= 1, f"seed={seed}: no token generated" + first_tokens.add(seq[0]) + + assert len(first_tokens) > 1, ( + "Flash Attention generated the same first token for five distinct " + f"random inputs ({first_tokens}) — softmax-reduction block-size bug " + "may have re-surfaced (see commit d9016a58)." + ) + + +# ---------------------------------------------------------------------------- +# Batch > 1 — exercises the WMMA path with a non-trivial batch dimension. +# Whisper-medium encoder runs at Sq = Sk = 1500, so this also covers the +# largest tile-count case. Both flash and standard paths must agree per +# batch element. +# ---------------------------------------------------------------------------- +@test_utils.require_cuda +@require_model +@pytest.mark.parametrize("batch_size", [2, 4]) +def test_flash_attention_encoder_batched(batch_size): + """Encoder correctness for batch_size > 1.""" + mels = np.stack( + [ + np.random.default_rng(seed).standard_normal((80, 3000)).astype(np.float32) + for seed in range(batch_size) + ], + axis=0, + ) + + m_off = ctranslate2.models.Whisper( + _MODEL_PATH, + device="cuda", + compute_type="float16", + flash_attention=False, + ) + m_on = ctranslate2.models.Whisper( + _MODEL_PATH, + device="cuda", + compute_type="float16", + flash_attention=True, + ) + + out_off = _encoder_output(m_off, mels) + out_on = _encoder_output(m_on, mels) + assert out_off.shape == out_on.shape == (batch_size, 1500, 1024) + + diff = np.abs(out_off - out_on) + max_diff = float(diff.max()) + max_abs = float(np.abs(out_off).max()) + 1e-8 + rel = max_diff / max_abs + print( + f"\n[B={batch_size}] encoder max_abs_diff={max_diff:.4f}, " + f"rel_diff={rel*100:.3f}%" + ) + assert max_diff <= 0.5, f"B={batch_size} max diff {max_diff:.4f} > 0.5" + + +# ---------------------------------------------------------------------------- +# Prompt-prefill correctness across a few prefill lengths. +# This exercises the seqlen_q path that's neither pure decode (seqlen_q=1) +# nor the WMMA fast path (seqlen_q >= 16), forcing the scalar-tiled and +# 3-pass fallback kernels to also get a correctness pass. +# ---------------------------------------------------------------------------- +@test_utils.require_cuda +@require_model +@pytest.mark.parametrize("n_prompt", [1, 3, 5, 8]) +def test_flash_attention_variable_prompt_length(n_prompt): + """generate() must agree between Flash=ON and OFF for varying numbers + of prompt tokens — covers seqlen_q = 1, 3, 5, 8 in the prefill step, + which routes through the decode-kernel (1), 3-pass fallback (3, 5, 8) + pieces of the dispatcher.""" + mel = _mel(seed=0) + base = [50258, 50259, 50360, 50364, 1029, 290, 264, 7184] + prompt = [base[:n_prompt]] + + m_off = ctranslate2.models.Whisper( + _MODEL_PATH, + device="cuda", + compute_type="float16", + flash_attention=False, + ) + m_on = ctranslate2.models.Whisper( + _MODEL_PATH, + device="cuda", + compute_type="float16", + flash_attention=True, + ) + + r_off = m_off.generate( + ctranslate2.StorageView.from_array(mel), + prompt, + beam_size=1, + max_length=n_prompt + 5, + ) + r_on = m_on.generate( + ctranslate2.StorageView.from_array(mel), + prompt, + beam_size=1, + max_length=n_prompt + 5, + ) + tok_off = r_off[0].sequences_ids[0] + tok_on = r_on[0].sequences_ids[0] + assert ( + tok_off == tok_on + ), f"n_prompt={n_prompt}: Flash=ON {tok_on} vs Flash=OFF {tok_off}" + + +# ---------------------------------------------------------------------------- +# Informational test: what does Flash Attention save in peak HBM for +# attention score buffers? Reported, not asserted — it's hardware-independent +# and serves as documentation of the algorithmic benefit. +# ---------------------------------------------------------------------------- +def test_flash_attention_score_buffer_savings(): + """Print what Flash Attention's online softmax avoids materialising in + HBM at typical Whisper / LLM shapes. Useful context when reading the + timing numbers.""" + cases = [ + ("Whisper-medium encoder layer", 1, 16, 1500, 1500), + ("Whisper-medium decoder self-attn (max_length=200)", 1, 16, 1, 200), + ("Whisper-medium decoder cross-attn", 1, 16, 1, 1500), + ("Hypothetical LLM @ 4k context", 1, 32, 4096, 4096), + ("Hypothetical LLM @ 32k context", 1, 32, 32768, 32768), + ] + print( + "\nScore-matrix HBM footprint that Flash Attention's online softmax " + "avoids materialising (per attention call):" + ) + print(f" {'shape':45s} {'bytes':>10s}") + for name, b, h, sq, sk in cases: + n = _score_buffer_bytes(sq, sk, num_heads=h, batch=b) + if n < 1024 * 1024: + n_human = f"{n/1024:.1f} KiB" + elif n < 1024 * 1024 * 1024: + n_human = f"{n/1024/1024:.1f} MiB" + else: + n_human = f"{n/1024/1024/1024:.2f} GiB" + print(f" {name:45s} {n_human:>10s}") diff --git a/src/layers/flash_attention.cc b/src/layers/flash_attention.cc index fb13b90fc..7a9eec4fa 100644 --- a/src/layers/flash_attention.cc +++ b/src/layers/flash_attention.cc @@ -110,7 +110,11 @@ namespace ctranslate2 { // init output StorageView context(dtype, device); - ops::FlashAttention fl_attn_ops(_queries_scale, _sliding_window); + // Causal masking only applies to the decoder's autoregressive self-attention. + // Encoder self-attention (and any non-decoder usage of FlashMultiHeadAttention) + // must run NON-causal — otherwise each token would only see its predecessors, + // which is wrong for bidirectional contexts. + ops::FlashAttention fl_attn_ops(_queries_scale, _sliding_window, /*is_causal=*/_is_decoder); fl_attn_ops(queries_proj, keys_proj, values_proj, context, cached_keys, cached_values, attention, return_normalized_attention, rotary_cos, rotary_sin, rotary_interleaved, nullptr/*alibli*/, offset); diff --git a/src/layers/transformer.cc b/src/layers/transformer.cc index 4fea80f7f..6a9905e9c 100644 --- a/src/layers/transformer.cc +++ b/src/layers/transformer.cc @@ -393,7 +393,8 @@ namespace ctranslate2 { scope + "/layer", _num_heads, model.get_flag_with_default(scope + "/pre_norm", true), - model.get_enum_value(scope + "/activation"))) + model.get_enum_value(scope + "/activation"), + _use_flash_attention)) , _position_encoder(_layers.front()->get_self_attention().has_positional_embeddings() ? nullptr : build_position_encoder(model, scope + "/position_encodings", _embeddings)) diff --git a/src/layers/whisper.cc b/src/layers/whisper.cc index 401c14677..763b47b28 100644 --- a/src/layers/whisper.cc +++ b/src/layers/whisper.cc @@ -17,7 +17,8 @@ namespace ctranslate2 { scope + "/layer", _num_heads, /*pre_norm=*/true, - ops::ActivationType::GELU)) + ops::ActivationType::GELU, + model.use_flash_attention())) , _output_norm(model, scope + "/layer_norm") { } diff --git a/src/models/model.cc b/src/models/model.cc index 1d8295b0a..0caf284a5 100644 --- a/src/models/model.cc +++ b/src/models/model.cc @@ -840,14 +840,20 @@ namespace ctranslate2 { int device_id = ctranslate2::get_device_index(ctranslate2::Device::CUDA); auto dprops = ctranslate2::cuda::get_device_properties(device_id); #ifdef CT2_USE_HIP +# ifdef CT2_WITH_FLASH_ATTN + supports_flash_attention = true; // native HIP kernels available +# else supports_flash_attention = false; +# endif #else supports_flash_attention = dprops.major >= 8; #endif } if (use_flash_attention && !supports_flash_attention) { #ifdef CT2_USE_HIP - throw std::invalid_argument("FlashAttention not supported on ROCm."); + throw std::invalid_argument( + "FlashAttention is not available for this ROCm build. " + "Rebuild CTranslate2 with -DWITH_HIP=ON -DWITH_FLASH_ATTN=ON."); #else throw std::invalid_argument("FlashAttention only supports Ampere GPUs or newer."); #endif diff --git a/src/ops/flash_attention_gpu.cu b/src/ops/flash_attention_gpu.cu index 24bcbf878..2f3a77919 100644 --- a/src/ops/flash_attention_gpu.cu +++ b/src/ops/flash_attention_gpu.cu @@ -1,10 +1,11 @@ #include "ctranslate2/ops/flash_attention.h" -#ifdef CT2_WITH_FLASH_ATTN +#if defined(CT2_WITH_FLASH_ATTN) && !defined(CT2_USE_HIP) #include "ctranslate2/ops/flash-attention/flash.h" #include "ctranslate2/ops/flash-attention/static_switch.h" #endif #include "ctranslate2/ops/transpose.h" #include "cuda/utils.h" +#include "cuda/helpers.h" #include "dispatch.h" @@ -14,6 +15,9 @@ namespace ctranslate2 { namespace ops { + +#ifndef CT2_USE_HIP // CUDA-only: CUTLASS/CuTe Flash Attention kernels + #ifdef CT2_WITH_FLASH_ATTN static void set_params_fprop(Flash_fwd_params ¶ms, // sizes @@ -364,5 +368,1097 @@ namespace ctranslate2 { throw std::runtime_error("Flash attention 2 is not supported"); #endif } - } -} + +#else // CT2_USE_HIP — native HIP Flash Attention implementation + +// --------------------------------------------------------------------------- +// HIP Flash Attention — native implementation for AMD GPUs (gfx1100 / RDNA3+) +// +// Four code paths. The dispatcher (flash_attention_hip_impl, below) picks +// the most efficient one that matches the input shape and dtype: +// +// Path Used when FP-paths LDS S? +// ───────────────────────────────────────────────────────────────────────── +// WMMA seqlen_q >= 16, D % 16 == 0, fp16, bf16 no +// D in {64, 128}, RDNA3+ (gfx11) +// Decode seqlen_q == 1, D in {64, 128}, fp16, bf16 fits in +// seqlen_k fits in LDS LDS +// Tiled (scalar) seqlen_q >= 64 (BM), D in {64,80,128} fp16, bf16 no +// (covers BF16 and head_dim=80 where WMMA isn't built) +// 3-pass everything else fp16, bf16 materialised +// (also serves as correctness oracle) +// +// All paths use the same [batch, seqlen, nheads, head_dim] memory layout +// and FP32 accumulators. Q and K/V come from the layer split-heads in +// that shape; the KV-cache (cached_keys, cached_values) shares the same +// layout but with a larger second-dim allocation (cache_size >= offset + +// seqlen_new) to amortise the cost of growing it. +// +// Limitations: +// - Inference only — no backward pass. +// - Rotary embeddings and ALiBi must be pre-applied by the layer. +// - WMMA path is FP16+BF16 only and gfx11+ (uses the wave32 16x16x16 +// WMMA built-in). +// --------------------------------------------------------------------------- + + + // ------------------------------------------------------------------- + // Kernel 0 — KV-cache write + // Copies new_kv[b, t, h, d] → cache[b, write_offset+t, h, d]. + // Grid: (seqlen_new, nheads, batch) + // Block: (head_dim, 1, 1) + // ------------------------------------------------------------------- + template + __global__ void hip_kv_cache_write_kernel( + const scalar_t* __restrict__ new_kv, // [batch, seqlen_new, nheads, head_dim] + scalar_t* __restrict__ cache, // [batch, cache_size, nheads, head_dim] + const int seqlen_new, + const int cache_size, + const int nheads, + const int head_dim, + const int write_offset) + { + const int b = blockIdx.z; + const int t = blockIdx.y; + const int h = blockIdx.x; + const int d = threadIdx.x; + if (d >= head_dim) return; + + const int src = b * seqlen_new * nheads * head_dim + t * nheads * head_dim + h * head_dim + d; + const int dst = b * cache_size * nheads * head_dim + + (write_offset + t) * nheads * head_dim + h * head_dim + d; + cache[dst] = new_kv[src]; + } + + // ------------------------------------------------------------------- + // Kernel 1 — Q @ K^T + // Grid: (ceildiv(seqlen_q * seqlen_k, 256), nheads, batch) + // Block: (256, 1, 1) + // Each thread computes one element of S[b, h, q, k]. + // + // k_time_stride: stride (in elements) between consecutive K time steps + // within one batch. For a K tensor of shape [batch, K_seqlen, nheads, head_dim] + // this equals K_seqlen * nheads * head_dim. When K is a view into a + // larger KV-cache buffer the allocation seqlen may differ from the + // logically active seqlen, so the stride must be passed explicitly. + // ------------------------------------------------------------------- + template + __global__ void hip_attn_qk_kernel( + const scalar_t* __restrict__ Q, // [batch, seqlen_q, nheads, head_dim] + const scalar_t* __restrict__ K, // [batch, K_alloc, nheads, head_dim] + float* __restrict__ S, // [batch, nheads, seqlen_q, seqlen_k] + const int seqlen_q, + const int seqlen_k, + const int k_time_stride, // K_alloc * nheads * head_dim + const int nheads, + const int head_dim, + const float scale) + { + const int b = blockIdx.z; + const int h = blockIdx.y; + const int idx = blockIdx.x * blockDim.x + threadIdx.x; + const int q = idx / seqlen_k; + const int k = idx % seqlen_k; + if (q >= seqlen_q) return; + + float dot = 0.f; + const int q_base = b * seqlen_q * nheads * head_dim + q * nheads * head_dim + h * head_dim; + const int k_base = b * k_time_stride + k * nheads * head_dim + h * head_dim; + for (int d = 0; d < head_dim; ++d) + dot += static_cast(Q[q_base + d]) * static_cast(K[k_base + d]); + + S[b * nheads * seqlen_q * seqlen_k + h * seqlen_q * seqlen_k + q * seqlen_k + k] = + dot * scale; + } + + // ------------------------------------------------------------------- + // Kernel 2 — row-wise softmax with optional causal mask + // Grid: (seqlen_q, nheads, batch) + // Block: next-power-of-2 up to min(seqlen_k, 256) + // + // The block size MUST be a power of two: the per-row max/sum tree + // reduction (`for (s = blockDim.x >> 1; s > 0; s >>= 1)`) drops + // elements at odd boundaries otherwise. Extra threads beyond + // seqlen_k contribute identity values (-1e9 for max, 0 for sum) and + // are harmless. See commit d9016a58 — the bug this caught caused + // generate() to always emit token 50411 because the 3-token prompt + // prefill ran the reduction with blockDim.x = 3. + // ------------------------------------------------------------------- + __global__ void hip_attn_softmax_kernel( + float* __restrict__ S, // [batch, nheads, seqlen_q, seqlen_k] — modified in place + const int nheads, + const int seqlen_q, + const int seqlen_k, + const bool is_causal, + const int q_offset) // position of q[0] in the full sequence (for KV-cache) + { + const int b = blockIdx.z; + const int h = blockIdx.y; + const int q = blockIdx.x; + + float* row = S + (b * nheads + h) * seqlen_q * seqlen_k + q * seqlen_k; + + extern __shared__ float smem[]; // [blockDim.x] for reduction + + // --- apply causal mask --- + const int q_pos = q + q_offset; + for (int k = threadIdx.x; k < seqlen_k; k += blockDim.x) + if (is_causal && k > q_pos) row[k] = -1e9f; + __syncthreads(); + + // --- find row max --- + float local_max = -1e9f; + for (int k = threadIdx.x; k < seqlen_k; k += blockDim.x) + local_max = fmaxf(local_max, row[k]); + smem[threadIdx.x] = local_max; + __syncthreads(); + for (int s = blockDim.x >> 1; s > 0; s >>= 1) { + if (threadIdx.x < s) smem[threadIdx.x] = fmaxf(smem[threadIdx.x], smem[threadIdx.x + s]); + __syncthreads(); + } + const float row_max = smem[0]; + __syncthreads(); + + // --- exp and row sum --- + float local_sum = 0.f; + for (int k = threadIdx.x; k < seqlen_k; k += blockDim.x) { + const float e = expf(row[k] - row_max); + row[k] = e; + local_sum += e; + } + smem[threadIdx.x] = local_sum; + __syncthreads(); + for (int s = blockDim.x >> 1; s > 0; s >>= 1) { + if (threadIdx.x < s) smem[threadIdx.x] += smem[threadIdx.x + s]; + __syncthreads(); + } + const float row_sum = smem[0]; + __syncthreads(); + + // --- normalise --- + const float inv_sum = (row_sum > 0.f) ? 1.f / row_sum : 0.f; + for (int k = threadIdx.x; k < seqlen_k; k += blockDim.x) + row[k] *= inv_sum; + } + + // ------------------------------------------------------------------- + // Kernel 3 — O = P @ V + // Grid: (seqlen_q, nheads, batch) + // Block: (head_dim, 1, 1) — head_dim <= 1024 + // Each thread accumulates one output channel O[b, q, h, d]. + // + // v_time_stride: same concept as k_time_stride in Kernel 1. + // ------------------------------------------------------------------- + template + __global__ void hip_attn_ov_kernel( + const float* __restrict__ P, // [batch, nheads, seqlen_q, seqlen_k] + const scalar_t* __restrict__ V, // [batch, V_alloc, nheads, head_dim] + scalar_t* __restrict__ O, // [batch, seqlen_q, nheads, head_dim] + const int seqlen_k, + const int v_time_stride, // V_alloc * nheads * head_dim + const int nheads, + const int head_dim) + { + const int b = blockIdx.z; + const int h = blockIdx.y; + const int q = blockIdx.x; + const int d = threadIdx.x; + if (d >= head_dim) return; + + const int seqlen_q = gridDim.x; + const float* p_row = + P + b * nheads * seqlen_q * seqlen_k + h * seqlen_q * seqlen_k + q * seqlen_k; + + float out = 0.f; + for (int k = 0; k < seqlen_k; ++k) + out += p_row[k] * + static_cast(V[b * v_time_stride + k * nheads * head_dim + h * head_dim + d]); + + O[b * seqlen_q * nheads * head_dim + q * nheads * head_dim + h * head_dim + d] = + static_cast(out); + } + +#if defined(__gfx1100__) || defined(__gfx1101__) || defined(__gfx1102__) || defined(__gfx1103__) || !defined(__HIP_DEVICE_COMPILE__) + // ------------------------------------------------------------------- + // RDNA3 WMMA Flash Attention forward kernel — FP16 only. + // + // Uses the wave32 16x16x16 fp16-input/fp32-accumulator WMMA built-in + // (__builtin_amdgcn_wmma_f32_16x16x16_f16_w32) for Q·K^T and P·V. + // Other paths in this file (3-pass / scalar tiled) remain for BF16, + // for non-multiple-of-16 head dimensions, and as the correctness oracle. + // + // Layout (wave32, RDNA3 — verified empirically via wmma_probe.cu): + // A-fragment (16x16 fp16, row-major): each lane holds one full row + // a_frag[j] = A[lane % 16, j], for j = 0..15 + // (Lanes 16..31 carry a duplicate of rows 0..15.) + // B-fragment (16x16 fp16, col-major): each lane holds one full column + // b_frag[i] = B[i, lane % 16], for i = 0..15 + // C-fragment (16x16 fp32, accumulator): each lane holds 8 elements + // c_frag[s] = C[2*s + (lane >> 4), lane % 16], for s = 0..7 + // (Lanes 0..15 hold even rows, lanes 16..31 hold odd rows.) + // + // S = Q · K^T is mapped to D = A · B with: + // A[i][k] = Q[i][k] (Q tile, row-major in LDS) + // B[k][j] = K[j][k] (K tile transposed view) + // + // Block layout: + // grid: (ceildiv(seqlen_q, BM), nheads, batch_size) + // block: 32 threads (one wave32) + // BM = 16 query rows per block, BN = 16 key tokens per K/V tile + // + // LDS per block (D = 64): + // q_lds[BM][D] fp16 — Q tile, scaled by softmax-scale on load + // k_lds[BN][D] fp16 — current K tile + // v_lds[BN][D] fp16 — current V tile + // p_lds[BM][BN] fp16 — softmaxed S (used as A-fragment for P·V) + // s_scratch[BM][BN] fp32 — temporary S before softmax + // m_lds[BM] fp32 — running max per query row + // l_lds[BM] fp32 — running sum (denominator) per query row + // alpha_lds[BM] fp32 — softmax correction factor per row + // Total ≈ 18 KiB / block, well within the 64 KiB budget. + // + // O is held entirely in registers as `o_frag[D/16]` (4 v8f32 fragments + // for D = 64) per lane. Each fragment covers a 16-column slice of O. + // ------------------------------------------------------------------- + // HalfT is the wave-level fp16/bf16 element type and determines which + // WMMA built-in we dispatch (f16 vs bf16). Both variants share the + // same wave32 accumulator-fragment layout on RDNA3. + template + __global__ void hip_flash_attn_wmma_fp16( + const HalfT* __restrict__ Q, + const HalfT* __restrict__ K, + const HalfT* __restrict__ V, + HalfT* __restrict__ O, + const int seqlen_q, + const int seqlen_k, + const int k_time_stride, + const int v_time_stride, + const int nheads, + const float scale, + const bool is_causal, + const int q_offset) + { + static_assert(D % 16 == 0, "head_dim must be a multiple of 16 for WMMA"); + constexpr int BM = 16; + constexpr int BN = 16; + constexpr int DT = D / 16; // number of D-tiles + + using v16f16 = HalfT __attribute__((ext_vector_type(16))); + using v8f32 = float __attribute__((ext_vector_type(8))); + + const int lane = threadIdx.x; // 0..31 + const int b = blockIdx.z; + const int h = blockIdx.y; + const int q_tile = blockIdx.x; + const int q_row_0 = q_tile * BM; + + __shared__ HalfT q_lds[BM][D]; + __shared__ HalfT k_lds[BN][D]; + __shared__ HalfT v_lds[BN][D]; + __shared__ HalfT p_lds[BM][BN]; + __shared__ float s_scratch[BM][BN]; + __shared__ float m_lds[BM]; + __shared__ float l_lds[BM]; + __shared__ float alpha_lds[BM]; + + // ---- Load Q tile, pre-scaled (so S = Q·K^T already has softmax-scale) ---- + for (int idx = lane; idx < BM * D; idx += 32) { + const int row = idx / D; + const int col = idx % D; + const int q_row = q_row_0 + row; + float v = 0.f; + if (q_row < seqlen_q) + v = static_cast(Q[b * seqlen_q * nheads * D + + q_row * nheads * D + h * D + col]) * scale; + q_lds[row][col] = static_cast(v); + } + + // ---- Initialise running state and output accumulators ---- + if (lane < BM) { + m_lds[lane] = -1e30f; + l_lds[lane] = 0.f; + } + v8f32 o_frag[DT]; + #pragma unroll + for (int t = 0; t < DT; ++t) + #pragma unroll + for (int s = 0; s < 8; ++s) o_frag[t][s] = 0.f; + + __syncthreads(); + + // ---- Iterate over K/V tiles ---- + const int num_k_tiles = (seqlen_k + BN - 1) / BN; + for (int kt = 0; kt < num_k_tiles; ++kt) { + const int k_row_0 = kt * BN; + + // -- Load K and V tiles -- + for (int idx = lane; idx < BN * D; idx += 32) { + const int row = idx / D; + const int col = idx % D; + const int k_row = k_row_0 + row; + HalfT kv_k = static_cast(0); + HalfT kv_v = static_cast(0); + if (k_row < seqlen_k) { + kv_k = K[b * k_time_stride + k_row * nheads * D + h * D + col]; + kv_v = V[b * v_time_stride + k_row * nheads * D + h * D + col]; + } + k_lds[row][col] = kv_k; + v_lds[row][col] = kv_v; + } + __syncthreads(); + + // -- Compute S[BM=16][BN=16] = Q · K^T using WMMA -- + v8f32 s_frag; + #pragma unroll + for (int s = 0; s < 8; ++s) s_frag[s] = 0.f; + + // Inner reduction over D in 16-element chunks. + // A (Q row, row-major): a_frag[j] = q_lds[lane & 15][inner*16 + j] + // B (K transposed, col-major): b_frag[i] = k_lds[lane & 15][inner*16 + i] + // (col-major B's column index = lane mod 16, mapping to the K row j; + // inner reduction index k maps to the D-position within the chunk.) + #pragma unroll + for (int inner = 0; inner < DT; ++inner) { + v16f16 a_frag, b_frag; + const int a_row = lane & 15; + const int b_col = lane & 15; + #pragma unroll + for (int x = 0; x < 16; ++x) { + a_frag[x] = q_lds[a_row][inner * 16 + x]; + b_frag[x] = k_lds[b_col][inner * 16 + x]; + } + if constexpr (std::is_same::value) + s_frag = __builtin_amdgcn_wmma_f32_16x16x16_f16_w32(a_frag, b_frag, s_frag); + else + s_frag = __builtin_amdgcn_wmma_f32_16x16x16_bf16_w32(a_frag, b_frag, s_frag); + } + + // -- Write S to LDS with causal/OOB masking baked in -- + { + const int s_col = lane & 15; + const int k_col_g = k_row_0 + s_col; + #pragma unroll + for (int s = 0; s < 8; ++s) { + const int s_row = 2 * s + (lane >> 4); + const int q_row_g = q_row_0 + s_row; + const int q_pos = q_row_g + q_offset; + const bool oob = (k_col_g >= seqlen_k) || (q_row_g >= seqlen_q); + const bool masked = is_causal && k_col_g > q_pos; + s_scratch[s_row][s_col] = (oob || masked) ? -1e30f : s_frag[s]; + } + } + __syncthreads(); + + // -- Per-row softmax + online update (one thread per row, 16 rows / 32 threads) -- + if (lane < BM) { + const int row = lane; + const float m_old = m_lds[row]; + const float l_old = l_lds[row]; + + float m_tile = -1e30f; + #pragma unroll + for (int c = 0; c < BN; ++c) + m_tile = fmaxf(m_tile, s_scratch[row][c]); + + const float m_new = fmaxf(m_old, m_tile); + const float alpha = (m_old == -1e30f) ? 0.f : __expf(m_old - m_new); + + float l_tile = 0.f; + #pragma unroll + for (int c = 0; c < BN; ++c) { + const float v = s_scratch[row][c]; + const float e = (v <= -1e29f) ? 0.f : __expf(v - m_new); + p_lds[row][c] = static_cast(e); // P for the P·V WMMA + l_tile += e; + } + + m_lds[row] = m_new; + l_lds[row] = alpha * l_old + l_tile; + alpha_lds[row] = alpha; + } + __syncthreads(); + + // -- Scale o_frag by per-row alpha -- + { + #pragma unroll + for (int t = 0; t < DT; ++t) { + #pragma unroll + for (int s = 0; s < 8; ++s) { + const int o_row = 2 * s + (lane >> 4); + o_frag[t][s] *= alpha_lds[o_row]; + } + } + } + + // -- Compute O += P · V using WMMA (one V-N-tile per D-tile of O) -- + // A (P, row-major): a_frag[j] = p_lds[lane & 15][j] (one row of P) + // B (V, col-major): b_frag[i] = v_lds[i][t*16 + lane&15] (column t*16+col of V) + // Inner reduction goes over BN = 16 (the K-token dimension), in one step. + { + v16f16 a_frag; + const int a_row = lane & 15; + #pragma unroll + for (int x = 0; x < 16; ++x) + a_frag[x] = p_lds[a_row][x]; + + #pragma unroll + for (int t = 0; t < DT; ++t) { + v16f16 b_frag; + const int b_col = lane & 15; + #pragma unroll + for (int i = 0; i < 16; ++i) + b_frag[i] = v_lds[i][t * 16 + b_col]; + + if constexpr (std::is_same::value) + o_frag[t] = __builtin_amdgcn_wmma_f32_16x16x16_f16_w32(a_frag, b_frag, o_frag[t]); + else + o_frag[t] = __builtin_amdgcn_wmma_f32_16x16x16_bf16_w32(a_frag, b_frag, o_frag[t]); + } + } + __syncthreads(); + } + + // ---- Normalise and store O ---- + if (lane < BM) + alpha_lds[lane] = (l_lds[lane] > 0.f) ? 1.f / l_lds[lane] : 0.f; // reuse alpha_lds as inv_l + __syncthreads(); + + #pragma unroll + for (int t = 0; t < DT; ++t) { + const int o_col_base = t * 16 + (lane & 15); + #pragma unroll + for (int s = 0; s < 8; ++s) { + const int o_row = 2 * s + (lane >> 4); + const int q_row_g = q_row_0 + o_row; + if (q_row_g < seqlen_q && o_col_base < D) { + const float val = o_frag[t][s] * alpha_lds[o_row]; + O[b * seqlen_q * nheads * D + q_row_g * nheads * D + h * D + o_col_base] + = static_cast(val); + } + } + } + } +#endif // gfx11 + + // ------------------------------------------------------------------- + // Tiled fused forward attention (Flash Attention 2 algorithm) + // + // Each grid block processes BM contiguous query rows of one (batch, head). + // Grid: (ceildiv(seqlen_q, BM), nheads, batch) + // Block: (BM, 1, 1) — one thread per query row. + // + // Per-thread state (kept in registers, never spilled to HBM): + // q_reg[D] : the thread's Q row, FP32 + // acc[D] : running output accumulator, FP32 + // m_i, l_i : running max / normaliser of the online softmax + // + // Per-block shared memory: + // s_k[BN][D] : current K tile + // s_v[BN][D] : current V tile + // (BM threads cooperatively load BN·D elements per tile.) + // + // For every K/V tile we compute s_tile[BN] = q · k_t for the local Q row, + // apply the bounds + causal mask, then perform the standard online-softmax + // update of (m_i, l_i, acc). After all tiles we divide the accumulator + // by l_i and store it back to HBM. + // + // Template parameters allow the compiler to fully unroll the inner D loops + // and to keep q_reg/acc entirely in registers. Supported D values are + // currently 64, 80, 128 (covering Whisper / common transformer heads). + // ------------------------------------------------------------------- + template + __global__ void hip_flash_attn_fwd_tiled( + const scalar_t* __restrict__ Q, // [batch, seqlen_q, nheads, D] + const scalar_t* __restrict__ K, // [batch, K_alloc, nheads, D] + const scalar_t* __restrict__ V, // [batch, V_alloc, nheads, D] + scalar_t* __restrict__ O, // [batch, seqlen_q, nheads, D] + const int seqlen_q, + const int seqlen_k, + const int k_time_stride, // K_alloc * nheads * D (batch stride) + const int v_time_stride, // V_alloc * nheads * D + const int nheads, + const float scale, + const bool is_causal, + const int q_offset) // absolute position of q[0] + { + const int b = blockIdx.z; + const int h = blockIdx.y; + const int q_tile = blockIdx.x; + const int tid = threadIdx.x; // 0 .. BM-1 + const int q = q_tile * BM + tid; // global query row + + __shared__ scalar_t s_k[BN][D]; + __shared__ scalar_t s_v[BN][D]; + + // -- Load this thread's Q row into registers (FP32, scaled once) -- + float q_reg[D]; + const bool q_valid = (q < seqlen_q); + if (q_valid) { + const int q_base = b * seqlen_q * nheads * D + q * nheads * D + h * D; + #pragma unroll + for (int d = 0; d < D; ++d) + q_reg[d] = static_cast(Q[q_base + d]) * scale; + } else { + #pragma unroll + for (int d = 0; d < D; ++d) q_reg[d] = 0.f; + } + + // -- Online softmax state -- + float m_i = -1e30f; + float l_i = 0.f; + float acc[D]; + #pragma unroll + for (int d = 0; d < D; ++d) acc[d] = 0.f; + + const int num_tiles = (seqlen_k + BN - 1) / BN; + const int q_pos = q + q_offset; + + for (int t = 0; t < num_tiles; ++t) { + const int k_start = t * BN; + + // -- Cooperatively load K and V tiles into LDS -- + // BM threads load BN*D elements each (== BN*D / BM per thread). + for (int idx = tid; idx < BN * D; idx += BM) { + const int kk = idx / D; + const int dd = idx % D; + const int kpos = k_start + kk; + if (kpos < seqlen_k) { + const int k_base = b * k_time_stride + kpos * nheads * D + h * D; + const int v_base = b * v_time_stride + kpos * nheads * D + h * D; + s_k[kk][dd] = K[k_base + dd]; + s_v[kk][dd] = V[v_base + dd]; + } else { + s_k[kk][dd] = static_cast(0); + s_v[kk][dd] = static_cast(0); + } + } + __syncthreads(); + + if (q_valid) { + // -- s_tile[kk] = q · k_t, with bounds + causal mask -- + float s_tile[BN]; + float m_tile = -1e30f; + #pragma unroll + for (int kk = 0; kk < BN; ++kk) { + const int kpos = k_start + kk; + const bool oob = kpos >= seqlen_k; + const bool masked = is_causal && kpos > q_pos; + if (oob || masked) { + s_tile[kk] = -1e30f; + } else { + float dot = 0.f; + #pragma unroll + for (int d = 0; d < D; ++d) + dot += q_reg[d] * static_cast(s_k[kk][d]); + s_tile[kk] = dot; + if (dot > m_tile) m_tile = dot; + } + } + + // -- Online softmax update -- + const float m_new = fmaxf(m_i, m_tile); + const float alpha = (m_i == -1e30f) ? 0.f : __expf(m_i - m_new); + + float l_tile = 0.f; + #pragma unroll + for (int kk = 0; kk < BN; ++kk) { + // exp(-inf - finite) == 0, but __expf is undefined for -inf + // on some HIP targets, so guard explicitly. + float e = (s_tile[kk] <= -1e29f) ? 0.f : __expf(s_tile[kk] - m_new); + s_tile[kk] = e; + l_tile += e; + } + + // acc = alpha * acc + s_tile @ V_tile + #pragma unroll + for (int d = 0; d < D; ++d) { + float a = alpha * acc[d]; + #pragma unroll + for (int kk = 0; kk < BN; ++kk) + a += s_tile[kk] * static_cast(s_v[kk][d]); + acc[d] = a; + } + l_i = alpha * l_i + l_tile; + m_i = m_new; + } + __syncthreads(); + } + + // -- Final normalisation and store -- + if (q_valid) { + const float inv_l = (l_i > 0.f) ? 1.f / l_i : 0.f; + const int o_base = b * seqlen_q * nheads * D + q * nheads * D + h * D; + #pragma unroll + for (int d = 0; d < D; ++d) + O[o_base + d] = static_cast(acc[d] * inv_l); + } + } + + // ------------------------------------------------------------------- + // Decode-optimised kernel (seqlen_q == 1) + // + // The tiled forward kernel above uses one thread per Q row. During + // autoregressive generation seqlen_q == 1, so that design would leave + // BM-1 of BM threads idle. This kernel inverts the parallelisation: + // a single block handles one (batch, head), and the threads cooperate + // along the K dimension instead. + // + // Grid: (1, nheads, batch) + // Block: BLOCK threads (BLOCK >= D, both powers of two) + // + // Phase 1 Compute all seqlen_k scores S[k] = Q . K[k] (each thread + // handles k = tid, tid+BLOCK, …) and stash them in LDS. + // Phase 2 Block-wide tree reduction for row max → exp(S - max) → sum. + // Phase 3 Output: thread tid (tid < D) accumulates one channel + // O[d] = Σ_k P[k] · V[k][d] / sum. V[k][.] is loaded with + // the natural [k, d] memory layout, so the BLOCK threads + // read coalesced D-wide vectors per k. + // + // LDS layout (one dynamic shared array; scratch is reused across + // phases — reduce buffer in phase 2, V-tile in phase 3): + // [0 .. D) q_lds (FP32 scaled Q) + // [D .. D+seqlen_k) s_lds (FP32 scores) + // [D+seqlen_k .. +max(BLOCK, V_TILE*D)) scratch + // + // V_TILE controls Phase-3 V-tiling: instead of every output channel + // streaming all seqlen_k V values from HBM with no reuse, the BLOCK + // threads cooperatively stage a V_TILE-wide slab of V into LDS once, + // then the D output-channel threads accumulate against the cached + // tile. HBM reads for V drop by ~BLOCK/D for that phase. + // + // gfx1100 LDS budget is 64 KiB. For D=64, BLOCK=64, V_TILE=64 the + // scratch region is 64*64*4 = 16 KiB, so seqlen_k can reach ~12 k. + // ------------------------------------------------------------------- + template + __global__ void hip_flash_decode_kernel( + const scalar_t* __restrict__ Q, // [batch, 1, nheads, D] + const scalar_t* __restrict__ K, // [batch, K_alloc, nheads, D] + const scalar_t* __restrict__ V, // [batch, V_alloc, nheads, D] + scalar_t* __restrict__ O, // [batch, 1, nheads, D] + const int seqlen_k, + const int k_time_stride, + const int v_time_stride, + const int nheads, + const float scale, + const bool is_causal, + const int q_offset) + { + const int b = blockIdx.z; + const int h = blockIdx.y; + const int tid = threadIdx.x; + + extern __shared__ float smem[]; + float* q_lds = smem; // D floats + float* s_lds = smem + D; // seqlen_k floats + float* scratch = s_lds + seqlen_k; // reduce_buf OR v_tile + + // ---- Load Q (FP32, scaled once) into LDS ---- + for (int d = tid; d < D; d += BLOCK) + q_lds[d] = static_cast(Q[b * nheads * D + h * D + d]) * scale; + __syncthreads(); + + // ---- Phase 1: S[k] = q · k_t (+ causal mask) ---- + const int q_pos = q_offset; // seqlen_q == 1, so q_pos == offset + for (int k = tid; k < seqlen_k; k += BLOCK) { + if (is_causal && k > q_pos) { + s_lds[k] = -1e30f; + continue; + } + const int k_base = b * k_time_stride + k * nheads * D + h * D; + float dot = 0.f; + #pragma unroll + for (int d = 0; d < D; ++d) + dot += q_lds[d] * static_cast(K[k_base + d]); + s_lds[k] = dot; + } + __syncthreads(); + + // ---- Phase 2a: row max via block-wide reduction (scratch=reduce_buf) ---- + float local_max = -1e30f; + for (int k = tid; k < seqlen_k; k += BLOCK) + local_max = fmaxf(local_max, s_lds[k]); + scratch[tid] = local_max; + __syncthreads(); + for (int s = BLOCK >> 1; s > 0; s >>= 1) { + if (tid < s) scratch[tid] = fmaxf(scratch[tid], scratch[tid + s]); + __syncthreads(); + } + const float row_max = scratch[0]; + __syncthreads(); + + // ---- Phase 2b: exp(S - max) and row sum ---- + float local_sum = 0.f; + for (int k = tid; k < seqlen_k; k += BLOCK) { + const float e = (s_lds[k] <= -1e29f) ? 0.f : __expf(s_lds[k] - row_max); + s_lds[k] = e; + local_sum += e; + } + scratch[tid] = local_sum; + __syncthreads(); + for (int s = BLOCK >> 1; s > 0; s >>= 1) { + if (tid < s) scratch[tid] += scratch[tid + s]; + __syncthreads(); + } + const float row_sum = scratch[0]; + const float inv_sum = (row_sum > 0.f) ? 1.f / row_sum : 0.f; + __syncthreads(); + + // ---- Phase 3: O[d] = Σ_k P[k] · V[k][d] / sum, with V-tiling ---- + // scratch is reinterpreted as v_tile (row-major [V_TILE][D] FP32). + float* v_tile = scratch; + float acc = 0.f; + + for (int kt = 0; kt < seqlen_k; kt += V_TILE) { + // -- Cooperative load: BLOCK threads stage V_TILE * D elements -- + for (int idx = tid; idx < V_TILE * D; idx += BLOCK) { + const int kk = idx / D; + const int dd = idx % D; + const int kpos = kt + kk; + float v_val = 0.f; + if (kpos < seqlen_k) { + const int v_idx = b * v_time_stride + kpos * nheads * D + h * D + dd; + v_val = static_cast(V[v_idx]); + } + v_tile[idx] = v_val; + } + __syncthreads(); + + // -- Accumulate: each output-channel thread sums over this tile -- + if (tid < D) { + const int kk_max = min((int)V_TILE, seqlen_k - kt); + for (int kk = 0; kk < kk_max; ++kk) { + const float p = s_lds[kt + kk]; + acc += p * v_tile[kk * D + tid]; + } + } + __syncthreads(); + } + + if (tid < D) + O[b * nheads * D + h * D + tid] = static_cast(acc * inv_sum); + } + + // ------------------------------------------------------------------- + // Dispatcher called from FlashAttention::compute below. + // + // KV-cache semantics (mirrors the CUDA CUTLASS path): + // offset == 0 — prefilling / encoder self-attention: + // keys/values contain the full context; cached_keys/values are + // filled or updated by the layer *before* this call. + // offset > 0 — autoregressive decoder step: + // keys/values contain only the NEW tokens (seqlen_new, typically 1). + // cached_keys/values hold tokens [0 .. offset-1] and have capacity + // for at least offset+seqlen_new tokens. + // We must: 1) write new tokens into the cache at position offset, + // 2) run attention with Q vs. full cache [0 .. offset+seqlen_new). + // ------------------------------------------------------------------- + template + static void flash_attention_hip_impl( + StorageView& queries, + StorageView& keys, + StorageView& values, + StorageView& output, + StorageView* cached_keys, + StorageView* cached_values, + float queries_scale, + bool is_causal, + dim_t offset) + { + const dim_t batch_size = queries.dim(0); + const dim_t seqlen_q = queries.dim(1); + const dim_t num_heads = queries.dim(2); + const dim_t head_dim = queries.dim(3); + const dim_t seqlen_new = keys.dim(1); // NEW tokens this step + + using DevT = typename cuda::DeviceType::type; + const DevT* q_ptr = reinterpret_cast(queries.data()); + + hipStream_t stream = cuda::get_cuda_stream(); + + // Determine K/V pointers and their batch-stride depending on whether + // we are using the KV-cache or the raw keys/values tensors. + const DevT* k_ptr; + const DevT* v_ptr; + dim_t seqlen_k; // number of KEY tokens to attend over + dim_t k_time_stride; // K_alloc * nheads * head_dim (batch stride for K) + dim_t v_time_stride; // V_alloc * nheads * head_dim + + if (offset == 0) { + // --- Prefilling / encoder --- + k_ptr = reinterpret_cast(keys.data()); + v_ptr = reinterpret_cast(values.data()); + seqlen_k = seqlen_new; + k_time_stride = seqlen_k * num_heads * head_dim; + v_time_stride = seqlen_k * num_heads * head_dim; + } else { + // --- Autoregressive decode step --- + // 1. Write new keys/values into the cache at position `offset`. + const dim_t cache_size = cached_keys->dim(1); + { + // grid: (num_heads, seqlen_new, batch) + // blockIdx.x = head index → matches kernel's h = blockIdx.x + // blockIdx.y = time step → matches kernel's t = blockIdx.y + // blockIdx.z = batch index → matches kernel's b = blockIdx.z + dim3 grid(num_heads, seqlen_new, batch_size); + const int block = min((int)head_dim, 1024); + hipLaunchKernelGGL( + hip_kv_cache_write_kernel, + grid, block, 0, stream, + reinterpret_cast(keys.data()), + reinterpret_cast(cached_keys->data()), + (int)seqlen_new, (int)cache_size, + (int)num_heads, (int)head_dim, (int)offset); + hipLaunchKernelGGL( + hip_kv_cache_write_kernel, + grid, block, 0, stream, + reinterpret_cast(values.data()), + reinterpret_cast(cached_values->data()), + (int)seqlen_new, (int)cache_size, + (int)num_heads, (int)head_dim, (int)offset); + } + // 2. Attend over the full (now updated) cache. + k_ptr = reinterpret_cast(cached_keys->data()); + v_ptr = reinterpret_cast(cached_values->data()); + seqlen_k = offset + seqlen_new; + k_time_stride = cache_size * num_heads * head_dim; + v_time_stride = cache_size * num_heads * head_dim; + } + + output.resize(queries.shape()); + DevT* o_ptr = reinterpret_cast(output.data()); + + // ---------------------------------------------------------------- + // Fast path A: decode-optimised kernel for seqlen_q == 1. + // Used for every autoregressive generation step. One block per + // (batch, head) — threads parallelise across the K dimension. + // ---------------------------------------------------------------- + if (seqlen_q == 1) { + // LDS budget: D + seqlen_k + max(BLOCK, V_TILE*D) FP32 elements. + // gfx1100 has 64 KiB per block. V_TILE is chosen per head_dim to + // maximise per-sync reuse without busting LDS. + const dim_t lds_budget_floats = 64 * 1024 / sizeof(float); + auto launch_decode = + [&](auto head_dim_const, auto block_const, auto vtile_const) -> bool { + constexpr int D = decltype(head_dim_const)::value; + constexpr int BLOCK = decltype(block_const)::value; + constexpr int V_TILE = decltype(vtile_const)::value; + if (head_dim != D) return false; + const size_t scratch = + std::max((size_t)BLOCK, (size_t)V_TILE * D); + const size_t lds_floats = (size_t)D + (size_t)seqlen_k + scratch; + if ((dim_t)lds_floats > lds_budget_floats) return false; + dim3 grid(1, num_heads, batch_size); + dim3 block(BLOCK); + const size_t lds_bytes = lds_floats * sizeof(float); + hipLaunchKernelGGL((hip_flash_decode_kernel), + grid, block, lds_bytes, stream, + q_ptr, k_ptr, v_ptr, o_ptr, + (int)seqlen_k, + (int)k_time_stride, (int)v_time_stride, + (int)num_heads, queries_scale, + is_causal, (int)offset); + return true; + }; + // BLOCK > D so the V-tile load is co-operative across more threads + // than there are output channels — gives ~BLOCK/D HBM-bandwidth + // reduction on V in phase 3 plus more parallelism in phase 1. + // V_TILE picked large to amortise the sync cost over more useful + // multiply-adds (D=64 -> 128 K-tokens per tile -> 32 KiB scratch). + if (launch_decode(std::integral_constant{}, + std::integral_constant{}, + std::integral_constant{})) return; + if (launch_decode(std::integral_constant{}, + std::integral_constant{}, + std::integral_constant{})) return; + // else: fall through to 3-pass for very long sequences or + // unsupported head dimensions. + } + + // ---------------------------------------------------------------- + // Fast path B: WMMA-accelerated kernel on RDNA3 (gfx11+). + // One wave32 per block (BM_W = BN = 16), used for both FP16 and BF16 + // inputs. Wave32 fragment layout is identical between fp16 and bf16 + // on gfx11; only the WMMA built-in differs (selected via if constexpr). + // + // Two gates here are deliberate: + // + // 1. The runtime gate (`wmma_supported_at_runtime`) inspects the + // active device's gcnArchName so multi-arch wheels (e.g. CI builds + // against gfx1030;gfx1100;…) only dispatch WMMA when running on a + // GPU that actually has the WMMA built-in. On other archs we + // fall through to the scalar tiled / 3-pass kernels. + // + // 2. The compile-time gate (`#if defined(__gfx11..) || !__HIP_DEVICE_COMPILE__`) + // matches the gate around `hip_flash_attn_wmma_fp16`'s definition. + // In a non-gfx11 device pass the kernel template isn't visible, so + // even though `hipLaunchKernelGGL` only generates host-side code, + // the kernel name has to resolve for the compiler to parse the + // call site. Wrapping the launch with the same guard means the + // call site simply isn't present in that pass. + // ---------------------------------------------------------------- + if constexpr (std::is_same::value + || std::is_same::value) { + using HalfT = std::conditional_t< + std::is_same::value, _Float16, __bf16>; + + // Runtime arch check — must be true to enter the WMMA path even + // when this TU was compiled with a gfx11 target somewhere in the + // arch list (multi-arch wheels). + bool wmma_supported_at_runtime = false; + { + const int device_id = ctranslate2::get_device_index(ctranslate2::Device::CUDA); + const auto& dprops = ctranslate2::cuda::get_device_properties(device_id); + const char* arch = dprops.gcnArchName; + // gcnArchName looks like "gfx1100" or "gfx1100:sramecc-:xnack-". + // WMMA is available on gfx11* and gfx12* (RDNA3 / RDNA4). + wmma_supported_at_runtime = + arch && arch[0]=='g' && arch[1]=='f' && arch[2]=='x' && arch[3]=='1' + && (arch[4]=='1' || arch[4]=='2'); + } + +#if defined(__gfx1100__) || defined(__gfx1101__) || defined(__gfx1102__) || defined(__gfx1103__) || !defined(__HIP_DEVICE_COMPILE__) + if (wmma_supported_at_runtime) { + auto launch_wmma_bm16 = [&](auto head_dim_const) -> bool { + constexpr int D = decltype(head_dim_const)::value; + constexpr int BM_W = 16; + if (head_dim != D) return false; + if (D % 16 != 0) return false; + if (seqlen_q < BM_W) return false; + dim3 grid((seqlen_q + BM_W - 1) / BM_W, num_heads, batch_size); + dim3 block(32); // one wave32 + hipLaunchKernelGGL((hip_flash_attn_wmma_fp16), + grid, block, 0, stream, + reinterpret_cast(q_ptr), + reinterpret_cast(k_ptr), + reinterpret_cast(v_ptr), + reinterpret_cast(o_ptr), + (int)seqlen_q, (int)seqlen_k, + (int)k_time_stride, (int)v_time_stride, + (int)num_heads, queries_scale, + is_causal, (int)offset); + return true; + }; + if (launch_wmma_bm16(std::integral_constant{})) return; + if (launch_wmma_bm16(std::integral_constant{})) return; + } +#else + (void)wmma_supported_at_runtime; +#endif + } + + // Fast path C: scalar tiled Flash Attention 2 kernel. + // Fallback for BF16 and head dimensions WMMA doesn't cover. + // Used when seqlen_q >= BM (one thread per query row). + // ---------------------------------------------------------------- + constexpr int BM = 64; + constexpr int BN = 64; + if (seqlen_q >= BM) { + auto launch_tiled = [&](auto head_dim_const) -> bool { + constexpr int D = decltype(head_dim_const)::value; + if (head_dim != D) return false; + dim3 grid((seqlen_q + BM - 1) / BM, num_heads, batch_size); + dim3 block(BM); + hipLaunchKernelGGL((hip_flash_attn_fwd_tiled), + grid, block, 0, stream, + q_ptr, k_ptr, v_ptr, o_ptr, + (int)seqlen_q, (int)seqlen_k, + (int)k_time_stride, (int)v_time_stride, + (int)num_heads, queries_scale, + is_causal, (int)offset); + return true; + }; + + if (launch_tiled(std::integral_constant{})) return; + if (launch_tiled(std::integral_constant{})) return; + if (launch_tiled(std::integral_constant{})) return; + } + + // ---------------------------------------------------------------- + // Fallback: three-pass reference implementation. Used when head_dim + // is not one of the specialised values above. Materialises the full + // [batch, nheads, seqlen_q, seqlen_k] score buffer in HBM. + // ---------------------------------------------------------------- + StorageView scores_buf({batch_size, num_heads, seqlen_q, seqlen_k}, + DataType::FLOAT32, Device::CUDA); + float* s_ptr = scores_buf.data(); + + // --- Pass 1: Q @ K^T --- + { + const int total = seqlen_q * seqlen_k; + const int block = 256; + dim3 grid((total + block - 1) / block, num_heads, batch_size); + hipLaunchKernelGGL(hip_attn_qk_kernel, grid, block, 0, stream, + q_ptr, k_ptr, s_ptr, + (int)seqlen_q, (int)seqlen_k, (int)k_time_stride, + (int)num_heads, (int)head_dim, + queries_scale); + } + + // --- Pass 2: row-wise softmax (with causal mask) --- + { + // The tree reduction inside hip_attn_softmax_kernel requires + // blockDim.x to be a power of 2. Round up min(seqlen_k, 256) to + // the next power of 2 (256 itself is already a power of 2, so the + // cap never breaks the invariant). Extra threads contribute the + // identity values (-1e9 / 0.0) and are harmless. + int block = 1; + while (block < (int)std::min(seqlen_k, (dim_t)256)) block <<= 1; + dim3 grid(seqlen_q, num_heads, batch_size); + hipLaunchKernelGGL(hip_attn_softmax_kernel, grid, block, + block * sizeof(float), // dynamic shared mem + stream, + s_ptr, (int)num_heads, (int)seqlen_q, (int)seqlen_k, + is_causal, (int)offset); + } + + // --- Pass 3: P @ V --- + { + const int block = min((int)head_dim, 1024); + dim3 grid(seqlen_q, num_heads, batch_size); + hipLaunchKernelGGL(hip_attn_ov_kernel, grid, block, 0, stream, + s_ptr, v_ptr, o_ptr, + (int)seqlen_k, (int)v_time_stride, + (int)num_heads, (int)head_dim); + } + } + + template <> + void FlashAttention::compute( + StorageView& queries, + StorageView& keys, + StorageView& values, + StorageView& output, + StorageView* cached_keys, + StorageView* cached_values, + StorageView* /*attention*/, + bool /*return_normalized_attention*/, + StorageView* /*rotary_cos*/, + StorageView* /*rotary_sin*/, + const bool /*rotary_interleave*/, + StorageView* /*alibi*/, + dim_t offset) const + { + const DataType dtype = queries.dtype(); + switch (dtype) { + case DataType::FLOAT16: + flash_attention_hip_impl( + queries, keys, values, output, + cached_keys, cached_values, + _queries_scale, _is_causal, offset); + break; + case DataType::BFLOAT16: + flash_attention_hip_impl( + queries, keys, values, output, + cached_keys, cached_values, + _queries_scale, _is_causal, offset); + break; + default: + throw std::invalid_argument( + "Flash Attention HIP only supports float16 and bfloat16 inputs."); + } + } + +#endif // CT2_USE_HIP / !CT2_USE_HIP + + } // namespace ops +} // namespace ctranslate2 diff --git a/src/ops/wmma_probe.cu b/src/ops/wmma_probe.cu new file mode 100644 index 000000000..fabc3f27a --- /dev/null +++ b/src/ops/wmma_probe.cu @@ -0,0 +1,136 @@ +// ----------------------------------------------------------------------------- +// WMMA probe kernel — verifies the wave32 RDNA3 fragment layout empirically. +// +// NOT built by default. This file is excluded from CMakeLists.txt's +// CUDA_SOURCES list (see the comment next to flash_attention_gpu.cu in +// CMakeLists.txt). It is kept in the source tree as a dev tool: when porting +// the WMMA Flash Attention kernels to a new gfx target, a new precision +// (e.g. bf16, fp8), or a new tile shape, run this probe first to discover +// the wave-level fragment layout the WMMA built-in produces — saves days of +// trial-and-error. +// +// To rebuild and run: +// 1) Uncomment `src/ops/wmma_probe.cu` in CMakeLists.txt's CUDA_SOURCES. +// 2) Re-cmake and rebuild ctranslate2.dll. +// 3) Run `python python/tests/benchmark_flash_attention.py` (or call +// ct2_wmma_probe_run via ctypes from any Python script). +// +// Strategy: feed a known A and B = identity into one 16x16x16 WMMA, then have +// every lane write its 8 fp32 accumulator elements to an output buffer indexed +// by lane and slot. Comparing the result on the host to a reference GEMM +// reveals which (row, col) of C each (lane, slot) pair corresponds to. +// +// Only compiled for HIP gfx11* targets. Plain HIP / Clang built-ins, no +// rocWMMA dependency. +// ----------------------------------------------------------------------------- +#ifdef CT2_USE_HIP + +#include + +namespace ctranslate2 { + namespace ops { + namespace wmma_probe { + + using v16f16 = _Float16 __attribute__((ext_vector_type(16))); + using v8f32 = float __attribute__((ext_vector_type(8))); + // __builtin_amdgcn_wmma_f32_16x16x16_f16_w32 is a Clang-internal + // __device__ built-in for gfx11. No forward declaration needed. + + // --------------------------------------------------------------------- + // Probe kernel. One block, one wave (32 threads). + // + // Inputs: + // A[16][16] row-major fp16 -- caller fills with A[i][j] = i*16 + j + // B[16][16] col-major fp16 -- caller fills with identity (B[i][j] = i==j) + // + // Output: + // out[lane*8 + slot] = the FP32 accumulator value held by `lane` at + // slot `slot` (0..7) after WMMA. + // + // With B = I, the WMMA result D = A * I = A, so out[lane*8 + slot] tells + // us which A[row][col] that (lane, slot) corresponds to. This is the + // ground-truth layout map we need to design real WMMA kernels. + // --------------------------------------------------------------------- + __global__ void wmma_probe_kernel(const _Float16* __restrict__ A, + const _Float16* __restrict__ B, + float* __restrict__ out) + { + const int lane = threadIdx.x; // 0 .. 31 + + // -- Load A row for this lane (rows 0..15; lanes 16..31 duplicate) -- + v16f16 a_frag; + const int row = lane % 16; + #pragma unroll + for (int j = 0; j < 16; ++j) + a_frag[j] = A[row * 16 + j]; + + // -- Load B column for this lane (cols 0..15; lanes 16..31 duplicate) -- + v16f16 b_frag; + const int col = lane % 16; + #pragma unroll + for (int i = 0; i < 16; ++i) + b_frag[i] = B[i * 16 + col]; + + // -- Zero the accumulator -- + v8f32 c_frag; + #pragma unroll + for (int s = 0; s < 8; ++s) c_frag[s] = 0.f; + + // -- The matmul -- + c_frag = __builtin_amdgcn_wmma_f32_16x16x16_f16_w32(a_frag, b_frag, c_frag); + + // -- Spill every lane's 8 accumulator slots so the host can inspect -- + #pragma unroll + for (int s = 0; s < 8; ++s) + out[lane * 8 + s] = c_frag[s]; + } + + } // namespace wmma_probe + } // namespace ops +} // namespace ctranslate2 + +// --------------------------------------------------------------------------- +// C-ABI driver — invoked via ctypes from a small Python script. +// Allocates a 16x16 fp16 A (A[i][j] = i*16 + j) and identity B, runs the +// probe kernel on the current device, and copies the per-lane accumulator +// dump into the caller-provided buffer. +// out must point to at least 32*8 floats of host memory. +// Returns 0 on success, a hipError_t code otherwise. +// --------------------------------------------------------------------------- +extern "C" __declspec(dllexport) +int ct2_wmma_probe_run(float* out_host) +{ + using ctranslate2::ops::wmma_probe::wmma_probe_kernel; + + const int N = 16 * 16; + _Float16 A_host[N]; + _Float16 B_host[N]; + for (int i = 0; i < 16; ++i) + for (int j = 0; j < 16; ++j) { + A_host[i * 16 + j] = static_cast<_Float16>(i * 16 + j); + B_host[i * 16 + j] = static_cast<_Float16>(i == j ? 1 : 0); + } + + _Float16 *A_d = nullptr, *B_d = nullptr; + float* out_d = nullptr; + hipError_t err; + + err = hipMalloc(&A_d, N * sizeof(_Float16)); if (err) return (int)err; + err = hipMalloc(&B_d, N * sizeof(_Float16)); if (err) return (int)err; + err = hipMalloc(&out_d, 32 * 8 * sizeof(float)); if (err) return (int)err; + + err = hipMemcpy(A_d, A_host, N * sizeof(_Float16), hipMemcpyHostToDevice); + if (err) return (int)err; + err = hipMemcpy(B_d, B_host, N * sizeof(_Float16), hipMemcpyHostToDevice); + if (err) return (int)err; + + hipLaunchKernelGGL(wmma_probe_kernel, dim3(1), dim3(32), 0, 0, A_d, B_d, out_d); + err = hipDeviceSynchronize(); + if (err) return (int)err; + + err = hipMemcpy(out_host, out_d, 32 * 8 * sizeof(float), hipMemcpyDeviceToHost); + hipFree(A_d); hipFree(B_d); hipFree(out_d); + return (int)err; +} + +#endif // CT2_USE_HIP