Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 18 additions & 4 deletions ep/bench/buffer.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import os
import datetime
import torch
import torch.distributed as dist
from typing import Callable, Tuple, Optional, Union, List
Expand Down Expand Up @@ -97,6 +98,16 @@ def __init__(
rdma_buffer_is_host_allocated = bool(torch.version.cuda)

rdma_buffer_ptr = self.scratch.data_ptr()
obj_timeout_secs = int(
os.getenv(
"UCCL_OBJ_PG_TIMEOUT_SECS", os.getenv("UCCL_PG_TIMEOUT_SECS", "120")
)
)
self.object_group = dist.new_group(
list(range(dist.get_world_size(group))),
backend="gloo",
timeout=datetime.timedelta(seconds=obj_timeout_secs),
)
self.proxies, self.workers = initialize_uccl(
rdma_buffer_ptr,
num_rdma_bytes,
Expand All @@ -106,8 +117,9 @@ def __init__(
use_normal_mode=not low_latency_mode,
is_intranode=is_intranode,
rdma_buffer_is_host_allocated=rdma_buffer_is_host_allocated,
object_group=self.object_group,
)
check_nvlink_connections(group)
check_nvlink_connections(group, object_group=self.object_group)

# Initialize the CPP runtime
self.rank = group.rank()
Expand Down Expand Up @@ -135,14 +147,14 @@ def __init__(
] * self.group_size
local_device_id = self.runtime.get_local_device_id()
# print("Before all_gather_object device_ids", local_device_id, flush=True)
dist.all_gather_object(device_ids, local_device_id, group)
dist.all_gather_object(device_ids, local_device_id, self.object_group)
# Synchronize IPC handles
ipc_handles = [
None,
] * self.group_size
local_ipc_handle = self.runtime.get_local_ipc_handle()
# print("Before all_gather_object ipc_handles", local_ipc_handle, flush=True)
dist.all_gather_object(ipc_handles, local_ipc_handle, group)
dist.all_gather_object(ipc_handles, local_ipc_handle, self.object_group)

rdma_ipc_handles = [None] * self.group_size
# CUDA IPC only works with device memory; skip when using cudaMallocHost.
Expand All @@ -151,7 +163,9 @@ def __init__(
if self.num_rdma_bytes > 0 and not rdma_buffer_is_host_allocated
else None
)
dist.all_gather_object(rdma_ipc_handles, local_rdma_ipc_handle, group)
dist.all_gather_object(
rdma_ipc_handles, local_rdma_ipc_handle, self.object_group
)
root_unique_id = None
# Make CPP runtime available
self.runtime.sync(
Expand Down
35 changes: 27 additions & 8 deletions ep/bench/utils.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import inspect
from typing import Any, Optional, Tuple, Union
import os
import datetime
import torch
import torch.distributed as dist
from typing import Optional
Expand Down Expand Up @@ -74,13 +75,16 @@ def init_dist(local_rank: int, num_local_ranks: int):

def init_dist_under_torchrun(local_rank: int, num_local_ranks: int):
# torchrun already sets RANK, WORLD_SIZE, MASTER_ADDR, MASTER_PORT
torch.cuda.set_device(local_rank)
timeout_secs = int(os.getenv("UCCL_PG_TIMEOUT_SECS", "120"))
dist.init_process_group(
backend="nccl", device_id=torch.device(f"cuda:{local_rank}")
backend="nccl",
device_id=torch.device(f"cuda:{local_rank}"),
timeout=datetime.timedelta(seconds=timeout_secs),
)

torch.set_default_dtype(torch.bfloat16)
torch.set_default_device(f"cuda:{local_rank}")
torch.cuda.set_device(local_rank)

return (
dist.get_rank(),
Expand Down Expand Up @@ -110,7 +114,9 @@ def get_peer_ip(rank: int, num_ranks: int, group: dist.ProcessGroup):
return peer_ip if peer_ip else ""


def get_cpu_proxies_meta(proxies, rank, scratch_ptr, scratch_bytes, num_ranks, group):
def get_cpu_proxies_meta(
proxies, rank, scratch_ptr, scratch_bytes, num_ranks, group, object_group=None
):
my_ip = ep.get_oob_ip()
meta = {
"rank": rank,
Expand All @@ -125,8 +131,9 @@ def get_cpu_proxies_meta(proxies, rank, scratch_ptr, scratch_bytes, num_ranks, g
device_index = int(os.environ["LOCAL_RANK"])
else:
device_index = torch.cuda.current_device()
torch.cuda.set_device(device_index)
dist.all_gather_object(all_meta, meta, group=group)
# torch.cuda.set_device(device_index)
collect_group = object_group if object_group is not None else group
dist.all_gather_object(all_meta, meta, group=collect_group)
rank2meta = {m["rank"]: m for m in all_meta}

# Debug: print IP distribution
Expand All @@ -142,7 +149,9 @@ def get_cpu_proxies_meta(proxies, rank, scratch_ptr, scratch_bytes, num_ranks, g
return rank2meta


def check_nvlink_connections(group: dist.ProcessGroup):
def check_nvlink_connections(
group: dist.ProcessGroup, object_group: Optional[dist.ProcessGroup] = None
):
"""
Check NVLink connection between every pair of GPUs.

Expand Down Expand Up @@ -170,7 +179,10 @@ def check_nvlink_connections(group: dist.ProcessGroup):
physical_device_indices = [
0,
] * group.size()
dist.all_gather_object(physical_device_indices, physical_device_idx, group)
collect_group = object_group if object_group is not None else group
dist.all_gather_object(
physical_device_indices, physical_device_idx, collect_group
)

# Check whether they are all connected via NVLink
# Reference: https://github.com/vllm-project/vllm/blob/b8e809a057765c574726a6077fd124db5077ce1f/vllm/platforms/cuda.py#L438
Expand Down Expand Up @@ -514,6 +526,7 @@ def initialize_uccl(
is_intranode=False,
use_normal_mode=False,
rdma_buffer_is_host_allocated=False,
object_group=None,
):
try:
for shm_file in glob.glob("/dev/shm/uccl_barrier_*"):
Expand Down Expand Up @@ -576,7 +589,13 @@ def initialize_uccl(
proxies.append(proxy)

rank2meta = get_cpu_proxies_meta(
proxies, rank, scratch_ptr, scratch_nbytes, num_ranks, group
proxies,
rank,
scratch_ptr,
scratch_nbytes,
num_ranks,
group,
object_group=object_group,
)
peers_meta_list = [rank2meta[r] for r in range(num_ranks)]

Expand Down
35 changes: 29 additions & 6 deletions ep/include/ep_config.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -196,8 +196,21 @@ struct LowLatencyLayout {
size_t num_bytes_per_combine_msg = hidden * sizeof(nv_bfloat16);

// Send buffer
size_t dispatch_send_buffer_bytes =
num_max_dispatch_tokens_per_rank * num_bytes_per_dispatch_msg;
// Buffer layout for RDMA sends, used by the batched RDMA-send path in the
// dispatch-LL kernel.
// clang-format off
// ┌──────────────────────────────────────────┬──────────────────────────────────────────────────────────┐
// │ Temp buffer (offset 0) │ Per-expert RDMA batch buffer (offset num_max_token) │
// │ rdma_x[token_idx] │ rdma_x[num_max_token + expert * num_max_token + slot] │
// │ Size: num_max_token * msg_size │ Size: num_experts * num_max_token * msg_size │
// └──────────────────────────────────────────┴──────────────────────────────────────────────────────────┘
// clang-format on
// Flow: (optional FP8 cast) -> temp buffer -> copy to per-expert batch
// buffer -> batched RDMA send
// TODO: Support per-GPU destination batching in this path.
size_t dispatch_send_buffer_bytes = (num_experts + 1) *
num_max_dispatch_tokens_per_rank *
num_bytes_per_dispatch_msg;
size_t combine_send_buffer_bytes = num_experts *
num_max_dispatch_tokens_per_rank *
num_bytes_per_combine_msg;
Expand All @@ -220,16 +233,26 @@ struct LowLatencyLayout {
total_bytes += recv_buffer_bytes * 2;

// Symmetric signaling buffers
size_t dispatch_recv_count_buffer_bytes = num_experts * sizeof(int);
size_t combine_recv_flag_buffer_bytes = dispatch_recv_count_buffer_bytes;
// Dispatch-LL uses one count per (dst_rank, src_rank); combine uses one
// flag per expert. Both share the same signaling region, so size by max.
size_t dispatch_recv_count_buffer_bytes =
static_cast<size_t>(num_ranks * num_ranks) * sizeof(int);
size_t combine_recv_flag_buffer_bytes = num_experts * sizeof(int);
size_t signaling_buffer_bytes = std::max(dispatch_recv_count_buffer_bytes,
combine_recv_flag_buffer_bytes);
size_t signaling_buffer_bytes_aligned =
align<size_t>(signaling_buffer_bytes, 128);
total_bytes += signaling_buffer_bytes_aligned * 2;

// Internode signaling buffers (for RDMA atomics): use 64-bit slots.
size_t signaling_buffer_bytes_internode = num_experts * sizeof(int64_t);
// Dispatch count and combine flag internode buffers share this region.
size_t dispatch_recv_count_buffer_bytes_internode =
static_cast<size_t>(num_ranks * num_ranks) * sizeof(int64_t);
size_t combine_recv_flag_buffer_bytes_internode =
num_experts * sizeof(int64_t);
size_t signaling_buffer_bytes_internode = std::max(
dispatch_recv_count_buffer_bytes_internode,
combine_recv_flag_buffer_bytes_internode);
size_t signaling_buffer_bytes_internode_aligned =
align<size_t>(signaling_buffer_bytes_internode, 128);
// These internode signaling buffers live inside `atomic_buffer_ptr` (not
Expand Down Expand Up @@ -286,4 +309,4 @@ size_t get_low_latency_rdma_size_hint(int num_max_dispatch_tokens_per_rank,
NUM_BUFFER_ALIGNMENT_BYTES;
}

} // namespace uccl
} // namespace uccl
32 changes: 31 additions & 1 deletion ep/include/ring_buffer.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,36 @@ struct TransferCmd {
static_assert(sizeof(TransferCmd) * 8 == 128, "TransferCmd must be 128 bits");
#endif

// TransferCmd::bytes is 24-bit. For dispatch WRITE commands (non-combine), we
// borrow the top 2 bits from expert_idx to extend bytes to 26-bit.
constexpr uint32_t kTransferCmdBytesMask = (1u << 24) - 1;
constexpr uint16_t kTransferCmdBytesExtShift = 14;
constexpr uint16_t kTransferCmdBytesExtMask = (1u << 2) - 1;
constexpr uint16_t kTransferCmdExpertIdxMask = (1u << 14) - 1;

__host__ __device__ inline bool is_dispatch_write_cmd(
TransferCmd const& cmd) {
return get_base_cmd(cmd.cmd_type) == CmdType::WRITE && !get_is_combine(cmd.cmd_type);
}

__host__ __device__ inline uint32_t get_transfer_cmd_bytes(
TransferCmd const& cmd) {
uint32_t bytes = cmd.bytes;
if (is_dispatch_write_cmd(cmd)) {
bytes |=
(static_cast<uint32_t>(cmd.expert_idx >> kTransferCmdBytesExtShift)
<< 24);
}
return bytes;
}

__host__ __device__ inline uint16_t get_transfer_cmd_expert_idx(
TransferCmd const& cmd) {
if (is_dispatch_write_cmd(cmd))
return static_cast<uint16_t>(cmd.expert_idx & kTransferCmdExpertIdxMask);
return cmd.expert_idx;
}

struct CopyTask {
uint64_t wr_id;
int dst_dev;
Expand Down Expand Up @@ -461,4 +491,4 @@ static inline void free_cmd_ring(uintptr_t addr) {
}
}

#endif // RING_BUFFER_CUH
#endif // RING_BUFFER_CUH
58 changes: 48 additions & 10 deletions ep/include/uccl_ibgda.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ __device__ __forceinline__ void nvshmemi_ibgda_put_nbi_warp(
int expert_idx, int lane_id, int message_idx,
uint64_t const* d2h_channel_addrs, int num_d2h_channel_addrs,
bool is_combine, int low_latency_buffer_idx = 0, uint64_t atomic_offset = 0,
uint64_t atomic_val = 0) {
uint64_t atomic_val = 0, int num_tokens = 1) {
// NOTE(MaoZiming): different from the nvshmemi_ibgda_put_nbi_warp in
// ibgda_device.cuh, we don't do warp-cooperation.
if (lane_id != 0) return;
Expand Down Expand Up @@ -60,13 +60,31 @@ __device__ __forceinline__ void nvshmemi_ibgda_put_nbi_warp(
make_cmd_type(CmdType::WRITE, is_combine, low_latency_buffer_idx);
cmd.req_rptr = rptr_val;
cmd.req_lptr = lptr_val;
cmd.bytes = bytes_val;
uint32_t cmd_bytes = static_cast<uint32_t>(bytes_val);
uint16_t cmd_expert_idx = static_cast<uint16_t>(expert_idx);
if constexpr (!use_normal_mode) {
if (!is_combine) {
EP_DEVICE_ASSERT((expert_idx & ~kTransferCmdExpertIdxMask) == 0);
EP_DEVICE_ASSERT((cmd_bytes >> 26) == 0);
auto bytes_hi2 = static_cast<uint16_t>(cmd_bytes >> 24);
cmd_expert_idx = static_cast<uint16_t>(
(expert_idx & kTransferCmdExpertIdxMask) |
(bytes_hi2 << kTransferCmdBytesExtShift));
cmd_bytes &= kTransferCmdBytesMask;
} else {
EP_DEVICE_ASSERT((cmd_bytes >> 24) == 0);
}
}
cmd.bytes = cmd_bytes;
cmd.dst_rank = dst_rank;
if constexpr (use_normal_mode) {
cmd.atomic_offset = atomic_offset;
cmd.atomic_val = atomic_val;
} else {
cmd.expert_idx = expert_idx;
cmd.expert_idx = cmd_expert_idx;
// Low-latency WRITE: use atomic_val byte for num_tokens (1..255).
EP_DEVICE_ASSERT(num_tokens > 0 && num_tokens <= 255);
cmd.atomic_val = static_cast<uint8_t>(num_tokens);
}
h->atomic_set_and_commit(cmd, &slot);
}
Expand All @@ -91,12 +109,29 @@ __device__ __forceinline__ void nvshmemi_ibgda_put_nbi_warp(
make_cmd_type(CmdType::WRITE, is_combine, low_latency_buffer_idx);
cmd.req_rptr = rptr_val;
cmd.req_lptr = lptr_val;
cmd.bytes = bytes_val;
uint32_t cmd_bytes = static_cast<uint32_t>(bytes_val);
uint16_t cmd_expert_idx = static_cast<uint16_t>(expert_idx);
if constexpr (!use_normal_mode) {
if (!is_combine) {
EP_DEVICE_ASSERT((expert_idx & ~kTransferCmdExpertIdxMask) == 0);
EP_DEVICE_ASSERT((cmd_bytes >> 26) == 0);
auto bytes_hi2 = static_cast<uint16_t>(cmd_bytes >> 24);
cmd_expert_idx = static_cast<uint16_t>(
(expert_idx & kTransferCmdExpertIdxMask) |
(bytes_hi2 << kTransferCmdBytesExtShift));
cmd_bytes &= kTransferCmdBytesMask;
} else {
EP_DEVICE_ASSERT((cmd_bytes >> 24) == 0);
}
}
cmd.bytes = cmd_bytes;
cmd.dst_rank = dst_rank;
if (bytes_val >> 24) {
printf("[nvshmemi_ibgda_put_nbi_warp] bytes too large: %llu\n",
(unsigned long long)bytes_val);
trap();
if constexpr (use_normal_mode) {
if (bytes_val >> 24) {
printf("[nvshmemi_ibgda_put_nbi_warp] bytes too large: %llu\n",
(unsigned long long)bytes_val);
trap();
}
}

if constexpr (use_normal_mode) {
Expand All @@ -114,7 +149,10 @@ __device__ __forceinline__ void nvshmemi_ibgda_put_nbi_warp(
cmd.atomic_offset = atomic_offset;
cmd.atomic_val = atomic_val;
} else {
cmd.expert_idx = expert_idx;
cmd.expert_idx = cmd_expert_idx;
// Low-latency WRITE: use atomic_val byte for num_tokens (1..255).
EP_DEVICE_ASSERT(num_tokens > 0 && num_tokens <= 255);
cmd.atomic_val = static_cast<uint8_t>(num_tokens);
}
h->atomic_set_and_commit(cmd, &slot);
break;
Expand Down Expand Up @@ -380,4 +418,4 @@ __forceinline__ __device__ void nvshmem_sync_with_same_gpu_idx(
}
}

} // namespace uccl
} // namespace uccl
6 changes: 5 additions & 1 deletion ep/include/uccl_proxy.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,11 @@ class UcclProxy {
num_experts * num_tokens * hidden * 2; // sizeof(bfloat16)
size_t send_buffer_bytes =
std::max(dispatch_send_buffer_bytes, combine_send_buffer_bytes);
size_t dispatch_recv_count_buffer_bytes = num_experts * 4;
size_t const signaling_slots = std::max(
static_cast<size_t>(num_experts),
static_cast<size_t>(proxy_->cfg_.num_ranks) *
static_cast<size_t>(proxy_->cfg_.num_ranks));
size_t dispatch_recv_count_buffer_bytes = signaling_slots * 4;
size_t signaling_buffer_bytes_aligned =
((dispatch_recv_count_buffer_bytes + 127) / 128) * 128;
uintptr_t dispatch_recv_data_offset =
Expand Down
Loading