diff --git a/lightllm/common/basemodel/prefill_cuda_graph.py b/lightllm/common/basemodel/prefill_cuda_graph.py index 3c53a1b81..a70c5ce63 100644 --- a/lightllm/common/basemodel/prefill_cuda_graph.py +++ b/lightllm/common/basemodel/prefill_cuda_graph.py @@ -30,15 +30,21 @@ def __init__(self, decode_cuda_graph: CudaGraph, tp_world_size: int): self.args = get_env_start_args() self.enable_prefill_microbatch_overlap = self.args.enable_prefill_microbatch_overlap self.max_handle_token_num = self.args.prefill_cudagraph_max_handle_token - - graph_handle_token_nums = [] - for i in range(2048): - token_num = int(2 ** (2 * i)) - if 1 < token_num < self.max_handle_token_num: - graph_handle_token_nums.append(token_num) + if self.args.batch_max_tokens is not None: + self.max_handle_token_num = min(self.max_handle_token_num, self.args.batch_max_tokens) + + graph_handle_token_nums = ( + list(range(4, 33, 4)) + + list(range(48, 257, 16)) + + list(range(288, 513, 32)) + + list(range(576, 1024 + 1, 64)) + + list(range(1280, 4096 + 1, 256)) + + list(range(4608, self.max_handle_token_num + 1, 512)) + ) + graph_handle_token_nums = [e for e in graph_handle_token_nums if e <= self.max_handle_token_num] graph_handle_token_nums.append(self.max_handle_token_num) - graph_handle_token_nums = list(set(graph_handle_token_nums)) + graph_handle_token_nums = list(set[int](graph_handle_token_nums)) graph_handle_token_nums.sort() if self.args.enable_tpsp_mix_mode: graph_handle_token_nums = [ diff --git a/lightllm/server/api_cli.py b/lightllm/server/api_cli.py index 2db6c67e7..f4b90b1d6 100644 --- a/lightllm/server/api_cli.py +++ b/lightllm/server/api_cli.py @@ -567,7 +567,10 @@ def make_argument_parser() -> argparse.ArgumentParser: " currently only for llama and qwen model, not support ep moe model", ) parser.add_argument( - "--prefill_cudagraph_max_handle_token", type=int, default=512, help="max handle token num for prefill cudagraph" + "--prefill_cudagraph_max_handle_token", + type=int, + default=8192, + help="max handle token num for prefill cudagraph", ) parser.add_argument( diff --git a/lightllm/server/core/objs/start_args_type.py b/lightllm/server/core/objs/start_args_type.py index 6d0ee0746..13d00c0a6 100644 --- a/lightllm/server/core/objs/start_args_type.py +++ b/lightllm/server/core/objs/start_args_type.py @@ -126,7 +126,7 @@ class StartArgs: enable_monitor_auth: bool = field(default=False) disable_cudagraph: bool = field(default=False) enable_prefill_cudagraph: bool = field(default=False) - prefill_cudagraph_max_handle_token: int = field(default=512) + prefill_cudagraph_max_handle_token: int = field(default=8192) graph_max_batch_size: int = field(default=256) graph_split_batch_size: int = field(default=32) graph_grow_step_size: int = field(default=16)