From 00a3117274c814f9c5953d61fb67cda7e10c03af Mon Sep 17 00:00:00 2001 From: wooway777 Date: Tue, 9 Jun 2026 15:57:09 +0800 Subject: [PATCH] issue/1221 - cambricon paged attention prefill naive --- .../bang/paged_attention_prefill_bang.h | 8 + .../bang/paged_attention_prefill_bang.mlu | 379 ++++++++++++++++++ .../ops/paged_attention_prefill/operator.cc | 15 + 3 files changed, 402 insertions(+) create mode 100644 src/infiniop/ops/paged_attention_prefill/bang/paged_attention_prefill_bang.h create mode 100644 src/infiniop/ops/paged_attention_prefill/bang/paged_attention_prefill_bang.mlu diff --git a/src/infiniop/ops/paged_attention_prefill/bang/paged_attention_prefill_bang.h b/src/infiniop/ops/paged_attention_prefill/bang/paged_attention_prefill_bang.h new file mode 100644 index 000000000..7d4763fe7 --- /dev/null +++ b/src/infiniop/ops/paged_attention_prefill/bang/paged_attention_prefill_bang.h @@ -0,0 +1,8 @@ +#ifndef __PAGED_ATTENTION_PREFILL_BANG_H__ +#define __PAGED_ATTENTION_PREFILL_BANG_H__ + +#include "../paged_attention_prefill.h" + +DESCRIPTOR(bang) + +#endif // __PAGED_ATTENTION_PREFILL_BANG_H__ diff --git a/src/infiniop/ops/paged_attention_prefill/bang/paged_attention_prefill_bang.mlu b/src/infiniop/ops/paged_attention_prefill/bang/paged_attention_prefill_bang.mlu new file mode 100644 index 000000000..dfdb781cd --- /dev/null +++ b/src/infiniop/ops/paged_attention_prefill/bang/paged_attention_prefill_bang.mlu @@ -0,0 +1,379 @@ +#include "../../../devices/bang/common_bang.h" +#include "paged_attention_prefill_bang.h" + +#include +#include +#include + +namespace { + +__nram__ char paged_attention_prefill_nram_buffer[NRAM_MAX_SIZE]; + +template +__mlu_device__ float toFloat(Tdata value) { + if constexpr (std::is_same_v) { + return __half2float(value); + } else if constexpr (std::is_same_v) { + return __bfloat162float(value); + } else { + return static_cast(value); + } +} + +template +__mlu_device__ Tdata fromFloat(float value) { + return static_cast(value); +} + +template +__mlu_device__ int64_t toInt64(Tindex value) { + return static_cast(value); +} + +template +__mlu_device__ void loadToFloat(float *dst, Tdata *tmp, const Tdata *src, size_t n) { + __memcpy(tmp, src, n * sizeof(Tdata), GDRAM2NRAM); + if constexpr (std::is_same_v) { + __bang_half2float(dst, tmp, n); + } else if constexpr (std::is_same_v) { + __bang_bfloat162float(dst, tmp, n); + } else { + __memcpy(dst, tmp, n * sizeof(float), NRAM2NRAM); + } +} + +template +__mlu_device__ void storeFromFloat(Tdata *dst, Tdata *tmp, float *src, size_t n) { + if constexpr (std::is_same_v) { + __bang_float2half(tmp, src, n); + __memcpy(dst, tmp, n * sizeof(Tdata), NRAM2GDRAM); + } else if constexpr (std::is_same_v) { + __bang_float2bfloat16(tmp, src, n); + __memcpy(dst, tmp, n * sizeof(Tdata), NRAM2GDRAM); + } else { + __memcpy(dst, src, n * sizeof(float), NRAM2GDRAM); + } +} + +template +__mlu_device__ float dotProduct(float *q_float, float *work_float, Tdata *tmp, const Tdata *k_ptr, size_t head_size) { + loadToFloat(work_float, tmp, k_ptr, head_size); + __bang_mul(work_float, work_float, q_float, head_size); + return __bang_sum(work_float, head_size); +} + +template +__mlu_device__ void accumulateValue(float *acc_float, float *work_float, Tdata *tmp, const Tdata *v_ptr, float weight, size_t head_size) { + loadToFloat(work_float, tmp, v_ptr, head_size); + __bang_mul_scalar(work_float, work_float, weight, head_size); + __bang_add(acc_float, acc_float, work_float, head_size); +} + +template +__mlu_device__ size_t findSeqId(size_t token_idx, const Tindex *cum_seq_lens_q, size_t num_seqs) { + size_t low = 0; + size_t high = num_seqs; + while (low + 1 < high) { + const size_t mid = (low + high) >> 1; + if (token_idx < static_cast(toInt64(cum_seq_lens_q[mid]))) { + high = mid; + } else { + low = mid; + } + } + return low; +} + +template +__mlu_global__ void pagedAttentionPrefillKernel( + Tdata *out, + const Tdata *q, + const Tdata *k_cache, + const Tdata *v_cache, + const Tindex *block_tables, + const Tindex *total_kv_lens, + const Tindex *cum_seq_lens_q, + const float *alibi_slopes, + size_t num_seqs, + size_t total_q_tokens, + size_t num_heads, + size_t num_kv_heads, + size_t head_size, + float scale, + size_t max_num_blocks_per_seq, + size_t page_block_size, + ptrdiff_t block_table_batch_stride, + ptrdiff_t q_stride, + ptrdiff_t q_head_stride, + ptrdiff_t k_batch_stride, + ptrdiff_t k_row_stride, + ptrdiff_t k_head_stride, + ptrdiff_t v_batch_stride, + ptrdiff_t v_row_stride, + ptrdiff_t v_head_stride, + ptrdiff_t o_stride, + ptrdiff_t o_head_stride) { + + const size_t task_num = taskDimX * taskDimY; + const size_t total = total_q_tokens * num_heads; + + char *nram_base = reinterpret_cast(((reinterpret_cast(paged_attention_prefill_nram_buffer) + ALIGN_SIZE - 1) / ALIGN_SIZE) * ALIGN_SIZE); + Tdata *tmp_data = reinterpret_cast(nram_base); + float *q_float = reinterpret_cast(tmp_data + head_size); + float *work_float = q_float + head_size; + float *acc_float = work_float + head_size; + float *logits = acc_float + head_size; + const size_t header_bytes = reinterpret_cast(logits) - nram_base; + const size_t logits_capacity = (header_bytes < NRAM_MAX_SIZE) ? ((NRAM_MAX_SIZE - header_bytes) / sizeof(float)) : 0; + + for (size_t linear = taskId; linear < total; linear += task_num) { + const size_t global_token_idx = linear / num_heads; + const size_t head_idx = linear - global_token_idx * num_heads; + const size_t seq_idx = findSeqId(global_token_idx, cum_seq_lens_q, num_seqs); + + const int64_t q_start_i64 = toInt64(cum_seq_lens_q[seq_idx]); + const int64_t q_end_i64 = toInt64(cum_seq_lens_q[seq_idx + 1]); + const int64_t q_len_i64 = q_end_i64 - q_start_i64; + const int64_t q_token_idx_i64 = static_cast(global_token_idx) - q_start_i64; + const int64_t total_kv_len_i64 = toInt64(total_kv_lens[seq_idx]); + const int64_t causal_limit_i64 = total_kv_len_i64 - q_len_i64 + q_token_idx_i64; + + Tdata *out_ptr = out + global_token_idx * o_stride + head_idx * o_head_stride; + if (causal_limit_i64 < 0) { + __bang_write_value(acc_float, head_size, 0.0f); + storeFromFloat(out_ptr, tmp_data, acc_float, head_size); + continue; + } + + const size_t attend_len = static_cast(causal_limit_i64 + 1); + const bool cache_logits = attend_len <= logits_capacity; + + const size_t num_queries_per_kv = num_heads / num_kv_heads; + const size_t kv_head_idx = head_idx / num_queries_per_kv; + const float alibi_slope = (alibi_slopes == nullptr) ? 0.0f : alibi_slopes[head_idx]; + const Tindex *block_table = block_tables + seq_idx * block_table_batch_stride; + const Tdata *q_ptr = q + global_token_idx * q_stride + head_idx * q_head_stride; + loadToFloat(q_float, tmp_data, q_ptr, head_size); + + float max_score = -FLT_MAX; + for (size_t token_idx = 0; token_idx < attend_len; ++token_idx) { + const size_t logical_block = token_idx / page_block_size; + const size_t token_in_block = token_idx - logical_block * page_block_size; + const int64_t physical_block = toInt64(block_table[logical_block]); + const Tdata *k_ptr = k_cache + physical_block * k_batch_stride + + kv_head_idx * k_head_stride + + token_in_block * k_row_stride; + + float score = dotProduct(q_float, work_float, tmp_data, k_ptr, head_size); + score *= scale; + if (alibi_slope != 0.0f) { + score += alibi_slope * static_cast(static_cast(token_idx) - causal_limit_i64); + } + if (cache_logits) { + logits[token_idx] = score; + } + if (score > max_score) { + max_score = score; + } + } + + float sum_exp = 0.0f; + if (cache_logits) { + for (size_t token_idx = 0; token_idx < attend_len; ++token_idx) { + const float prob = expf(logits[token_idx] - max_score); + logits[token_idx] = prob; + sum_exp += prob; + } + } else { + for (size_t token_idx = 0; token_idx < attend_len; ++token_idx) { + const size_t logical_block = token_idx / page_block_size; + const size_t token_in_block = token_idx - logical_block * page_block_size; + const int64_t physical_block = toInt64(block_table[logical_block]); + const Tdata *k_ptr = k_cache + physical_block * k_batch_stride + + kv_head_idx * k_head_stride + + token_in_block * k_row_stride; + float score = dotProduct(q_float, work_float, tmp_data, k_ptr, head_size) * scale; + if (alibi_slope != 0.0f) { + score += alibi_slope * static_cast(static_cast(token_idx) - causal_limit_i64); + } + sum_exp += expf(score - max_score); + } + } + const float inv_sum = 1.0f / (sum_exp + 1e-6f); + + __bang_write_value(acc_float, head_size, 0.0f); + for (size_t token_idx = 0; token_idx < attend_len; ++token_idx) { + const size_t logical_block = token_idx / page_block_size; + const size_t token_in_block = token_idx - logical_block * page_block_size; + const int64_t physical_block = toInt64(block_table[logical_block]); + float prob = 0.0f; + if (cache_logits) { + prob = logits[token_idx] * inv_sum; + } else { + const Tdata *k_ptr = k_cache + physical_block * k_batch_stride + + kv_head_idx * k_head_stride + + token_in_block * k_row_stride; + float score = dotProduct(q_float, work_float, tmp_data, k_ptr, head_size) * scale; + if (alibi_slope != 0.0f) { + score += alibi_slope * static_cast(static_cast(token_idx) - causal_limit_i64); + } + prob = expf(score - max_score) * inv_sum; + } + const Tdata *v_ptr = v_cache + physical_block * v_batch_stride + + kv_head_idx * v_head_stride + + token_in_block * v_row_stride; + accumulateValue(acc_float, work_float, tmp_data, v_ptr, prob, head_size); + } + storeFromFloat(out_ptr, tmp_data, acc_float, head_size); + } +} + +template +infiniStatus_t launchPagedAttentionPrefill( + cnrtQueue_t queue, + void *out, + const void *q, + const void *k_cache, + const void *v_cache, + const void *block_tables, + const void *total_kv_lens, + const void *cum_seq_lens_q, + const void *alibi_slopes, + const op::paged_attention_prefill::PagedAttentionPrefillInfo &info, + int core_per_cluster, + int cluster_count) { + + cnrtDim3_t kernel_dim; + kernel_dim.x = core_per_cluster; + kernel_dim.y = cluster_count; + kernel_dim.z = 1; + + pagedAttentionPrefillKernel<<>>( + static_cast(out), + static_cast(q), + static_cast(k_cache), + static_cast(v_cache), + static_cast(block_tables), + static_cast(total_kv_lens), + static_cast(cum_seq_lens_q), + static_cast(alibi_slopes), + info.num_seqs, + info.total_q_tokens, + info.num_heads, + info.num_kv_heads, + info.head_size, + info.scale, + info.max_num_blocks_per_seq, + info.page_block_size, + info.block_table_batch_stride, + info.q_stride, + info.q_head_stride, + info.k_batch_stride, + info.k_row_stride, + info.k_head_stride, + info.v_batch_stride, + info.v_row_stride, + info.v_head_stride, + info.o_stride, + info.o_head_stride); + cnrtQueueSync(queue); + return INFINI_STATUS_SUCCESS; +} + +template +infiniStatus_t dispatchPagedAttentionPrefillDtype( + infiniDtype_t dtype, + cnrtQueue_t queue, + void *out, + const void *q, + const void *k_cache, + const void *v_cache, + const void *block_tables, + const void *total_kv_lens, + const void *cum_seq_lens_q, + const void *alibi_slopes, + const op::paged_attention_prefill::PagedAttentionPrefillInfo &info, + int core_per_cluster, + int cluster_count) { + switch (dtype) { + case INFINI_DTYPE_F16: + return launchPagedAttentionPrefill(queue, out, q, k_cache, v_cache, block_tables, total_kv_lens, cum_seq_lens_q, alibi_slopes, info, core_per_cluster, cluster_count); + case INFINI_DTYPE_BF16: + return launchPagedAttentionPrefill(queue, out, q, k_cache, v_cache, block_tables, total_kv_lens, cum_seq_lens_q, alibi_slopes, info, core_per_cluster, cluster_count); + default: + return INFINI_STATUS_BAD_TENSOR_DTYPE; + } +} + +} // namespace + +namespace op::paged_attention_prefill::bang { + +struct Descriptor::Opaque { + std::shared_ptr internal; +}; + +Descriptor::~Descriptor() { + delete _opaque; +} + +infiniStatus_t Descriptor::create( + infiniopHandle_t handle_, + Descriptor **desc_ptr, + infiniopTensorDescriptor_t out_desc, + infiniopTensorDescriptor_t q_desc, + infiniopTensorDescriptor_t k_cache_desc, + infiniopTensorDescriptor_t v_cache_desc, + infiniopTensorDescriptor_t block_tables_desc, + infiniopTensorDescriptor_t total_kv_lens_desc, + infiniopTensorDescriptor_t cum_seqlens_q_desc, + const std::optional &alibi_slopes_desc, + float scale) { + + auto handle = reinterpret_cast(handle_); + auto info = PagedAttentionPrefillInfo::create( + out_desc, q_desc, k_cache_desc, v_cache_desc, + block_tables_desc, total_kv_lens_desc, cum_seqlens_q_desc, + alibi_slopes_desc, scale); + CHECK_RESULT(info); + + *desc_ptr = new Descriptor( + new Opaque{static_cast(handle)->internal()}, + info.take(), 0, handle->device, handle->device_id); + + return INFINI_STATUS_SUCCESS; +} + +infiniStatus_t Descriptor::calculate( + void *workspace, + size_t workspace_size, + void *out, + const void *q, + const void *k_cache, + const void *v_cache, + const void *block_tables, + const void *total_kv_lens, + const void *cum_seq_lens_q, + const void *alibi_slopes, + void *stream) const { + + (void)workspace; + (void)workspace_size; + + auto queue = reinterpret_cast(stream); + const int core_per_cluster = _opaque->internal->getCorePerCluster(); + const int cluster_count = _opaque->internal->getClusterCount(); + + if (_info.index_dtype == INFINI_DTYPE_I64) { + return dispatchPagedAttentionPrefillDtype(_info.dtype, queue, out, q, k_cache, v_cache, block_tables, total_kv_lens, cum_seq_lens_q, alibi_slopes, _info, core_per_cluster, cluster_count); + } + if (_info.index_dtype == INFINI_DTYPE_I32) { + return dispatchPagedAttentionPrefillDtype(_info.dtype, queue, out, q, k_cache, v_cache, block_tables, total_kv_lens, cum_seq_lens_q, alibi_slopes, _info, core_per_cluster, cluster_count); + } + if (_info.index_dtype == INFINI_DTYPE_U32) { + return dispatchPagedAttentionPrefillDtype(_info.dtype, queue, out, q, k_cache, v_cache, block_tables, total_kv_lens, cum_seq_lens_q, alibi_slopes, _info, core_per_cluster, cluster_count); + } + return INFINI_STATUS_BAD_TENSOR_DTYPE; +} + +} // namespace op::paged_attention_prefill::bang diff --git a/src/infiniop/ops/paged_attention_prefill/operator.cc b/src/infiniop/ops/paged_attention_prefill/operator.cc index 36804cfff..2dc1c0723 100644 --- a/src/infiniop/ops/paged_attention_prefill/operator.cc +++ b/src/infiniop/ops/paged_attention_prefill/operator.cc @@ -8,6 +8,9 @@ #ifdef ENABLE_METAX_API #include "metax/paged_attention_prefill_metax.h" #endif +#ifdef ENABLE_CAMBRICON_API +#include "bang/paged_attention_prefill_bang.h" +#endif #ifdef ENABLE_MOORE_API #include "moore/paged_attention_prefill_moore.h" #endif @@ -50,6 +53,9 @@ __INFINI_C infiniStatus_t infiniopCreatePagedAttentionPrefillDescriptor( #endif #ifdef ENABLE_MOORE_API CREATE(INFINI_DEVICE_MOORE, moore) +#endif +#ifdef ENABLE_CAMBRICON_API + CREATE(INFINI_DEVICE_CAMBRICON, bang) #endif default: return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED; @@ -80,6 +86,9 @@ __INFINI_C infiniStatus_t infiniopGetPagedAttentionPrefillWorkspaceSize( #endif #ifdef ENABLE_MOORE_API GET(INFINI_DEVICE_MOORE, moore) +#endif +#ifdef ENABLE_CAMBRICON_API + GET(INFINI_DEVICE_CAMBRICON, bang) #endif default: return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED; @@ -117,6 +126,9 @@ __INFINI_C infiniStatus_t infiniopPagedAttentionPrefill( #endif #ifdef ENABLE_MOORE_API CALCULATE(INFINI_DEVICE_MOORE, moore) +#endif +#ifdef ENABLE_CAMBRICON_API + CALCULATE(INFINI_DEVICE_CAMBRICON, bang) #endif default: return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED; @@ -146,6 +158,9 @@ __INFINI_C infiniStatus_t infiniopDestroyPagedAttentionPrefillDescriptor( #endif #ifdef ENABLE_MOORE_API DESTROY(INFINI_DEVICE_MOORE, moore) +#endif +#ifdef ENABLE_CAMBRICON_API + DESTROY(INFINI_DEVICE_CAMBRICON, bang) #endif default: return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED;