Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions src/infiniop/ops/paged_attention/bang/paged_attention_bang.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
#ifndef __PAGED_ATTENTION_BANG_H__
#define __PAGED_ATTENTION_BANG_H__

#include "../paged_attention.h"

DESCRIPTOR(bang)

#endif // __PAGED_ATTENTION_BANG_H__
344 changes: 344 additions & 0 deletions src/infiniop/ops/paged_attention/bang/paged_attention_bang.mlu
Original file line number Diff line number Diff line change
@@ -0,0 +1,344 @@
#include "../../../devices/bang/common_bang.h"
#include "paged_attention_bang.h"

#include <cfloat>
#include <cstdint>
#include <type_traits>

namespace {

__nram__ char paged_attention_nram_buffer[NRAM_MAX_SIZE];

template <typename Tdata>
__mlu_device__ float toFloat(Tdata value) {
if constexpr (std::is_same_v<Tdata, half>) {
return __half2float(value);
} else if constexpr (std::is_same_v<Tdata, bfloat16_t>) {
return __bfloat162float(value);
} else {
return static_cast<float>(value);
}
}

template <typename Tdata>
__mlu_device__ Tdata fromFloat(float value) {
return static_cast<Tdata>(value);
}

template <typename Tindex>
__mlu_device__ int64_t toInt64(Tindex value) {
return static_cast<int64_t>(value);
}

template <typename Tdata>
__mlu_device__ void loadToFloat(float *dst, Tdata *tmp, const Tdata *src, size_t n) {
__memcpy(tmp, src, n * sizeof(Tdata), GDRAM2NRAM);
if constexpr (std::is_same_v<Tdata, half>) {
__bang_half2float(dst, tmp, n);
} else if constexpr (std::is_same_v<Tdata, bfloat16_t>) {
__bang_bfloat162float(dst, tmp, n);
} else {
__memcpy(dst, tmp, n * sizeof(float), NRAM2NRAM);
}
}

template <typename Tdata>
__mlu_device__ void storeFromFloat(Tdata *dst, Tdata *tmp, float *src, size_t n) {
if constexpr (std::is_same_v<Tdata, half>) {
__bang_float2half(tmp, src, n);
__memcpy(dst, tmp, n * sizeof(Tdata), NRAM2GDRAM);
} else if constexpr (std::is_same_v<Tdata, bfloat16_t>) {
__bang_float2bfloat16(tmp, src, n);
__memcpy(dst, tmp, n * sizeof(Tdata), NRAM2GDRAM);
} else {
__memcpy(dst, src, n * sizeof(float), NRAM2GDRAM);
}
}

template <typename Tdata>
__mlu_device__ float dotProduct(float *q_float, float *work_float, Tdata *tmp, const Tdata *k_ptr, size_t head_size) {
loadToFloat(work_float, tmp, k_ptr, head_size);
__bang_mul(work_float, work_float, q_float, head_size);
return __bang_sum(work_float, head_size);
}

template <typename Tdata>
__mlu_device__ void accumulateValue(float *acc_float, float *work_float, Tdata *tmp, const Tdata *v_ptr, float weight, size_t head_size) {
loadToFloat(work_float, tmp, v_ptr, head_size);
__bang_mul_scalar(work_float, work_float, weight, head_size);
__bang_add(acc_float, acc_float, work_float, head_size);
}

template <typename Tdata, typename Tindex>
__mlu_global__ void pagedAttentionKernel(
Tdata *out,
const Tdata *q,
const Tdata *k_cache,
const Tdata *v_cache,
const Tindex *block_tables,
const Tindex *seq_lens,
const float *alibi_slopes,
size_t num_seqs,
size_t num_heads,
size_t num_kv_heads,
size_t head_size,
float scale,
size_t max_num_blocks_per_seq,
size_t page_block_size,
ptrdiff_t q_stride,
ptrdiff_t k_batch_stride,
ptrdiff_t k_row_stride,
ptrdiff_t k_head_stride,
ptrdiff_t v_batch_stride,
ptrdiff_t v_row_stride,
ptrdiff_t v_head_stride,
ptrdiff_t o_stride,
ptrdiff_t block_table_batch_stride,
ptrdiff_t cache_lens_stride) {

const size_t task_num = taskDimX * taskDimY;
const size_t total = num_seqs * num_heads;

char *nram_base = reinterpret_cast<char *>(((reinterpret_cast<size_t>(paged_attention_nram_buffer) + ALIGN_SIZE - 1) / ALIGN_SIZE) * ALIGN_SIZE);
Tdata *tmp_data = reinterpret_cast<Tdata *>(nram_base);
float *q_float = reinterpret_cast<float *>(tmp_data + head_size);
float *work_float = q_float + head_size;
float *acc_float = work_float + head_size;
float *logits = acc_float + head_size;
const size_t header_bytes = reinterpret_cast<char *>(logits) - nram_base;
const size_t logits_capacity = (header_bytes < NRAM_MAX_SIZE) ? ((NRAM_MAX_SIZE - header_bytes) / sizeof(float)) : 0;

for (size_t linear = taskId; linear < total; linear += task_num) {
const size_t seq_idx = linear / num_heads;
const size_t head_idx = linear - seq_idx * num_heads;
const int64_t seq_len_i64 = toInt64(seq_lens[seq_idx * cache_lens_stride]);
Tdata *out_ptr = out + seq_idx * o_stride + head_idx * head_size;

if (seq_len_i64 <= 0) {
__bang_write_value(acc_float, head_size, 0.0f);
storeFromFloat(out_ptr, tmp_data, acc_float, head_size);
continue;
}

const size_t seq_len = static_cast<size_t>(seq_len_i64);
const bool cache_logits = seq_len <= logits_capacity;

const size_t num_queries_per_kv = num_heads / num_kv_heads;
const size_t kv_head_idx = head_idx / num_queries_per_kv;
const float alibi_slope = (alibi_slopes == nullptr) ? 0.0f : alibi_slopes[head_idx];
const Tindex *block_table = block_tables + seq_idx * block_table_batch_stride;
const Tdata *q_ptr = q + seq_idx * q_stride + head_idx * head_size;
loadToFloat(q_float, tmp_data, q_ptr, head_size);

float max_score = -FLT_MAX;
for (size_t token_idx = 0; token_idx < seq_len; ++token_idx) {
const size_t logical_block = token_idx / page_block_size;
const size_t token_in_block = token_idx - logical_block * page_block_size;
const int64_t physical_block = toInt64(block_table[logical_block]);
const Tdata *k_ptr = k_cache + physical_block * k_batch_stride
+ kv_head_idx * k_head_stride
+ token_in_block * k_row_stride;

float score = dotProduct(q_float, work_float, tmp_data, k_ptr, head_size);
score *= scale;
if (alibi_slope != 0.0f) {
score += alibi_slope * static_cast<float>(static_cast<int64_t>(token_idx) - seq_len_i64 + 1);
}
if (cache_logits) {
logits[token_idx] = score;
}
if (score > max_score) {
max_score = score;
}
}

float sum_exp = 0.0f;
if (cache_logits) {
for (size_t token_idx = 0; token_idx < seq_len; ++token_idx) {
const float prob = expf(logits[token_idx] - max_score);
logits[token_idx] = prob;
sum_exp += prob;
}
} else {
for (size_t token_idx = 0; token_idx < seq_len; ++token_idx) {
const size_t logical_block = token_idx / page_block_size;
const size_t token_in_block = token_idx - logical_block * page_block_size;
const int64_t physical_block = toInt64(block_table[logical_block]);
const Tdata *k_ptr = k_cache + physical_block * k_batch_stride
+ kv_head_idx * k_head_stride
+ token_in_block * k_row_stride;
float score = dotProduct(q_float, work_float, tmp_data, k_ptr, head_size) * scale;
if (alibi_slope != 0.0f) {
score += alibi_slope * static_cast<float>(static_cast<int64_t>(token_idx) - seq_len_i64 + 1);
}
sum_exp += expf(score - max_score);
}
}
const float inv_sum = 1.0f / (sum_exp + 1e-6f);

__bang_write_value(acc_float, head_size, 0.0f);
for (size_t token_idx = 0; token_idx < seq_len; ++token_idx) {
const size_t logical_block = token_idx / page_block_size;
const size_t token_in_block = token_idx - logical_block * page_block_size;
const int64_t physical_block = toInt64(block_table[logical_block]);
float prob = 0.0f;
if (cache_logits) {
prob = logits[token_idx] * inv_sum;
} else {
const Tdata *k_ptr = k_cache + physical_block * k_batch_stride
+ kv_head_idx * k_head_stride
+ token_in_block * k_row_stride;
float score = dotProduct(q_float, work_float, tmp_data, k_ptr, head_size) * scale;
if (alibi_slope != 0.0f) {
score += alibi_slope * static_cast<float>(static_cast<int64_t>(token_idx) - seq_len_i64 + 1);
}
prob = expf(score - max_score) * inv_sum;
}
const Tdata *v_ptr = v_cache + physical_block * v_batch_stride
+ kv_head_idx * v_head_stride
+ token_in_block * v_row_stride;
accumulateValue(acc_float, work_float, tmp_data, v_ptr, prob, head_size);
}
storeFromFloat(out_ptr, tmp_data, acc_float, head_size);
}
}

template <typename Tindex, typename Tdata>
infiniStatus_t launchPagedAttention(
cnrtQueue_t queue,
void *out,
const void *q,
const void *k_cache,
const void *v_cache,
const void *block_tables,
const void *seq_lens,
const void *alibi_slopes,
const op::paged_attention::PagedAttentionInfo &info,
int core_per_cluster,
int cluster_count) {

cnrtDim3_t kernel_dim;
kernel_dim.x = core_per_cluster;
kernel_dim.y = cluster_count;
kernel_dim.z = 1;

pagedAttentionKernel<Tdata, Tindex><<<kernel_dim, cnrtFuncTypeUnion1, queue>>>(
static_cast<Tdata *>(out),
static_cast<const Tdata *>(q),
static_cast<const Tdata *>(k_cache),
static_cast<const Tdata *>(v_cache),
static_cast<const Tindex *>(block_tables),
static_cast<const Tindex *>(seq_lens),
static_cast<const float *>(alibi_slopes),
info.num_seqs,
info.num_heads,
info.num_kv_heads,
info.head_size,
info.scale,
info.max_num_blocks_per_seq,
info.page_block_size,
info.q_stride,
info.k_batch_stride,
info.k_row_stride,
info.k_head_stride,
info.v_batch_stride,
info.v_row_stride,
info.v_head_stride,
info.o_stride,
info.block_table_batch_stride,
info.cache_lens_stride);
cnrtQueueSync(queue);
return INFINI_STATUS_SUCCESS;
}

template <typename Tindex>
infiniStatus_t dispatchPagedAttentionDtype(
infiniDtype_t dtype,
cnrtQueue_t queue,
void *out,
const void *q,
const void *k_cache,
const void *v_cache,
const void *block_tables,
const void *seq_lens,
const void *alibi_slopes,
const op::paged_attention::PagedAttentionInfo &info,
int core_per_cluster,
int cluster_count) {
switch (dtype) {
case INFINI_DTYPE_F16:
return launchPagedAttention<Tindex, half>(queue, out, q, k_cache, v_cache, block_tables, seq_lens, alibi_slopes, info, core_per_cluster, cluster_count);
case INFINI_DTYPE_BF16:
return launchPagedAttention<Tindex, bfloat16_t>(queue, out, q, k_cache, v_cache, block_tables, seq_lens, alibi_slopes, info, core_per_cluster, cluster_count);
default:
return INFINI_STATUS_BAD_TENSOR_DTYPE;
}
}

} // namespace

