-
Notifications
You must be signed in to change notification settings - Fork 288
Expand file tree
/
Copy pathfused_moe.hpp
More file actions
62 lines (53 loc) · 2.85 KB
/
fused_moe.hpp
File metadata and controls
62 lines (53 loc) · 2.85 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#pragma once
#include "fused_moesorting.hpp"
#include "fused_moegemm.hpp"
struct fused_moe_args
{
const void* a_ptr; // [m, k], input token
const void* a_scale_ptr; // [m, 1], token scale
const void* g_ptr; // [e, n, k]/[e, 2*n, k], pre-shuffle([e, nr, kr, w])
const void* d_ptr; // [e, n, k], pre-shuffle([e, nr, kr, w])
const void* g_scale_ptr; // [e, 1, n], gate(up) scale
const void* d_scale_ptr; // [e, 1, k], down scale
const void* y_smooth_scale_ptr; // [e, 1, n], smooth-quant-scale for 2nd gemm input
const void* local_expert_mask_ptr; // [e], local_expert_mask_ptr for EP
const void* local_tokens; // [1] if not nullptr, tokens read from here
void* o_ptr; // [m, k], output token (no need to do zeroing)
void* ws_ptr; // size is moe_sorting_get_workspace_size()
// if return zero, then could be nullptr
// must be cleard before use
const void* topk_ids_ptr; // [tokens, topk]
const void* topk_weight_ptr; // [tokens, topk]
void* sorted_token_ids_ptr; // [max_num_tokens_padded]
void* sorted_weight_ptr; // [max_num_tokens_padded]
void* sorted_expert_ids_ptr; // [(max_num_tokens_padded + block_size - 1) / block_size]
void* num_sorted_tiles_ptr; // [1]
ck_tile::index_t block_m; // block_m, used to devide the input
ck_tile::index_t hidden_size; // k
ck_tile::index_t intermediate_size; // n / TP, for Gate. and Up, Down is also this value
ck_tile::index_t num_tokens; // input number of tokens for current iteration
ck_tile::index_t num_experts; // number of groups
ck_tile::index_t topk; // need this?
ck_tile::index_t stride_token; // for input/output, stride for each row, should >= hidden_size
};
// This is the public API, will be generated by script
struct fused_moe_traits
{
std::string prec_i; // input precision
std::string prec_w; // weight precision
std::string prec_o; // output precision
std::string prec_st; // token scale data type
std::string prec_sw; // weight scale data type
std::string prec_sq; // smooth quant scale
std::string prec_kw; // topk-weight data type
int block_m;
int activation; // 0:gelu, 1:silu
int gate_only; // 0:g1u0, 1:g1u1
int fused_quant; // 0:no-sweep, 1:smooth-dynamic-quant, 2:dynamic-quant
bool local_expert_masking; // if mask experts as local expert
};
// if return zero, no ws needed
int fused_moe_get_workspace_size(int tokens, int num_experts, int topk);
float fused_moe(fused_moe_traits, fused_moe_args, const ck_tile::stream_config&);