-
Notifications
You must be signed in to change notification settings - Fork 333
【draft】Mtp optimization #1266
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
【draft】Mtp optimization #1266
Changes from 16 commits
6cf58b4
f1251a3
d9b1fdd
315366a
73ea125
dc91e59
aefe67e
3c28fb0
2dc933e
0b08de8
952ec15
3750118
4bf4287
05e0dfd
c2b7569
80219af
41180d3
8afd7a8
2b277fa
93c2ada
6395447
fc20624
1b08d15
775adbd
d4830ff
da944f8
22c5996
6fbe8d8
c4a9f74
5eb4889
379f256
e943a43
e121b9d
1723230
6890bc0
1e1fb98
29535b2
5b925a6
538200d
2831c70
158c7a3
4c08120
a837cbb
17ed333
7927e15
53a7077
322f713
c3e46c9
2f0c250
8ab2ab4
76ca4ce
1d0f18e
6e69701
1565698
5e857c8
3272073
1126013
f591158
74189c2
df54935
261f5b4
03ac700
7bc26a7
4b93c8c
d7bdef0
a731153
97a0123
62226d5
0783c1e
4777936
9dcfc05
9229d28
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -7,3 +7,9 @@ dist | |
| .vscode | ||
| tmp/ | ||
| requirements-musa.txt | ||
|
|
||
| hf_datasets_cache/ | ||
| wandb/ | ||
| datasets/ | ||
| trace/ | ||
| experiment_results/ | ||
Large diffs are not rendered by default.
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -25,14 +25,14 @@ | |
| from lightllm.common.basemodel.triton_kernel.gather_token_id import gather_token | ||
| from lightllm.utils.log_utils import init_logger | ||
| from lightllm.utils.dist_utils import get_dp_world_size | ||
| from lightllm.utils.envs_utils import get_env_start_args, get_llm_data_type, get_added_mtp_kv_layer_num | ||
| from lightllm.utils.envs_utils import enable_triton_mtp_kernel, get_env_start_args, get_llm_data_type, get_added_mtp_kv_layer_num | ||
| from lightllm.distributed.communication_op import dist_group_manager | ||
| from lightllm.common.basemodel.batch_objs import ModelInput, ModelOutput | ||
| from lightllm.common.triton_utils.autotuner import AutotuneLevel | ||
| from lightllm.utils.custom_kernel_utis import pad2dim_tensor_to_new_batch | ||
| from lightllm.utils.envs_utils import set_model_init_status, enable_diverse_mode_gqa_decode_fast_kernel | ||
| from lightllm.utils.envs_utils import set_model_init_status, enable_diverse_mode_gqa_decode_fast_kernel, enable_dynamic_mtp_verify | ||
| from lightllm.common.triton_utils.autotuner import Autotuner | ||
| from lightllm.utils.infer_utils import post_empty_cache | ||
| from lightllm.utils.infer_utils import calculate_time, post_empty_cache | ||
| from .attention import get_prefill_att_backend_class, get_decode_att_backend_class | ||
| from .attention import BaseAttBackend | ||
|
|
||
|
|
@@ -93,16 +93,11 @@ def __init__(self, kvargs): | |
| self.mem_fraction = kvargs.get("mem_fraction", 0.9) | ||
| self.tp_world_size_ = get_dp_world_size() | ||
| self.enable_tpsp_mix_mode = get_env_start_args().enable_tpsp_mix_mode | ||
|
|
||
| self.is_mtp_mode = self.args.mtp_mode in [ | ||
| "vanilla_with_att", | ||
| "eagle_with_att", | ||
| "vanilla_no_att", | ||
| "eagle_no_att", | ||
| ] | ||
| self.prefill_graph: PrefillCudaGraph = None | ||
|
|
||
| self._init_config() | ||
| self._init_speculative_algo(kvargs) | ||
|
|
||
| self._verify_must() | ||
| self._verify_params() | ||
| self._init_quant() | ||
|
|
@@ -137,6 +132,44 @@ def __init__(self, kvargs): | |
| set_model_init_status(True) | ||
| return | ||
|
|
||
| def _init_speculative_algo(self, kvargs): | ||
| self.is_mtp_mode = self.args.mtp_mode in [ | ||
| "vanilla_with_att", | ||
| "eagle_with_att", | ||
| "vanilla_no_att", | ||
| "eagle_no_att", | ||
| "eagle3" | ||
| ] | ||
| self.is_mtp_draft_model = kvargs.get("is_mtp_draft_model", False) | ||
| self.is_eagle3_mode = "eagle3" in self.args.mtp_mode | ||
| if self.is_eagle3_mode and not self.is_mtp_draft_model: | ||
| self.eagle_hidden_layers = [1, self.config["n_layer"]//2-1, self.config["n_layer"]-4] | ||
| elif self.is_mtp_mode: | ||
| self.eagle_hidden_layers = [self.config["n_layer"]-1] | ||
| else: | ||
| self.eagle_hidden_layers = [] | ||
|
|
||
| if self.is_eagle3_mode: | ||
| # load the hidden_proj weight from the draft model path | ||
| draft_model_path = kvargs.get("mtp_draft_model_dir", None) | ||
| assert draft_model_path is not None, "mtp_draft_model_dir must be provided when eagle3 mode is enabled" | ||
| if os.path.exists(os.path.join(draft_model_path[0], "pytorch_model.bin")): | ||
| self.draft_model_weight_dict = torch.load(os.path.join(draft_model_path[0], "pytorch_model.bin")) | ||
| self.hidden_proj_weight = self.draft_model_weight_dict["fc.weight"].to(torch.bfloat16).to("cuda") | ||
| del self.draft_model_weight_dict | ||
| gc.collect() | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. |
||
| else: | ||
| try: | ||
| from safetensors import safe_open | ||
| with safe_open(os.path.join(draft_model_path[0], "model.safetensors"), framework="pt", device="cuda") as f: | ||
| # Check if the key exists to avoid KeyError | ||
| if "fc.weight" in f.keys(): | ||
| self.hidden_proj_weight = f.get_tensor("fc.weight").to(torch.bfloat16).to("cuda") | ||
| except Exception as e: | ||
| logger.warning(f"Failed to load hidden_proj_weight from safetensors with error: {e}") | ||
| self.hidden_proj_weight = None | ||
|
|
||
|
|
||
| def _wait_other_modules_ready(self): | ||
| for event in self.wait_events: | ||
| event.wait() | ||
|
|
@@ -151,6 +184,9 @@ def _init_config(self): | |
| repair_config(self.config, same_names=["num_hidden_layers", "n_layer"]) | ||
| if self.finetune_config: | ||
| self.config["vocab_size"] = self.finetune_config.vocab_size | ||
| if "draft_vocab_size" in self.config.keys(): | ||
| self.config["target_vocab_size"] = self.config["vocab_size"] | ||
| self.config["vocab_size"] = self.config["draft_vocab_size"] | ||
| return | ||
|
|
||
| @final | ||
|
|
@@ -314,6 +350,16 @@ def _create_inferstate(self, model_input: ModelInput, microbatch_index: int = 0) | |
| if enable_diverse_mode_gqa_decode_fast_kernel(): | ||
| infer_state.b_shared_seq_len = model_input.b_shared_seq_len | ||
| infer_state.b_mark_shared_group = model_input.b_mark_shared_group | ||
| elif enable_dynamic_mtp_verify() or enable_triton_mtp_kernel(): | ||
| # 动态 MTP 验证模式下,也需要传递 b_mark_shared_group | ||
| infer_state.b_mark_shared_group = model_input.b_mark_shared_group | ||
| # 将b_mark_shared_group pad到跟input_ids一样的长度,避免后续使用时出现形状不匹配的问题 | ||
| # infer_state.b_mark_shared_group = F.pad( | ||
| # infer_state.b_mark_shared_group, | ||
| # (0, infer_state.input_ids.shape[0] - infer_state.b_mark_shared_group.shape[0]), | ||
| # mode="constant", | ||
| # value=0, | ||
| # ) | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. |
||
|
|
||
| infer_state.multimodal_params = model_input.multimodal_params | ||
|
|
||
|
|
@@ -377,6 +423,11 @@ def _create_padded_decode_model_input(self, model_input: ModelInput, new_batch_s | |
| new_model_input.b_mark_shared_group = F.pad( | ||
| new_model_input.b_mark_shared_group, (0, padded_batch_size), mode="constant", value=1 | ||
| ) | ||
| elif enable_dynamic_mtp_verify() or enable_triton_mtp_kernel(): | ||
| if new_model_input.b_mark_shared_group is not None: | ||
| new_model_input.b_mark_shared_group = F.pad( | ||
| new_model_input.b_mark_shared_group, (0, padded_batch_size), mode="constant", value=0 | ||
| ) | ||
|
|
||
| # 特殊模型,特殊模式的特殊变量的特殊 padding | ||
| if new_model_input.mtp_draft_input_hiddens is not None: | ||
|
|
@@ -561,12 +612,19 @@ def _context_forward(self, infer_state: InferStateInfo): | |
| input_tensors = [input_embs] | ||
|
|
||
| def prefill_func(input_tensors, infer_state): | ||
| capture_hiddens = [] | ||
| _input_embs = input_tensors[0] | ||
| for i in range(self.layers_num): | ||
| layer = self.layers_infer[i] | ||
| layer_method = (layer.context_forward, layer.tpsp_context_forward)[run_mode_index] | ||
| _input_embs = layer_method(_input_embs, infer_state, self.trans_layers_weight[i]) | ||
| return [_input_embs] | ||
| if i in self.eagle_hidden_layers: | ||
| capture_hiddens.append(_input_embs.clone()) | ||
| capture_hidden = torch.cat(capture_hiddens, dim=-1) if len(capture_hiddens) > 0 else None | ||
| if self.is_eagle3_mode and not self.is_mtp_draft_model and capture_hidden is not None: | ||
| capture_hidden = self.hidden_proj_weight @ capture_hidden.transpose(0, 1) | ||
| capture_hidden = capture_hidden.transpose(0, 1) | ||
| return [_input_embs, capture_hidden] | ||
|
|
||
| handle_token_num = input_ids.shape[0] | ||
|
|
||
|
|
@@ -597,7 +655,7 @@ def prefill_func(input_tensors, infer_state): | |
|
|
||
| # 特殊模型特殊模式的额外输出 | ||
| if self.is_mtp_mode: | ||
| model_output.mtp_main_output_hiddens = input_embs | ||
| model_output.mtp_main_output_hiddens = output_tensors[1] | ||
|
|
||
| # 在开启使用deepep的时候,需要调用clear_deepep_buffer做资源清理,没有启用的时候 | ||
| # 该调用没有实际意义 | ||
|
|
@@ -611,16 +669,23 @@ def _token_forward(self, infer_state: InferStateInfo): | |
| cuda_input_ids = input_ids | ||
| pre_method = (self.pre_infer.token_forward, self.pre_infer.tpsp_token_forward)[run_mode_index] | ||
| input_embs = pre_method(cuda_input_ids, infer_state, self.pre_post_weight) | ||
| capture_hiddens = [] | ||
| for i in range(self.layers_num): | ||
| layer = self.layers_infer[i] | ||
| layer_method = (layer.token_forward, layer.tpsp_token_forward)[run_mode_index] | ||
| input_embs: torch.Tensor = layer_method(input_embs, infer_state, self.trans_layers_weight[i]) | ||
| if i in self.eagle_hidden_layers: | ||
| capture_hiddens.append(input_embs.clone()) | ||
|
|
||
| capture_hidden = torch.cat(capture_hiddens, dim=-1) if len(capture_hiddens) > 0 else None | ||
| if self.is_eagle3_mode and not self.is_mtp_draft_model and capture_hidden is not None: | ||
| capture_hidden = self.hidden_proj_weight @ capture_hidden.transpose(0, 1) | ||
| capture_hidden = capture_hidden.transpose(0, 1) | ||
| post_method = (self.post_infer.token_forward, self.post_infer.tpsp_token_forward)[run_mode_index] | ||
| predict_logits: torch.Tensor = post_method(input_embs, infer_state, self.pre_post_weight) | ||
|
|
||
| if self.is_mtp_mode: | ||
| graph_out_hiddens = input_embs.contiguous() | ||
| graph_out_hiddens = capture_hidden.contiguous() | ||
|
|
||
| model_output = ModelOutput(logits=predict_logits.contiguous()) | ||
|
|
||
|
|
@@ -1027,6 +1092,7 @@ def _gen_special_model_input(self, token_num: int): | |
| or "Qwen3MOEMTPModel" in str(self.__class__) | ||
| or "MistralMTPModel" in str(self.__class__) | ||
| or "Glm4MoeLiteMTPModel" in str(self.__class__) | ||
| or "Qwen3EagleModel" in str(self.__class__) | ||
| ) | ||
| if is_mtp_draft_model: | ||
| special_model_input["mtp_draft_input_hiddens"] = torch.randn( | ||
|
|
@@ -1036,3 +1102,22 @@ def _gen_special_model_input(self, token_num: int): | |
| special_model_input["mtp_draft_input_hiddens"] = None | ||
|
|
||
| return special_model_input | ||
|
|
||
| """ | ||
| tensor([[ 1.9100e+02, 4.0400e+02, -4.2400e+02, ..., -3.3200e+02, | ||
| -2.9250e+01, -5.5000e+01], | ||
| [ 5.1875e+00, 9.5625e+00, 2.8516e-01, ..., 3.8906e+00, | ||
| 4.0625e+00, -3.1562e+00], | ||
| [-4.7461e-01, -9.1250e+00, 9.0000e+00, ..., 6.0312e+00, | ||
| 9.3750e+00, 6.4375e+00], | ||
| [-4.2500e+00, -3.8906e+00, -5.4297e-01, ..., 3.4844e+00, | ||
| 3.1719e+00, 9.5625e+00]], device='cuda:0', dtype=torch.bfloat16) | ||
| tensor([[ 1.9100e+02, 4.0400e+02, -4.2400e+02, ..., -3.3200e+02, | ||
| -2.9250e+01, -5.5000e+01], | ||
| [ 5.1875e+00, 9.5625e+00, 2.8516e-01, ..., 3.8906e+00, | ||
| 4.0625e+00, -3.1562e+00], | ||
| [-4.7461e-01, -9.1250e+00, 9.0000e+00, ..., 6.0312e+00, | ||
| 9.3750e+00, 6.4375e+00], | ||
| [-4.2500e+00, -3.8906e+00, -5.4297e-01, ..., 3.4844e+00, | ||
| 3.1719e+00, 9.5625e+00]], device='cuda:1', dtype=torch.bfloat16) | ||
| """ | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -85,6 +85,7 @@ def __init__(self): | |
| # 在开启 mtp_mode 时,mtp draft model | ||
| # 的输入会用到,其他模型和场景都不会用到 | ||
| self.mtp_draft_input_hiddens: Optional[torch.Tensor] = None | ||
| self.is_mtp_draft_model: bool = False | ||
|
|
||
| # 在单节点多dp的运行模式下,在进行prefill的阶段,如果出现了dp之间数据不平衡的现象, | ||
| # 可以将推理的数据,进行重新分配到各个dp,在做 att 之前,重新 all to all 到各自的 | ||
|
|
@@ -137,7 +138,10 @@ def copy_for_cuda_graph(self, new_infer_state: "InferStateInfo"): | |
| if isinstance(attr_value, torch.Tensor): | ||
| attr_ = getattr(self, attr_name, None) | ||
| if attr_ is not None and attr_.data_ptr() != attr_value.data_ptr(): | ||
| attr_.copy_(attr_value, non_blocking=True) | ||
| try: | ||
| attr_.copy_(attr_value, non_blocking=True) | ||
| except Exception as e: | ||
| print(f"Warning: copy tensor {attr_name} failed during cuda graph copy, error: {e}") | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. |
||
|
|
||
| self.decode_att_state.copy_for_decode_cuda_graph(new_infer_state.decode_att_state) | ||
| if self.decode_att_state1 is not None: | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Avoid using absolute imports inside a function. Move this import to the top of the file to improve readability and maintainability.