Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 9 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,15 @@

### New features

* Native HIP Flash Attention for AMD RDNA3+ GPUs. Adds four dispatched
kernels (WMMA via the wave32 16x16x16 fp16/bf16 built-in, scalar tiled,
decode-optimised for `seqlen_q == 1`, and a 3-pass correctness oracle)
plus the KV-cache write path for autoregressive decoding. Enabled with
`-DWITH_HIP=ON -DWITH_FLASH_ATTN=ON`. Encoder and decoder self-attention
use this path when `flash_attention=True` is passed at model load time.
~1.08-1.11x encoder speedup and ~1.05-1.09x `generate()` speedup on
Whisper-medium / RX 7900 XTX vs. the standard MultiHeadAttention path.

### Fixes and improvements

## [v4.7.1](https://github.com/OpenNMT/CTranslate2/releases/tag/v4.7.1) (2026-02-04)
Expand Down
13 changes: 12 additions & 1 deletion CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -238,6 +238,13 @@ set(CUDA_SOURCES
src/ops/conv1d_gpu.cu
src/ops/dequantize_gpu.cu
src/ops/flash_attention_gpu.cu
# src/ops/wmma_probe.cu — RDNA3 WMMA fragment-layout probe (dev tool).
# Builds a tiny kernel + C-ABI driver used to
# reverse-engineer the wave32 accumulator layout.
# Off by default to keep it out of the production
# DLL; uncomment to rebuild and rerun the layout
# probe (see python/tests/benchmark_flash_attention.py
# and the kernel's own header comment).
src/ops/gather_gpu.cu
src/ops/gumbel_max_gpu.cu
src/ops/layer_norm_gpu.cu
Expand Down Expand Up @@ -719,7 +726,11 @@ elseif(WITH_HIP)
src/ops/awq/dequantize_gpu.cu
)
if(WITH_FLASH_ATTN)
message(FATAL_ERROR "WITH_HIP=ON incompatible with WITH_FLASH_ATTN=ON")
# Native HIP Flash Attention implementation (three-pass: QK^T, softmax, PV).
# Does not require CUTLASS/CuTe; uses plain HIP kernels with FP32 accumulators.
# Supports FP16 and BF16 inputs; tested on gfx1100 (RDNA3 / RX 7900 XTX).
add_definitions(-DCT2_WITH_FLASH_ATTN)
message(STATUS "Building with HIP Flash Attention support")
endif()

set_source_files_properties(${CUDA_SOURCES} PROPERTIES LANGUAGE HIP)
Expand Down
299 changes: 299 additions & 0 deletions python/tests/benchmark_flash_attention.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,299 @@
"""Standalone benchmark script for the native HIP Flash Attention path.

Measures both speed and HBM footprint of `flash_attention=True` vs the
standard MultiHeadAttention oracle, on the encoder pass and on `generate`
at a few different max_length values.

This is *not* a pytest — it's run manually (`python benchmark_flash_attention.py`)
because the numbers are timing-sensitive and we don't want regression
flapping in CI. pytest tests for correctness live in test_flash_attention.py.

HBM is measured via the HIP runtime's `hipMemGetInfo`, called through
ctypes. On Windows the runtime DLL is `amdhip64_*.dll`; on Linux it's
`libamdhip64.so`.

Usage:
python benchmark_flash_attention.py [--model PATH] [--runs N]
"""

import argparse
import ctypes
import os
import sys
import time

import numpy as np

# ----------------------------------------------------------------------------
# Local-build DLL loader (mirrors test_flash_attention.py).
# ----------------------------------------------------------------------------
if sys.platform == "win32":
import site

for site_dir in site.getsitepackages() + [site.getusersitepackages()]:
for sub in ("_rocm_sdk_core/bin", "_rocm_sdk_libraries_custom/bin"):
cand = os.path.join(site_dir, *sub.split("/"))
if os.path.isdir(cand):
try:
os.add_dll_directory(cand)
except (FileNotFoundError, OSError):
pass

