Skip to content
Merged
2 changes: 1 addition & 1 deletion docker/Dockerfile
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ FROM nvidia/cuda:${CUDA_VERSION}-cudnn-devel-ubuntu22.04
ARG PYTHON_VERSION=3.10
ARG MAMBA_VERSION=24.7.1-0
ARG VLLM_VERSION=0.21.0
ARG NIXL_REF=v1.1.0
ARG NIXL_REF=v1.2.0
ARG FLASH_MLA_REF=47c35a7
ARG DEEPGEMM_REF=891d57b4db1071624b5c8fa0d1e51cb317fa709f
ARG TARGETPLATFORM
Expand Down
268 changes: 142 additions & 126 deletions lightllm/common/kv_trans_kernel/nixl_kv_trans.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@
@triton.jit
def _page_io(
mem_index_ptr,
token_num,
page_write_head_num,
k_page_ptr,
k_page_stride_size,
k_page_stride_layer_num,
Expand Down Expand Up @@ -45,88 +47,91 @@ def _page_io(
k_stride_size = tl.cast(k_stride_size, dtype=tl.int64)
v_stride_size = tl.cast(v_stride_size, dtype=tl.int64)

tid = tl.program_id(0)
kv_head_id = tl.program_id(1)
page_head_id = page_head_start + kv_head_id
start_index = tl.program_id(0)
grid_num = tl.num_programs(0)

mem_index = tl.load(mem_index_ptr + tid)
off_dim = tl.arange(0, HEAD_DIM_BLOCK)
if NEED_MASK:
mask = off_dim < head_dim
else:
mask = None
for tid in tl.range(start_index, token_num, step=grid_num):
for kv_head_id in tl.range(page_write_head_num):

for layer_index in tl.range(layer_num, num_stages=3):
if IS_WRITE:
k_tensor = tl.load(
k_ptr
+ layer_index * k_stride_layer_num
+ mem_index * k_stride_size
+ kv_head_id * k_stride_head
+ off_dim * k_stride_dim,
mask=mask,
)
v_tensor = tl.load(
v_ptr
+ layer_index * v_stride_layer_num
+ mem_index * v_stride_size
+ kv_head_id * v_stride_head
+ off_dim * v_stride_dim,
mask=mask,
)
tl.store(
k_page_ptr
+ tid * k_page_stride_size
+ layer_index * k_page_stride_layer_num
+ page_head_id * k_page_stride_head
+ off_dim * k_page_stride_dim,
k_tensor,
mask=mask,
)
tl.store(
v_page_ptr
+ tid * v_page_stride_size
+ layer_index * v_page_stride_layer_num
+ page_head_id * v_page_stride_head
+ off_dim * v_page_stride_dim,
v_tensor,
mask=mask,
)
else:
k_page_tensor = tl.load(
k_page_ptr
+ tid * k_page_stride_size
+ layer_index * k_page_stride_layer_num
+ page_head_id * k_page_stride_head
+ off_dim * k_page_stride_dim,
mask=mask,
)
v_page_tensor = tl.load(
v_page_ptr
+ tid * v_page_stride_size
+ layer_index * v_page_stride_layer_num
+ page_head_id * v_page_stride_head
+ off_dim * v_page_stride_dim,
mask=mask,
)
tl.store(
k_ptr
+ layer_index * k_stride_layer_num
+ mem_index * k_stride_size
+ kv_head_id * k_stride_head
+ off_dim * k_stride_dim,
k_page_tensor,
mask=mask,
)
tl.store(
v_ptr
+ layer_index * v_stride_layer_num
+ mem_index * v_stride_size
+ kv_head_id * v_stride_head
+ off_dim * v_stride_dim,
v_page_tensor,
mask=mask,
)
page_head_id = page_head_start + kv_head_id
mem_index = tl.load(mem_index_ptr + tid)
off_dim = tl.arange(0, HEAD_DIM_BLOCK)
if NEED_MASK:
mask = off_dim < head_dim
else:
mask = None

for layer_index in tl.range(layer_num, num_stages=3):
if IS_WRITE:
k_tensor = tl.load(
k_ptr
+ layer_index * k_stride_layer_num
+ mem_index * k_stride_size
+ kv_head_id * k_stride_head
+ off_dim,
mask=mask,
)
v_tensor = tl.load(
v_ptr
+ layer_index * v_stride_layer_num
+ mem_index * v_stride_size
+ kv_head_id * v_stride_head
+ off_dim,
mask=mask,
)
tl.store(
k_page_ptr
+ tid * k_page_stride_size
+ layer_index * k_page_stride_layer_num
+ page_head_id * k_page_stride_head
+ off_dim,
k_tensor,
mask=mask,
)
tl.store(
v_page_ptr
+ tid * v_page_stride_size
+ layer_index * v_page_stride_layer_num
+ page_head_id * v_page_stride_head
+ off_dim,
v_tensor,
mask=mask,
)
else:
k_page_tensor = tl.load(
k_page_ptr
+ tid * k_page_stride_size
+ layer_index * k_page_stride_layer_num
+ page_head_id * k_page_stride_head
+ off_dim,
mask=mask,
)
v_page_tensor = tl.load(
v_page_ptr
+ tid * v_page_stride_size
+ layer_index * v_page_stride_layer_num
+ page_head_id * v_page_stride_head
+ off_dim,
mask=mask,
)
tl.store(
k_ptr
+ layer_index * k_stride_layer_num
+ mem_index * k_stride_size
+ kv_head_id * k_stride_head
+ off_dim,
k_page_tensor,
mask=mask,
)
tl.store(
v_ptr
+ layer_index * v_stride_layer_num
+ mem_index * v_stride_size
+ kv_head_id * v_stride_head
+ off_dim,
v_page_tensor,
mask=mask,
)
return


Expand Down Expand Up @@ -169,10 +174,17 @@ def page_io(
page_head_start = tp_index * (page_write_head_num)

token_num = len(mem_indexes)
grid = (token_num, page_write_head_num)
grid = (128,)

assert k_page_tensor.stride(3) == 1
assert v_page_tensor.stride(3) == 1
assert k_buffer.stride(3) == 1
assert v_buffer.stride(3) == 1

_page_io[grid](
mem_index_ptr=mem_indexes,
token_num=token_num,
page_write_head_num=page_write_head_num,
k_page_ptr=k_page_tensor,
k_page_stride_size=k_page_tensor.stride(0),
k_page_stride_layer_num=k_page_tensor.stride(1),
Expand Down Expand Up @@ -207,6 +219,7 @@ def page_io(
@triton.jit
def _mla_page_io(
mem_index_ptr,
token_num,
page_ptr,
page_stride_size,
page_stride_layer_num,
Expand All @@ -227,52 +240,54 @@ def _mla_page_io(
kv_stride_layer_num = tl.cast(kv_stride_layer_num, dtype=tl.int64)
kv_stride_size = tl.cast(kv_stride_size, dtype=tl.int64)

tid = tl.program_id(0)
start_index = tl.program_id(0)
grid_num = tl.num_programs(0)

mem_index = tl.load(mem_index_ptr + tid)
off_dim = tl.arange(0, HEAD_DIM_BLOCK)
if NEED_MASK:
mask = off_dim < head_dim
else:
mask = None

for layer_index in tl.range(layer_num, num_stages=3):
if IS_WRITE:
kv_tensor = tl.load(
kv_ptr
+ layer_index * kv_stride_layer_num
+ mem_index * kv_stride_size
+ 0 * kv_stride_head
+ off_dim * kv_stride_dim,
mask=mask,
)
tl.store(
page_ptr
+ tid * page_stride_size
+ layer_index * page_stride_layer_num
+ 0 * page_stride_head
+ off_dim * page_stride_dim,
kv_tensor,
mask=mask,
)
for tid in tl.range(start_index, token_num, step=grid_num):
mem_index = tl.load(mem_index_ptr + tid)
off_dim = tl.arange(0, HEAD_DIM_BLOCK)
if NEED_MASK:
mask = off_dim < head_dim
else:
page_tensor = tl.load(
page_ptr
+ tid * page_stride_size
+ layer_index * page_stride_layer_num
+ 0 * page_stride_head
+ off_dim * page_stride_dim,
mask=mask,
)
tl.store(
kv_ptr
+ layer_index * kv_stride_layer_num
+ mem_index * kv_stride_size
+ 0 * kv_stride_head
+ off_dim * kv_stride_dim,
page_tensor,
mask=mask,
)
mask = None

for layer_index in tl.range(layer_num, num_stages=3):
if IS_WRITE:
kv_tensor = tl.load(
kv_ptr
+ layer_index * kv_stride_layer_num
+ mem_index * kv_stride_size
+ 0 * kv_stride_head
+ off_dim * kv_stride_dim,
mask=mask,
)
tl.store(
page_ptr
+ tid * page_stride_size
+ layer_index * page_stride_layer_num
+ 0 * page_stride_head
+ off_dim * page_stride_dim,
kv_tensor,
mask=mask,
)
else:
page_tensor = tl.load(
page_ptr
+ tid * page_stride_size
+ layer_index * page_stride_layer_num
+ 0 * page_stride_head
+ off_dim * page_stride_dim,
mask=mask,
)
tl.store(
kv_ptr
+ layer_index * kv_stride_layer_num
+ mem_index * kv_stride_size
+ 0 * kv_stride_head
+ off_dim * kv_stride_dim,
page_tensor,
mask=mask,
)
return


Expand All @@ -290,10 +305,11 @@ def mla_page_io(mem_indexes: torch.Tensor, page_tensor: torch.Tensor, kv_buffer:
assert page_head_num == kv_head_num == 1

token_num = len(mem_indexes)
grid = (token_num,)
grid = (64,)

_mla_page_io[grid](
mem_index_ptr=mem_indexes,
token_num=token_num,
page_ptr=page_tensor,
page_stride_size=page_tensor.stride(0),
page_stride_layer_num=page_tensor.stride(1),
Expand Down
4 changes: 2 additions & 2 deletions lightllm/server/httpserver/manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -380,9 +380,9 @@ async def generate(
pickle.dumps((ObjType.NIXL_UPLOAD_NP_PROMPT_IDS, group_request_id, prompt_ids))
)
try:
await asyncio.wait_for(nixl_pd_event.wait(), timeout=80)
await asyncio.wait_for(nixl_pd_event.wait(), timeout=180)
except asyncio.TimeoutError:
logger.error(f"nixl np node wait nixl_pd_event 36s time out, group_req_id {group_request_id}")
logger.error(f"nixl np node wait nixl_pd_event 180s time out, group_req_id {group_request_id}")
raise Exception(f"group_req_id {group_request_id} wait nixl_pd_event time out")

decode_node_info: NIXLDecodeNodeInfo = nixl_pd_event.decode_node_info
Expand Down
2 changes: 1 addition & 1 deletion lightllm/server/httpserver_for_pd_master/manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -331,7 +331,7 @@ async def fetch_nixl_stream(
)

try:
await asyncio.wait_for(up_status_event.wait(), timeout=60)
await asyncio.wait_for(up_status_event.wait(), timeout=180)
except asyncio.TimeoutError:
logger.warning(f"group_request_id: {group_request_id} kv move time out err, server is busy now.")
raise ServerBusyError()
Expand Down
5 changes: 4 additions & 1 deletion lightllm/server/pd_io_struct.py
Original file line number Diff line number Diff line change
Expand Up @@ -273,6 +273,8 @@ class NIXLChunckedTransTask:
first_gen_token_id: Optional[int]
first_gen_token_logprob: Optional[float]

nixl_write_stage: Optional[str] = None

# transfer params
nixl_src_page_index: Optional[int] = None
nixl_dst_page_index: Optional[int] = None
Expand All @@ -284,6 +286,7 @@ class NIXLChunckedTransTask:
start_trans_time: float = None # 用于标记传输开始的时间。同时标记是否正在传输中

error_info: Optional[str] = None
transfer_time_out_secs: int = 66

def __post_init__(self):
if self.start_kv_index < 0 or self.end_kv_index < self.start_kv_index:
Expand All @@ -300,7 +303,7 @@ def time_out(self) -> bool:
return True
return False
else:
if time.time() - self.start_trans_time > self.time_out_secs + 88:
if time.time() - self.start_trans_time > self.transfer_time_out_secs:
return True
else:
return False
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -178,14 +178,17 @@ def _create_nixl_trans_task(
):
# 确定传输设备
if req_obj.nixl_trans_device_id == -1:
if not hasattr(self, "nixl_iter_device_id"):
self.nixl_iter_device_id = 0
req_obj.nixl_trans_device_id = self.nixl_iter_device_id
# only self.is_master_in_dp will be used.
req_obj.nixl_trans_device_id = random.randint(0, self.node_world_size - 1)
self.nixl_iter_device_id = (self.nixl_iter_device_id + 1) % self.node_world_size

trans_task = NIXLChunckedTransTask(
request_id=req_obj.req_id,
start_kv_index=kv_start_index,
end_kv_index=kv_end_index,
time_out_secs=80,
time_out_secs=180,
pd_master_node_id=req_obj.sampling_param.pd_master_node_id,
prefill_dp_index=None,
decode_dp_index=self.dp_rank_in_node,
Expand Down
Loading
Loading