Skip to content
Open
Show file tree
Hide file tree
Changes from 16 commits
Commits
Show all changes
72 commits
Select commit Hold shift + click to select a range
6cf58b4
fix
hiworldwzj Apr 9, 2026
f1251a3
add gitignore
flyinglandlord Mar 12, 2026
d9b1fdd
finish usable mtp kernel
flyinglandlord Mar 17, 2026
315366a
end-to-end finish
flyinglandlord Mar 19, 2026
73ea125
fix cudagraph support
flyinglandlord Mar 19, 2026
dc91e59
save runnable version of dynamic mtp
flyinglandlord Mar 26, 2026
aefe67e
save runnable version of dynamic mtp
flyinglandlord Mar 26, 2026
3c28fb0
fix
hiworldwzj Apr 9, 2026
2dc933e
save fixed dynamic mtp
flyinglandlord Mar 27, 2026
0b08de8
save
flyinglandlord Mar 30, 2026
952ec15
save
flyinglandlord Mar 30, 2026
3750118
add experiment script
flyinglandlord Apr 1, 2026
4bf4287
update mtp kernel support BLOCK_BATCH < max_verify_group_size
flyinglandlord Apr 1, 2026
05e0dfd
fix implementation issues
flyinglandlord Apr 3, 2026
c2b7569
save
flyinglandlord Apr 4, 2026
80219af
save
flyinglandlord Apr 8, 2026
41180d3
fix
hiworldwzj Apr 9, 2026
8afd7a8
fix
hiworldwzj Apr 9, 2026
2b277fa
fix
hiworldwzj Apr 9, 2026
93c2ada
fix
hiworldwzj Apr 9, 2026
6395447
fix
hiworldwzj Apr 9, 2026
fc20624
fix
hiworldwzj Apr 9, 2026
1b08d15
fix
hiworldwzj Apr 9, 2026
775adbd
fix
hiworldwzj Apr 9, 2026
d4830ff
fix
hiworldwzj Apr 9, 2026
da944f8
fix
hiworldwzj Apr 9, 2026
22c5996
fix
hiworldwzj Apr 9, 2026
6fbe8d8
fix
hiworldwzj Apr 9, 2026
c4a9f74
fix lightllm/server/router/model_infer/mode_backend/generic_pre_proce…
flyinglandlord Apr 9, 2026
5eb4889
update generic_pre_process.py
flyinglandlord Apr 9, 2026
379f256
fix
flyinglandlord Apr 9, 2026
e943a43
fix
flyinglandlord Apr 9, 2026
e121b9d
refactor qwen3_eagle3
flyinglandlord Apr 9, 2026
1723230
add stage1
hiworldwzj Apr 9, 2026
6890bc0
fix
hiworldwzj Apr 9, 2026
1e1fb98
fix
hiworldwzj Apr 9, 2026
29535b2
fix
hiworldwzj Apr 9, 2026
5b925a6
fix
hiworldwzj Apr 9, 2026
538200d
fix
hiworldwzj Apr 9, 2026
2831c70
fix
hiworldwzj Apr 9, 2026
158c7a3
fix
hiworldwzj Apr 9, 2026
4c08120
fix
hiworldwzj Apr 10, 2026
a837cbb
fix
hiworldwzj Apr 10, 2026
17ed333
fix
hiworldwzj Apr 10, 2026
7927e15
fix
hiworldwzj Apr 10, 2026
53a7077
fix base_backend.py
flyinglandlord Apr 10, 2026
322f713
fix
hiworldwzj Apr 10, 2026
c3e46c9
fix
hiworldwzj Apr 10, 2026
2f0c250
fix
hiworldwzj Apr 10, 2026
8ab2ab4
fix
hiworldwzj Apr 10, 2026
76ca4ce
fix
hiworldwzj Apr 10, 2026
1d0f18e
fix
hiworldwzj Apr 10, 2026
6e69701
fix
hiworldwzj Apr 10, 2026
1565698
fix
hiworldwzj Apr 10, 2026
5e857c8
fix
hiworldwzj Apr 10, 2026
3272073
fix
hiworldwzj Apr 11, 2026
1126013
fix
hiworldwzj Apr 11, 2026
f591158
fix
flyinglandlord Apr 13, 2026
74189c2
add vllm test script
flyinglandlord Apr 20, 2026
df54935
remove 200000 token limit in test script
flyinglandlord May 7, 2026
261f5b4
Mtp optimization ema overlap (#1310)
flyinglandlord May 27, 2026
03ac700
add sample_dynamic_mtp_req_mask
hiworldwzj May 27, 2026
7bc26a7
fix
hiworldwzj May 27, 2026
4b93c8c
fix
hiworldwzj May 27, 2026
d7bdef0
fix
hiworldwzj May 27, 2026
a731153
Merge branch 'mtp_optimization' of https://github.com/ModelTC/LightLL…
flyinglandlord May 27, 2026
97a0123
fix
flyinglandlord May 29, 2026
62226d5
fix
hiworldwzj May 29, 2026
0783c1e
fix
hiworldwzj May 29, 2026
4777936
fix
hiworldwzj May 29, 2026
9dcfc05
fix
hiworldwzj May 29, 2026
9229d28
add dynamic fa3 (#1334)
shihaobai Jun 9, 2026
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -7,3 +7,9 @@ dist
.vscode
tmp/
requirements-musa.txt

hf_datasets_cache/
wandb/
datasets/
trace/
experiment_results/
12,002 changes: 12,002 additions & 0 deletions datasets/gsm8k.json

Large diffs are not rendered by default.

71 changes: 69 additions & 2 deletions lightllm/common/basemodel/attention/triton/fp.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
import dataclasses
import torch

from lightllm.utils.envs_utils import enable_dynamic_mtp_verify, get_diverse_max_batch_shared_group_size, get_env_start_args
from ..base_att import BaseAttBackend, BasePrefillAttState, BaseDecodeAttState, AttControl
from typing import Optional

Expand Down Expand Up @@ -80,8 +82,20 @@ def _nomarl_prefill_att(self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor,

@dataclasses.dataclass
class TritonDecodeAttState(BaseDecodeAttState):
# MTP related state variables
b_mark_shared_group: torch.Tensor = None
mtp_size: int = 1

def init_state(self):
pass
args_mtp_step = get_env_start_args().mtp_step

if args_mtp_step > 0:
# MTP mode initialization
self.mtp_size = args_mtp_step + 1
self.b_mark_shared_group = self.infer_state.b_mark_shared_group
else:
self.mtp_size = 1
self.b_mark_shared_group = None

def copy_for_decode_cuda_graph(self, new_state: "TritonDecodeAttState"):
super().copy_for_decode_cuda_graph(new_state)
Expand All @@ -99,9 +113,17 @@ def decode_att(
assert att_control.tp_alibi is not None
return self._alibi_decode_att(q=q, k=k, v=v, att_control=att_control, alloc_func=alloc_func)
else:
from lightllm.utils.envs_utils import get_env_start_args
args_mtp_step = get_env_start_args().mtp_step

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

Avoid using absolute imports inside a function. Move this import to the top of the file to improve readability and maintainability.


q_head_num = q.shape[1]
k_head_num = k.shape[1]
if q_head_num == k_head_num:

if args_mtp_step > 0:
# MTP mode: use mtp diverse attention
assert q_head_num >= k_head_num, "MTP diverse attention requires q_head_num >= k_head_num"
return self._mtp_diverse_decode_gqa_att(q=q, k=k, v=v, alloc_func=alloc_func)
elif q_head_num == k_head_num:
return self._normal_decode_flash_decoding_att(q=q, k=k, v=v, alloc_func=alloc_func)
elif q_head_num > k_head_num:
return self._normal_decode_gqa_flash_decoding_att(q=q, k=k, v=v, alloc_func=alloc_func)
Expand Down Expand Up @@ -182,6 +204,51 @@ def _normal_decode_gqa_flash_decoding_att(

return out

def _mtp_diverse_decode_gqa_att(
self,
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
alloc_func=torch.empty,
):
"""
MTP Diverse GQA Attention for static and dynamic MTP mode.

In static MTP mode, each request has 1 Q token, but the i-th request in a group
can only see the first i+1 KV tokens.
In dynamic MTP mode, each request has a dynamic mtp_size, and b_mark_shared_group
is built dynamically based on the actual mtp_size of each request.

Input/Output shape: [batch_size, num_heads, head_dim]
"""
from ...triton_kernel.att.decode_att.gqa.mtp_diverse import (
token_decode_attention_mtp_diverse_single_token,
)
from lightllm.utils.envs_utils import enable_dynamic_mtp_verify

batch_size = self.infer_state.batch_size
b_seq_len = self.infer_state.b_seq_len

# 在动态 MTP 验证模式下,使用 infer_state.b_mark_shared_group(从 model_input 传递)
# 在静态 MTP 模式下,使用 self.b_mark_shared_group(在 init_state 中初始化)
b_mark_shared_group = self.infer_state.b_mark_shared_group

block_seq = 256

out = token_decode_attention_mtp_diverse_single_token(
q=q,
k=k,
v=v,
Req_to_tokens=self.infer_state.req_manager.req_to_token_indexs,
B_req_idx=self.infer_state.b_req_idx,
b_seq_len=b_seq_len,
b_mark_shared_group=b_mark_shared_group,
block_seq=block_seq,
alloc_tensor_func=alloc_func,
)

return out

def _normal_decode_gqa_flash_decoding_att_vsm(
self,
q: torch.Tensor,
Expand Down
111 changes: 98 additions & 13 deletions lightllm/common/basemodel/basemodel.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,14 +25,14 @@
from lightllm.common.basemodel.triton_kernel.gather_token_id import gather_token
from lightllm.utils.log_utils import init_logger
from lightllm.utils.dist_utils import get_dp_world_size
from lightllm.utils.envs_utils import get_env_start_args, get_llm_data_type, get_added_mtp_kv_layer_num
from lightllm.utils.envs_utils import enable_triton_mtp_kernel, get_env_start_args, get_llm_data_type, get_added_mtp_kv_layer_num
from lightllm.distributed.communication_op import dist_group_manager
from lightllm.common.basemodel.batch_objs import ModelInput, ModelOutput
from lightllm.common.triton_utils.autotuner import AutotuneLevel
from lightllm.utils.custom_kernel_utis import pad2dim_tensor_to_new_batch
from lightllm.utils.envs_utils import set_model_init_status, enable_diverse_mode_gqa_decode_fast_kernel
from lightllm.utils.envs_utils import set_model_init_status, enable_diverse_mode_gqa_decode_fast_kernel, enable_dynamic_mtp_verify
from lightllm.common.triton_utils.autotuner import Autotuner
from lightllm.utils.infer_utils import post_empty_cache
from lightllm.utils.infer_utils import calculate_time, post_empty_cache
from .attention import get_prefill_att_backend_class, get_decode_att_backend_class
from .attention import BaseAttBackend

Expand Down Expand Up @@ -93,16 +93,11 @@ def __init__(self, kvargs):
self.mem_fraction = kvargs.get("mem_fraction", 0.9)
self.tp_world_size_ = get_dp_world_size()
self.enable_tpsp_mix_mode = get_env_start_args().enable_tpsp_mix_mode

self.is_mtp_mode = self.args.mtp_mode in [
"vanilla_with_att",
"eagle_with_att",
"vanilla_no_att",
"eagle_no_att",
]
self.prefill_graph: PrefillCudaGraph = None

self._init_config()
self._init_speculative_algo(kvargs)

self._verify_must()
self._verify_params()
self._init_quant()
Expand Down Expand Up @@ -137,6 +132,44 @@ def __init__(self, kvargs):
set_model_init_status(True)
return

def _init_speculative_algo(self, kvargs):
self.is_mtp_mode = self.args.mtp_mode in [
"vanilla_with_att",
"eagle_with_att",
"vanilla_no_att",
"eagle_no_att",
"eagle3"
]
self.is_mtp_draft_model = kvargs.get("is_mtp_draft_model", False)
self.is_eagle3_mode = "eagle3" in self.args.mtp_mode
if self.is_eagle3_mode and not self.is_mtp_draft_model:
self.eagle_hidden_layers = [1, self.config["n_layer"]//2-1, self.config["n_layer"]-4]
elif self.is_mtp_mode:
self.eagle_hidden_layers = [self.config["n_layer"]-1]
else:
self.eagle_hidden_layers = []

if self.is_eagle3_mode:
# load the hidden_proj weight from the draft model path
draft_model_path = kvargs.get("mtp_draft_model_dir", None)
assert draft_model_path is not None, "mtp_draft_model_dir must be provided when eagle3 mode is enabled"
if os.path.exists(os.path.join(draft_model_path[0], "pytorch_model.bin")):
self.draft_model_weight_dict = torch.load(os.path.join(draft_model_path[0], "pytorch_model.bin"))
self.hidden_proj_weight = self.draft_model_weight_dict["fc.weight"].to(torch.bfloat16).to("cuda")
del self.draft_model_weight_dict
gc.collect()

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The use of torch.load with pytorch_model.bin can be slow and memory-intensive. Consider using safetensors for faster and safer model weight loading if possible.

else:
try:
from safetensors import safe_open
with safe_open(os.path.join(draft_model_path[0], "model.safetensors"), framework="pt", device="cuda") as f:
# Check if the key exists to avoid KeyError
if "fc.weight" in f.keys():
self.hidden_proj_weight = f.get_tensor("fc.weight").to(torch.bfloat16).to("cuda")
except Exception as e:
logger.warning(f"Failed to load hidden_proj_weight from safetensors with error: {e}")
self.hidden_proj_weight = None


def _wait_other_modules_ready(self):
for event in self.wait_events:
event.wait()
Expand All @@ -151,6 +184,9 @@ def _init_config(self):
repair_config(self.config, same_names=["num_hidden_layers", "n_layer"])
if self.finetune_config:
self.config["vocab_size"] = self.finetune_config.vocab_size
if "draft_vocab_size" in self.config.keys():
self.config["target_vocab_size"] = self.config["vocab_size"]
self.config["vocab_size"] = self.config["draft_vocab_size"]
return

@final
Expand Down Expand Up @@ -314,6 +350,16 @@ def _create_inferstate(self, model_input: ModelInput, microbatch_index: int = 0)
if enable_diverse_mode_gqa_decode_fast_kernel():
infer_state.b_shared_seq_len = model_input.b_shared_seq_len
infer_state.b_mark_shared_group = model_input.b_mark_shared_group
elif enable_dynamic_mtp_verify() or enable_triton_mtp_kernel():
# 动态 MTP 验证模式下,也需要传递 b_mark_shared_group
infer_state.b_mark_shared_group = model_input.b_mark_shared_group
# 将b_mark_shared_group pad到跟input_ids一样的长度,避免后续使用时出现形状不匹配的问题
# infer_state.b_mark_shared_group = F.pad(
# infer_state.b_mark_shared_group,
# (0, infer_state.input_ids.shape[0] - infer_state.b_mark_shared_group.shape[0]),
# mode="constant",
# value=0,
# )

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The commented-out code block should be removed if it is no longer needed, or uncommented if it is intended to be part of the logic. Leaving dead code reduces maintainability.


infer_state.multimodal_params = model_input.multimodal_params

Expand Down Expand Up @@ -377,6 +423,11 @@ def _create_padded_decode_model_input(self, model_input: ModelInput, new_batch_s
new_model_input.b_mark_shared_group = F.pad(
new_model_input.b_mark_shared_group, (0, padded_batch_size), mode="constant", value=1
)
elif enable_dynamic_mtp_verify() or enable_triton_mtp_kernel():
if new_model_input.b_mark_shared_group is not None:
new_model_input.b_mark_shared_group = F.pad(
new_model_input.b_mark_shared_group, (0, padded_batch_size), mode="constant", value=0
)

# 特殊模型,特殊模式的特殊变量的特殊 padding
if new_model_input.mtp_draft_input_hiddens is not None:
Expand Down Expand Up @@ -561,12 +612,19 @@ def _context_forward(self, infer_state: InferStateInfo):
input_tensors = [input_embs]

def prefill_func(input_tensors, infer_state):
capture_hiddens = []
_input_embs = input_tensors[0]
for i in range(self.layers_num):
layer = self.layers_infer[i]
layer_method = (layer.context_forward, layer.tpsp_context_forward)[run_mode_index]
_input_embs = layer_method(_input_embs, infer_state, self.trans_layers_weight[i])
return [_input_embs]
if i in self.eagle_hidden_layers:
capture_hiddens.append(_input_embs.clone())
capture_hidden = torch.cat(capture_hiddens, dim=-1) if len(capture_hiddens) > 0 else None
if self.is_eagle3_mode and not self.is_mtp_draft_model and capture_hidden is not None:
capture_hidden = self.hidden_proj_weight @ capture_hidden.transpose(0, 1)
capture_hidden = capture_hidden.transpose(0, 1)
return [_input_embs, capture_hidden]

handle_token_num = input_ids.shape[0]

Expand Down Expand Up @@ -597,7 +655,7 @@ def prefill_func(input_tensors, infer_state):

# 特殊模型特殊模式的额外输出
if self.is_mtp_mode:
model_output.mtp_main_output_hiddens = input_embs
model_output.mtp_main_output_hiddens = output_tensors[1]

# 在开启使用deepep的时候,需要调用clear_deepep_buffer做资源清理,没有启用的时候
# 该调用没有实际意义
Expand All @@ -611,16 +669,23 @@ def _token_forward(self, infer_state: InferStateInfo):
cuda_input_ids = input_ids
pre_method = (self.pre_infer.token_forward, self.pre_infer.tpsp_token_forward)[run_mode_index]
input_embs = pre_method(cuda_input_ids, infer_state, self.pre_post_weight)
capture_hiddens = []
for i in range(self.layers_num):
layer = self.layers_infer[i]
layer_method = (layer.token_forward, layer.tpsp_token_forward)[run_mode_index]
input_embs: torch.Tensor = layer_method(input_embs, infer_state, self.trans_layers_weight[i])
if i in self.eagle_hidden_layers:
capture_hiddens.append(input_embs.clone())

capture_hidden = torch.cat(capture_hiddens, dim=-1) if len(capture_hiddens) > 0 else None
if self.is_eagle3_mode and not self.is_mtp_draft_model and capture_hidden is not None:
capture_hidden = self.hidden_proj_weight @ capture_hidden.transpose(0, 1)
capture_hidden = capture_hidden.transpose(0, 1)
post_method = (self.post_infer.token_forward, self.post_infer.tpsp_token_forward)[run_mode_index]
predict_logits: torch.Tensor = post_method(input_embs, infer_state, self.pre_post_weight)

if self.is_mtp_mode:
graph_out_hiddens = input_embs.contiguous()
graph_out_hiddens = capture_hidden.contiguous()

model_output = ModelOutput(logits=predict_logits.contiguous())

Expand Down Expand Up @@ -1027,6 +1092,7 @@ def _gen_special_model_input(self, token_num: int):
or "Qwen3MOEMTPModel" in str(self.__class__)
or "MistralMTPModel" in str(self.__class__)
or "Glm4MoeLiteMTPModel" in str(self.__class__)
or "Qwen3EagleModel" in str(self.__class__)
)
if is_mtp_draft_model:
special_model_input["mtp_draft_input_hiddens"] = torch.randn(
Expand All @@ -1036,3 +1102,22 @@ def _gen_special_model_input(self, token_num: int):
special_model_input["mtp_draft_input_hiddens"] = None

return special_model_input

"""
tensor([[ 1.9100e+02, 4.0400e+02, -4.2400e+02, ..., -3.3200e+02,
-2.9250e+01, -5.5000e+01],
[ 5.1875e+00, 9.5625e+00, 2.8516e-01, ..., 3.8906e+00,
4.0625e+00, -3.1562e+00],
[-4.7461e-01, -9.1250e+00, 9.0000e+00, ..., 6.0312e+00,
9.3750e+00, 6.4375e+00],
[-4.2500e+00, -3.8906e+00, -5.4297e-01, ..., 3.4844e+00,
3.1719e+00, 9.5625e+00]], device='cuda:0', dtype=torch.bfloat16)
tensor([[ 1.9100e+02, 4.0400e+02, -4.2400e+02, ..., -3.3200e+02,
-2.9250e+01, -5.5000e+01],
[ 5.1875e+00, 9.5625e+00, 2.8516e-01, ..., 3.8906e+00,
4.0625e+00, -3.1562e+00],
[-4.7461e-01, -9.1250e+00, 9.0000e+00, ..., 6.0312e+00,
9.3750e+00, 6.4375e+00],
[-4.2500e+00, -3.8906e+00, -5.4297e-01, ..., 3.4844e+00,
3.1719e+00, 9.5625e+00]], device='cuda:1', dtype=torch.bfloat16)
"""
21 changes: 14 additions & 7 deletions lightllm/common/basemodel/batch_objs.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
from dataclasses import dataclass, field
from typing import Optional
from typing import List
from lightllm.utils.envs_utils import enable_diverse_mode_gqa_decode_fast_kernel
from lightllm.utils.envs_utils import enable_diverse_mode_gqa_decode_fast_kernel, enable_dynamic_mtp_verify, enable_triton_mtp_kernel
from lightllm.utils.tensor_utils import tensor_to_no_ref_tensor


Expand All @@ -14,6 +14,8 @@ class ModelInput:
# 在 decode 阶段, max_q_seq_len 必定是 1,
max_q_seq_len: int
max_kv_seq_len: int
# 用于记录原始请求数量,主要是为了动态MTP mode下保留原始请求数
original_num_reqs: int = None
max_cache_len: int = None
prefix_total_token_num: int = None
input_ids: torch.Tensor = None
Expand Down Expand Up @@ -59,19 +61,24 @@ def to_cuda(self):
self.b_ready_cache_len = self.b_ready_cache_len.cuda(non_blocking=True)
if self.b_prefill_start_loc is not None:
self.b_prefill_start_loc = self.b_prefill_start_loc.cuda(non_blocking=True)
if not self.is_prefill and enable_diverse_mode_gqa_decode_fast_kernel():
if not self.is_prefill and \
(enable_diverse_mode_gqa_decode_fast_kernel() or enable_dynamic_mtp_verify() or enable_triton_mtp_kernel()):
batch_size = len(self.b_req_idx)
if self.b_mark_shared_group is None:
self.b_mark_shared_group = torch.ones(size=(batch_size,), dtype=torch.int32, device="cuda")
self.b_mark_shared_group = torch.zeros(size=(batch_size,), dtype=torch.int32, device="cuda")
else:
self.b_mark_shared_group = self.b_mark_shared_group.cuda(non_blocking=True)
if self.b_shared_seq_len is None:
self.b_shared_seq_len = torch.zeros(size=(batch_size,), dtype=torch.int32, device="cuda")
else:
self.b_shared_seq_len = self.b_shared_seq_len.cuda(non_blocking=True)
# b_shared_seq_len 只在 diverse mode 下使用,动态 MTP mode 不需要
if enable_diverse_mode_gqa_decode_fast_kernel():
if self.b_shared_seq_len is None:
self.b_shared_seq_len = torch.zeros(size=(batch_size,), dtype=torch.int32, device="cuda")
else:
self.b_shared_seq_len = self.b_shared_seq_len.cuda(non_blocking=True)

def __post_init__(self):
self.check_input()
if self.original_num_reqs is None:
self.original_num_reqs = self.batch_size

def check_input(self):
assert len(self.multimodal_params) == self.batch_size
Expand Down
6 changes: 5 additions & 1 deletion lightllm/common/basemodel/infer_struct.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,7 @@ def __init__(self):
# 在开启 mtp_mode 时,mtp draft model
# 的输入会用到,其他模型和场景都不会用到
self.mtp_draft_input_hiddens: Optional[torch.Tensor] = None
self.is_mtp_draft_model: bool = False

# 在单节点多dp的运行模式下,在进行prefill的阶段,如果出现了dp之间数据不平衡的现象,
# 可以将推理的数据,进行重新分配到各个dp,在做 att 之前,重新 all to all 到各自的
Expand Down Expand Up @@ -137,7 +138,10 @@ def copy_for_cuda_graph(self, new_infer_state: "InferStateInfo"):
if isinstance(attr_value, torch.Tensor):
attr_ = getattr(self, attr_name, None)
if attr_ is not None and attr_.data_ptr() != attr_value.data_ptr():
attr_.copy_(attr_value, non_blocking=True)
try:
attr_.copy_(attr_value, non_blocking=True)
except Exception as e:
print(f"Warning: copy tensor {attr_name} failed during cuda graph copy, error: {e}")

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

Using print for error logging is not recommended in production code. Use the project's logger to ensure errors are captured in the standard logging infrastructure.


self.decode_att_state.copy_for_decode_cuda_graph(new_infer_state.decode_att_state)
if self.decode_att_state1 is not None:
Expand Down
Loading
Loading