Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
127 changes: 100 additions & 27 deletions src/infiniop/ops/random_sample/ascend/randomsample_aclnn.cc
Original file line number Diff line number Diff line change
Expand Up @@ -72,37 +72,110 @@ Descriptor::calculate(
bool dosample = topk_ > 1 && temperature != 0.0f && topp != 0.0f && random_val != 0.0f;
auto topk_shape = std::vector<int64_t>{dosample ? topk_ : 1};
auto topk_stride = std::vector<int64_t>{1};
auto topk_idx = new aclnnTensorDescriptor(toAclDataType(_info.dt_i), topk_shape, topk_stride);
auto topk_val = new aclnnTensorDescriptor(toAclDataType(_info.dt_p), topk_shape, topk_stride);
auto topk_val_addr = workspace;
auto topk_idx_addr = (void *)((uint8_t *)workspace + topk_ * infiniSizeOf(_info.dt_p));

bool is_bf16 = (_info.dt_p == INFINI_DTYPE_BF16);

void *probs_for_topk = const_cast<void *>(probs);
void *topk_val_addr = workspace;
void *topk_idx_addr = (void *)((uint8_t *)workspace + topk_ * infiniSizeOf(is_bf16 ? INFINI_DTYPE_F32 : _info.dt_p));

uint64_t topk_workspace_size = 0;
aclOpExecutor *topk_executor = nullptr;
CHECK_ACL(aclnnTopkGetWorkspaceSize(_opaque->probs->tensor,
topk_shape[0],
0,
true,
true,
topk_val->tensor,
dosample ? topk_idx->tensor : _opaque->result->tensor,
&topk_workspace_size,
&topk_executor));
CHECK_ACL(aclSetAclOpExecutorRepeatable(topk_executor));
void *topk_workspace;
CHECK_ACL(aclrtMalloc(&topk_workspace, topk_workspace_size, ACL_MEM_MALLOC_HUGE_FIRST));
AclSetTensorAddr(topk_executor, 0, _opaque->probs->tensor, (void *)probs);
AclSetTensorAddr(topk_executor, 1, topk_val->tensor, topk_val_addr);
if (!dosample) {
AclSetTensorAddr(topk_executor, 2, _opaque->result->tensor, result);

if (is_bf16) {
void *probs_fp32;
CHECK_ACL(aclrtMalloc(&probs_fp32, _info.n * sizeof(float), ACL_MEM_MALLOC_HUGE_FIRST));

void *probs_bf16_host;
CHECK_ACL(aclrtMallocHost(&probs_bf16_host, _info.n * sizeof(bf16_t)));
void *probs_fp32_host;
CHECK_ACL(aclrtMallocHost(&probs_fp32_host, _info.n * sizeof(float)));

CHECK_ACL(aclrtMemcpy(probs_bf16_host, _info.n * sizeof(bf16_t), probs, _info.n * sizeof(bf16_t), ACL_MEMCPY_DEVICE_TO_HOST));

auto bf16_ptr = static_cast<bf16_t *>(probs_bf16_host);
auto fp32_ptr = static_cast<float *>(probs_fp32_host);
for (uint64_t i = 0; i < _info.n; i++) {
fp32_ptr[i] = _bf16_to_f32(bf16_ptr[i]);
}

CHECK_ACL(aclrtMemcpy(probs_fp32, _info.n * sizeof(float), probs_fp32_host, _info.n * sizeof(float), ACL_MEMCPY_HOST_TO_DEVICE));

CHECK_ACL(aclrtFreeHost(probs_bf16_host));
CHECK_ACL(aclrtFreeHost(probs_fp32_host));

int64_t shape = _info.n;
int64_t stride = 1;
auto probs_fp32_desc = new aclnnTensorDescriptor(toAclDataType(INFINI_DTYPE_F32), {shape}, {stride});

auto topk_val_fp32_desc = new aclnnTensorDescriptor(toAclDataType(INFINI_DTYPE_F32), topk_shape, topk_stride);
auto topk_idx_desc = new aclnnTensorDescriptor(toAclDataType(_info.dt_i), topk_shape, topk_stride);

CHECK_ACL(aclnnTopkGetWorkspaceSize(probs_fp32_desc->tensor,
topk_shape[0],
0,
true,
true,
topk_val_fp32_desc->tensor,
dosample ? topk_idx_desc->tensor : _opaque->result->tensor,
&topk_workspace_size,
&topk_executor));
CHECK_ACL(aclSetAclOpExecutorRepeatable(topk_executor));
void *topk_workspace;
CHECK_ACL(aclrtMalloc(&topk_workspace, topk_workspace_size, ACL_MEM_MALLOC_HUGE_FIRST));
AclSetTensorAddr(topk_executor, 0, probs_fp32_desc->tensor, probs_fp32);
AclSetTensorAddr(topk_executor, 1, topk_val_fp32_desc->tensor, topk_val_addr);
if (!dosample) {
AclSetTensorAddr(topk_executor, 2, _opaque->result->tensor, result);
} else {
AclSetTensorAddr(topk_executor, 2, topk_idx_desc->tensor, topk_idx_addr);
}
CHECK_ACL(aclnnTopk(topk_workspace, topk_workspace_size, topk_executor, stream));
CHECK_ACL(aclrtFree(topk_workspace));

delete topk_val_fp32_desc;
delete topk_idx_desc;
delete probs_fp32_desc;

if (dosample) {
auto status = random_sample_kernel_launch(probs_fp32, result, topk_val_addr, topk_idx_addr, random_val, topp, topk_, temperature, _info.n, INFINI_DTYPE_F32, stream);
CHECK_STATUS(status);
}

CHECK_ACL(aclrtFree(probs_fp32));
} else {
AclSetTensorAddr(topk_executor, 2, topk_idx->tensor, topk_idx_addr);
}
CHECK_ACL(aclnnTopk(topk_workspace, topk_workspace_size, topk_executor, stream));
CHECK_ACL(aclrtFree(topk_workspace));
auto topk_val = new aclnnTensorDescriptor(toAclDataType(_info.dt_p), topk_shape, topk_stride);
auto topk_idx = new aclnnTensorDescriptor(toAclDataType(_info.dt_i), topk_shape, topk_stride);

CHECK_ACL(aclnnTopkGetWorkspaceSize(_opaque->probs->tensor,
topk_shape[0],
0,
true,
true,
topk_val->tensor,
dosample ? topk_idx->tensor : _opaque->result->tensor,
&topk_workspace_size,
&topk_executor));
CHECK_ACL(aclSetAclOpExecutorRepeatable(topk_executor));
void *topk_workspace;
CHECK_ACL(aclrtMalloc(&topk_workspace, topk_workspace_size, ACL_MEM_MALLOC_HUGE_FIRST));
AclSetTensorAddr(topk_executor, 0, _opaque->probs->tensor, (void *)probs);
AclSetTensorAddr(topk_executor, 1, topk_val->tensor, topk_val_addr);
if (!dosample) {
AclSetTensorAddr(topk_executor, 2, _opaque->result->tensor, result);
} else {
AclSetTensorAddr(topk_executor, 2, topk_idx->tensor, topk_idx_addr);
}
CHECK_ACL(aclnnTopk(topk_workspace, topk_workspace_size, topk_executor, stream));
CHECK_ACL(aclrtFree(topk_workspace));

if (dosample) {
auto status = random_sample_kernel_launch(probs_for_topk, result, topk_val_addr, topk_idx_addr, random_val, topp, topk_, temperature, _info.n, _info.dt_p, stream);
CHECK_STATUS(status);
}

if (dosample) {
auto status = random_sample_kernel_launch((void *)probs, result, topk_val_addr, topk_idx_addr, random_val, topp, topk_, temperature, _info.n, _info.dt_p, stream);
CHECK_STATUS(status);
delete topk_val;
delete topk_idx;
}

return INFINI_STATUS_SUCCESS;
Expand Down
Loading