From af897ddc958a9b709879cbe63c3e6c0688774758 Mon Sep 17 00:00:00 2001 From: shihaobai <1798930569@qq.com> Date: Thu, 4 Jun 2026 12:49:09 +0000 Subject: [PATCH 1/2] prefill cudagraph capture size --- lightllm/common/basemodel/prefill_cuda_graph.py | 16 ++++++++++------ 1 file changed, 10 insertions(+), 6 deletions(-) diff --git a/lightllm/common/basemodel/prefill_cuda_graph.py b/lightllm/common/basemodel/prefill_cuda_graph.py index 3c53a1b81..6aa5a71cd 100644 --- a/lightllm/common/basemodel/prefill_cuda_graph.py +++ b/lightllm/common/basemodel/prefill_cuda_graph.py @@ -31,14 +31,18 @@ def __init__(self, decode_cuda_graph: CudaGraph, tp_world_size: int): self.enable_prefill_microbatch_overlap = self.args.enable_prefill_microbatch_overlap self.max_handle_token_num = self.args.prefill_cudagraph_max_handle_token - graph_handle_token_nums = [] - for i in range(2048): - token_num = int(2 ** (2 * i)) - if 1 < token_num < self.max_handle_token_num: - graph_handle_token_nums.append(token_num) + graph_handle_token_nums = ( + list(range(4, 33, 4)) + + list(range(48, 257, 16)) + + list(range(288, 513, 32)) + + list(range(576, 1024 + 1, 64)) + + list(range(1280, 4096 + 1, 256)) + + list(range(4608, self.max_handle_token_num + 1, 512)) + ) + graph_handle_token_nums = [e for e in graph_handle_token_nums if e <= self.max_handle_token_num] graph_handle_token_nums.append(self.max_handle_token_num) - graph_handle_token_nums = list(set(graph_handle_token_nums)) + graph_handle_token_nums = list(set[int](graph_handle_token_nums)) graph_handle_token_nums.sort() if self.args.enable_tpsp_mix_mode: graph_handle_token_nums = [ From 76ee67b8447c5b819b25f25b14f2cd227c8e086c Mon Sep 17 00:00:00 2001 From: shihaobai <1798930569@qq.com> Date: Thu, 4 Jun 2026 12:58:55 +0000 Subject: [PATCH 2/2] fix --- lightllm/common/basemodel/prefill_cuda_graph.py | 2 ++ lightllm/server/api_cli.py | 5 ++++- lightllm/server/core/objs/start_args_type.py | 2 +- 3 files changed, 7 insertions(+), 2 deletions(-) diff --git a/lightllm/common/basemodel/prefill_cuda_graph.py b/lightllm/common/basemodel/prefill_cuda_graph.py index 6aa5a71cd..a70c5ce63 100644 --- a/lightllm/common/basemodel/prefill_cuda_graph.py +++ b/lightllm/common/basemodel/prefill_cuda_graph.py @@ -30,6 +30,8 @@ def __init__(self, decode_cuda_graph: CudaGraph, tp_world_size: int): self.args = get_env_start_args() self.enable_prefill_microbatch_overlap = self.args.enable_prefill_microbatch_overlap self.max_handle_token_num = self.args.prefill_cudagraph_max_handle_token + if self.args.batch_max_tokens is not None: + self.max_handle_token_num = min(self.max_handle_token_num, self.args.batch_max_tokens) graph_handle_token_nums = ( list(range(4, 33, 4)) diff --git a/lightllm/server/api_cli.py b/lightllm/server/api_cli.py index 2db6c67e7..f4b90b1d6 100644 --- a/lightllm/server/api_cli.py +++ b/lightllm/server/api_cli.py @@ -567,7 +567,10 @@ def make_argument_parser() -> argparse.ArgumentParser: " currently only for llama and qwen model, not support ep moe model", ) parser.add_argument( - "--prefill_cudagraph_max_handle_token", type=int, default=512, help="max handle token num for prefill cudagraph" + "--prefill_cudagraph_max_handle_token", + type=int, + default=8192, + help="max handle token num for prefill cudagraph", ) parser.add_argument( diff --git a/lightllm/server/core/objs/start_args_type.py b/lightllm/server/core/objs/start_args_type.py index 6d0ee0746..13d00c0a6 100644 --- a/lightllm/server/core/objs/start_args_type.py +++ b/lightllm/server/core/objs/start_args_type.py @@ -126,7 +126,7 @@ class StartArgs: enable_monitor_auth: bool = field(default=False) disable_cudagraph: bool = field(default=False) enable_prefill_cudagraph: bool = field(default=False) - prefill_cudagraph_max_handle_token: int = field(default=512) + prefill_cudagraph_max_handle_token: int = field(default=8192) graph_max_batch_size: int = field(default=256) graph_split_batch_size: int = field(default=32) graph_grow_step_size: int = field(default=16)