import ctranslate2 # noqa: E402


# ----------------------------------------------------------------------------
# HIP runtime memory query via ctypes.
# Returns (free_bytes, total_bytes) for the currently-active device.
# ----------------------------------------------------------------------------
def _load_hip_runtime():
if sys.platform == "win32":
# The DLL name on Windows has a major-version suffix that varies.
site_dir = next(
d
for d in (
os.path.join(s, "_rocm_sdk_core", "bin")
for s in (
__import__("site").getsitepackages()
+ [__import__("site").getusersitepackages()]
)
)
if os.path.isdir(d)
)
candidates = sorted(
f
for f in os.listdir(site_dir)
if f.startswith("amdhip64") and f.endswith(".dll")
)
if not candidates:
raise FileNotFoundError("amdhip64_*.dll not found in ROCm SDK bin")
return ctypes.CDLL(os.path.join(site_dir, candidates[-1]))
else:
return ctypes.CDLL("libamdhip64.so")


_hip = _load_hip_runtime()
_hip.hipMemGetInfo.argtypes = [
ctypes.POINTER(ctypes.c_size_t),
ctypes.POINTER(ctypes.c_size_t),
]
_hip.hipMemGetInfo.restype = ctypes.c_int
_hip.hipDeviceSynchronize.argtypes = []
_hip.hipDeviceSynchronize.restype = ctypes.c_int


def hbm_free_total():
"""(free_bytes, total_bytes) on the active HIP device."""
free = ctypes.c_size_t()
total = ctypes.c_size_t()
rc = _hip.hipMemGetInfo(ctypes.byref(free), ctypes.byref(total))
if rc != 0:
raise RuntimeError(f"hipMemGetInfo returned {rc}")
return free.value, total.value


def gpu_sync():
_hip.hipDeviceSynchronize()


# ----------------------------------------------------------------------------
# Bench helpers.
# ----------------------------------------------------------------------------
def _mel(seed=0):
rng = np.random.default_rng(seed)
return rng.standard_normal((1, 80, 3000)).astype(np.float32)


def _bench(fn, runs=20, warmup=3):
for _ in range(warmup):
fn()
gpu_sync()
t0 = time.perf_counter()
for _ in range(runs):
fn()
gpu_sync()
return (time.perf_counter() - t0) / runs * 1000 # ms per call


def measure_hbm_delta(loader_fn, work_fn):
"""Build a model with `loader_fn`, baseline its persistent HBM, run
`work_fn(model)` once (so the allocator pool grows), then measure the
additional HBM held. Returns (model_bytes, work_bytes_added)."""
gpu_sync()
free_empty, _ = hbm_free_total()

model = loader_fn()
gpu_sync()
free_after_load, _ = hbm_free_total()
model_bytes = free_empty - free_after_load

# Warm up enough to make the allocator pool reach its working-set peak.
for _ in range(3):
work_fn(model)
gpu_sync()
free_after_work, _ = hbm_free_total()
work_bytes = free_after_load - free_after_work
return model, model_bytes, work_bytes


def fmt_bytes(n):
n = abs(n)
if n < 1024 * 1024:
return f"{n/1024:.1f} KiB"
if n < 1024 * 1024 * 1024:
return f"{n/1024/1024:.1f} MiB"
return f"{n/1024/1024/1024:.2f} GiB"


# ----------------------------------------------------------------------------
# Main benchmark.
# ----------------------------------------------------------------------------
PROMPTS = [[50258, 50259, 50360]]


def main():
parser = argparse.ArgumentParser()
parser.add_argument(
"--model",
default=None,
help="Path to a converted faster-whisper-medium snapshot. "
"If omitted, the HuggingFace cache is searched.",
)
parser.add_argument("--runs", type=int, default=20)
parser.add_argument(
"--compute-type",
default="float16",
choices=["float16", "bfloat16"],
)
args = parser.parse_args()

