Skip to content
Open
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions lightllm/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
from lightllm.utils.device_utils import is_musa

if is_musa():
import torchada # noqa: F401
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,8 @@ def _fwd_kernel_token_att1(
).to(tl.int64)
off_k = k_loc[:, None] * stride_kbs + cur_kv_head * stride_kh + offs_d[None, :] * stride_kd
k = tl.load(K + off_k, mask=offs_n_new[:, None] < cur_batch_end_index, other=0.0)
att_value = tl.sum(q[None, :] * k, 1, dtype=tl.float32)
att_value = tl.sum(q[None, :] * k, 1)
att_value = att_value.to(tl.float32)
att_value *= sm_scale
off_o = cur_head * att_stride_h + (cur_batch_in_all_start_index + offs_n) * att_stride_bs
tl.store(Att_Out + off_o, att_value, mask=offs_n_new < cur_batch_end_index)
Expand Down
16 changes: 13 additions & 3 deletions lightllm/server/api_models.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import time
import uuid

from pydantic import BaseModel, Field, field_validator, model_validator
from pydantic import BaseModel, ConfigDict, Field, field_validator, model_validator
from typing import Any, Dict, List, Optional, Union, Literal, ClassVar
from transformers import GenerationConfig

Expand All @@ -21,6 +21,14 @@ class Message(BaseModel):
content: Union[str, List[MessageContent]]


class CharacterMessage(BaseModel):
"""Message format for character-based chat, where role is inferred from name."""

name: str
content: Union[str, List[MessageContent]]
role: Optional[str] = None # Optional, can be inferred from role_setting


class Function(BaseModel):
"""Function descriptions."""

Expand Down Expand Up @@ -105,7 +113,7 @@ def _normalize_role(cls, v):
raise ValueError("'role' must be a string")


ChatCompletionMessageParam = Union[ChatCompletionMessageGenericParam, Message]
ChatCompletionMessageParam = Union[ChatCompletionMessageGenericParam, Message, CharacterMessage]


class CompletionRequest(BaseModel):
Expand Down Expand Up @@ -176,6 +184,8 @@ def apply_loaded_defaults(cls, data: Any):


class ChatCompletionRequest(BaseModel):
model_config = ConfigDict(populate_by_name=True)

model: str
messages: List[ChatCompletionMessageParam]
function_call: Optional[str] = "none"
Expand Down Expand Up @@ -216,7 +226,7 @@ class ChatCompletionRequest(BaseModel):
top_k: Optional[int] = -1
repetition_penalty: Optional[float] = 1.0
ignore_eos: Optional[bool] = False
role_settings: Optional[Dict[str, str]] = None
role_settings: Optional[Dict[str, str]] = Field(default=None, alias="role_setting")
character_settings: Optional[List[Dict[str, str]]] = None

# Class variables to store loaded default values
Expand Down
3 changes: 2 additions & 1 deletion lightllm/server/api_openai.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,7 +105,8 @@ def _get_history_tool_calls_cnt(request: ChatCompletionRequest) -> int:
messages = getattr(request, "messages", [])
idx = 0
for msg in messages:
if msg.role == "assistant":
role = getattr(msg, "role", None)
if role == "assistant":
tool_calls = getattr(msg, "tool_calls", None)
idx += len(list(tool_calls)) if tool_calls is not None else 0 # noqa
return idx
Expand Down
6 changes: 5 additions & 1 deletion lightllm/server/build_prompt.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,11 @@ async def build_prompt(request, tools) -> str:
global tokenizer
# pydantic格式转成dict, 否则,当根据tokenizer_config.json拼template时,Jinja判断无法识别
messages = [m.model_dump(by_alias=True, exclude_none=True) for m in request.messages]
kwargs = {"conversation": messages}
kwargs = {
"conversation": messages,
# 假设 request 对象里有这个字段,或者你想传空
"system_instruction": getattr(request, "system_instruction", ""),
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The code attempts to access request.system_instruction, but this field is not defined in the ChatCompletionRequest model in api_models.py. This will always result in an empty string "" being used due to getattr, making this new parameter ineffective.

To properly implement this feature, you should add system_instruction as an optional field to the ChatCompletionRequest model in lightllm/server/api_models.py.

For example:

# In lightllm/server/api_models.py
class ChatCompletionRequest(BaseModel):
    # ...
    messages: List[ChatCompletionMessageParam]
    system_instruction: Optional[str] = None
    # ...

Additionally, the Chinese comment # 假设 request 对象里有这个字段,或者你想传空 is informal. It would be better to remove it once the feature is fully implemented, or replace it with a formal English comment explaining the purpose of system_instruction.

}
if request.character_settings:
kwargs["character_settings"] = request.character_settings
if request.role_settings:
Expand Down
31 changes: 25 additions & 6 deletions lightllm/utils/device_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,11 +81,14 @@ def calcu_kernel_best_vsm_count(kernel, num_warps):
return num_sm


@lru_cache(maxsize=1)
def is_musa():
return hasattr(torch.version, "musa") and torch.version.musa is not None


@lru_cache(maxsize=None)
def get_current_device_name():
import torch

if torch.cuda.is_available():
if torch.cuda.is_available() or is_musa():
device = torch.cuda.current_device()
gpu_name = torch.cuda.get_device_name(device)
# 4090 trans to 4090 D
Expand All @@ -103,8 +106,6 @@ def init_p2p(device_index):
"""
torch 调用跨卡的to操作后,triton编译的算子便能自动操作跨卡tensor。
"""
import torch

num_gpus = torch.cuda.device_count()
tensor = torch.zeros((1,))
tensor = tensor.to(f"cuda:{device_index}")
Expand All @@ -127,8 +128,26 @@ def has_nvlink():
result = result.decode("utf-8")
# Check if the output contains 'NVLink'
return any(f"NV{i}" in result for i in range(1, 8))
except FileNotFoundError:
# nvidia-smi is not installed, assume no NVLink
return False
except subprocess.CalledProcessError:
# If there's an error while executing nvidia-smi, assume no NVLink
return False


def has_mtlink():
try:
# Call mthreads-gmi to get the topology matrix
result = subprocess.check_output(["mthreads-gmi", "topo", "--matrix"])
result = result.decode("utf-8")
# Check if the output contains 'MTLink'
return any(f"MT{i}" in result for i in range(1, 8))
except FileNotFoundError:
# mthreads-gmi is not installed, assume no MTLink
return False
except subprocess.CalledProcessError:
# If there's an error (e.g., nvidia-smi is not installed or another issue), assume no NVLink
# If there's an error while executing mthreads-gmi, assume no MTLink
return False


Expand Down