diff --git a/src/infiniop/ops/random_sample/ascend/randomsample_aclnn.cc b/src/infiniop/ops/random_sample/ascend/randomsample_aclnn.cc index 62dfb2f0e..317bac517 100644 --- a/src/infiniop/ops/random_sample/ascend/randomsample_aclnn.cc +++ b/src/infiniop/ops/random_sample/ascend/randomsample_aclnn.cc @@ -72,37 +72,110 @@ Descriptor::calculate( bool dosample = topk_ > 1 && temperature != 0.0f && topp != 0.0f && random_val != 0.0f; auto topk_shape = std::vector{dosample ? topk_ : 1}; auto topk_stride = std::vector{1}; - auto topk_idx = new aclnnTensorDescriptor(toAclDataType(_info.dt_i), topk_shape, topk_stride); - auto topk_val = new aclnnTensorDescriptor(toAclDataType(_info.dt_p), topk_shape, topk_stride); - auto topk_val_addr = workspace; - auto topk_idx_addr = (void *)((uint8_t *)workspace + topk_ * infiniSizeOf(_info.dt_p)); + + bool is_bf16 = (_info.dt_p == INFINI_DTYPE_BF16); + + void *probs_for_topk = const_cast(probs); + void *topk_val_addr = workspace; + void *topk_idx_addr = (void *)((uint8_t *)workspace + topk_ * infiniSizeOf(is_bf16 ? INFINI_DTYPE_F32 : _info.dt_p)); + uint64_t topk_workspace_size = 0; aclOpExecutor *topk_executor = nullptr; - CHECK_ACL(aclnnTopkGetWorkspaceSize(_opaque->probs->tensor, - topk_shape[0], - 0, - true, - true, - topk_val->tensor, - dosample ? topk_idx->tensor : _opaque->result->tensor, - &topk_workspace_size, - &topk_executor)); - CHECK_ACL(aclSetAclOpExecutorRepeatable(topk_executor)); - void *topk_workspace; - CHECK_ACL(aclrtMalloc(&topk_workspace, topk_workspace_size, ACL_MEM_MALLOC_HUGE_FIRST)); - AclSetTensorAddr(topk_executor, 0, _opaque->probs->tensor, (void *)probs); - AclSetTensorAddr(topk_executor, 1, topk_val->tensor, topk_val_addr); - if (!dosample) { - AclSetTensorAddr(topk_executor, 2, _opaque->result->tensor, result); + + if (is_bf16) { + void *probs_fp32; + CHECK_ACL(aclrtMalloc(&probs_fp32, _info.n * sizeof(float), ACL_MEM_MALLOC_HUGE_FIRST)); + + void *probs_bf16_host; + CHECK_ACL(aclrtMallocHost(&probs_bf16_host, _info.n * sizeof(bf16_t))); + void *probs_fp32_host; + CHECK_ACL(aclrtMallocHost(&probs_fp32_host, _info.n * sizeof(float))); + + CHECK_ACL(aclrtMemcpy(probs_bf16_host, _info.n * sizeof(bf16_t), probs, _info.n * sizeof(bf16_t), ACL_MEMCPY_DEVICE_TO_HOST)); + + auto bf16_ptr = static_cast(probs_bf16_host); + auto fp32_ptr = static_cast(probs_fp32_host); + for (uint64_t i = 0; i < _info.n; i++) { + fp32_ptr[i] = _bf16_to_f32(bf16_ptr[i]); + } + + CHECK_ACL(aclrtMemcpy(probs_fp32, _info.n * sizeof(float), probs_fp32_host, _info.n * sizeof(float), ACL_MEMCPY_HOST_TO_DEVICE)); + + CHECK_ACL(aclrtFreeHost(probs_bf16_host)); + CHECK_ACL(aclrtFreeHost(probs_fp32_host)); + + int64_t shape = _info.n; + int64_t stride = 1; + auto probs_fp32_desc = new aclnnTensorDescriptor(toAclDataType(INFINI_DTYPE_F32), {shape}, {stride}); + + auto topk_val_fp32_desc = new aclnnTensorDescriptor(toAclDataType(INFINI_DTYPE_F32), topk_shape, topk_stride); + auto topk_idx_desc = new aclnnTensorDescriptor(toAclDataType(_info.dt_i), topk_shape, topk_stride); + + CHECK_ACL(aclnnTopkGetWorkspaceSize(probs_fp32_desc->tensor, + topk_shape[0], + 0, + true, + true, + topk_val_fp32_desc->tensor, + dosample ? topk_idx_desc->tensor : _opaque->result->tensor, + &topk_workspace_size, + &topk_executor)); + CHECK_ACL(aclSetAclOpExecutorRepeatable(topk_executor)); + void *topk_workspace; + CHECK_ACL(aclrtMalloc(&topk_workspace, topk_workspace_size, ACL_MEM_MALLOC_HUGE_FIRST)); + AclSetTensorAddr(topk_executor, 0, probs_fp32_desc->tensor, probs_fp32); + AclSetTensorAddr(topk_executor, 1, topk_val_fp32_desc->tensor, topk_val_addr); + if (!dosample) { + AclSetTensorAddr(topk_executor, 2, _opaque->result->tensor, result); + } else { + AclSetTensorAddr(topk_executor, 2, topk_idx_desc->tensor, topk_idx_addr); + } + CHECK_ACL(aclnnTopk(topk_workspace, topk_workspace_size, topk_executor, stream)); + CHECK_ACL(aclrtFree(topk_workspace)); + + delete topk_val_fp32_desc; + delete topk_idx_desc; + delete probs_fp32_desc; + + if (dosample) { + auto status = random_sample_kernel_launch(probs_fp32, result, topk_val_addr, topk_idx_addr, random_val, topp, topk_, temperature, _info.n, INFINI_DTYPE_F32, stream); + CHECK_STATUS(status); + } + + CHECK_ACL(aclrtFree(probs_fp32)); } else { - AclSetTensorAddr(topk_executor, 2, topk_idx->tensor, topk_idx_addr); - } - CHECK_ACL(aclnnTopk(topk_workspace, topk_workspace_size, topk_executor, stream)); - CHECK_ACL(aclrtFree(topk_workspace)); + auto topk_val = new aclnnTensorDescriptor(toAclDataType(_info.dt_p), topk_shape, topk_stride); + auto topk_idx = new aclnnTensorDescriptor(toAclDataType(_info.dt_i), topk_shape, topk_stride); + + CHECK_ACL(aclnnTopkGetWorkspaceSize(_opaque->probs->tensor, + topk_shape[0], + 0, + true, + true, + topk_val->tensor, + dosample ? topk_idx->tensor : _opaque->result->tensor, + &topk_workspace_size, + &topk_executor)); + CHECK_ACL(aclSetAclOpExecutorRepeatable(topk_executor)); + void *topk_workspace; + CHECK_ACL(aclrtMalloc(&topk_workspace, topk_workspace_size, ACL_MEM_MALLOC_HUGE_FIRST)); + AclSetTensorAddr(topk_executor, 0, _opaque->probs->tensor, (void *)probs); + AclSetTensorAddr(topk_executor, 1, topk_val->tensor, topk_val_addr); + if (!dosample) { + AclSetTensorAddr(topk_executor, 2, _opaque->result->tensor, result); + } else { + AclSetTensorAddr(topk_executor, 2, topk_idx->tensor, topk_idx_addr); + } + CHECK_ACL(aclnnTopk(topk_workspace, topk_workspace_size, topk_executor, stream)); + CHECK_ACL(aclrtFree(topk_workspace)); + + if (dosample) { + auto status = random_sample_kernel_launch(probs_for_topk, result, topk_val_addr, topk_idx_addr, random_val, topp, topk_, temperature, _info.n, _info.dt_p, stream); + CHECK_STATUS(status); + } - if (dosample) { - auto status = random_sample_kernel_launch((void *)probs, result, topk_val_addr, topk_idx_addr, random_val, topp, topk_, temperature, _info.n, _info.dt_p, stream); - CHECK_STATUS(status); + delete topk_val; + delete topk_idx; } return INFINI_STATUS_SUCCESS;