diff --git a/ep/bench/buffer.py b/ep/bench/buffer.py index 14ef8cc34..8f4303482 100644 --- a/ep/bench/buffer.py +++ b/ep/bench/buffer.py @@ -1,4 +1,5 @@ import os +import datetime import torch import torch.distributed as dist from typing import Callable, Tuple, Optional, Union, List @@ -97,6 +98,16 @@ def __init__( rdma_buffer_is_host_allocated = bool(torch.version.cuda) rdma_buffer_ptr = self.scratch.data_ptr() + obj_timeout_secs = int( + os.getenv( + "UCCL_OBJ_PG_TIMEOUT_SECS", os.getenv("UCCL_PG_TIMEOUT_SECS", "120") + ) + ) + self.object_group = dist.new_group( + list(range(dist.get_world_size(group))), + backend="gloo", + timeout=datetime.timedelta(seconds=obj_timeout_secs), + ) self.proxies, self.workers = initialize_uccl( rdma_buffer_ptr, num_rdma_bytes, @@ -106,8 +117,9 @@ def __init__( use_normal_mode=not low_latency_mode, is_intranode=is_intranode, rdma_buffer_is_host_allocated=rdma_buffer_is_host_allocated, + object_group=self.object_group, ) - check_nvlink_connections(group) + check_nvlink_connections(group, object_group=self.object_group) # Initialize the CPP runtime self.rank = group.rank() @@ -135,14 +147,14 @@ def __init__( ] * self.group_size local_device_id = self.runtime.get_local_device_id() # print("Before all_gather_object device_ids", local_device_id, flush=True) - dist.all_gather_object(device_ids, local_device_id, group) + dist.all_gather_object(device_ids, local_device_id, self.object_group) # Synchronize IPC handles ipc_handles = [ None, ] * self.group_size local_ipc_handle = self.runtime.get_local_ipc_handle() # print("Before all_gather_object ipc_handles", local_ipc_handle, flush=True) - dist.all_gather_object(ipc_handles, local_ipc_handle, group) + dist.all_gather_object(ipc_handles, local_ipc_handle, self.object_group) rdma_ipc_handles = [None] * self.group_size # CUDA IPC only works with device memory; skip when using cudaMallocHost. @@ -151,7 +163,9 @@ def __init__( if self.num_rdma_bytes > 0 and not rdma_buffer_is_host_allocated else None ) - dist.all_gather_object(rdma_ipc_handles, local_rdma_ipc_handle, group) + dist.all_gather_object( + rdma_ipc_handles, local_rdma_ipc_handle, self.object_group + ) root_unique_id = None # Make CPP runtime available self.runtime.sync( diff --git a/ep/bench/utils.py b/ep/bench/utils.py index 1bcc7fee3..46ae81c1d 100644 --- a/ep/bench/utils.py +++ b/ep/bench/utils.py @@ -1,6 +1,7 @@ import inspect from typing import Any, Optional, Tuple, Union import os +import datetime import torch import torch.distributed as dist from typing import Optional @@ -74,13 +75,16 @@ def init_dist(local_rank: int, num_local_ranks: int): def init_dist_under_torchrun(local_rank: int, num_local_ranks: int): # torchrun already sets RANK, WORLD_SIZE, MASTER_ADDR, MASTER_PORT + torch.cuda.set_device(local_rank) + timeout_secs = int(os.getenv("UCCL_PG_TIMEOUT_SECS", "120")) dist.init_process_group( - backend="nccl", device_id=torch.device(f"cuda:{local_rank}") + backend="nccl", + device_id=torch.device(f"cuda:{local_rank}"), + timeout=datetime.timedelta(seconds=timeout_secs), ) torch.set_default_dtype(torch.bfloat16) torch.set_default_device(f"cuda:{local_rank}") - torch.cuda.set_device(local_rank) return ( dist.get_rank(), @@ -110,7 +114,9 @@ def get_peer_ip(rank: int, num_ranks: int, group: dist.ProcessGroup): return peer_ip if peer_ip else "" -def get_cpu_proxies_meta(proxies, rank, scratch_ptr, scratch_bytes, num_ranks, group): +def get_cpu_proxies_meta( + proxies, rank, scratch_ptr, scratch_bytes, num_ranks, group, object_group=None +): my_ip = ep.get_oob_ip() meta = { "rank": rank, @@ -125,8 +131,9 @@ def get_cpu_proxies_meta(proxies, rank, scratch_ptr, scratch_bytes, num_ranks, g device_index = int(os.environ["LOCAL_RANK"]) else: device_index = torch.cuda.current_device() - torch.cuda.set_device(device_index) - dist.all_gather_object(all_meta, meta, group=group) + # torch.cuda.set_device(device_index) + collect_group = object_group if object_group is not None else group + dist.all_gather_object(all_meta, meta, group=collect_group) rank2meta = {m["rank"]: m for m in all_meta} # Debug: print IP distribution @@ -142,7 +149,9 @@ def get_cpu_proxies_meta(proxies, rank, scratch_ptr, scratch_bytes, num_ranks, g return rank2meta -def check_nvlink_connections(group: dist.ProcessGroup): +def check_nvlink_connections( + group: dist.ProcessGroup, object_group: Optional[dist.ProcessGroup] = None +): """ Check NVLink connection between every pair of GPUs. @@ -170,7 +179,10 @@ def check_nvlink_connections(group: dist.ProcessGroup): physical_device_indices = [ 0, ] * group.size() - dist.all_gather_object(physical_device_indices, physical_device_idx, group) + collect_group = object_group if object_group is not None else group + dist.all_gather_object( + physical_device_indices, physical_device_idx, collect_group + ) # Check whether they are all connected via NVLink # Reference: https://github.com/vllm-project/vllm/blob/b8e809a057765c574726a6077fd124db5077ce1f/vllm/platforms/cuda.py#L438 @@ -514,6 +526,7 @@ def initialize_uccl( is_intranode=False, use_normal_mode=False, rdma_buffer_is_host_allocated=False, + object_group=None, ): try: for shm_file in glob.glob("/dev/shm/uccl_barrier_*"): @@ -576,7 +589,13 @@ def initialize_uccl( proxies.append(proxy) rank2meta = get_cpu_proxies_meta( - proxies, rank, scratch_ptr, scratch_nbytes, num_ranks, group + proxies, + rank, + scratch_ptr, + scratch_nbytes, + num_ranks, + group, + object_group=object_group, ) peers_meta_list = [rank2meta[r] for r in range(num_ranks)] diff --git a/ep/include/ep_config.hpp b/ep/include/ep_config.hpp index 451b3d27b..393ec5409 100644 --- a/ep/include/ep_config.hpp +++ b/ep/include/ep_config.hpp @@ -196,8 +196,21 @@ struct LowLatencyLayout { size_t num_bytes_per_combine_msg = hidden * sizeof(nv_bfloat16); // Send buffer - size_t dispatch_send_buffer_bytes = - num_max_dispatch_tokens_per_rank * num_bytes_per_dispatch_msg; + // Buffer layout for RDMA sends, used by the batched RDMA-send path in the + // dispatch-LL kernel. + // clang-format off + // ┌──────────────────────────────────────────┬──────────────────────────────────────────────────────────┐ + // │ Temp buffer (offset 0) │ Per-expert RDMA batch buffer (offset num_max_token) │ + // │ rdma_x[token_idx] │ rdma_x[num_max_token + expert * num_max_token + slot] │ + // │ Size: num_max_token * msg_size │ Size: num_experts * num_max_token * msg_size │ + // └──────────────────────────────────────────┴──────────────────────────────────────────────────────────┘ + // clang-format on + // Flow: (optional FP8 cast) -> temp buffer -> copy to per-expert batch + // buffer -> batched RDMA send + // TODO: Support per-GPU destination batching in this path. + size_t dispatch_send_buffer_bytes = (num_experts + 1) * + num_max_dispatch_tokens_per_rank * + num_bytes_per_dispatch_msg; size_t combine_send_buffer_bytes = num_experts * num_max_dispatch_tokens_per_rank * num_bytes_per_combine_msg; @@ -220,8 +233,11 @@ struct LowLatencyLayout { total_bytes += recv_buffer_bytes * 2; // Symmetric signaling buffers - size_t dispatch_recv_count_buffer_bytes = num_experts * sizeof(int); - size_t combine_recv_flag_buffer_bytes = dispatch_recv_count_buffer_bytes; + // Dispatch-LL uses one count per (dst_rank, src_rank); combine uses one + // flag per expert. Both share the same signaling region, so size by max. + size_t dispatch_recv_count_buffer_bytes = + static_cast(num_ranks * num_ranks) * sizeof(int); + size_t combine_recv_flag_buffer_bytes = num_experts * sizeof(int); size_t signaling_buffer_bytes = std::max(dispatch_recv_count_buffer_bytes, combine_recv_flag_buffer_bytes); size_t signaling_buffer_bytes_aligned = @@ -229,7 +245,14 @@ struct LowLatencyLayout { total_bytes += signaling_buffer_bytes_aligned * 2; // Internode signaling buffers (for RDMA atomics): use 64-bit slots. - size_t signaling_buffer_bytes_internode = num_experts * sizeof(int64_t); + // Dispatch count and combine flag internode buffers share this region. + size_t dispatch_recv_count_buffer_bytes_internode = + static_cast(num_ranks * num_ranks) * sizeof(int64_t); + size_t combine_recv_flag_buffer_bytes_internode = + num_experts * sizeof(int64_t); + size_t signaling_buffer_bytes_internode = std::max( + dispatch_recv_count_buffer_bytes_internode, + combine_recv_flag_buffer_bytes_internode); size_t signaling_buffer_bytes_internode_aligned = align(signaling_buffer_bytes_internode, 128); // These internode signaling buffers live inside `atomic_buffer_ptr` (not @@ -286,4 +309,4 @@ size_t get_low_latency_rdma_size_hint(int num_max_dispatch_tokens_per_rank, NUM_BUFFER_ALIGNMENT_BYTES; } -} // namespace uccl \ No newline at end of file +} // namespace uccl diff --git a/ep/include/ring_buffer.cuh b/ep/include/ring_buffer.cuh index 86f5cba00..0243ba5a9 100644 --- a/ep/include/ring_buffer.cuh +++ b/ep/include/ring_buffer.cuh @@ -91,6 +91,36 @@ struct TransferCmd { static_assert(sizeof(TransferCmd) * 8 == 128, "TransferCmd must be 128 bits"); #endif +// TransferCmd::bytes is 24-bit. For dispatch WRITE commands (non-combine), we +// borrow the top 2 bits from expert_idx to extend bytes to 26-bit. +constexpr uint32_t kTransferCmdBytesMask = (1u << 24) - 1; +constexpr uint16_t kTransferCmdBytesExtShift = 14; +constexpr uint16_t kTransferCmdBytesExtMask = (1u << 2) - 1; +constexpr uint16_t kTransferCmdExpertIdxMask = (1u << 14) - 1; + +__host__ __device__ inline bool is_dispatch_write_cmd( + TransferCmd const& cmd) { + return get_base_cmd(cmd.cmd_type) == CmdType::WRITE && !get_is_combine(cmd.cmd_type); +} + +__host__ __device__ inline uint32_t get_transfer_cmd_bytes( + TransferCmd const& cmd) { + uint32_t bytes = cmd.bytes; + if (is_dispatch_write_cmd(cmd)) { + bytes |= + (static_cast(cmd.expert_idx >> kTransferCmdBytesExtShift) + << 24); + } + return bytes; +} + +__host__ __device__ inline uint16_t get_transfer_cmd_expert_idx( + TransferCmd const& cmd) { + if (is_dispatch_write_cmd(cmd)) + return static_cast(cmd.expert_idx & kTransferCmdExpertIdxMask); + return cmd.expert_idx; +} + struct CopyTask { uint64_t wr_id; int dst_dev; @@ -461,4 +491,4 @@ static inline void free_cmd_ring(uintptr_t addr) { } } -#endif // RING_BUFFER_CUH \ No newline at end of file +#endif // RING_BUFFER_CUH diff --git a/ep/include/uccl_ibgda.cuh b/ep/include/uccl_ibgda.cuh index 41b48d1cd..85ff20307 100644 --- a/ep/include/uccl_ibgda.cuh +++ b/ep/include/uccl_ibgda.cuh @@ -29,7 +29,7 @@ __device__ __forceinline__ void nvshmemi_ibgda_put_nbi_warp( int expert_idx, int lane_id, int message_idx, uint64_t const* d2h_channel_addrs, int num_d2h_channel_addrs, bool is_combine, int low_latency_buffer_idx = 0, uint64_t atomic_offset = 0, - uint64_t atomic_val = 0) { + uint64_t atomic_val = 0, int num_tokens = 1) { // NOTE(MaoZiming): different from the nvshmemi_ibgda_put_nbi_warp in // ibgda_device.cuh, we don't do warp-cooperation. if (lane_id != 0) return; @@ -60,13 +60,31 @@ __device__ __forceinline__ void nvshmemi_ibgda_put_nbi_warp( make_cmd_type(CmdType::WRITE, is_combine, low_latency_buffer_idx); cmd.req_rptr = rptr_val; cmd.req_lptr = lptr_val; - cmd.bytes = bytes_val; + uint32_t cmd_bytes = static_cast(bytes_val); + uint16_t cmd_expert_idx = static_cast(expert_idx); + if constexpr (!use_normal_mode) { + if (!is_combine) { + EP_DEVICE_ASSERT((expert_idx & ~kTransferCmdExpertIdxMask) == 0); + EP_DEVICE_ASSERT((cmd_bytes >> 26) == 0); + auto bytes_hi2 = static_cast(cmd_bytes >> 24); + cmd_expert_idx = static_cast( + (expert_idx & kTransferCmdExpertIdxMask) | + (bytes_hi2 << kTransferCmdBytesExtShift)); + cmd_bytes &= kTransferCmdBytesMask; + } else { + EP_DEVICE_ASSERT((cmd_bytes >> 24) == 0); + } + } + cmd.bytes = cmd_bytes; cmd.dst_rank = dst_rank; if constexpr (use_normal_mode) { cmd.atomic_offset = atomic_offset; cmd.atomic_val = atomic_val; } else { - cmd.expert_idx = expert_idx; + cmd.expert_idx = cmd_expert_idx; + // Low-latency WRITE: use atomic_val byte for num_tokens (1..255). + EP_DEVICE_ASSERT(num_tokens > 0 && num_tokens <= 255); + cmd.atomic_val = static_cast(num_tokens); } h->atomic_set_and_commit(cmd, &slot); } @@ -91,12 +109,29 @@ __device__ __forceinline__ void nvshmemi_ibgda_put_nbi_warp( make_cmd_type(CmdType::WRITE, is_combine, low_latency_buffer_idx); cmd.req_rptr = rptr_val; cmd.req_lptr = lptr_val; - cmd.bytes = bytes_val; + uint32_t cmd_bytes = static_cast(bytes_val); + uint16_t cmd_expert_idx = static_cast(expert_idx); + if constexpr (!use_normal_mode) { + if (!is_combine) { + EP_DEVICE_ASSERT((expert_idx & ~kTransferCmdExpertIdxMask) == 0); + EP_DEVICE_ASSERT((cmd_bytes >> 26) == 0); + auto bytes_hi2 = static_cast(cmd_bytes >> 24); + cmd_expert_idx = static_cast( + (expert_idx & kTransferCmdExpertIdxMask) | + (bytes_hi2 << kTransferCmdBytesExtShift)); + cmd_bytes &= kTransferCmdBytesMask; + } else { + EP_DEVICE_ASSERT((cmd_bytes >> 24) == 0); + } + } + cmd.bytes = cmd_bytes; cmd.dst_rank = dst_rank; - if (bytes_val >> 24) { - printf("[nvshmemi_ibgda_put_nbi_warp] bytes too large: %llu\n", - (unsigned long long)bytes_val); - trap(); + if constexpr (use_normal_mode) { + if (bytes_val >> 24) { + printf("[nvshmemi_ibgda_put_nbi_warp] bytes too large: %llu\n", + (unsigned long long)bytes_val); + trap(); + } } if constexpr (use_normal_mode) { @@ -114,7 +149,10 @@ __device__ __forceinline__ void nvshmemi_ibgda_put_nbi_warp( cmd.atomic_offset = atomic_offset; cmd.atomic_val = atomic_val; } else { - cmd.expert_idx = expert_idx; + cmd.expert_idx = cmd_expert_idx; + // Low-latency WRITE: use atomic_val byte for num_tokens (1..255). + EP_DEVICE_ASSERT(num_tokens > 0 && num_tokens <= 255); + cmd.atomic_val = static_cast(num_tokens); } h->atomic_set_and_commit(cmd, &slot); break; @@ -380,4 +418,4 @@ __forceinline__ __device__ void nvshmem_sync_with_same_gpu_idx( } } -} // namespace uccl \ No newline at end of file +} // namespace uccl diff --git a/ep/include/uccl_proxy.hpp b/ep/include/uccl_proxy.hpp index 0a8749b74..999989f88 100644 --- a/ep/include/uccl_proxy.hpp +++ b/ep/include/uccl_proxy.hpp @@ -62,7 +62,11 @@ class UcclProxy { num_experts * num_tokens * hidden * 2; // sizeof(bfloat16) size_t send_buffer_bytes = std::max(dispatch_send_buffer_bytes, combine_send_buffer_bytes); - size_t dispatch_recv_count_buffer_bytes = num_experts * 4; + size_t const signaling_slots = std::max( + static_cast(num_experts), + static_cast(proxy_->cfg_.num_ranks) * + static_cast(proxy_->cfg_.num_ranks)); + size_t dispatch_recv_count_buffer_bytes = signaling_slots * 4; size_t signaling_buffer_bytes_aligned = ((dispatch_recv_count_buffer_bytes + 127) / 128) * 128; uintptr_t dispatch_recv_data_offset = diff --git a/ep/src/internode_ll.cu b/ep/src/internode_ll.cu index 375f076f9..fb30d2d19 100644 --- a/ep/src/internode_ll.cu +++ b/ep/src/internode_ll.cu @@ -18,6 +18,11 @@ constexpr int kNumMaxWarpGroups = 16; constexpr int kNumMaxWarpGroups = 32; #endif +struct PackedDispatchExpertHeader { + int token_offset; + int token_count; +}; + template __launch_bounds__(kNumThreads, 1) __global__ void clean_low_latency_buffer(int* clean_0, int num_clean_int_0, @@ -53,12 +58,12 @@ __global__ __launch_bounds__(1024, 1) void dispatch( int64_t* dispatch_wait_recv_cost_stats, void* rdma_recv_x, int* rdma_recv_count, void* rdma_x, void const* x, int64_t const* topk_idx, int* atomic_counter_per_expert, int* atomic_finish_counter_per_expert, - int* next_clean, int64_t* next_clean_second, int num_next_clean_int, - int num_tokens, int num_max_dispatch_tokens_per_rank, int num_topk, - int num_experts, int rank, int num_ranks, int num_warp_groups, - int num_warps_per_group, bool round_scale, int phases, - uint64_t const* d2h_channel_addrs, int num_d2h_channel_addrs, - int max_nvl_peers, int low_latency_buffer_idx, + int* atomic_send_counter_per_expert, int* next_clean, + int64_t* next_clean_second, int num_next_clean_int, int num_tokens, + int num_max_dispatch_tokens_per_rank, int num_topk, int num_experts, + int rank, int num_ranks, int num_warp_groups, int num_warps_per_group, + bool round_scale, int phases, uint64_t const* d2h_channel_addrs, + int num_d2h_channel_addrs, int max_nvl_peers, int low_latency_buffer_idx, void** ipc_rdma_base_ptrs = nullptr, void* rdma_buffer_ptr = nullptr, void* atomic_buffer_ptr = nullptr, int64_t* rdma_recv_count_internode = nullptr, @@ -71,7 +76,9 @@ __global__ __launch_bounds__(1024, 1) void dispatch( auto const num_local_experts = num_experts / num_ranks; auto const warp_group_id = warp_id / num_warps_per_group; auto const sub_warp_id = warp_id % num_warps_per_group; + auto const responsible_rank = sm_id * num_warp_groups + warp_group_id; auto const responsible_expert_idx = sm_id * num_warp_groups + warp_group_id; + auto* rank_unique_write_cursor = grid_sync_barrier_ptr + 1; // May extract UE8M0 from the scales using scale_t = std::conditional_t; @@ -87,16 +94,30 @@ __global__ __launch_bounds__(1024, 1) void dispatch( size_t const hidden_int4 = hidden_bytes / sizeof(int4); // Message package: hidden data, FP8 scales, index at source - // NOTES: currently we have 3 reserved int fields for future use + // NOTES: metadata int4 layout: + // [0] src_token_idx + // [1] packed local expert list (low 32 bits, 5 bits per entry) + // [2] packed local expert list (high 32 bits) + // [3] number of local expert entries (duplicates preserved) using vec_t = typename std::conditional::type; size_t const num_bytes_per_msg = sizeof(int4) + (kUseFP8 ? (kHidden + num_scales * sizeof(float)) : (kHidden * sizeof(nv_bfloat16))); - size_t const num_int4_per_msg = num_bytes_per_msg / sizeof(int4); EP_DEVICE_ASSERT(num_bytes_per_msg % sizeof(int4) == 0); - - // Expert counts - __shared__ int shared_num_tokens_sent_per_expert[kNumMaxWarpGroups]; + size_t const rank_header_bytes = align( + static_cast(num_local_experts) * + sizeof(PackedDispatchExpertHeader), + sizeof(int4)); + size_t const rank_payload_bytes = + static_cast(num_local_experts) * + static_cast(num_max_dispatch_tokens_per_rank) * num_bytes_per_msg; + // Per-rank layout: [src_rank][header + packed payload] on receiver. + size_t const rank_region_bytes = rank_header_bytes + rank_payload_bytes; + + __shared__ int shared_num_tokens_to_send_per_rank[kNumMaxWarpGroups]; + + // Global counter slots used for batching sends to each top-k destination. + constexpr int kNumMaxTopK = 9; #if defined(__HIP_PLATFORM_AMD__) || defined(__HIPCC__) // initialize barrier @@ -105,143 +126,20 @@ __global__ __launch_bounds__(1024, 1) void dispatch( // Sending phase if ((phases & LOW_LATENCY_SEND_PHASE) == 0) goto LOW_LATENCY_DISPATCH_RECV; + EP_DEVICE_ASSERT(num_local_experts <= 32); // There are 2 kinds of warps in this part: // 1. The first-kind warps for FP8 cast and sending top-k tokens // 2. The last warp for reading `topk_idx` and count for per-expert // information if (warp_id < num_warps - 1) { - constexpr int kNumElemsPerRead = sizeof(int4) / sizeof(nv_bfloat16); - EP_STATIC_ASSERT(kHidden % (WARP_SIZE * kNumElemsPerRead) == 0, - "Invalid hidden"); - EP_STATIC_ASSERT(kNumElemsPerRead * WARP_SIZE % kNumPerChannels == 0, - "Invalid vectorization"); - auto const num_threads = (num_warps - 1) * WARP_SIZE; - size_t const hidden_bf16_int4 = kHidden / kNumElemsPerRead; - for (int token_idx = sm_id; token_idx < num_tokens; token_idx += num_sms) { - auto const x_int4 = - static_cast(x) + token_idx * hidden_bf16_int4; - auto const rdma_x_src_idx = reinterpret_cast( - static_cast(rdma_x) + token_idx * num_bytes_per_msg); - auto const rdma_x_vec = reinterpret_cast( - reinterpret_cast(rdma_x_src_idx) + sizeof(int4)); - auto const rdma_x_scales = reinterpret_cast( - reinterpret_cast(rdma_x_vec) + hidden_bytes); - - // Overlap top-k index read and source token index writes auto dst_expert_idx = warp_id < num_topk ? static_cast(__ldg( topk_idx + token_idx * num_topk + warp_id)) : -1; - thread_id == 0 ? (*rdma_x_src_idx = token_idx) : 0; - -// FP8 cast -#pragma unroll - for (int i = thread_id; i < hidden_bf16_int4; i += num_threads) { - // Read - auto int4_value = __ldg(x_int4 + i); - - if constexpr (kUseFP8) { - // Calculate local amax - auto bf16_values = reinterpret_cast(&int4_value); - float fp32_values[kNumElemsPerRead]; - float amax = kFP8Margin, scale, scale_inv; -#pragma unroll - for (int j = 0; j < kNumElemsPerRead; ++j) { - fp32_values[j] = static_cast(bf16_values[j]); - amax = fmaxf(amax, fabsf(fp32_values[j])); - } - - // Reduce amax and scale - -#if defined(__HIP_PLATFORM_AMD__) || defined(__HIPCC__) - EP_STATIC_ASSERT(kNumElemsPerRead * WARP_SIZE / kNumPerChannels == 4, - "Invalid vectorization"); - amax = warp_reduce_max<16>(amax); - calculate_fp8_scales(amax, scale, scale_inv, round_scale); - if (lane_id % 16 == 0) -#else - EP_STATIC_ASSERT(kNumElemsPerRead * WARP_SIZE / kNumPerChannels == 2, - "Invalid vectorization"); - amax = warp_reduce_max<16>(amax); - calculate_fp8_scales(amax, scale, scale_inv, round_scale); - if (lane_id == 0 or lane_id == 16) -#endif - rdma_x_scales[i * kNumElemsPerRead / 128] = scale_inv; - - // Cast into send buffer - vec_t int2_value; - auto fp8x2_values = - reinterpret_cast<__nv_fp8x2_storage_t*>(&int2_value); -#pragma unroll - for (int j = 0; j < kNumElemsPerRead; j += 2) { - float2 fp32x2 = {fp32_values[j] * scale, - fp32_values[j + 1] * scale}; - fp8x2_values[j / 2] = - __nv_cvt_float2_to_fp8x2(fp32x2, __NV_SATFINITE, -#if defined(__HIP_PLATFORM_AMD__) || defined(__HIPCC__) -#if defined(__gfx942__) - __HIP_E4M3_FNUZ -#else - __HIP_E4M3 -#endif -#else - __NV_E4M3 -#endif - ); - } - rdma_x_vec[i] = int2_value; - } else { - // Reinterpret-cast is for C++14 compatibility - rdma_x_vec[i] = *reinterpret_cast(&int4_value); - } - } - sync_barrier_1(num_threads); - - // Issue IBGDA sends - if (dst_expert_idx >= 0) { - int slot_idx = - lane_id == 0 - ? atomicAdd(atomic_counter_per_expert + dst_expert_idx, 1) - : 0; - slot_idx = __shfl_sync(WARP_MASK, slot_idx, 0); - auto const dst_rank = dst_expert_idx / num_local_experts; - auto const dst_expert_local_idx = dst_expert_idx % num_local_experts; - auto const src_ptr = reinterpret_cast(rdma_x_src_idx); - auto const dst_ptr = - reinterpret_cast(rdma_recv_x) + - dst_expert_local_idx * num_ranks * - num_max_dispatch_tokens_per_rank * num_bytes_per_msg + - rank * num_max_dispatch_tokens_per_rank * num_bytes_per_msg + - slot_idx * num_bytes_per_msg; - auto const dst_p2p_ptr = - ipc_rdma_base_ptrs - ? uccl::get_ipc_p2p_ptr(dst_ptr, ipc_rdma_base_ptrs, rank, - dst_rank, max_nvl_peers, 0) - : 0; - if (dst_p2p_ptr == 0) { - __threadfence_system(); - uccl::nvshmemi_ibgda_put_nbi_warp( - dst_ptr - reinterpret_cast(rdma_buffer_ptr), - src_ptr - reinterpret_cast(rdma_buffer_ptr), - num_bytes_per_msg, dst_rank, - /*warp_id=*/dst_expert_local_idx, // NOTE(Yang): for selecting - // rb. - lane_id, slot_idx, d2h_channel_addrs, num_d2h_channel_addrs, - false, low_latency_buffer_idx); - } else { - // Intra-node: use direct memory copy via IPC - auto const* src_int4_ptr = reinterpret_cast(src_ptr); - auto* dst_int4_ptr = reinterpret_cast(dst_p2p_ptr); - UNROLLED_WARP_COPY(8, lane_id, num_int4_per_msg, dst_int4_ptr, - src_int4_ptr, ld_nc_global, st_na_global); - } - // Increase counter after finishing - __syncwarp(); - lane_id == 0 ? atomic_add_release_global( - atomic_finish_counter_per_expert + dst_expert_idx, 1) - : 0; + if (warp_id < num_topk && lane_id == 0 && dst_expert_idx >= 0) { + atomicAdd(atomic_counter_per_expert + dst_expert_idx, 1); } } } else if (warp_id == num_warps - 1) { @@ -256,6 +154,9 @@ __global__ __launch_bounds__(1024, 1) void dispatch( } // Notify before executing `int_p` __syncwarp(); +#pragma unroll + for (int i = lane_id; i < num_ranks; i += WARP_SIZE) + rank_unique_write_cursor[i] = 0; #pragma unroll for (int i = lane_id; i < num_experts; i += WARP_SIZE) atomic_add_release_global(atomic_finish_counter_per_expert + i, @@ -281,35 +182,324 @@ __global__ __launch_bounds__(1024, 1) void dispatch( for (int i = expert_begin_idx; i < expert_end_idx; ++i) { auto sum = warp_reduce_sum(expert_count[i - expert_begin_idx]); if (lane_id == 0) { - shared_num_tokens_sent_per_expert[i - expert_begin_idx] = sum; atomic_add_release_global(atomic_finish_counter_per_expert + i, FINISHED_SUM_TAG - sum); } } } __syncthreads(); - // Issue count sends - if (responsible_expert_idx < num_experts and sub_warp_id == 0 and - lane_id == 0) { - auto const dst_rank = responsible_expert_idx / num_local_experts; - auto const dst_expert_local_idx = - responsible_expert_idx % num_local_experts; - auto const num_tokens_sent = - shared_num_tokens_sent_per_expert[responsible_expert_idx - - sm_id * num_warp_groups]; - // Wait local sends issued and send expert counts - while (ld_acquire_global(atomic_finish_counter_per_expert + - responsible_expert_idx) != FINISHED_SUM_TAG * 2) + + // Grid-wide sync before batch-send. #if defined(__HIP_PLATFORM_AMD__) || defined(__HIPCC__) - __builtin_amdgcn_s_sleep(1); + amd::grid_sync(grid_sync_barrier_ptr, num_sms); #else - ; + cg::this_grid().sync(); #endif - auto dst_ptr = reinterpret_cast( - rdma_recv_count + dst_expert_local_idx * num_ranks + rank); + // Build per-rank headers and initialize packed write cursors. + if (responsible_rank < num_ranks && sub_warp_id == 0 && lane_id == 0) { + auto const batch_buf_offset = + num_max_dispatch_tokens_per_rank * num_bytes_per_msg; + auto const rank_buf_base = + static_cast(rdma_x) + batch_buf_offset + + responsible_rank * rank_region_bytes; + auto* packed_header = + reinterpret_cast(rank_buf_base); + for (int e = 0; e < num_local_experts; ++e) { + auto const expert_idx = responsible_rank * num_local_experts + e; + auto const expert_tokens = atomic_counter_per_expert[expert_idx]; + packed_header[e] = {0, expert_tokens}; + } + // Count unique source tokens for this destination rank. + int unique_tokens = 0; + for (int token_idx = 0; token_idx < num_tokens; ++token_idx) { + bool rank_hit = false; + for (int k = 0; k < num_topk; ++k) { + auto const dst_expert_idx = + static_cast(__ldg(topk_idx + token_idx * num_topk + k)); + if (dst_expert_idx < 0) continue; + if (dst_expert_idx / num_local_experts == responsible_rank) { + rank_hit = true; + break; + } + } + unique_tokens += static_cast(rank_hit); + } + // Store per-rank unique-token count in header[0].token_offset. + packed_header[0].token_offset = unique_tokens; + rank_unique_write_cursor[responsible_rank] = 0; + } + __syncwarp(); + + // Grid-wide sync before direct app->packed transport writes. +#if defined(__HIP_PLATFORM_AMD__) || defined(__HIPCC__) + amd::grid_sync(grid_sync_barrier_ptr, num_sms); +#else + cg::this_grid().sync(); +#endif + + // Pass-2: write tokens directly to packed transport payload. + if (warp_id < num_warps - 1) { + constexpr int kNumElemsPerRead = sizeof(int4) / sizeof(nv_bfloat16); + EP_STATIC_ASSERT(kHidden % (WARP_SIZE * kNumElemsPerRead) == 0, + "Invalid hidden"); + EP_STATIC_ASSERT(kNumElemsPerRead * WARP_SIZE % kNumPerChannels == 0, + "Invalid vectorization"); + size_t const hidden_bf16_int4 = kHidden / kNumElemsPerRead; + + auto const num_token_strides = num_sms * (num_warps - 1); + for (int token_idx = sm_id * (num_warps - 1) + warp_id; token_idx < num_tokens; + token_idx += num_token_strides) { + int unique_ranks[kNumMaxTopK]; + uint64_t unique_local_expert_lists[kNumMaxTopK]; + int unique_local_expert_counts[kNumMaxTopK]; + int unique_rank_count = 0; + + if (lane_id == 0) { +#pragma unroll + for (int i = 0; i < kNumMaxTopK; ++i) { + unique_ranks[i] = -1; + unique_local_expert_lists[i] = 0; + unique_local_expert_counts[i] = 0; + } + for (int k = 0; k < num_topk; ++k) { + auto const dst_expert_idx = + static_cast(__ldg(topk_idx + token_idx * num_topk + k)); + if (dst_expert_idx < 0) continue; + auto const dst_rank = dst_expert_idx / num_local_experts; + auto const dst_local_expert = dst_expert_idx % num_local_experts; + + int idx = -1; +#pragma unroll + for (int u = 0; u < kNumMaxTopK; ++u) { + if (u < unique_rank_count && unique_ranks[u] == dst_rank) { + idx = u; + break; + } + } + if (idx == -1) { + idx = unique_rank_count++; + unique_ranks[idx] = dst_rank; + unique_local_expert_lists[idx] = 0; + unique_local_expert_counts[idx] = 0; + } + auto const entry_idx = unique_local_expert_counts[idx]++; + EP_DEVICE_ASSERT(entry_idx < kNumMaxTopK); + unique_local_expert_lists[idx] |= + (static_cast(dst_local_expert) << (entry_idx * 5)); + } + } + + unique_rank_count = __shfl_sync(WARP_MASK, unique_rank_count, 0); + + for (int u = 0; u < unique_rank_count; ++u) { + auto const dst_rank = __shfl_sync(WARP_MASK, unique_ranks[u], 0); + auto const dst_local_expert_list_lo = + __shfl_sync(WARP_MASK, + static_cast(unique_local_expert_lists[u]), 0); + auto const dst_local_expert_list_hi = + __shfl_sync(WARP_MASK, + static_cast(unique_local_expert_lists[u] >> 32), 0); + auto const dst_local_expert_count = + __shfl_sync(WARP_MASK, unique_local_expert_counts[u], 0); + + int slot_idx = 0; + if (lane_id == 0) + slot_idx = atomicAdd(rank_unique_write_cursor + dst_rank, 1); + slot_idx = __shfl_sync(WARP_MASK, slot_idx, 0); + + auto const batch_buf_offset = + num_max_dispatch_tokens_per_rank * num_bytes_per_msg; + auto* const dst_msg_ptr = + static_cast(rdma_x) + batch_buf_offset + + dst_rank * rank_region_bytes + rank_header_bytes + + static_cast(slot_idx) * num_bytes_per_msg; + + auto* const dst_meta = reinterpret_cast(dst_msg_ptr); + auto* const dst_vec = + reinterpret_cast(dst_msg_ptr + sizeof(int4)); + auto* const dst_scales = + reinterpret_cast(dst_msg_ptr + sizeof(int4) + hidden_bytes); + auto const* x_int4 = + static_cast(x) + token_idx * hidden_bf16_int4; + + if (lane_id == 0) { + dst_meta[0] = token_idx; + dst_meta[1] = dst_local_expert_list_lo; + dst_meta[2] = dst_local_expert_list_hi; + dst_meta[3] = dst_local_expert_count; + } + +#pragma unroll + for (int i = lane_id; i < hidden_bf16_int4; i += WARP_SIZE) { + auto int4_value = __ldg(x_int4 + i); + if constexpr (kUseFP8) { + auto bf16_values = reinterpret_cast(&int4_value); + float fp32_values[kNumElemsPerRead]; + float amax = kFP8Margin, scale, scale_inv; +#pragma unroll + for (int j = 0; j < kNumElemsPerRead; ++j) { + fp32_values[j] = static_cast(bf16_values[j]); + amax = fmaxf(amax, fabsf(fp32_values[j])); + } +#if defined(__HIP_PLATFORM_AMD__) || defined(__HIPCC__) + EP_STATIC_ASSERT(kNumElemsPerRead * WARP_SIZE / kNumPerChannels == 4, + "Invalid vectorization"); + amax = warp_reduce_max<16>(amax); + calculate_fp8_scales(amax, scale, scale_inv, round_scale); + if (lane_id % 16 == 0) +#else + EP_STATIC_ASSERT(kNumElemsPerRead * WARP_SIZE / kNumPerChannels == 2, + "Invalid vectorization"); + amax = warp_reduce_max<16>(amax); + calculate_fp8_scales(amax, scale, scale_inv, round_scale); + if (lane_id == 0 or lane_id == 16) +#endif + dst_scales[i * kNumElemsPerRead / 128] = scale_inv; + + vec_t int2_value; + auto fp8x2_values = + reinterpret_cast<__nv_fp8x2_storage_t*>(&int2_value); +#pragma unroll + for (int j = 0; j < kNumElemsPerRead; j += 2) { + float2 fp32x2 = {fp32_values[j] * scale, + fp32_values[j + 1] * scale}; + fp8x2_values[j / 2] = + __nv_cvt_float2_to_fp8x2(fp32x2, __NV_SATFINITE, +#if defined(__HIP_PLATFORM_AMD__) || defined(__HIPCC__) +#if defined(__gfx942__) + __HIP_E4M3_FNUZ +#else + __HIP_E4M3 +#endif +#else + __NV_E4M3 +#endif + ); + } + dst_vec[i] = int2_value; + } else { + dst_vec[i] = *reinterpret_cast(&int4_value); + } + } + __syncwarp(); + } + + if (lane_id == 0) { + for (int k = 0; k < num_topk; ++k) { + auto const dst_expert_idx = + static_cast(__ldg(topk_idx + token_idx * num_topk + k)); + if (dst_expert_idx >= 0) + atomic_add_release_global(atomic_finish_counter_per_expert + + dst_expert_idx, + 1); + } + } + } + } + __syncthreads(); + + // Grid-wide sync before batch-send. +#if defined(__HIP_PLATFORM_AMD__) || defined(__HIPCC__) + amd::grid_sync(grid_sync_barrier_ptr, num_sms); +#else + cg::this_grid().sync(); +#endif + + // Batch RDMA send phase - one put per destination rank (contiguous slice) + // Each warp group handles one rank (only first sub_warp does the send) + if (responsible_rank < num_ranks && sub_warp_id == 0) { + // Wait for all experts on this rank to finish copying to batch buffer. + int num_tokens_to_send = 0; + auto const batch_buf_offset = + num_max_dispatch_tokens_per_rank * num_bytes_per_msg; + auto const rank_buf_base = + static_cast(rdma_x) + batch_buf_offset + + responsible_rank * rank_region_bytes; + auto const* packed_header = + reinterpret_cast(rank_buf_base); + if (lane_id == 0) { + for (int e = 0; e < num_local_experts; ++e) { + int expert_idx = responsible_rank * num_local_experts + e; + // if (rank == 0) + // printf( + // "[dispatch wait enter] rank=%d responsible_rank=%d expert_local=%d " + // "expert_idx=%d finish=%d\n", + // rank, responsible_rank, e, expert_idx, + // ld_acquire_global(atomic_finish_counter_per_expert + expert_idx)); + while (ld_acquire_global(atomic_finish_counter_per_expert + expert_idx) != + FINISHED_SUM_TAG * 2) +#if defined(__HIP_PLATFORM_AMD__) || defined(__HIPCC__) + __builtin_amdgcn_s_sleep(1); +#else + ; +#endif + // if (rank == 0) + // printf( + // "[dispatch wait exit] rank=%d responsible_rank=%d expert_local=%d " + // "expert_idx=%d finish=%d\n", + // rank, responsible_rank, e, expert_idx, + // ld_acquire_global(atomic_finish_counter_per_expert + expert_idx)); + num_tokens_to_send += (e == 0 ? packed_header[0].token_offset : 0); + } + } + num_tokens_to_send = __shfl_sync(WARP_MASK, num_tokens_to_send, 0); + if (lane_id == 0) + shared_num_tokens_to_send_per_rank[warp_group_id] = num_tokens_to_send; + __syncwarp(); + + if (num_tokens_to_send > 0) { + // Receiver partitions by src_rank; we write to our (sender) region. + auto const dst_base = reinterpret_cast(rdma_recv_x) + + rank * rank_region_bytes; + auto const dst_p2p_ptr = + ipc_rdma_base_ptrs + ? uccl::get_ipc_p2p_ptr(dst_base, ipc_rdma_base_ptrs, rank, + responsible_rank, max_nvl_peers, 0) + : 0; + + auto const total_bytes = + rank_header_bytes + + static_cast(num_tokens_to_send) * num_bytes_per_msg; + constexpr size_t kMaxCmdBytes = (1u << 26) - 1; + EP_DEVICE_ASSERT( + total_bytes <= kMaxCmdBytes && + "Low-latency dispatch rank slice exceeds command byte encoding"); + + __threadfence_system(); + + if (dst_p2p_ptr != 0) { + auto const* src_int4_ptr = reinterpret_cast(rank_buf_base); + auto* dst_int4_ptr = reinterpret_cast(dst_p2p_ptr); + UNROLLED_WARP_COPY(8, lane_id, total_bytes / sizeof(int4), dst_int4_ptr, + src_int4_ptr, ld_nc_global, st_na_global); + } else { + EP_DEVICE_ASSERT(num_tokens_to_send <= 255 && + "IBGDA low-latency path requires <=255 tokens"); + uccl::nvshmemi_ibgda_put_nbi_warp( + dst_base - reinterpret_cast(rdma_buffer_ptr), + reinterpret_cast(rank_buf_base) - + reinterpret_cast(rdma_buffer_ptr), + total_bytes, responsible_rank, + /*warp_id=*/rank, lane_id, /*slot=*/0, + d2h_channel_addrs, num_d2h_channel_addrs, false, + low_latency_buffer_idx, 0, 0, num_tokens_to_send); + } + } + } + + __threadfence_system(); // Ensure batch sends are visible before count sends + + // Issue count sends — one atomic per (src rank, dst rank) + if (responsible_rank < num_ranks and sub_warp_id == 0 and lane_id == 0) { + auto const dst_rank = responsible_rank; + auto const num_tokens_sent = + shared_num_tokens_to_send_per_rank[warp_group_id]; + + auto dst_ptr = reinterpret_cast(rdma_recv_count + + dst_rank * num_ranks + rank); auto dst_ptr_internode = reinterpret_cast( - rdma_recv_count_internode + dst_expert_local_idx * num_ranks + rank); + rdma_recv_count_internode + dst_rank * num_ranks + rank); // Try to use IPC for intra-node atomic operations auto const dst_p2p_ptr = ipc_rdma_base_ptrs @@ -321,20 +511,26 @@ __global__ __launch_bounds__(1024, 1) void dispatch( uccl::nvshmemi_ibgda_amo_nonfetch_add( dst_ptr_internode, reinterpret_cast(atomic_buffer_ptr), -num_tokens_sent - 1, dst_rank, - /*warp_id=*/dst_expert_local_idx, // NOTE(Yang): for selecting rb. - false, d2h_channel_addrs, num_d2h_channel_addrs, false, - low_latency_buffer_idx); + /*warp_id=*/rank, false, d2h_channel_addrs, num_d2h_channel_addrs, + false, low_latency_buffer_idx); } else { + EP_DEVICE_ASSERT(dst_rank / max_nvl_peers == rank / max_nvl_peers && + "IPC path should only be used for intra-node communication"); // Intra-node: use direct atomic operation st_release_sys_global(reinterpret_cast(dst_p2p_ptr), -num_tokens_sent - 1); } - // Clean workspace for next use - atomic_counter_per_expert[responsible_expert_idx] = 0; - atomic_finish_counter_per_expert[responsible_expert_idx] = 0; - - // Clean `packed_recv_count` - if (dst_rank == 0) packed_recv_count[dst_expert_local_idx] = 0; + // Clean workspace for next use (all experts for this dst_rank) + for (int e = 0; e < num_local_experts; ++e) { + int expert_idx = dst_rank * num_local_experts + e; + atomic_counter_per_expert[expert_idx] = 0; + atomic_finish_counter_per_expert[expert_idx] = 0; + atomic_send_counter_per_expert[expert_idx] = 0; + } + rank_unique_write_cursor[dst_rank] = 0; + if (dst_rank == 0) { + for (int e = 0; e < num_local_experts; ++e) packed_recv_count[e] = 0; + } } __syncwarp(); @@ -357,11 +553,11 @@ LOW_LATENCY_DISPATCH_RECV: if (responsible_expert_idx < num_experts) { auto const src_rank = responsible_expert_idx / num_local_experts; auto const local_expert_idx = responsible_expert_idx % num_local_experts; - auto const rdma_recv_x_uint8 = - static_cast(rdma_recv_x) + - local_expert_idx * num_ranks * num_max_dispatch_tokens_per_rank * - num_bytes_per_msg + - src_rank * num_max_dispatch_tokens_per_rank * num_bytes_per_msg; + auto const rank_recv_base_uint8 = + static_cast(rdma_recv_x) + src_rank * rank_region_bytes; + auto const rank_recv_header = + reinterpret_cast(rank_recv_base_uint8); + auto const rank_recv_payload_uint8 = rank_recv_base_uint8 + rank_header_bytes; auto const recv_x_int4 = static_cast(packed_recv_x) + local_expert_idx * num_ranks * num_max_dispatch_tokens_per_rank * hidden_int4; @@ -379,12 +575,15 @@ LOW_LATENCY_DISPATCH_RECV: // Shared between sub-warps in warp groups __shared__ int shared_num_recv_tokens[kNumMaxWarpGroups], - shared_recv_token_begin_idx[kNumMaxWarpGroups]; + shared_recv_token_begin_idx[kNumMaxWarpGroups], + shared_num_unique_recv_tokens[kNumMaxWarpGroups], + shared_recv_write_cursor[kNumMaxWarpGroups]; // Wait tokens to arrive // NOTES: using sub-warp 1 to overlap with sub-warp 0 int num_recv_tokens_internode = 0, num_recv_tokens_ipc = 0, - num_recv_tokens = 0, recv_token_begin_idx = 0; + num_recv_tokens = 0, recv_token_begin_idx = 0, + num_unique_recv_tokens = 0; #if defined(__HIP_PLATFORM_AMD__) || defined(__HIPCC__) EP_DEVICE_ASSERT(num_warps_per_group > 1); #else @@ -392,31 +591,46 @@ LOW_LATENCY_DISPATCH_RECV: #endif if (sub_warp_id == 1 and lane_id == 0) { auto start_time = clock64(); + // if (rank == 0) + // printf( + // "[dispatch recv wait ipc enter] rank=%d src_rank=%d same_node=%d\n", + // rank, src_rank, src_rank / max_nvl_peers == rank / max_nvl_peers); while ((src_rank / max_nvl_peers == rank / max_nvl_peers) && (num_recv_tokens_ipc = ld_acquire_sys_global( - rdma_recv_count + local_expert_idx * num_ranks + src_rank)) == - 0) + rdma_recv_count + rank * num_ranks + src_rank)) == 0) #if defined(__HIP_PLATFORM_AMD__) || defined(__HIPCC__) __builtin_amdgcn_s_sleep(1); #else ; #endif - + // if (rank == 0) + // printf( + // "[dispatch recv wait ipc exit] rank=%d src_rank=%d ipc_raw=%d\n", + // rank, src_rank, num_recv_tokens_ipc); + + // if (rank == 0) + // printf( + // "[dispatch recv wait internode enter] rank=%d src_rank=%d diff_node=%d\n", + // rank, src_rank, src_rank / max_nvl_peers != rank / max_nvl_peers); while ((src_rank / max_nvl_peers != rank / max_nvl_peers) && (num_recv_tokens_internode = static_cast( ld_acquire_sys_global(reinterpret_cast( - rdma_recv_count_internode + local_expert_idx * num_ranks + + rdma_recv_count_internode + rank * num_ranks + src_rank)))) == 0) #if defined(__HIP_PLATFORM_AMD__) || defined(__HIPCC__) __builtin_amdgcn_s_sleep(1); #else ; #endif - + // if (rank == 0) + // printf( + // "[dispatch recv wait internode exit] rank=%d src_rank=%d " + // "internode_raw=%d\n", + // rank, src_rank, num_recv_tokens_internode); if (src_rank / max_nvl_peers == rank / max_nvl_peers) { if (ld_acquire_sys_global(reinterpret_cast( - rdma_recv_count_internode + local_expert_idx * num_ranks + - src_rank)) != 0) { + rdma_recv_count_internode + rank * num_ranks + src_rank)) != + 0) { printf( "Same node but rdma_recv_count_internode is not zero! src_rank: " "%d, rank: %d, max_nvl_peers: %d\n", @@ -425,9 +639,8 @@ LOW_LATENCY_DISPATCH_RECV: } } if (src_rank / max_nvl_peers != rank / max_nvl_peers) { - if (ld_acquire_sys_global(rdma_recv_count + - local_expert_idx * num_ranks + src_rank) != - 0) { + if (ld_acquire_sys_global(rdma_recv_count + rank * num_ranks + + src_rank) != 0) { printf( "Different node but rdma_recv_count is not zero! src_rank: %d, " "rank: %d, max_nvl_peers: %d\n", @@ -440,25 +653,58 @@ LOW_LATENCY_DISPATCH_RECV: num_recv_tokens_internode != 0 ? -num_recv_tokens_internode - 1 : 0; num_recv_tokens_ipc = num_recv_tokens_ipc != 0 ? -num_recv_tokens_ipc - 1 : 0; + auto const num_recv_tokens_total = + num_recv_tokens_internode + num_recv_tokens_ipc; + // printf( - // "num_recv_tokens_internode: %d, num_recv_tokens_ipc: %d, src_rank:" - // "%d, rank: %d, max_nvl_peers: %d, responsible_expert_idx: %d," - // "num_experts: %d, num_local_experts: %d\n", - // num_recv_tokens_internode, num_recv_tokens_ipc, src_rank, rank, - // max_nvl_peers, responsible_expert_idx, num_experts, - // num_local_experts); - num_recv_tokens = num_recv_tokens_internode + num_recv_tokens_ipc; + // "[dispatch recv] rank=%d src_rank=%d expected ipc=%d internode=%d\n", + // rank, src_rank, num_recv_tokens_ipc, num_recv_tokens_internode); + + // Read per-expert packed [count] and per-rank unique-token count. + if (num_recv_tokens_total == 0) { + num_recv_tokens = 0; + num_unique_recv_tokens = 0; + } else { + auto const rank_unique_token_count = + ld_acquire_sys_global(reinterpret_cast(rank_recv_header)); + auto const header_int_ptr = reinterpret_cast( + rank_recv_header + local_expert_idx); + num_recv_tokens = ld_acquire_sys_global(header_int_ptr + 1); + num_unique_recv_tokens = rank_unique_token_count; + EP_DEVICE_ASSERT(num_recv_tokens >= 0); + EP_DEVICE_ASSERT(num_unique_recv_tokens >= 0); + EP_DEVICE_ASSERT(num_unique_recv_tokens <= num_recv_tokens_total); + } recv_token_begin_idx = atomicAdd(packed_recv_count + local_expert_idx, num_recv_tokens); shared_num_recv_tokens[warp_group_id] = num_recv_tokens; shared_recv_token_begin_idx[warp_group_id] = recv_token_begin_idx; + shared_num_unique_recv_tokens[warp_group_id] = num_unique_recv_tokens; + shared_recv_write_cursor[warp_group_id] = recv_token_begin_idx; recv_range[src_rank] = pack2(num_recv_tokens, recv_token_begin_idx); + auto const src_slice_offset = + static_cast(reinterpret_cast( + rank_recv_payload_uint8) - + static_cast(rdma_recv_x)); + auto const dst_slice_ptr = + reinterpret_cast(recv_x_int4) + + static_cast(recv_token_begin_idx) * hidden_int4 * + sizeof(int4); + auto const dst_slice_offset = + static_cast(dst_slice_ptr - + static_cast(packed_recv_x)); + // printf( + // "[dispatch recv slice] rank=%d src_rank=%d expert=%d src_off=%lld " + // "dst_off=%lld recv_tokens=%d unique_tokens=%d begin=%d\n", + // rank, src_rank, local_expert_idx, src_slice_offset, dst_slice_offset, + // num_recv_tokens, num_unique_recv_tokens, recv_token_begin_idx); + // Add stats for diagnosis if (cumulative_local_expert_recv_stats != nullptr) atomicAdd(cumulative_local_expert_recv_stats + local_expert_idx, num_recv_tokens); - if (dispatch_wait_recv_cost_stats != nullptr) + if (dispatch_wait_recv_cost_stats != nullptr && local_expert_idx == 0) atomicAdd(reinterpret_cast( dispatch_wait_recv_cost_stats + src_rank), wait_recv_cost); @@ -471,57 +717,104 @@ LOW_LATENCY_DISPATCH_RECV: #endif num_recv_tokens = shared_num_recv_tokens[warp_group_id]; recv_token_begin_idx = shared_recv_token_begin_idx[warp_group_id]; + num_unique_recv_tokens = shared_num_unique_recv_tokens[warp_group_id]; // Copy tokens EP_DEVICE_ASSERT(num_scales <= 64); - for (int i = sub_warp_id; i < num_recv_tokens; i += num_warps_per_group) { - // Copy source info - auto const src_src_idx = - reinterpret_cast(rdma_recv_x_uint8 + i * num_bytes_per_msg); - if (lane_id == 0) - recv_src_info[recv_token_begin_idx + i] = ld_nc_global(src_src_idx); - __syncwarp(); + for (int i = sub_warp_id; i < num_unique_recv_tokens; + i += num_warps_per_group) { + auto const src_msg_ptr = + rank_recv_payload_uint8 + static_cast(i) * num_bytes_per_msg; + auto const src_meta_ptr = reinterpret_cast(src_msg_ptr); + + int src_token_idx = 0; + int dst_local_expert_list_lo = 0; + int dst_local_expert_list_hi = 0; + int dst_local_expert_count = 0; + if (lane_id == 0) { + src_token_idx = ld_nc_global(src_meta_ptr); + dst_local_expert_list_lo = ld_nc_global(src_meta_ptr + 1); + dst_local_expert_list_hi = ld_nc_global(src_meta_ptr + 2); + dst_local_expert_count = ld_nc_global(src_meta_ptr + 3); + } + src_token_idx = __shfl_sync(WARP_MASK, src_token_idx, 0); + dst_local_expert_list_lo = + __shfl_sync(WARP_MASK, dst_local_expert_list_lo, 0); + dst_local_expert_list_hi = + __shfl_sync(WARP_MASK, dst_local_expert_list_hi, 0); + dst_local_expert_count = + __shfl_sync(WARP_MASK, dst_local_expert_count, 0); + EP_DEVICE_ASSERT(dst_local_expert_count >= 0 && + dst_local_expert_count <= kNumMaxTopK); + + auto const dst_local_expert_list = + (static_cast(static_cast(dst_local_expert_list_hi)) + << 32) | + static_cast(dst_local_expert_list_lo); + + int local_expert_multiplicity = 0; + for (int j = 0; j < dst_local_expert_count; ++j) { + auto const expert_j = + static_cast((dst_local_expert_list >> (j * 5)) & 0x1f); + local_expert_multiplicity += (expert_j == local_expert_idx); + } + if (local_expert_multiplicity == 0) continue; + + for (int rep = 0; rep < local_expert_multiplicity; ++rep) { + int dst_token_idx = 0; + if (lane_id == 0) + dst_token_idx = atomicAdd(shared_recv_write_cursor + warp_group_id, 1); + dst_token_idx = __shfl_sync(WARP_MASK, dst_token_idx, 0); + + // Copy source info + auto const src_src_idx = const_cast(src_meta_ptr); + if (lane_id == 0) + recv_src_info[dst_token_idx] = src_token_idx; + __syncwarp(); - // Copy data - // NOTES: only 2 load iterations for 7K hidden with 7 unrolls - auto const src_data = reinterpret_cast( - reinterpret_cast(src_src_idx) + sizeof(int4)); - auto const dst_data = - recv_x_int4 + (recv_token_begin_idx + i) * hidden_int4; - UNROLLED_WARP_COPY(7, lane_id, hidden_int4, dst_data, src_data, - ld_nc_global, st_na_global); - - // Copy scales - if constexpr (kUseFP8) { - // Equivalent CuTe layout: - // (num_tokens, (num_packed, num_elems_per_pack)):(num_elems_per_pack, - // (num_tokens * num_elems_per_pack, 1)) - auto const src_scales = reinterpret_cast( - reinterpret_cast(src_data) + hidden_bytes); - auto const num_elems_per_pack = - static_cast(sizeof(packed_t) / sizeof(scale_t)); - auto const token_idx = recv_token_begin_idx + i; - auto const token_stride = num_elems_per_pack; - auto const pack_stride = - num_ranks * num_max_dispatch_tokens_per_rank * num_elems_per_pack; - if (lane_id < num_scales) { - auto const pack_idx = lane_id / num_elems_per_pack; - auto const elem_idx = lane_id % num_elems_per_pack; - auto scale = extract_required_scale_format( - ld_nc_global(src_scales + lane_id)); - recv_x_scales[token_idx * token_stride + pack_idx * pack_stride + - elem_idx] = scale; - } - if (lane_id + WARP_SIZE < num_scales) { - auto const pack_idx = (lane_id + WARP_SIZE) / num_elems_per_pack; - auto const elem_idx = (lane_id + WARP_SIZE) % num_elems_per_pack; - auto scale = extract_required_scale_format( - ld_nc_global(src_scales + lane_id + WARP_SIZE)); - recv_x_scales[token_idx * token_stride + pack_idx * pack_stride + - elem_idx] = scale; + // Copy data + // NOTES: only 2 load iterations for 7K hidden with 7 unrolls + auto const src_data = reinterpret_cast( + reinterpret_cast(src_src_idx) + sizeof(int4)); + auto const dst_data = recv_x_int4 + dst_token_idx * hidden_int4; + UNROLLED_WARP_COPY(7, lane_id, hidden_int4, dst_data, src_data, + ld_nc_global, st_na_global); + + // Copy scales + if constexpr (kUseFP8) { + // Equivalent CuTe layout: + // (num_tokens, (num_packed, num_elems_per_pack)):(num_elems_per_pack, + // (num_tokens * num_elems_per_pack, 1)) + auto const src_scales = reinterpret_cast( + reinterpret_cast(src_data) + hidden_bytes); + auto const num_elems_per_pack = + static_cast(sizeof(packed_t) / sizeof(scale_t)); + auto const token_idx = dst_token_idx; + auto const token_stride = num_elems_per_pack; + auto const pack_stride = + num_ranks * num_max_dispatch_tokens_per_rank * num_elems_per_pack; + if (lane_id < num_scales) { + auto const pack_idx = lane_id / num_elems_per_pack; + auto const elem_idx = lane_id % num_elems_per_pack; + auto scale = extract_required_scale_format( + ld_nc_global(src_scales + lane_id)); + recv_x_scales[token_idx * token_stride + pack_idx * pack_stride + + elem_idx] = scale; + } + if (lane_id + WARP_SIZE < num_scales) { + auto const pack_idx = (lane_id + WARP_SIZE) / num_elems_per_pack; + auto const elem_idx = (lane_id + WARP_SIZE) % num_elems_per_pack; + auto scale = extract_required_scale_format( + ld_nc_global(src_scales + lane_id + WARP_SIZE)); + recv_x_scales[token_idx * token_stride + pack_idx * pack_stride + + elem_idx] = scale; + } } } } + if (sub_warp_id == 1 && lane_id == 0) + EP_DEVICE_ASSERT(shared_recv_write_cursor[warp_group_id] == + recv_token_begin_idx + num_recv_tokens); // if (blockIdx.x == 0 && threadIdx.x == 0) // printf("[dispatch] RECV finished\n"); } @@ -551,37 +844,65 @@ void dispatch(void* packed_recv_x, void* packed_recv_x_scales, auto const num_warps = num_warp_groups * num_warps_per_group; auto const num_sms = ceil_div(num_experts, num_warp_groups); EP_HOST_ASSERT(num_topk <= kNumMaxTopK); + // Low-latency dispatch encodes write size in an extended 26-bit field. + // Guard here to fail fast for oversized rank slices. + { + auto const num_local_experts = num_experts / num_ranks; + EP_HOST_ASSERT(num_local_experts <= 32 && + "dispatch dedup metadata uses a 32-bit local expert mask"); + auto const num_scales = hidden / 128; + size_t const num_bytes_per_msg = + sizeof(int4) + (use_fp8 ? (hidden + num_scales * sizeof(float)) + : (hidden * sizeof(nv_bfloat16))); + size_t const rank_header_bytes = align( + static_cast(num_local_experts) * + sizeof(PackedDispatchExpertHeader), + sizeof(int4)); + size_t const rank_region_bytes = + rank_header_bytes + + static_cast(num_local_experts) * + static_cast(num_max_dispatch_tokens_per_rank) * + num_bytes_per_msg; + constexpr size_t kMaxCmdBytes = (1u << 26) - 1; + EP_HOST_ASSERT( + rank_region_bytes <= kMaxCmdBytes && + "Low-latency dispatch rank slice exceeds command byte encoding"); + } // Workspace checks auto atomic_counter_per_expert = static_cast(workspace); auto atomic_finish_counter_per_expert = atomic_counter_per_expert + num_experts; - auto grid_sync_barrier_ptr = atomic_finish_counter_per_expert + num_experts; - EP_HOST_ASSERT((num_experts * 2 + 1) * sizeof(int) <= NUM_WORKSPACE_BYTES); + auto atomic_send_counter_per_expert = + atomic_finish_counter_per_expert + num_experts; + auto grid_sync_barrier_ptr = atomic_send_counter_per_expert + num_experts; + EP_HOST_ASSERT((num_experts * 3 + 1 + num_ranks) * sizeof(int) <= + NUM_WORKSPACE_BYTES); // FP8 checks if (use_ue8m0) EP_HOST_ASSERT(round_scale and "UE8M0 SF requires `round_scale=True`"); -#define DISPATCH_LAUNCH_CASE(hidden) \ - { \ - auto dispatch_func = dispatch