diff --git a/fbgemm_gpu/experimental/ikbo/benchmarks/ikbo_fa_bench.py b/fbgemm_gpu/experimental/ikbo/benchmarks/ikbo_fa_bench.py new file mode 100644 index 0000000000..d046549a86 --- /dev/null +++ b/fbgemm_gpu/experimental/ikbo/benchmarks/ikbo_fa_bench.py @@ -0,0 +1,199 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import random +from functools import partial + +import torch +import triton +from ikbo.ops.tlx_ikbo_fa_ws import tlx_flash_attn_ikbo_tma_persistent +from ikbo.ops.triton_ikbo_fa import triton_flash_attn_ikbo_tma +from torch._inductor.utils import do_bench_using_profiling + +num_heads, n_seed, d_head = 2, 64, 128 +DEFAULT_B = 2048 +DEFAULT_CAND_TO_USER_RATIO = 64 +DEVICE = "cuda" +DTYPE = torch.float16 + +PROVIDERS = [ + "Inductor SDPA", + "Broadcast + inductor SDPA", + "Triton IKBO FA2", + "TLX IKBO FA3 persistence generalized", +] +PROVIDER_NAMES = [ + "Inductor SDPA", + "Broadcast + inductor SDPA", + "Triton IKBO FA2", + "TLX IKBO FA3 persistence generalized", +] + + +def pytorch_sdpa(query, key, value): + return torch.nn.functional.scaled_dot_product_attention( + query, key, value, attn_mask=None, dropout_p=0.0, is_causal=False + ) + + +def broadcast_sdpa( + query, key, value, cand_to_user_index, n_seed, num_heads, d_head, max_seq_len +): + # for accuracy check + query_sdpa = query.view(-1, n_seed, num_heads, d_head).permute(0, 2, 1, 3) + key_sdpa = key.view(-1, max_seq_len, num_heads, d_head) + key_sdpa_broadcast = torch.index_select( + key_sdpa, dim=0, index=cand_to_user_index + ).permute(0, 2, 1, 3) + value_sdpa = value.view(-1, max_seq_len, num_heads, d_head) + value_sdpa_broadcast = torch.index_select( + value_sdpa, dim=0, index=cand_to_user_index + ).permute(0, 2, 1, 3) + return pytorch_sdpa(query_sdpa, key_sdpa_broadcast, value_sdpa_broadcast) + + +def prepare_inputs_by_config( + B: int, + n_seed: int, + num_heads: int, + d_head: int, + max_seq_len: int, + low_num_cands_per_user: int = DEFAULT_CAND_TO_USER_RATIO, + high_num_cands_per_user: int = DEFAULT_CAND_TO_USER_RATIO, +): + def _generate_num_cands_per_user(): + res = [] + cum_sum = 0 + cand_grid = [] + while True: + # Odd and even number of candidates per user got even chance + cur = random.randint( + low_num_cands_per_user, high_num_cands_per_user + ) + random.randint(0, 1) + for grid in range(cum_sum, min(cum_sum + cur, B), 2): + cand_grid.append(grid) + if cum_sum + cur >= B: + res.append(B - cum_sum) + break + cum_sum += cur + res.append(cur) + return res, cand_grid + + res = _generate_num_cands_per_user() + num_cands_per_user_tensor = torch.tensor(res[0]) + cand_grid = torch.tensor(res[1], dtype=torch.int32, device=DEVICE) + + cand_to_user_index = torch.repeat_interleave( + torch.arange(num_cands_per_user_tensor.size(0)), + num_cands_per_user_tensor, + ).to(dtype=torch.int32, device=DEVICE) + Bu = num_cands_per_user_tensor.size(0) + + query = torch.randn((B * n_seed, num_heads, d_head), device=DEVICE, dtype=DTYPE) + key = torch.randn((Bu * max_seq_len, num_heads, d_head), device=DEVICE, dtype=DTYPE) + value = torch.randn( + (Bu * max_seq_len, num_heads, d_head), device=DEVICE, dtype=DTYPE + ) + return ( + query, + key, + value, + cand_to_user_index, + cand_grid, + ) + + +def _run_provider(provider, seq_len): + torch.manual_seed(0) + q, k, v, cand_to_user_index, cand_grid = prepare_inputs_by_config( + B=DEFAULT_B, + n_seed=n_seed, + num_heads=num_heads, + d_head=d_head, + max_seq_len=seq_len, + low_num_cands_per_user=DEFAULT_CAND_TO_USER_RATIO, + high_num_cands_per_user=DEFAULT_CAND_TO_USER_RATIO, + ) + q_sdpa = q.view(-1, n_seed, num_heads, d_head).permute(0, 2, 1, 3) + k_sdpa = k.view(-1, seq_len, num_heads, d_head) + k_broadcast = torch.index_select(k_sdpa, dim=0, index=cand_to_user_index).permute( + 0, 2, 1, 3 + ) + v_sdpa = v.view(-1, seq_len, num_heads, d_head) + v_broadcast = torch.index_select(v_sdpa, dim=0, index=cand_to_user_index).permute( + 0, 2, 1, 3 + ) + + def flops(ms): + return (DEFAULT_B * num_heads * n_seed * d_head * seq_len * 4) / ms * 1e-9 + + if provider == "Inductor SDPA": + eager_fn = partial(pytorch_sdpa, q_sdpa, k_broadcast, v_broadcast) + fn = torch.compile( + eager_fn, + backend="inductor", + options={"max_autotune": True}, + ) + elif provider == "Broadcast + inductor SDPA": + eager_fn = partial( + broadcast_sdpa, + q, + k, + v, + cand_to_user_index, + n_seed, + num_heads, + d_head, + seq_len, + ) + fn = torch.compile( + eager_fn, + backend="inductor", + options={"max_autotune": True}, + ) + elif provider == "Triton IKBO FA2": + fn = partial( + triton_flash_attn_ikbo_tma, q, k, v, cand_to_user_index, n_seed, seq_len + ) + elif provider == "TLX IKBO FA3 persistence generalized": + fn = partial( + tlx_flash_attn_ikbo_tma_persistent, + q, + k, + v, + cand_to_user_index, + n_seed, + seq_len, + cand_grid, + ) + else: + return 100 + + return flops(do_bench_using_profiling(fn)) + + +@triton.testing.perf_report( + triton.testing.Benchmark( + x_names=["seq_len"], + x_vals=[512, 1024, 2048, 4096, 8192, 16384], + line_arg="provider", + line_vals=PROVIDERS, + line_names=PROVIDER_NAMES, + ylabel="Latency (ms)", + plot_name="IKBO FA latency - Sequence Length", + args={}, + ) +) +def benchmark_vary_seq(seq_len, provider): + return _run_provider(provider, seq_len) + + +def main(): + benchmark_vary_seq.run(show_plots=False, print_data=True) + + +if __name__ == "__main__": + main() diff --git a/fbgemm_gpu/experimental/ikbo/benchmarks/ikbo_lce_bench.py b/fbgemm_gpu/experimental/ikbo/benchmarks/ikbo_lce_bench.py index 0051fcb399..48c0a43405 100644 --- a/fbgemm_gpu/experimental/ikbo/benchmarks/ikbo_lce_bench.py +++ b/fbgemm_gpu/experimental/ikbo/benchmarks/ikbo_lce_bench.py @@ -18,7 +18,7 @@ DTYPE = torch.float16 PAD_UNIT = 8 # for fp16/bf16 -# Representative realistic dimensions. +# Representative dimensions. # M is non-round because torch.compile fuses multiple LCE modules (with output # sizes like 128, 64, 32, ...) into one batched matmul; M is their sum. M, N, K_USER, K_CAND = 433, 256, 1178, 866 diff --git a/fbgemm_gpu/experimental/ikbo/ikbo/ops/tlx_ikbo_fa_ws.py b/fbgemm_gpu/experimental/ikbo/ikbo/ops/tlx_ikbo_fa_ws.py new file mode 100644 index 0000000000..c60f5e8268 --- /dev/null +++ b/fbgemm_gpu/experimental/ikbo/ikbo/ops/tlx_ikbo_fa_ws.py @@ -0,0 +1,600 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import math +from typing import Optional + +import torch +import triton + +# proton related +import triton.language as tl + +try: + import triton.language.extra.tlx as tlx +except ImportError: + print("TLX not found!") + +from triton.tools.tensor_descriptor import TensorDescriptor + + +def _host_descriptor_pre_hook_tlx_persistent(nargs): + BLOCK_M = nargs["BLOCK_M"] + BLOCK_N = nargs["BLOCK_N"] + if nargs.get("desc_q", None) is None or not isinstance( + nargs["desc_q"], TensorDescriptor + ): + return + BLOCK_D = nargs["BLOCK_D"] + NUM_MMA_GROUPS = nargs["NUM_MMA_GROUPS"] + + BLOCK_M_SPLIT = BLOCK_M // NUM_MMA_GROUPS + nargs["desc_q"].block_shape = [BLOCK_M_SPLIT, BLOCK_D] + nargs["desc_o"].block_shape = [BLOCK_M_SPLIT, BLOCK_D] + nargs["desc_v"].block_shape = [BLOCK_N, BLOCK_D] + nargs["desc_k"].block_shape = [BLOCK_N, BLOCK_D] + + +configs_tlx_persistent = [ + triton.Config( + { + "BLOCK_M": 128, + "BLOCK_N": 128, + "NUM_BUFFERS": 2, + "NUM_MMA_WARPS": 8, + "NUM_MMA_GROUPS": 2, + }, + num_warps=4, + num_stages=0, + pre_hook=_host_descriptor_pre_hook_tlx_persistent, + ), +] + + +@triton.jit +def _get_bufidx_phase(accum_cnt, NUM_BUFFERS): + bufIdx = accum_cnt % NUM_BUFFERS + phase = (accum_cnt // NUM_BUFFERS) & 1 + return bufIdx, phase + + +@triton.autotune( + configs=configs_tlx_persistent, + key=["d_model", "q_seq_len", "H"], +) +@triton.jit # pragma: no cover +def _attn_fwd_tlx_tma_pipeline_persistent_general( + desc_q, + desc_k, + desc_v, + desc_o, + cand_to_user_mapping, + cand_grid, + q_stride0, + q_stride1, + q_stride2, + k_stride0, + k_stride1, + k_stride2, + v_stride0, + v_stride1, + v_stride2, + o_stride0, + o_stride1, + o_stride2, + q_seq_len, + max_seq_len, + sm_scale, + H, + cand_batch_launch_kernel_instance, + NUM_SMS, + num_cand, + BLOCK_M: tl.constexpr, + BLOCK_N: tl.constexpr, + BLOCK_D: tl.constexpr, + NUM_MMA_WARPS: tl.constexpr, + NUM_MMA_GROUPS: tl.constexpr, + NUM_BUFFERS: tl.constexpr, +): + """ + Kernel for computing the attention: output = softmax(Q * K.T * sm_scale) * V + """ + BLOCK_M_SPLIT: tl.constexpr = BLOCK_M // NUM_MMA_GROUPS + start_pid = tl.program_id(0) + + num_tiles = ( + (q_seq_len + BLOCK_M_SPLIT - 1) + // BLOCK_M_SPLIT + * cand_batch_launch_kernel_instance + * H + ) + + tiles_per_SM = num_tiles // NUM_SMS + if start_pid < num_tiles % NUM_SMS: + tiles_per_SM += 1 + + num_pid_B_seed = ( + (q_seq_len + BLOCK_M_SPLIT - 1) + // BLOCK_M_SPLIT + * cand_batch_launch_kernel_instance + ) + + num_seq = (q_seq_len + BLOCK_M_SPLIT - 1) // BLOCK_M_SPLIT + # allocate buffers + q_tiles = tlx.local_alloc( + (BLOCK_M_SPLIT, BLOCK_D), tlx.dtype_of(desc_q), NUM_MMA_GROUPS + ) + k_tiles = tlx.local_alloc((BLOCK_N, BLOCK_D), tlx.dtype_of(desc_k), NUM_BUFFERS) + v_tiles = tlx.local_alloc((BLOCK_N, BLOCK_D), tlx.dtype_of(desc_v), NUM_BUFFERS) + o_tiles = tlx.local_alloc( + (BLOCK_M_SPLIT, BLOCK_D), tlx.dtype_of(desc_o), NUM_MMA_GROUPS + ) + # allocate mbarriers + q_empties = tlx.alloc_barriers(num_barriers=NUM_MMA_GROUPS, arrive_count=1) + q_fulls = tlx.alloc_barriers(num_barriers=NUM_MMA_GROUPS, arrive_count=1) + k_empties = tlx.alloc_barriers( + num_barriers=NUM_BUFFERS, arrive_count=NUM_MMA_GROUPS + ) + k_fulls = tlx.alloc_barriers(num_barriers=NUM_BUFFERS, arrive_count=1) + v_empties = tlx.alloc_barriers( + num_barriers=NUM_BUFFERS, arrive_count=NUM_MMA_GROUPS + ) + v_fulls = tlx.alloc_barriers(num_barriers=NUM_BUFFERS, arrive_count=1) + + with tlx.async_tasks(): + # == producer group == # + with tlx.async_task("default"): + # initialize offsets + q0_cnt = 0 + q1_cnt = 1 + kv_cnt = 0 + for i in tl.range(tiles_per_SM): + # pid needs special taken care of, B=0, B=1 form 2MMA, q_seq grid.x, batch size grid.y, head grid.z + + pid = start_pid + i * NUM_SMS + pid_q_cand_batch = pid % num_pid_B_seed + + pid_q_seq = pid_q_cand_batch % (num_seq) + # dummy refers to the dummy index given even number of candidates ranked by users + pid_cand_batch_dummy = pid_q_cand_batch // num_seq + pid_head = pid // num_pid_B_seed + odd_pid_cand = False + pid_cand_batch = tl.load(cand_grid + pid_cand_batch_dummy) + pid_user_batch = tl.load(cand_to_user_mapping + pid_cand_batch) + + if pid_cand_batch + 1 >= num_cand: + odd_pid_cand = True + else: + pid_user_batch2 = tl.load(cand_to_user_mapping + pid_cand_batch + 1) + + # odd number batch ads per user + if pid_user_batch2 != pid_user_batch: + odd_pid_cand = True + + seq_start_kv = pid_user_batch * max_seq_len + seq_end_kv = seq_start_kv + max_seq_len + qo_seq_offset = pid_cand_batch * q_seq_len + pid_q_seq * BLOCK_M_SPLIT + + # load q0, k0, then q1, v0, k1, v1, and etc to help loading pipelining + # q0 + q0_buf_id, q0_phase = _get_bufidx_phase(q0_cnt, NUM_MMA_GROUPS) + tlx.barrier_wait(q_empties[q0_buf_id], q0_phase ^ 1) + tlx.barrier_expect_bytes( + q_fulls[q0_buf_id], 2 * BLOCK_M_SPLIT * BLOCK_D + ) # float16 + tlx.async_descriptor_load( + desc_q, + q_tiles[q0_buf_id], + [qo_seq_offset.to(tl.int32), pid_head * q_stride1], + q_fulls[q0_buf_id], + ) + q0_cnt += NUM_MMA_GROUPS + + # k0 + kv_buf_id, kv_phase = _get_bufidx_phase(kv_cnt, NUM_BUFFERS) + tlx.barrier_wait(k_empties[kv_buf_id], kv_phase ^ 1) + # load K + tlx.barrier_expect_bytes( + k_fulls[kv_buf_id], 2 * BLOCK_N * BLOCK_D + ) # float16 + tlx.async_descriptor_load( + desc_k, + k_tiles[kv_buf_id], + [seq_start_kv.to(tl.int32), pid_head * k_stride1], + k_fulls[kv_buf_id], + ) + + # q1 + if not odd_pid_cand: + q1_buf_id, q1_phase = _get_bufidx_phase(q1_cnt, NUM_MMA_GROUPS) + tlx.barrier_wait(q_empties[q1_buf_id], q1_phase ^ 1) + tlx.barrier_expect_bytes( + q_fulls[q1_buf_id], 2 * BLOCK_M_SPLIT * BLOCK_D + ) # float16 + qo_offset_split = ( + qo_seq_offset + q1_buf_id * BLOCK_M_SPLIT + ) # get another batch_ads + tlx.async_descriptor_load( + desc_q, + q_tiles[q1_buf_id], + [qo_offset_split.to(tl.int32), pid_head * q_stride1], + q_fulls[q1_buf_id], + ) + + q1_cnt += NUM_MMA_GROUPS + + kv_cnt_start = kv_cnt + + for start_n in tl.range(seq_start_kv + BLOCK_N, seq_end_kv, BLOCK_N): + # k1 + kv_buf_id, kv_phase = _get_bufidx_phase(kv_cnt + 1, NUM_BUFFERS) + # wait for the K buffer to be released by the consumer + tlx.barrier_wait(k_empties[kv_buf_id], kv_phase ^ 1) + tlx.barrier_expect_bytes( + k_fulls[kv_buf_id], 2 * BLOCK_N * BLOCK_D + ) # float16 + tlx.async_descriptor_load( + desc_k, + k_tiles[kv_buf_id], + [start_n.to(tl.int32), pid_head * k_stride1], + k_fulls[kv_buf_id], + ) + + # v0 + kv_buf_id, kv_phase = _get_bufidx_phase(kv_cnt, NUM_BUFFERS) + # wait for the V buffer to be released by the consumer + tlx.barrier_wait(v_empties[kv_buf_id], kv_phase ^ 1) + # load V + tlx.barrier_expect_bytes( + v_fulls[kv_buf_id], 2 * BLOCK_N * BLOCK_D + ) # float16 + tlx.async_descriptor_load( + desc_v, + v_tiles[kv_buf_id], + [(start_n - BLOCK_N).to(tl.int32), pid_head * v_stride1], + v_fulls[kv_buf_id], + ) + kv_cnt += 1 + + start_n = (kv_cnt - kv_cnt_start) * BLOCK_N + seq_start_kv + kv_buf_id, kv_phase = _get_bufidx_phase(kv_cnt, NUM_BUFFERS) + # wait for the V buffer to be released by the consumer + tlx.barrier_wait(v_empties[kv_buf_id], kv_phase ^ 1) + # load V + tlx.barrier_expect_bytes( + v_fulls[kv_buf_id], 2 * BLOCK_N * BLOCK_D + ) # float16 + tlx.async_descriptor_load( + desc_v, + v_tiles[kv_buf_id], + [start_n.to(tl.int32), pid_head * v_stride1], + v_fulls[kv_buf_id], + ) + kv_cnt += 1 + + # == consumer group == # + with tlx.async_task( + num_warps=NUM_MMA_WARPS // NUM_MMA_GROUPS, + registers=232, + replicate=NUM_MMA_GROUPS, + ): + q_cnt = 0 + kv_cnt = 0 + + cid = tlx.async_task_replica_id() + + if cid == 1: + tlx.named_barrier_arrive(9, 256) + for i in tl.range(tiles_per_SM): + kv_cnt_start = kv_cnt + # pid needs special taken care of, B=0, B=1 form 2MMA, q_seq grid.x, batch size grid.y, head grid.z + pid = start_pid + i * NUM_SMS + pid_q_cand_batch = pid % num_pid_B_seed + + pid_q_seq = pid_q_cand_batch % (num_seq) + # dummy refers to the dummy index given even number of candidates ranked by users + pid_cand_batch_dummy = pid_q_cand_batch // num_seq + pid_head = pid // num_pid_B_seed + odd_pid_cand = False + pid_cand_batch = tl.load(cand_grid + pid_cand_batch_dummy) + pid_user_batch = tl.load(cand_to_user_mapping + pid_cand_batch) + + if pid_cand_batch + 1 >= num_cand: + odd_pid_cand = True + else: + pid_user_batch2 = tl.load(cand_to_user_mapping + pid_cand_batch + 1) + + # odd number batch ads per user + if pid_user_batch2 != pid_user_batch: + # Skip this 2nd consumer warpgroup when odd number of ads + odd_pid_cand = True + + seq_start_kv = pid_user_batch * max_seq_len + seq_end_kv = seq_start_kv + max_seq_len + + # Skip 2nd consumer warpgroup when odd number of ads + if not (odd_pid_cand and cid == 1): + m_i = tl.zeros([BLOCK_M_SPLIT], dtype=tl.float32) - float("inf") + l_i = tl.zeros([BLOCK_M_SPLIT], dtype=tl.float32) + 1.0 + acc = tl.zeros([BLOCK_M_SPLIT, BLOCK_D], dtype=tl.float32) + + # QKT prefetch + q_buf_id, q_phase = _get_bufidx_phase(q_cnt, NUM_MMA_GROUPS) + tlx.barrier_wait(q_fulls[cid], q_phase) + q_cnt += NUM_MMA_GROUPS + offset_seq_n = tl.arange(0, BLOCK_N) + + k_buf_id, k_phase = _get_bufidx_phase(kv_cnt, NUM_BUFFERS) + tlx.barrier_wait(k_fulls[k_buf_id], k_phase) + k_tile = tlx.local_trans(k_tiles[k_buf_id]) + if not odd_pid_cand: + if cid == 0: + # Consumer 0 waits for Consumer 1 to be ready (prevents both issuing simultaneously) + tlx.named_barrier_wait(9, 256) + else: + # Consumer 1 waits for Consumer 0 to finish its async_dot + tlx.named_barrier_wait(10, 256) + + qk = tlx.async_dot(q_tiles[cid], k_tile) + if not odd_pid_cand: + if cid == 0: + # Consumer 0 done, signal Consumer 1 to proceed + tlx.named_barrier_arrive(10, 256) + else: + # Consumer 1 done, signal Consumer 0 for next iteration + tlx.named_barrier_arrive(9, 256) + + # wait for the MMA using to complete + qk = tlx.async_dot_wait(0, qk) + tlx.barrier_arrive(k_empties[k_buf_id], 1) + if seq_end_kv - seq_start_kv < BLOCK_N: + mask_seq = offset_seq_n[None, :] < seq_end_kv - seq_start_kv + qk = tl.where(mask_seq, qk, -1.0e10) + if seq_end_kv - seq_start_kv <= BLOCK_N: + tlx.barrier_arrive(q_empties[cid], 1) + + # -- compute m_i and l_i for prefetch with optimization ---- + m_i = tl.max(qk, 1) * sm_scale + qk = qk * sm_scale - m_i[:, None] + p = tl.math.exp2(qk) + l_i = tl.sum(p, 1) + kv_cnt += 1 + + # == K loop == + for start_n in tl.range( + seq_start_kv + BLOCK_N, seq_end_kv, BLOCK_N + ): + k_buf_id, k_phase = _get_bufidx_phase(kv_cnt, NUM_BUFFERS) + # wait for the K buffer to be populated by the producer + tlx.barrier_wait(k_fulls[k_buf_id], k_phase) + + k_tile = tlx.local_trans(k_tiles[k_buf_id]) + if not odd_pid_cand: + if cid == 0: + # Consumer 0 waits for Consumer 1 to be ready (prevents both issuing simultaneously) + tlx.named_barrier_wait(9, 256) + else: + # Consumer 1 waits for Consumer 0 to finish its async_dot + tlx.named_barrier_wait(10, 256) + qk = tlx.async_dot(q_tiles[cid], k_tile) + if not odd_pid_cand: + if cid == 0: + # Consumer 0 done, signal Consumer 1 to proceed + tlx.named_barrier_arrive(10, 256) + else: + # Consumer 1 done, signal Consumer 0 for next iteration + tlx.named_barrier_arrive(9, 256) + + # compute pv from the previous iteration + # wait for the previous V buffer to be populated by the producer + v_buf_id, v_phase = _get_bufidx_phase((kv_cnt - 1), NUM_BUFFERS) + tlx.barrier_wait(v_fulls[v_buf_id], v_phase) + + # prepare p and v for the dot + p = p.to(tlx.dtype_of(desc_k)) + acc = tlx.async_dot(p, v_tiles[v_buf_id], acc) + + # wait for the current qk MMA to complete + qk = tlx.async_dot_wait(1, qk) + if start_n + BLOCK_N >= seq_end_kv: + # release the Q buffer when the last QK is finished + tlx.barrier_arrive(q_empties[cid], 1) + if start_n + BLOCK_N > seq_end_kv: + # masking logic + mask_seq = offset_seq_n[None, :] < seq_end_kv - start_n + qk = tl.where(mask_seq, qk, -1.0e10) + + # release the K buffer + tlx.barrier_arrive(k_empties[k_buf_id], 1) + + # -- compute m_i and l_i ---- + m_ij = tl.maximum(m_i, tl.max(qk, 1) * sm_scale) + qk = qk * sm_scale - m_ij[:, None] + p = tl.math.exp2(qk) + # -- compute correction factor + alpha = tl.math.exp2(m_i - m_ij) + l_ij = tl.sum(p, 1) + # update m_i and l_i + l_i = l_i * alpha + l_ij + m_i = m_ij + + # -- update output accumulator -- + # wait for the previous pv MMA to complete + acc = tlx.async_dot_wait(0, acc) + # release the V buffer + tlx.barrier_arrive(v_empties[v_buf_id], 1) + acc = acc * alpha[:, None] + kv_cnt += 1 + + # == epilogue == + # compute pv from the last iteration + # wait for the V buffer to be populated by the producer + v_buf_id, v_phase = _get_bufidx_phase((kv_cnt - 1), NUM_BUFFERS) + tlx.barrier_wait(v_fulls[v_buf_id], v_phase) + # prepare p and v for the dot + p = p.to(tlx.dtype_of(desc_k)) + acc = tlx.async_dot(p, v_tiles[v_buf_id], acc) + + # Overlap reciprocal operation (CUDA core) of li with the epilogue MMA + rcp_l_i = 1.0 / l_i + + # wait for the MMA using to complete + acc = tlx.async_dot_wait(0, acc) + # release the V buffer + tlx.barrier_arrive(v_empties[v_buf_id], 1) + + qo_seq_offset = ( + pid_cand_batch * q_seq_len + pid_q_seq * BLOCK_M_SPLIT + ) + qo_offset_split = ( + qo_seq_offset + cid * BLOCK_M_SPLIT + ) # get another batch_ads + + # replace acc/li by decouple it with reciprocal and multiply, multiply is faster than divide + acc = acc * rcp_l_i[:, None] + + # == store output (async) == + output = acc.to(tlx.dtype_of(desc_o)) + tlx.async_descriptor_store_wait(0) + tlx.local_store(o_tiles[cid], output) + tlx.fence_async_shared() + tlx.async_descriptor_store( + desc_o, + o_tiles[cid], + [qo_offset_split.to(tl.int32), pid_head * o_stride1], + ) + + # Always advance counters to stay in sync with producer, + # even when skipping the computation for consumer 2. + # kv buffers are shared between both warp groups so phase + # must be tracked even when one group is idle. + if odd_pid_cand and cid == 1: + # drain first K tile (producer loads k0 before the inner loop) + k_buf_id, k_phase = _get_bufidx_phase(kv_cnt, NUM_BUFFERS) + tlx.barrier_wait(k_fulls[k_buf_id], k_phase) + tlx.barrier_arrive(k_empties[k_buf_id], 1) + kv_cnt += 1 + # drain inner-loop K/V tiles + for _ in tl.range(seq_start_kv + BLOCK_N, seq_end_kv, BLOCK_N): + # K tile (next buffer) + k_buf_id, k_phase = _get_bufidx_phase(kv_cnt, NUM_BUFFERS) + tlx.barrier_wait(k_fulls[k_buf_id], k_phase) + tlx.barrier_arrive(k_empties[k_buf_id], 1) + # V tile (current buffer) + v_buf_id, v_phase = _get_bufidx_phase(kv_cnt - 1, NUM_BUFFERS) + tlx.barrier_wait(v_fulls[v_buf_id], v_phase) + tlx.barrier_arrive(v_empties[v_buf_id], 1) + kv_cnt += 1 + + # drain last V tile + v_buf_id, v_phase = _get_bufidx_phase(kv_cnt - 1, NUM_BUFFERS) + tlx.barrier_wait(v_fulls[v_buf_id], v_phase) + tlx.barrier_arrive(v_empties[v_buf_id], 1) + + +def tlx_flash_attn_ikbo_tma_persistent( + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + cand_to_user_mapping: torch.Tensor, + q_seq_len: int, + max_seq_len: int, + cand_grid: torch.Tensor, + scale: Optional[float] = None, +) -> torch.Tensor: + """ + Ba: candidate batch size, Bu: user batch, H: num heads, D: head dim + query: [Ba * n_seeds, H, D] Dense tensor + key: [Bu * max_seq_len, H, D] Dense tensor (similar to jagged tensor expression, jagged tensor is for variable seq length) + value: [Bu * max_seq_len, H, D] Dense tensor + max_seq_len: int + cand_to_user_mapping: [Ba] tensor [0, 0, ..., 1, 1, ..., 2, 2, ...] index: cand batch id, value: user batch id + cand_grid: a tensor to tell how many q_iterations need to be launched considering odd number of candidates ranked by users + scale: float + output: [Ba * n_seeds, H, D] Dense tensor + """ + + sm_scale = scale + B_seed, H, d_head = query.shape + Bu_max_seq_len, _, _ = key.shape + if sm_scale is None: + sm_scale = 1.0 / math.sqrt(d_head) + + sm_scale = sm_scale / math.log(2.0) + BLOCK_D = triton.next_power_of_2(d_head) + + output = torch.empty_like( + query, + ) + dummy_block = [1, 1] + desc_q = TensorDescriptor( + query, + shape=[B_seed, H * d_head], + strides=[H * d_head, 1], + block_shape=dummy_block, + ) + desc_v = TensorDescriptor( + value, + shape=[Bu_max_seq_len, H * d_head], + strides=[H * d_head, 1], + block_shape=dummy_block, + ) + desc_k = TensorDescriptor( + key, + shape=[Bu_max_seq_len, H * d_head], + strides=[H * d_head, 1], + block_shape=dummy_block, + ) + desc_o = TensorDescriptor( + output, + shape=[B_seed, H * d_head], + strides=[H * d_head, 1], + block_shape=dummy_block, + ) + NUM_SMS = torch.cuda.get_device_properties("cuda").multi_processor_count + + cand_batch_launch_kernel_instance = cand_grid.shape[0] + + def grid(META): + return ( + min( + NUM_SMS, + triton.cdiv(q_seq_len, META["BLOCK_M"] // META["NUM_MMA_GROUPS"]) + * cand_batch_launch_kernel_instance + * H, + ), + ) + + _attn_fwd_tlx_tma_pipeline_persistent_general[grid]( + desc_q, + desc_k, + desc_v, + desc_o, + cand_to_user_mapping, + cand_grid, + query.stride(0), + query.stride(1), + query.stride(2), # + key.stride(0), + key.stride(1), + key.stride(2), # + value.stride(0), + value.stride(1), + value.stride(2), # + output.stride(0), + output.stride(1), + output.stride(2), # + q_seq_len, + max_seq_len, + sm_scale, + H, + cand_batch_launch_kernel_instance, + NUM_SMS=NUM_SMS, + num_cand=B_seed // q_seq_len, + BLOCK_D=BLOCK_D, + ) + + return output diff --git a/fbgemm_gpu/experimental/ikbo/ikbo/ops/triton_ikbo_fa.py b/fbgemm_gpu/experimental/ikbo/ikbo/ops/triton_ikbo_fa.py new file mode 100644 index 0000000000..a21b2059ce --- /dev/null +++ b/fbgemm_gpu/experimental/ikbo/ikbo/ops/triton_ikbo_fa.py @@ -0,0 +1,263 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import math +from typing import Optional + +import torch +import triton +import triton.language as tl +from triton.tools.tensor_descriptor import TensorDescriptor + + +def _host_descriptor_pre_hook(nargs): + BLOCK_M = nargs["BLOCK_M"] + BLOCK_N = nargs["BLOCK_N"] + if nargs.get("desc_q", None) is None or not isinstance( + nargs["desc_q"], TensorDescriptor + ): + return + BLOCK_D = nargs["BLOCK_D"] + nargs["desc_q"].block_shape = [BLOCK_M, BLOCK_D] + nargs["desc_v"].block_shape = [BLOCK_N, BLOCK_D] + nargs["desc_k"].block_shape = [BLOCK_N, BLOCK_D] + + +def _get_fw_configs(): + configs = [ + triton.Config( + { + "BLOCK_M": bm, + "BLOCK_N": bn, + }, + num_stages=ns, + num_warps=nw, + pre_hook=_host_descriptor_pre_hook, + ) + for bm in [16, 32, 64, 128] + for bn in [16, 32, 64, 128] + for nw in [4, 8, 16, 32] + for ns in [1, 2, 3, 4, 5] + ] + return configs + + +@triton.jit # pragma: no cover +def _attn_fwd_inner_tma( + output, + acc, + l_i, + m_i, + q, + desc_k, + desc_v, + pid_head, + k_stride1, + v_stride1, + seq_start_kv, + seq_end_kv, + qk_scale, + allow_tf32, + BLOCK_N: tl.constexpr, +): + offset_seq_n = tl.arange(0, BLOCK_N) + for start_n in tl.range(0, seq_end_kv - seq_start_kv, BLOCK_N): + start_n = tl.multiple_of(start_n, BLOCK_N) + k = desc_k.load([(start_n + seq_start_kv).to(tl.int32), pid_head * k_stride1]) + # for jaggedness and S is not a multiplier of BLOCK_N which affect softmax + # mask_seq = offset_seq_n[None, :] + start_n - seq_start_kv < max_seq_len + mask_seq = offset_seq_n[None, :] < seq_end_kv - start_n - seq_start_kv + qk = tl.dot(q, tl.trans(k), allow_tf32=allow_tf32) + qk = tl.where(mask_seq, qk, -1.0e10) + m_ij = tl.maximum(m_i, tl.max(qk, 1) * qk_scale) + qk = qk * qk_scale - m_ij[:, None] + p = tl.math.exp2(qk) + # p = tl.math.exp(qk) + alpha = tl.math.exp2(m_i - m_ij) + # alpha = tl.math.exp(m_i - m_ij) + l_ij = tl.sum(p, 1) + acc = acc * alpha[:, None] + v = desc_v.load([(start_n + seq_start_kv).to(tl.int32), pid_head * v_stride1]) + p = p.to(output.dtype.element_ty) + acc = tl.dot(p, v, acc, allow_tf32=allow_tf32) + l_i = l_i * alpha + l_ij + m_i = m_ij + return acc, l_i, m_i + + +@triton.autotune( + configs=_get_fw_configs(), + key=["d_model", "q_seq_len"], +) +@triton.jit # pragma: no cover +def _attn_fwd_tma( + desc_q, + desc_k, + desc_v, + output, + cand_to_user_mapping, + q_stride0, + q_stride1, + q_stride2, + k_stride0, + k_stride1, + k_stride2, + v_stride0, + v_stride1, + v_stride2, + o_stride0, + o_stride1, + o_stride2, + q_seq_len, + max_seq_len, + sm_scale, + d_head, + allow_tf32, + BLOCK_M: tl.constexpr, + BLOCK_N: tl.constexpr, + BLOCK_D: tl.constexpr, +): + """ + Kernel for computing the attention: output = softmax(Q * K.T * sm_scale) * V + """ + # map the slow index of query to the fast index of key + # this is the index of the first element of the row + # in the jagged tensor + pid_q_seq = tl.program_id(0) # Sequence of Q/BLOCK_M + pid_cand_batch = tl.program_id( + 1 + ) # Batch ads, launch prio higher than head due to Bu can be shared + pid_head = tl.program_id(2) # head + pid_user_batch = tl.load(cand_to_user_mapping + pid_cand_batch) + + seq_start_kv = pid_user_batch * max_seq_len + + O_block_ptr = tl.make_block_ptr( + base=output + pid_head * o_stride1 + pid_cand_batch * q_seq_len * o_stride0, + shape=(q_seq_len, d_head), + strides=(o_stride0, o_stride2), + offsets=(pid_q_seq * BLOCK_M, 0), + block_shape=(BLOCK_M, BLOCK_D), + order=(1, 0), + ) + # maximum value of the qkT to avoid float overflow + m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf") + # sum of the exp(qkT - m_i) to calculate the softmax demoniator + l_i = tl.zeros([BLOCK_M], dtype=tl.float32) + 1.0 + acc = tl.zeros([BLOCK_M, BLOCK_D], dtype=tl.float32) + qk_scale = sm_scale + qo_seq_offset = pid_cand_batch * q_seq_len + pid_q_seq * BLOCK_M + q = desc_q.load([qo_seq_offset.to(tl.int32), pid_head * q_stride1]) + + acc, l_i, m_i = _attn_fwd_inner_tma( + output, + acc, + l_i, + m_i, + q, + desc_k, + desc_v, + pid_head, + k_stride1, + v_stride1, + seq_start_kv, + seq_start_kv + max_seq_len, + qk_scale, + allow_tf32, + BLOCK_N, + ) + acc = acc / l_i[:, None] + tl.store(O_block_ptr, acc.to(output.dtype.element_ty), boundary_check=[0, 1]) + + +def triton_flash_attn_ikbo_tma( + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + cand_to_user_mapping: torch.Tensor, + q_seq_len: int, + max_seq_len: int, + scale: Optional[float] = None, +) -> torch.Tensor: + """ + Ba: candidate batch size, Bu: user batch, H: num heads, D: head dim + query: [Ba * n_seeds, H, D] Dense tensor + key: [Bu * max_seq_len, H, D] Dense tensor (similar to jagged tensor expression, jagged tensor is for variable seq length) + value: [Bu * max_seq_len, H, D] Dense tensor + max_seq_len: int + cand_to_user_mapping: [Ba] tensor [0, 0, ..., 1, 1, ..., 2, 2, ...] index: cand batch id, value: user batch id + scale: float + output: [Ba * n_seeds, H, D] Dense tensor + """ + + sm_scale = scale + # d_model = query.shape[-1] + B_seed, H, d_head = query.shape + Bu_max_seq_len, _, _ = key.shape + if sm_scale is None: + sm_scale = 1.0 / math.sqrt(d_head) + + sm_scale = sm_scale / math.log(2.0) + BLOCK_D = triton.next_power_of_2(d_head) + output = torch.empty_like( + query, + ) + dummy_block = [1, 1] + desc_q = TensorDescriptor( + query, + shape=[B_seed, H * d_head], + strides=[H * d_head, 1], + block_shape=dummy_block, + ) + desc_v = TensorDescriptor( + value, + shape=[Bu_max_seq_len, H * d_head], + strides=[H * d_head, 1], + block_shape=dummy_block, + ) + desc_k = TensorDescriptor( + key, + shape=[Bu_max_seq_len, H * d_head], + strides=[H * d_head, 1], + block_shape=dummy_block, + ) + + def grid(META): + return ( + triton.cdiv(q_seq_len, META["BLOCK_M"]), + B_seed // q_seq_len, + H, + ) + + _attn_fwd_tma[grid]( + desc_q, + desc_k, + desc_v, + output, + cand_to_user_mapping, + query.stride(0), + query.stride(1), + query.stride(2), # + key.stride(0), + key.stride(1), + key.stride(2), # + value.stride(0), + value.stride(1), + value.stride(2), # + output.stride(0), + output.stride(1), + output.stride(2), # + q_seq_len, + max_seq_len, + sm_scale, + # li, + # mi, + d_head=d_head, + allow_tf32=True if query.dtype == torch.float32 else False, + BLOCK_D=BLOCK_D, + ) + + return output diff --git a/fbgemm_gpu/experimental/ikbo/tests/ikbo_fa_test.py b/fbgemm_gpu/experimental/ikbo/tests/ikbo_fa_test.py new file mode 100644 index 0000000000..93f5713b2e --- /dev/null +++ b/fbgemm_gpu/experimental/ikbo/tests/ikbo_fa_test.py @@ -0,0 +1,83 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + + +import sys + +import pytest +import torch +from ikbo.benchmarks.ikbo_fa_bench import broadcast_sdpa, prepare_inputs_by_config +from ikbo.ops.tlx_ikbo_fa_ws import tlx_flash_attn_ikbo_tma_persistent +from ikbo.ops.triton_ikbo_fa import triton_flash_attn_ikbo_tma + +DEVICE = "cuda" +DTYPE = torch.float16 + + +@pytest.mark.parametrize("B", [512, 102, 2048]) +@pytest.mark.parametrize("n_seed", [64]) +@pytest.mark.parametrize("num_heads", [1, 2, 4, 6]) +@pytest.mark.parametrize("d_head", [128]) +@pytest.mark.parametrize("max_seq_len", [500, 512, 1000, 1024, 2000, 2048]) +@pytest.mark.parametrize("cand_to_user_ratio", [10, 70]) +def test_triton_ikbo_fa(B, n_seed, num_heads, d_head, max_seq_len, cand_to_user_ratio): + query, key, value, cand_to_user_index, cand_grid = prepare_inputs_by_config( + B, + n_seed, + num_heads, + d_head, + max_seq_len, + cand_to_user_ratio, + cand_to_user_ratio, + ) + triton_output = ( + triton_flash_attn_ikbo_tma( + query, key, value, cand_to_user_index, n_seed, max_seq_len + ) + .view(B, n_seed, num_heads, d_head) + .permute(0, 2, 1, 3) + ) + torch_output = broadcast_sdpa( + query, key, value, cand_to_user_index, n_seed, num_heads, d_head, max_seq_len + ) + torch.testing.assert_close(torch_output, triton_output, atol=1e-3, rtol=1e-4) + + +@pytest.mark.parametrize("B", [512, 102, 2048]) +@pytest.mark.parametrize("n_seed", [64]) +@pytest.mark.parametrize("num_heads", [1, 2, 4, 6]) +@pytest.mark.parametrize("d_head", [128]) +@pytest.mark.parametrize("max_seq_len", [500, 512, 1000, 1024, 2000, 2048]) +@pytest.mark.parametrize("cand_to_user_ratio", [10, 70]) +def test_tlx_ikbo_fa(B, n_seed, num_heads, d_head, max_seq_len, cand_to_user_ratio): + query, key, value, cand_to_user_index, cand_grid = prepare_inputs_by_config( + B, + n_seed, + num_heads, + d_head, + max_seq_len, + cand_to_user_ratio, + cand_to_user_ratio, + ) + tlx_output = ( + tlx_flash_attn_ikbo_tma_persistent( + query, key, value, cand_to_user_index, n_seed, max_seq_len, cand_grid + ) + .view(B, n_seed, num_heads, d_head) + .permute(0, 2, 1, 3) + ) + torch_output = broadcast_sdpa( + query, key, value, cand_to_user_index, n_seed, num_heads, d_head, max_seq_len + ) + torch.testing.assert_close(torch_output, tlx_output, atol=1e-3, rtol=1e-4) + + +def main(): + sys.exit(pytest.main([__file__, "-v"])) + + +if __name__ == "__main__": + main() diff --git a/fbgemm_gpu/experimental/ikbo/utility/ikbo_lce_analysis.py b/fbgemm_gpu/experimental/ikbo/utility/ikbo_lce_analysis.py index 623137c01c..d54e94f50b 100644 --- a/fbgemm_gpu/experimental/ikbo/utility/ikbo_lce_analysis.py +++ b/fbgemm_gpu/experimental/ikbo/utility/ikbo_lce_analysis.py @@ -18,7 +18,7 @@ DEVICE = "cuda" DTYPE = torch.float16 -# Representative realistic dimensions. +# Representative dimensions. B, M, N, K_USER, K_CAND = 1024, 433, 256, 1178, 866 CAND_TO_USER_RATIO = 70