Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
#ifndef __PAGED_ATTENTION_PREFILL_BANG_H__
#define __PAGED_ATTENTION_PREFILL_BANG_H__

#include "../paged_attention_prefill.h"

DESCRIPTOR(bang)

#endif // __PAGED_ATTENTION_PREFILL_BANG_H__
Original file line number Diff line number Diff line change
@@ -0,0 +1,379 @@
#include "../../../devices/bang/common_bang.h"
#include "paged_attention_prefill_bang.h"

#include <cfloat>
#include <cstdint>
#include <type_traits>

namespace {

__nram__ char paged_attention_prefill_nram_buffer[NRAM_MAX_SIZE];

template <typename Tdata>
__mlu_device__ float toFloat(Tdata value) {
if constexpr (std::is_same_v<Tdata, half>) {
return __half2float(value);
} else if constexpr (std::is_same_v<Tdata, bfloat16_t>) {
return __bfloat162float(value);
} else {
return static_cast<float>(value);
}
}

template <typename Tdata>
__mlu_device__ Tdata fromFloat(float value) {
return static_cast<Tdata>(value);
}

template <typename Tindex>
__mlu_device__ int64_t toInt64(Tindex value) {
return static_cast<int64_t>(value);
}

template <typename Tdata>
__mlu_device__ void loadToFloat(float *dst, Tdata *tmp, const Tdata *src, size_t n) {
__memcpy(tmp, src, n * sizeof(Tdata), GDRAM2NRAM);
if constexpr (std::is_same_v<Tdata, half>) {
__bang_half2float(dst, tmp, n);
} else if constexpr (std::is_same_v<Tdata, bfloat16_t>) {
__bang_bfloat162float(dst, tmp, n);
} else {
__memcpy(dst, tmp, n * sizeof(float), NRAM2NRAM);
}
}

template <typename Tdata>
__mlu_device__ void storeFromFloat(Tdata *dst, Tdata *tmp, float *src, size_t n) {
if constexpr (std::is_same_v<Tdata, half>) {
__bang_float2half(tmp, src, n);
__memcpy(dst, tmp, n * sizeof(Tdata), NRAM2GDRAM);
} else if constexpr (std::is_same_v<Tdata, bfloat16_t>) {
__bang_float2bfloat16(tmp, src, n);
__memcpy(dst, tmp, n * sizeof(Tdata), NRAM2GDRAM);
} else {
__memcpy(dst, src, n * sizeof(float), NRAM2GDRAM);
}
}

template <typename Tdata>
__mlu_device__ float dotProduct(float *q_float, float *work_float, Tdata *tmp, const Tdata *k_ptr, size_t head_size) {
loadToFloat(work_float, tmp, k_ptr, head_size);
__bang_mul(work_float, work_float, q_float, head_size);
return __bang_sum(work_float, head_size);
}

template <typename Tdata>
__mlu_device__ void accumulateValue(float *acc_float, float *work_float, Tdata *tmp, const Tdata *v_ptr, float weight, size_t head_size) {
loadToFloat(work_float, tmp, v_ptr, head_size);
__bang_mul_scalar(work_float, work_float, weight, head_size);
__bang_add(acc_float, acc_float, work_float, head_size);
}

template <typename Tindex>
__mlu_device__ size_t findSeqId(size_t token_idx, const Tindex *cum_seq_lens_q, size_t num_seqs) {
size_t low = 0;
size_t high = num_seqs;
while (low + 1 < high) {
const size_t mid = (low + high) >> 1;
if (token_idx < static_cast<size_t>(toInt64(cum_seq_lens_q[mid]))) {
high = mid;
} else {
low = mid;
}
}
return low;
}

template <typename Tdata, typename Tindex>
__mlu_global__ void pagedAttentionPrefillKernel(
Tdata *out,
const Tdata *q,
const Tdata *k_cache,
const Tdata *v_cache,
const Tindex *block_tables,
const Tindex *total_kv_lens,
const Tindex *cum_seq_lens_q,
const float *alibi_slopes,
size_t num_seqs,
size_t total_q_tokens,
size_t num_heads,
size_t num_kv_heads,
size_t head_size,
float scale,
size_t max_num_blocks_per_seq,
size_t page_block_size,
ptrdiff_t block_table_batch_stride,
ptrdiff_t q_stride,
ptrdiff_t q_head_stride,
ptrdiff_t k_batch_stride,
ptrdiff_t k_row_stride,
ptrdiff_t k_head_stride,
ptrdiff_t v_batch_stride,
ptrdiff_t v_row_stride,
ptrdiff_t v_head_stride,
ptrdiff_t o_stride,
ptrdiff_t o_head_stride) {

const size_t task_num = taskDimX * taskDimY;
const size_t total = total_q_tokens * num_heads;

char *nram_base = reinterpret_cast<char *>(((reinterpret_cast<size_t>(paged_attention_prefill_nram_buffer) + ALIGN_SIZE - 1) / ALIGN_SIZE) * ALIGN_SIZE);
Tdata *tmp_data = reinterpret_cast<Tdata *>(nram_base);
float *q_float = reinterpret_cast<float *>(tmp_data + head_size);
float *work_float = q_float + head_size;
float *acc_float = work_float + head_size;
float *logits = acc_float + head_size;
const size_t header_bytes = reinterpret_cast<char *>(logits) - nram_base;
const size_t logits_capacity = (header_bytes < NRAM_MAX_SIZE) ? ((NRAM_MAX_SIZE - header_bytes) / sizeof(float)) : 0;

for (size_t linear = taskId; linear < total; linear += task_num) {
const size_t global_token_idx = linear / num_heads;
const size_t head_idx = linear - global_token_idx * num_heads;
const size_t seq_idx = findSeqId(global_token_idx, cum_seq_lens_q, num_seqs);

const int64_t q_start_i64 = toInt64(cum_seq_lens_q[seq_idx]);
const int64_t q_end_i64 = toInt64(cum_seq_lens_q[seq_idx + 1]);
const int64_t q_len_i64 = q_end_i64 - q_start_i64;
const int64_t q_token_idx_i64 = static_cast<int64_t>(global_token_idx) - q_start_i64;
const int64_t total_kv_len_i64 = toInt64(total_kv_lens[seq_idx]);
const int64_t causal_limit_i64 = total_kv_len_i64 - q_len_i64 + q_token_idx_i64;

Tdata *out_ptr = out + global_token_idx * o_stride + head_idx * o_head_stride;
if (causal_limit_i64 < 0) {
__bang_write_value(acc_float, head_size, 0.0f);
storeFromFloat(out_ptr, tmp_data, acc_float, head_size);
continue;
}

const size_t attend_len = static_cast<size_t>(causal_limit_i64 + 1);
const bool cache_logits = attend_len <= logits_capacity;

const size_t num_queries_per_kv = num_heads / num_kv_heads;
const size_t kv_head_idx = head_idx / num_queries_per_kv;
const float alibi_slope = (alibi_slopes == nullptr) ? 0.0f : alibi_slopes[head_idx];
const Tindex *block_table = block_tables + seq_idx * block_table_batch_stride;
const Tdata *q_ptr = q + global_token_idx * q_stride + head_idx * q_head_stride;
loadToFloat(q_float, tmp_data, q_ptr, head_size);

float max_score = -FLT_MAX;
for (size_t token_idx = 0; token_idx < attend_len; ++token_idx) {
const size_t logical_block = token_idx / page_block_size;
const size_t token_in_block = token_idx - logical_block * page_block_size;
const int64_t physical_block = toInt64(block_table[logical_block]);
const Tdata *k_ptr = k_cache + physical_block * k_batch_stride
+ kv_head_idx * k_head_stride
+ token_in_block * k_row_stride;

float score = dotProduct(q_float, work_float, tmp_data, k_ptr, head_size);
score *= scale;
if (alibi_slope != 0.0f) {
score += alibi_slope * static_cast<float>(static_cast<int64_t>(token_idx) - causal_limit_i64);
}
if (cache_logits) {
logits[token_idx] = score;
}
if (score > max_score) {
max_score = score;
}
}

float sum_exp = 0.0f;
if (cache_logits) {
for (size_t token_idx = 0; token_idx < attend_len; ++token_idx) {
const float prob = expf(logits[token_idx] - max_score);
logits[token_idx] = prob;
sum_exp += prob;
}
} else {
for (size_t token_idx = 0; token_idx < attend_len; ++token_idx) {
const size_t logical_block = token_idx / page_block_size;
const size_t token_in_block = token_idx - logical_block * page_block_size;
const int64_t physical_block = toInt64(block_table[logical_block]);
const Tdata *k_ptr = k_cache + physical_block * k_batch_stride
+ kv_head_idx * k_head_stride
+ token_in_block * k_row_stride;
float score = dotProduct(q_float, work_float, tmp_data, k_ptr, head_size) * scale;
if (alibi_slope != 0.0f) {
score += alibi_slope * static_cast<float>(static_cast<int64_t>(token_idx) - causal_limit_i64);
}
sum_exp += expf(score - max_score);
}
}
const float inv_sum = 1.0f / (sum_exp + 1e-6f);

__bang_write_value(acc_float, head_size, 0.0f);
for (size_t token_idx = 0; token_idx < attend_len; ++token_idx) {
const size_t logical_block = token_idx / page_block_size;
const size_t token_in_block = token_idx - logical_block * page_block_size;
const int64_t physical_block = toInt64(block_table[logical_block]);
float prob = 0.0f;
if (cache_logits) {
prob = logits[token_idx] * inv_sum;
} else {
const Tdata *k_ptr = k_cache + physical_block * k_batch_stride
+ kv_head_idx * k_head_stride
+ token_in_block * k_row_stride;
float score = dotProduct(q_float, work_float, tmp_data, k_ptr, head_size) * scale;
if (alibi_slope != 0.0f) {
score += alibi_slope * static_cast<float>(static_cast<int64_t>(token_idx) - causal_limit_i64);
}
prob = expf(score - max_score) * inv_sum;
}
const Tdata *v_ptr = v_cache + physical_block * v_batch_stride
+ kv_head_idx * v_head_stride
+ token_in_block * v_row_stride;
accumulateValue(acc_float, work_float, tmp_data, v_ptr, prob, head_size);
}
storeFromFloat(out_ptr, tmp_data, acc_float, head_size);
}
}

template <typename Tindex, typename Tdata>
infiniStatus_t launchPagedAttentionPrefill(
cnrtQueue_t queue,
void *out,
const void *q,
const void *k_cache,
const void *v_cache,
const void *block_tables,
const void *total_kv_lens,
const void *cum_seq_lens_q,
const void *alibi_slopes,
const op::paged_attention_prefill::PagedAttentionPrefillInfo &info,
int core_per_cluster,
int cluster_count) {

cnrtDim3_t kernel_dim;
kernel_dim.x = core_per_cluster;
kernel_dim.y = cluster_count;
kernel_dim.z = 1;

pagedAttentionPrefillKernel<Tdata, Tindex><<<kernel_dim, cnrtFuncTypeUnion1, queue>>>(
static_cast<Tdata *>(out),
static_cast<const Tdata *>(q),
static_cast<const Tdata *>(k_cache),
static_cast<const Tdata *>(v_cache),
static_cast<const Tindex *>(block_tables),
static_cast<const Tindex *>(total_kv_lens),
static_cast<const Tindex *>(cum_seq_lens_q),
static_cast<const float *>(alibi_slopes),
info.num_seqs,
info.total_q_tokens,
info.num_heads,
info.num_kv_heads,
info.head_size,
info.scale,
info.max_num_blocks_per_seq,
info.page_block_size,
info.block_table_batch_stride,
info.q_stride,
info.q_head_stride,
info.k_batch_stride,
info.k_row_stride,
info.k_head_stride,
info.v_batch_stride,
info.v_row_stride,
info.v_head_stride,
info.o_stride,
info.o_head_stride);
cnrtQueueSync(queue);
return INFINI_STATUS_SUCCESS;
}

template <typename Tindex>
infiniStatus_t dispatchPagedAttentionPrefillDtype(
infiniDtype_t dtype,
cnrtQueue_t queue,
void *out,
const void *q,
const void *k_cache,
const void *v_cache,
const void *block_tables,
const void *total_kv_lens,
const void *cum_seq_lens_q,
const void *alibi_slopes,
const op::paged_attention_prefill::PagedAttentionPrefillInfo &info,
int core_per_cluster,
int cluster_count) {
switch (dtype) {
case INFINI_DTYPE_F16:
return launchPagedAttentionPrefill<Tindex, half>(queue, out, q, k_cache, v_cache, block_tables, total_kv_lens, cum_seq_lens_q, alibi_slopes, info, core_per_cluster, cluster_count);
case INFINI_DTYPE_BF16:
return launchPagedAttentionPrefill<Tindex, bfloat16_t>(queue, out, q, k_cache, v_cache, block_tables, total_kv_lens, cum_seq_lens_q, alibi_slopes, info, core_per_cluster, cluster_count);
default:
return INFINI_STATUS_BAD_TENSOR_DTYPE;
}
}

} // namespace

