From f99afe6aa35c41e1bcacfae72df5987923898fde Mon Sep 17 00:00:00 2001 From: Siyu Wu Date: Thu, 18 Dec 2025 14:34:53 +0000 Subject: [PATCH] feat(misc): Profiler support use --enable_profiling=MODE to enable, currently support torch_profile and nvtx (use with NVIDIA Nsight system) mode --- lightllm/server/api_cli.py | 15 ++ lightllm/server/api_http.py | 18 ++ lightllm/server/api_start.py | 8 +- lightllm/server/core/objs/start_args_type.py | 5 + lightllm/server/httpserver/manager.py | 19 +- lightllm/server/router/manager.py | 22 +- .../model_infer/mode_backend/base_backend.py | 12 + lightllm/server/router/profiler_service.py | 52 ++++ lightllm/utils/profiler.py | 227 ++++++++++++++++++ skills/lightllm-profiler-control/SKILL.md | 46 ++++ 10 files changed, 419 insertions(+), 5 deletions(-) create mode 100644 lightllm/server/router/profiler_service.py create mode 100644 lightllm/utils/profiler.py create mode 100644 skills/lightllm-profiler-control/SKILL.md diff --git a/lightllm/server/api_cli.py b/lightllm/server/api_cli.py index 7e40421140..1bdf8f3427 100644 --- a/lightllm/server/api_cli.py +++ b/lightllm/server/api_cli.py @@ -842,4 +842,19 @@ def make_argument_parser() -> argparse.ArgumentParser: If the op is not implemented for the platform and the hardware support triton, it will use triton implementation.""", ) + parser.add_argument( + "--enable_profiling", + type=str, + choices=["torch_profiler", "nvtx"], + default=None, + help="""Enable profiler support. + This will expose '/profiler_start' and '/profiler_stop' API, + below profiling features will only be enabled in this range. + Options: + 'torch_profiler': will setup torch.profiler.profile(), trace files will be saved to './trace', + or set by 'LIGHTLLM_TRACE_DIR' env; + 'nvtx': will add NVTX marks for external profiler like NVIDIA Nsight System + (you should set it up by yourself). + A NVTX range named 'LIGHTLLM_PROFILE' will be added within the profiling range.""", + ) return parser diff --git a/lightllm/server/api_http.py b/lightllm/server/api_http.py index c6809fd2ad..270e2a8cfd 100755 --- a/lightllm/server/api_http.py +++ b/lightllm/server/api_http.py @@ -455,6 +455,24 @@ async def kv_move_status(websocket: WebSocket): return +@app.get("/profiler_start") +async def profiler_start() -> Response: + if g_objs.args.enable_profiling: + await g_objs.httpserver_manager.profiler_cmd("start") + return JSONResponse({"status": "ok"}) + else: + return JSONResponse({"message": "Profiling support not enabled"}, status_code=400) + + +@app.get("/profiler_stop") +async def profiler_stop() -> Response: + if g_objs.args.enable_profiling: + await g_objs.httpserver_manager.profiler_cmd("stop") + return JSONResponse({"status": "ok"}) + else: + return JSONResponse({"message": "Profiling support not enabled"}, status_code=400) + + @app.on_event("shutdown") async def shutdown(): logger.info("Received signal to shutdown. Performing graceful shutdown...") diff --git a/lightllm/server/api_start.py b/lightllm/server/api_start.py index aaaefb930d..3cf431d650 100644 --- a/lightllm/server/api_start.py +++ b/lightllm/server/api_start.py @@ -350,13 +350,14 @@ def normal_or_p_d_start(args): node_world_size = args.tp // args.nnodes can_use_ports = alloc_can_use_network_port( - num=9 + node_world_size + args.visual_dp * args.visual_tp + args.visual_dp + args.audio_dp, + num=10 + node_world_size + args.visual_dp * args.visual_tp + args.visual_dp + args.audio_dp, used_ports=already_uesd_ports, ) logger.info(f"alloced ports: {can_use_ports}") ( nccl_port, router_port, + router_profiler_port, detokenization_port, http_server_port, visual_port, @@ -364,8 +365,8 @@ def normal_or_p_d_start(args): cache_port, metric_port, multi_level_kv_cache_port, - ) = can_use_ports[0:9] - can_use_ports = can_use_ports[9:] + ) = can_use_ports[0:10] + can_use_ports = can_use_ports[10:] if args.visual_nccl_ports is None: args.visual_nccl_ports = can_use_ports[: args.visual_dp] @@ -383,6 +384,7 @@ def normal_or_p_d_start(args): if args.nccl_port is None: args.nccl_port = nccl_port args.router_port = router_port + args.router_profiler_port = router_profiler_port args.detokenization_port = detokenization_port args.http_server_port = http_server_port args.visual_port = visual_port diff --git a/lightllm/server/core/objs/start_args_type.py b/lightllm/server/core/objs/start_args_type.py index e7f35780a4..40c8028158 100644 --- a/lightllm/server/core/objs/start_args_type.py +++ b/lightllm/server/core/objs/start_args_type.py @@ -103,6 +103,10 @@ class StartArgs: use_reward_model: bool = field(default=False) use_tgi_api: bool = field(default=False) health_monitor: bool = field(default=False) + enable_profiling: Optional[str] = field( + default=None, + metadata={"choices": ["torch_profiler", "nvtx"]}, + ) metric_gateway: Optional[str] = field(default=None) job_name: str = field(default="lightllm") grouping_key: List[str] = field(default_factory=list) @@ -182,6 +186,7 @@ class StartArgs: enable_dp_prompt_cache_fetch: bool = field(default=False) # zmp ports router_port: int = field(default=None) + router_profiler_port: int = field(default=None) detokenization_port: int = field(default=None) http_server_port: int = field(default=None) visual_port: int = field(default=None) diff --git a/lightllm/server/httpserver/manager.py b/lightllm/server/httpserver/manager.py index e47692d1b0..0f1b873111 100644 --- a/lightllm/server/httpserver/manager.py +++ b/lightllm/server/httpserver/manager.py @@ -13,7 +13,7 @@ from frozendict import frozendict asyncio.set_event_loop_policy(uvloop.EventLoopPolicy()) -from typing import Union, List, Tuple, Dict, Optional, AsyncGenerator +from typing import Literal, Union, List, Tuple, Dict, Optional, AsyncGenerator from websockets import ClientConnection from fastapi import Request from ..tokenizer import get_tokenizer @@ -800,6 +800,23 @@ async def abort(self, group_req_id: int) -> bool: logger.warning(f"aborted group_request_id {group_req_objs.group_req_id}") return True + def _get_router_profiler_client(self): + router_profiler_client = getattr(self, "router_profiler_client", None) + if router_profiler_client is None or getattr(router_profiler_client, "closed", False): + from lightllm.utils.retry_utils import retry + + self.router_profiler_client = retry(max_attempts=20, wait_time=0.5)(rpyc.connect)( + "localhost", + self.args.router_profiler_port, + config={"allow_pickle": True}, + ) + self.router_profiler_client._channel.stream.sock.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1) + return self.router_profiler_client + + async def profiler_cmd(self, cmd: Literal["start", "stop"]): + client = self._get_router_profiler_client() + client.root.profiler_cmd(cmd) + async def recycle_resource_loop(self): pre_time_mark = time.time() diff --git a/lightllm/server/router/manager.py b/lightllm/server/router/manager.py index 1de4238a5c..dfb8866601 100644 --- a/lightllm/server/router/manager.py +++ b/lightllm/server/router/manager.py @@ -12,7 +12,7 @@ import torch.multiprocessing as mp import torch.distributed as dist import multiprocessing -from typing import Dict, List, Optional, Union +from typing import Dict, List, Optional from .batch import Batch, Req from .model_infer.model_rpc import start_model_process, ModelRpcClient from .req_queue import build_req_queue @@ -26,6 +26,7 @@ from lightllm.server.multi_level_kv_cache.cpu_cache_client import CpuKvCacheClient from lightllm.server.core.objs.shm_objs_io_buffer import ShmObjsIOBuffer from lightllm.utils.log_utils import init_logger, log_time_ready +from lightllm.utils.profiler import ProfilerCmd from lightllm.server.router.token_load import TokenLoad from lightllm.server.metrics.manager import MetricClient from lightllm.common.kv_cache_mem_manager import ReadOnlyStaticsMemoryManager @@ -34,6 +35,7 @@ from lightllm.utils.envs_utils import get_unique_server_name from lightllm.server.router.dynamic_prompt.shared_arr import SharedInt from .stats import RouterStatics +from .profiler_service import RouterProfilerCmdQueue, start_router_profiler_server logger = init_logger(__name__) @@ -102,6 +104,8 @@ def __init__(self, args: StartArgs): else CpuKvCacheClient(only_create_meta_data=True, init_shm_data=False) ) self.router_statics = RouterStatics(self.args) + self.profiler_cmd_queue = RouterProfilerCmdQueue() + return async def wait_to_model_ready(self): @@ -275,6 +279,7 @@ async def _step(self): """ # 接受新请求,并尝试调度 await self._recv_new_reqs_and_schedule() + await self._write_profiler_cmds() # 判断是否有新请求加入推理 # 激进调度满足,有新的推理batch就需要进行加入。 # 或者延迟step的步数满足了当前条件,也需要进行新的推理batch的加入。 @@ -303,6 +308,17 @@ async def _add_batch(self, batch: Batch): logger.debug(f"Prefill Batch: {batch.simple_log()} \n") return + async def _write_profiler_cmds(self): + cmd = self.profiler_cmd_queue.pop() + if cmd is None: + return + + while not self.shm_reqs_io_buffer.is_empty(): + await asyncio.sleep(0.001) + self.shm_reqs_io_buffer.write_obj([ProfilerCmd(cmd)]) + self.shm_reqs_io_buffer.set_ready() + return + async def _aborted_reqs(self, aborted_reqs: List[Req]): cmds = [AbortedReqCmd(req_id=r.request_id) for r in aborted_reqs] while not self.shm_reqs_io_buffer.is_empty(): @@ -537,6 +553,10 @@ def handle_exception(loop, context): ) loop.run_until_complete(router.wait_to_model_ready()) + router.profiler_rpyc_server, router.profiler_rpyc_thread = start_router_profiler_server( + args, + router.profiler_cmd_queue, + ) except: import traceback import sys diff --git a/lightllm/server/router/model_infer/mode_backend/base_backend.py b/lightllm/server/router/model_infer/mode_backend/base_backend.py index 4323a62d1c..a65dfb1bbb 100644 --- a/lightllm/server/router/model_infer/mode_backend/base_backend.py +++ b/lightllm/server/router/model_infer/mode_backend/base_backend.py @@ -49,6 +49,7 @@ from lightllm.common.basemodel.triton_kernel.gather_token_id import scatter_token from lightllm.server.pd_io_struct import PDChunckedTransTaskRet from .multi_level_kv_cache import MultiLevelKvCacheModule +from lightllm.utils.profiler import ProcessProfiler, ProfilerCmd class ModeBackend: @@ -240,6 +241,10 @@ def init_model(self, kvargs): if self.args.enable_cpu_cache: self.multi_level_cache_module = MultiLevelKvCacheModule(self) + prof_name = f"lightllm-model_backend-node{self.node_rank}_dev{get_current_device_id()}" + prof_mode = self.args.enable_profiling + self.profiler = ProcessProfiler(mode=prof_mode, name=prof_name, use_multi_thread=True) if prof_mode else None + # 启动infer_loop_thread, 启动两个线程进行推理,对于具备双batch推理折叠得场景 # 可以降低 cpu overhead,大幅提升gpu得使用率。 self.infer_loop_thread = threading.Thread(target=self.infer_loop, daemon=True) @@ -363,6 +368,10 @@ def _try_read_new_reqs(self): self._try_read_new_reqs_multinode_tp() else: self._try_read_new_reqs_normal() + + # on each loop thread + if self.profiler is not None: + self.profiler.multi_thread_helper() return def _try_read_new_reqs_normal(self): @@ -428,6 +437,9 @@ def _read_reqs_buffer_and_init_reqs(self): if obj.req_id in g_infer_context.requests_mapping: req: InferReq = g_infer_context.requests_mapping[obj.req_id] req.infer_aborted = True + elif isinstance(obj, ProfilerCmd): + if self.profiler is not None: + self.profiler.cmd(obj) else: assert False, f"error type {type(obj)}" if init_reqs: diff --git a/lightllm/server/router/profiler_service.py b/lightllm/server/router/profiler_service.py new file mode 100644 index 0000000000..dd27d8d399 --- /dev/null +++ b/lightllm/server/router/profiler_service.py @@ -0,0 +1,52 @@ +import threading + +import rpyc + +from lightllm.utils.log_utils import init_logger + +logger = init_logger(__name__) + + +class RouterProfilerCmdQueue: + def __init__(self): + self.cmds = [] + self.lock = threading.Lock() + + def append(self, cmd: str): + with self.lock: + self.cmds.append(cmd) + return + + def pop(self): + with self.lock: + if not self.cmds: + return None + return self.cmds.pop(0) + + +class RouterProfilerService(rpyc.Service): + def __init__(self, profiler_cmd_queue: RouterProfilerCmdQueue): + super().__init__() + self.profiler_cmd_queue = profiler_cmd_queue + + def exposed_profiler_cmd(self, cmd: str): + self.profiler_cmd_queue.append(cmd) + return + + +def start_router_profiler_server(args, profiler_cmd_queue: RouterProfilerCmdQueue): + if not args.enable_profiling: + return None, None + + from rpyc.utils.server import ThreadedServer + import lightllm.utils.rpyc_fix_utils as _ + + server = ThreadedServer( + RouterProfilerService(profiler_cmd_queue), + port=args.router_profiler_port, + protocol_config={"allow_pickle": True}, + ) + thread = threading.Thread(target=server.start, daemon=True) + thread.start() + logger.info(f"router profiler rpyc server started on port {args.router_profiler_port}") + return server, thread diff --git a/lightllm/utils/profiler.py b/lightllm/utils/profiler.py new file mode 100644 index 0000000000..6ed23dcedf --- /dev/null +++ b/lightllm/utils/profiler.py @@ -0,0 +1,227 @@ +from dataclasses import dataclass +import os +import threading +import traceback +from typing import Any, Literal, Optional +import torch + +from lightllm.utils.log_utils import init_logger + +logger = init_logger(__name__) + + +@dataclass +class ProfilerCmd: + cmd: Literal["start", "stop"] + + +def _get_thread_id() -> int: + # Get native thread ID (LWP) for correlation with system tools like htop/nsys + if hasattr(threading, "get_native_id"): + return threading.get_native_id() + return threading.get_ident() + + +class ProcessProfiler: + def __init__( + self, + mode: Literal["torch_profiler", "nvtx"], + name: Optional[str] = None, + use_multi_thread: bool = False, + torch_profiler_with_stack: bool = True, + ) -> None: + """ + Process Level Profiler Manager. + For multi-threading, set `use_multi_thread=True` + and call `.multi_thread_helper()` regularly in each worker thread. + """ + self.mode = mode + self.name = name or "unnamed" + self.use_multi_thread = use_multi_thread + self.torch_profiler_with_stack = torch_profiler_with_stack + + self.is_active: bool = False # Process-level logical state + self._threadlocal = threading.local() + + # make sure only one active torch.profiler per process + self._lock = threading.Lock() + self._process_torch_profiler_active_tid: int | None = None + + if self.mode == "torch_profiler": + self._trace_dir = os.getenv("LIGHTLLM_TRACE_DIR", "./trace") + os.makedirs(self._trace_dir, exist_ok=True) + elif self.mode == "nvtx": + self._nvtx_toplevel_mark = "LIGHTLLM_PROFILE" + else: + raise ValueError("invalid profiler mode") + + self._log_init_info() + + @property + def _local(self): + """Lazy initialization of thread-local storage.""" + if not hasattr(self._threadlocal, "initialized"): + self._threadlocal.initialized = True + self._threadlocal.is_active = False + self._threadlocal.profiler_obj = None + self._threadlocal.nvtx_range_id = None + return self._threadlocal + + def _log_init_info(self): + logger.warning("-" * 50) + logger.warning( + f"[pid={os.getpid()} tid={_get_thread_id()}] Profiler <{self.name}> initialized with mode: {self.mode}" + ) + if self.mode == "torch_profiler": + logger.warning( + "Profiler support for torch.profiler enabled (--enable_profiling=torch_profiler), " + "trace files will be saved to %s (change it with LIGHTLLM_TRACE_DIR env var)", + self._trace_dir, + ) + elif self.mode == "nvtx": + logger.warning( + "Profiler support for NVTX enabled (--enable_profiling=nvtx), toplevel NVTX mark is '%s'\n" + "you can use it with external profiling tools like NVIDIA Nsight Systems.", + self._nvtx_toplevel_mark, + ) + logger.warning( + "e.g. nsys profile --capture-range=nvtx --nvtx-capture=%s --trace=cuda,nvtx " + "-e NSYS_NVTX_PROFILER_REGISTER_ONLY=0 [other nsys options] " + "python -m lightllm.server.api_server --enable_profiling=nvtx [other lightllm options]", + self._nvtx_toplevel_mark, + ) + logger.warning("Use /profiler_start and /profiler_stop HTTP GET APIs to start/stop profiling") + logger.warning("DO NOT enable this feature in production environment") + logger.warning("-" * 50) + + def _torch_profiler_start(self) -> None: + with self._lock: + if self._process_torch_profiler_active_tid is not None: + return + self._process_torch_profiler_active_tid = _get_thread_id() + + torch.cuda.synchronize() + worker_name = f"{self.name}_tid{_get_thread_id()}" if self.use_multi_thread else self.name + + trace_handler = torch.profiler.tensorboard_trace_handler( + self._trace_dir, + worker_name=worker_name, + use_gzip=True, + ) + + p = torch.profiler.profile( + activities=[ + torch.profiler.ProfilerActivity.CPU, + torch.profiler.ProfilerActivity.CUDA, + ], + schedule=None, + with_stack=self.torch_profiler_with_stack, + record_shapes=True, + on_trace_ready=trace_handler, + ) + + self._local.profiler_obj = p + p.start() + torch.cuda.synchronize() + + def _nvtx_start(self) -> None: + torch.cuda.synchronize() + self._local.nvtx_range_id = torch.cuda.nvtx.range_start(self._nvtx_toplevel_mark) + torch.cuda.synchronize() + + def _thread_start(self) -> None: + if self._local.is_active: + return + + try: + logger.info(f"[{self.name} @ tid={_get_thread_id()}] Start Profiler.") + if self.mode == "torch_profiler": + self._torch_profiler_start() + elif self.mode == "nvtx": + self._nvtx_start() + + self._local.is_active = True + except Exception as e: + logger.error( + f"[{self.name} @ tid={_get_thread_id()}] Failed to start profiler in thread {_get_thread_id()}: {e}" + ) + traceback.print_exc() + # Reset state on failure to prevent infinite retry loops + self._local.is_active = False + + def _torch_profiler_stop(self) -> None: + if self._process_torch_profiler_active_tid != _get_thread_id(): + return + + torch.cuda.synchronize() + logger.info(f"[{self.name} @ tid={_get_thread_id()}] Saving trace (blocking)...") + try: + if self._local.profiler_obj: + self._local.profiler_obj.stop() + except Exception as e: + logger.error(f"[{self.name} @ tid={_get_thread_id()}] Error stopping torch profiler: {e}") + finally: + self._local.profiler_obj = None # Explicitly release reference to allow GC + self._process_torch_profiler_active_tid = None + + torch.cuda.synchronize() + + def _nvtx_stop(self) -> None: + torch.cuda.synchronize() + if self._local.nvtx_range_id is not None: + torch.cuda.nvtx.range_end(self._local.nvtx_range_id) + self._local.nvtx_range_id = None + torch.cuda.synchronize() + + def _thread_stop(self) -> None: + if not self._local.is_active: + return + + try: + if self.mode == "torch_profiler": + self._torch_profiler_stop() + elif self.mode == "nvtx": + self._nvtx_stop() + logger.info(f"[{self.name} @ tid={_get_thread_id()}] Profiler stopped.") + except Exception as e: + logger.error(f"[{self.name} @ tid={_get_thread_id()}] Failed to stop profiler: {e}") + finally: + # Mark inactive regardless of success to avoid repeated errors + self._local.is_active = False + + def start(self) -> None: + self.is_active = True + if not self.use_multi_thread: + self._thread_start() + + def stop(self) -> None: + self.is_active = False + if not self.use_multi_thread: + self._thread_stop() + + def multi_thread_helper(self) -> None: + """ + **only for multi-threading use cases** + Worker polling method. Must be called within the inference loop. + """ + if not self.use_multi_thread: + return + + # Catch-all to prevent profiler errors from crashing inference logic + try: + local_active = self._local.is_active + + if self.is_active and not local_active: + self._thread_start() + elif not self.is_active and local_active: + self._thread_stop() + except Exception: + pass + + def cmd(self, cmd_obj: ProfilerCmd) -> None: + if cmd_obj.cmd == "start": + self.start() + elif cmd_obj.cmd == "stop": + self.stop() + else: + raise ValueError(f"Invalid profiler cmd: {cmd_obj.cmd}") diff --git a/skills/lightllm-profiler-control/SKILL.md b/skills/lightllm-profiler-control/SKILL.md new file mode 100644 index 0000000000..2832a628bd --- /dev/null +++ b/skills/lightllm-profiler-control/SKILL.md @@ -0,0 +1,46 @@ +--- +name: lightllm-profiler-control +description: LightLLM profiler 使用说明。用于需要启动或停止 LightLLM 的 torch_profiler / nvtx profiling 功能时,尤其是查看 --enable_profiling、/profiler_start、/profiler_stop 的使用方法。 +--- + +# LightLLM Profiler 使用说明 + +## 使用场景 + +当用户需要使用 LightLLM profiler 功能时使用本 skill,包括: + +- 启动服务时打开 profiler 能力。 +- 通过 HTTP 接口控制 profiler start / stop。 + +## 启动方式 + +服务启动时增加 `--enable_profiling`: + +```bash +python -m lightllm.server.api_server \ + --model_dir /path/to/model \ + --enable_profiling torch_profiler +``` + +支持值: + +- `torch_profiler`:启用 PyTorch profiler,trace 默认写入 `./trace`,也可通过 `LIGHTLLM_TRACE_DIR` 指定目录。 +- `nvtx`:启用 NVTX range,配合 NVIDIA Nsight Systems 等外部工具采集。 + +未设置 `--enable_profiling` 时,`/profiler_start` 和 `/profiler_stop` 会返回未启用提示。 + +## HTTP 控制接口 + +启动 profiler: + +```bash +curl http://127.0.0.1:8000/profiler_start +``` + +停止 profiler: + +```bash +curl http://127.0.0.1:8000/profiler_stop +``` + +端口 `8000` 替换为服务启动时的 `--port`。