namespace op::paged_attention::bang {

struct Descriptor::Opaque {
std::shared_ptr<device::bang::Handle::Internal> internal;
};

Descriptor::~Descriptor() {
delete _opaque;
}

infiniStatus_t Descriptor::create(
infiniopHandle_t handle_,
Descriptor **desc_ptr,
infiniopTensorDescriptor_t out_desc,
infiniopTensorDescriptor_t q_desc,
infiniopTensorDescriptor_t k_cache_desc,
infiniopTensorDescriptor_t v_cache_desc,
infiniopTensorDescriptor_t block_tables_desc,
infiniopTensorDescriptor_t seq_lens_desc,
const std::optional<infiniopTensorDescriptor_t> &alibi_slopes_desc,
float scale) {

auto handle = reinterpret_cast<device::bang::cambricon::Handle *>(handle_);
auto info = PagedAttentionInfo::create(out_desc, q_desc, k_cache_desc, v_cache_desc, block_tables_desc, seq_lens_desc, alibi_slopes_desc, scale);
CHECK_RESULT(info);

*desc_ptr = new Descriptor(
new Opaque{static_cast<device::bang::Handle *>(handle)->internal()},
info.take(), 0, handle->device, handle->device_id);

return INFINI_STATUS_SUCCESS;
}

infiniStatus_t Descriptor::calculate(
void *workspace,
size_t workspace_size,
void *out,
const void *q,
const void *k_cache,
const void *v_cache,
const void *block_tables,
const void *seq_lens,
const void *alibi_slopes,
void *stream) const {

(void)workspace;
(void)workspace_size;

auto queue = reinterpret_cast<cnrtQueue_t>(stream);
const int core_per_cluster = _opaque->internal->getCorePerCluster();
const int cluster_count = _opaque->internal->getClusterCount();

if (_info.index_dtype == INFINI_DTYPE_I64) {
return dispatchPagedAttentionDtype<int64_t>(_info.dtype, queue, out, q, k_cache, v_cache, block_tables, seq_lens, alibi_slopes, _info, core_per_cluster, cluster_count);
}
if (_info.index_dtype == INFINI_DTYPE_I32) {
return dispatchPagedAttentionDtype<int32_t>(_info.dtype, queue, out, q, k_cache, v_cache, block_tables, seq_lens, alibi_slopes, _info, core_per_cluster, cluster_count);
}
if (_info.index_dtype == INFINI_DTYPE_U32) {
return dispatchPagedAttentionDtype<uint32_t>(_info.dtype, queue, out, q, k_cache, v_cache, block_tables, seq_lens, alibi_slopes, _info, core_per_cluster, cluster_count);
}
return INFINI_STATUS_BAD_TENSOR_DTYPE;
}

} // namespace op::paged_attention::bang
Loading
Loading