namespace op::paged_attention_prefill::bang {

struct Descriptor::Opaque {
std::shared_ptr<device::bang::Handle::Internal> internal;
};

Descriptor::~Descriptor() {
delete _opaque;
}

infiniStatus_t Descriptor::create(
infiniopHandle_t handle_,
Descriptor **desc_ptr,
infiniopTensorDescriptor_t out_desc,
infiniopTensorDescriptor_t q_desc,
infiniopTensorDescriptor_t k_cache_desc,
infiniopTensorDescriptor_t v_cache_desc,
infiniopTensorDescriptor_t block_tables_desc,
infiniopTensorDescriptor_t total_kv_lens_desc,
infiniopTensorDescriptor_t cum_seqlens_q_desc,
const std::optional<infiniopTensorDescriptor_t> &alibi_slopes_desc,
float scale) {

auto handle = reinterpret_cast<device::bang::cambricon::Handle *>(handle_);
auto info = PagedAttentionPrefillInfo::create(
out_desc, q_desc, k_cache_desc, v_cache_desc,
block_tables_desc, total_kv_lens_desc, cum_seqlens_q_desc,
alibi_slopes_desc, scale);
CHECK_RESULT(info);

*desc_ptr = new Descriptor(
new Opaque{static_cast<device::bang::Handle *>(handle)->internal()},
info.take(), 0, handle->device, handle->device_id);

return INFINI_STATUS_SUCCESS;
}

infiniStatus_t Descriptor::calculate(
void *workspace,
size_t workspace_size,
void *out,
const void *q,
const void *k_cache,
const void *v_cache,
const void *block_tables,
const void *total_kv_lens,
const void *cum_seq_lens_q,
const void *alibi_slopes,
void *stream) const {

(void)workspace;
(void)workspace_size;

auto queue = reinterpret_cast<cnrtQueue_t>(stream);
const int core_per_cluster = _opaque->internal->getCorePerCluster();
const int cluster_count = _opaque->internal->getClusterCount();

if (_info.index_dtype == INFINI_DTYPE_I64) {
return dispatchPagedAttentionPrefillDtype<int64_t>(_info.dtype, queue, out, q, k_cache, v_cache, block_tables, total_kv_lens, cum_seq_lens_q, alibi_slopes, _info, core_per_cluster, cluster_count);
}
if (_info.index_dtype == INFINI_DTYPE_I32) {
return dispatchPagedAttentionPrefillDtype<int32_t>(_info.dtype, queue, out, q, k_cache, v_cache, block_tables, total_kv_lens, cum_seq_lens_q, alibi_slopes, _info, core_per_cluster, cluster_count);
}
if (_info.index_dtype == INFINI_DTYPE_U32) {
return dispatchPagedAttentionPrefillDtype<uint32_t>(_info.dtype, queue, out, q, k_cache, v_cache, block_tables, total_kv_lens, cum_seq_lens_q, alibi_slopes, _info, core_per_cluster, cluster_count);
}
return INFINI_STATUS_BAD_TENSOR_DTYPE;
}

} // namespace op::paged_attention_prefill::bang
Loading
Loading