if args.model is None:
snaps = os.path.expanduser(
"~/.cache/huggingface/hub/"
"models--Systran--faster-whisper-medium/snapshots"
)
if not os.path.isdir(snaps):
sys.exit("No --model given and faster-whisper-medium not cached.")
shas = [
d
for d in os.listdir(snaps)
if os.path.isfile(os.path.join(snaps, d, "model.bin"))
]
if not shas:
sys.exit("No converted snapshot found in HF cache.")
args.model = os.path.join(snaps, shas[0])

print(f"Model: {args.model}")
print(f"Compute type: {args.compute_type}, runs/case: {args.runs}\n")

mel = _mel(seed=0)

# ------------------------------------------------------------------
# HBM measurement: load both variants, measure persistent + working set.
# ------------------------------------------------------------------
print("=" * 70)
print("HBM footprint")
print("=" * 70)
for flash in (False, True):
label = "Flash=ON " if flash else "Flash=OFF"

def load():
return ctranslate2.models.Whisper(
args.model,
device="cuda",
compute_type=args.compute_type,
flash_attention=flash,
)

def work(m):
r = m.generate(
ctranslate2.StorageView.from_array(mel),
PROMPTS,
beam_size=1,
max_length=100,
)
return r

model, model_b, work_b = measure_hbm_delta(load, work)
print(
f" {label}: model = {fmt_bytes(model_b)}, "
f"working set (generate max_length=100) = +{fmt_bytes(work_b)}"
)
del model
gpu_sync()

# Theoretical Flash Attention savings on the encoder score buffer.
sq = sk = 1500
h = 16
saved_per_layer = sq * sk * h * 4
print()
print(
" Score-matrix that the standard path materialises and Flash "
"avoids, per encoder layer:"
)
print(f" {sq}x{sk}x{h} heads x FP32 = {fmt_bytes(saved_per_layer)} per layer")
print(
f" 24 layers => up to {fmt_bytes(24 * saved_per_layer)} of HBM "
f"traffic per encoder pass"
)
print()

# ------------------------------------------------------------------
# Performance: encoder-only, then generate() at several lengths.
# ------------------------------------------------------------------
print("=" * 70)
print("Performance")
print("=" * 70)

m_off = ctranslate2.models.Whisper(
args.model,
device="cuda",
compute_type=args.compute_type,
flash_attention=False,
)
m_on = ctranslate2.models.Whisper(
args.model,
device="cuda",
compute_type=args.compute_type,
flash_attention=True,
)

def enc_fn(m):
return lambda: m.encode(ctranslate2.StorageView.from_array(mel), to_cpu=False)

for batch in (1, 4):
mel_b = np.stack([mel[0]] * batch, axis=0)

def enc_b(m):
return lambda: m.encode(
ctranslate2.StorageView.from_array(mel_b), to_cpu=False
)

off_ms = _bench(enc_b(m_off), runs=args.runs)
on_ms = _bench(enc_b(m_on), runs=args.runs)
print(
f" encoder (B={batch}, Sq=Sk=1500): "
f"OFF={off_ms:6.2f} ms ON={on_ms:6.2f} ms speedup={off_ms/on_ms:.2f}x"
)

print()
for max_len in (30, 100, 200, 448):

def gen_b(m, ml=max_len):
return lambda: m.generate(
ctranslate2.StorageView.from_array(mel),
PROMPTS,
beam_size=1,
max_length=ml,
)

off_ms = _bench(gen_b(m_off), runs=max(args.runs // 4, 3))
on_ms = _bench(gen_b(m_on), runs=max(args.runs // 4, 3))
print(
f" generate (max_length={max_len:3d}): "
f"OFF={off_ms:7.1f} ms ON={on_ms:7.1f} ms speedup={off_ms/on_ms:.2f}x"
)


if __name__ == "__main__":
main()
Loading
Loading