diff --git a/ep/include/uccl_ibgda.cuh b/ep/include/uccl_ibgda.cuh index 3c85f6566..a334592eb 100644 --- a/ep/include/uccl_ibgda.cuh +++ b/ep/include/uccl_ibgda.cuh @@ -70,6 +70,8 @@ __device__ __forceinline__ void nvshmemi_ibgda_put_nbi_warp( cmd.atomic_val = atomic_val; } else { cmd.expert_idx = expert_idx; + if (atomic_val != 0) + cmd.atomic_val = atomic_val; } h->atomic_set_and_commit(cmd, &slot); } diff --git a/ep/src/internode_ll.cu b/ep/src/internode_ll.cu index 281f1f1fc..94dacba14 100644 --- a/ep/src/internode_ll.cu +++ b/ep/src/internode_ll.cu @@ -51,9 +51,10 @@ __global__ __launch_bounds__(1024, 1) void dispatch( int64_t* packed_recv_layout_range, int* packed_recv_count, int* cumulative_local_expert_recv_stats, int64_t* dispatch_wait_recv_cost_stats, void* rdma_recv_x, - int* rdma_recv_count, void* rdma_x, void const* x, int64_t const* topk_idx, - int* atomic_counter_per_expert, int* atomic_finish_counter_per_expert, - int* next_clean, int* next_clean_second, int num_next_clean_int, + int* rdma_recv_count, void* rdma_x, void const* x, int64_t const* topk_idx, + int* atomic_counter_per_expert, int* atomic_finish_counter_per_expert, + int* token_ready_flags, int* next_clean, int* next_clean_second, + int num_next_clean_int, int num_tokens, int num_max_dispatch_tokens_per_rank, int num_topk, int num_experts, int rank, int num_ranks, int num_warp_groups, int num_warps_per_group, bool round_scale, int phases, @@ -92,10 +93,13 @@ __global__ __launch_bounds__(1024, 1) void dispatch( sizeof(int4) + (kUseFP8 ? (kHidden + num_scales * sizeof(float)) : (kHidden * sizeof(nv_bfloat16))); size_t const num_int4_per_msg = num_bytes_per_msg / sizeof(int4); + constexpr int kPrevSlotsToCheck = 6; + EP_STATIC_ASSERT(kPrevSlotsToCheck > 0, "Chunk size must be positive"); EP_DEVICE_ASSERT(num_bytes_per_msg % sizeof(int4) == 0); // Expert counts __shared__ int shared_num_tokens_sent_per_expert[kNumMaxWarpGroups]; + __shared__ int full_expert_count_shared[300]; #if defined(__HIP_PLATFORM_AMD__) || defined(__HIPCC__) // initialize barrier @@ -105,6 +109,21 @@ __global__ __launch_bounds__(1024, 1) void dispatch( // Sending phase if ((phases & LOW_LATENCY_SEND_PHASE) == 0) goto LOW_LATENCY_DISPATCH_RECV; + if (threadIdx.x < num_experts) { + full_expert_count_shared[threadIdx.x] = 0; + } + __syncthreads(); + + if (warp_id == num_warps - 1) { + #pragma unroll 8 + for (int i = lane_id; i < num_tokens * num_topk; i += WARP_SIZE) { + auto idx = static_cast(__ldg(topk_idx + i)); + if (idx >= 0 && idx < num_experts) + atomicAdd(&full_expert_count_shared[idx], 1); + } + } + __syncthreads(); + // There are 2 kinds of warps in this part: // 1. The first-kind warps for FP8 cast and sending top-k tokens // 2. The last warp for reading `topk_idx` and count for per-expert @@ -115,29 +134,50 @@ __global__ __launch_bounds__(1024, 1) void dispatch( "Invalid hidden"); EP_STATIC_ASSERT(kNumElemsPerRead * WARP_SIZE % kNumPerChannels == 0, "Invalid vectorization"); - auto const num_threads = (num_warps - 1) * WARP_SIZE; size_t const hidden_bf16_int4 = kHidden / kNumElemsPerRead; for (int token_idx = sm_id; token_idx < num_tokens; token_idx += num_sms) { auto const x_int4 = static_cast(x) + token_idx * hidden_bf16_int4; - auto const rdma_x_src_idx = reinterpret_cast( - static_cast(rdma_x) + token_idx * num_bytes_per_msg); - auto const rdma_x_vec = reinterpret_cast( - reinterpret_cast(rdma_x_src_idx) + sizeof(int4)); - auto const rdma_x_scales = reinterpret_cast( - reinterpret_cast(rdma_x_vec) + hidden_bytes); - // Overlap top-k index read and source token index writes auto dst_expert_idx = warp_id < num_topk ? static_cast(__ldg( topk_idx + token_idx * num_topk + warp_id)) : -1; - thread_id == 0 ? (*rdma_x_src_idx = token_idx) : 0; + int slot_idx = 0; + uint8_t* rdma_x_msg_ptr = nullptr; + int* rdma_x_src_idx = nullptr; + vec_t* rdma_x_vec = nullptr; + [[maybe_unused]] float* rdma_x_scales = nullptr; + if (dst_expert_idx >= 0) { + slot_idx = lane_id == 0 + ? atomicAdd(atomic_counter_per_expert + dst_expert_idx, + 1) + : 0; + slot_idx = __shfl_sync(WARP_MASK, slot_idx, 0); + EP_DEVICE_ASSERT(slot_idx < num_max_dispatch_tokens_per_rank); + auto const message_idx = + dst_expert_idx * num_max_dispatch_tokens_per_rank + slot_idx; + rdma_x_msg_ptr = static_cast(rdma_x) + + message_idx * num_bytes_per_msg; + rdma_x_src_idx = reinterpret_cast(rdma_x_msg_ptr); + rdma_x_vec = reinterpret_cast(rdma_x_msg_ptr + sizeof(int4)); + if constexpr (kUseFP8) + rdma_x_scales = reinterpret_cast( + reinterpret_cast(rdma_x_vec) + hidden_bytes); + if (lane_id == 0) *rdma_x_src_idx = token_idx; + } + + if (dst_expert_idx >= 0) { + auto* ready_flags_per_expert = + token_ready_flags + + dst_expert_idx * num_max_dispatch_tokens_per_rank; + if (lane_id == 0) st_release_sys_global( + ready_flags_per_expert + slot_idx, 0); // FP8 cast #pragma unroll - for (int i = thread_id; i < hidden_bf16_int4; i += num_threads) { + for (int i = lane_id; i < hidden_bf16_int4; i += WARP_SIZE) { // Read auto int4_value = __ldg(x_int4 + i); @@ -185,52 +225,110 @@ __global__ __launch_bounds__(1024, 1) void dispatch( // Reinterpret-cast is for C++14 compatibility rdma_x_vec[i] = *reinterpret_cast(&int4_value); } - } - sync_barrier_1(num_threads); + } + __syncwarp(); + if (lane_id == 0) st_release_sys_global( + ready_flags_per_expert + slot_idx, 1); + __syncwarp(); - // Issue IBGDA sends - if (dst_expert_idx >= 0) { - int slot_idx = - lane_id == 0 - ? atomicAdd(atomic_counter_per_expert + dst_expert_idx, 1) - : 0; - slot_idx = __shfl_sync(WARP_MASK, slot_idx, 0); + bool issue_chunk_send = false; + int chunk_start_slot = slot_idx; + int chunk_len_slots = 0; + + if (lane_id == 0) { + if (kPrevSlotsToCheck > 0) { + int num_tokens_for_expert = 0; + if (dst_expert_idx >= 0 && dst_expert_idx < num_experts) + num_tokens_for_expert = + full_expert_count_shared[dst_expert_idx]; + + issue_chunk_send = ((slot_idx + 1) % kPrevSlotsToCheck) == 0 || + (num_tokens_for_expert > 0 && + slot_idx == num_tokens_for_expert - 1); + + if (issue_chunk_send) { + chunk_len_slots = + (num_tokens_for_expert > 0 && slot_idx == num_tokens_for_expert - 1) + ? ((slot_idx % kPrevSlotsToCheck) + 1) + : kPrevSlotsToCheck; + chunk_start_slot = slot_idx - (chunk_len_slots - 1); + + // printf("[dispatch] warp_id: %d, token_idx: %d, dst_expert_idx: %d, slot_idx:" + // " %d, issue_chunk_send: %d, chunk_len_slots: %d, num_tokens_for_expert: %d\n", + // warp_id, token_idx, dst_expert_idx, slot_idx, issue_chunk_send, chunk_len_slots, num_tokens_for_expert); + } + + } + } + issue_chunk_send = __shfl_sync(WARP_MASK, issue_chunk_send, 0); + chunk_start_slot = __shfl_sync(WARP_MASK, chunk_start_slot, 0); + chunk_len_slots = __shfl_sync(WARP_MASK, chunk_len_slots, 0); + + if (issue_chunk_send) { + if (lane_id == 0) { + int const num_prev_slots = + chunk_len_slots > 1 ? min(slot_idx, chunk_len_slots - 1) + : 0; + for (int prev = num_prev_slots; prev > 0; --prev) { + auto const* prev_flag_ptr = + ready_flags_per_expert + slot_idx - prev; + while (ld_acquire_sys_global(prev_flag_ptr) == 0) {} + } + } + __syncwarp(); + + auto* chunk_msg_ptr = + rdma_x_msg_ptr - (chunk_len_slots - 1) * num_bytes_per_msg; + auto const src_ptr = + reinterpret_cast(chunk_msg_ptr); auto const dst_rank = dst_expert_idx / num_local_experts; - auto const dst_expert_local_idx = dst_expert_idx % num_local_experts; - auto const src_ptr = reinterpret_cast(rdma_x_src_idx); + auto const dst_expert_local_idx = + dst_expert_idx % num_local_experts; auto const dst_ptr = reinterpret_cast(rdma_recv_x) + dst_expert_local_idx * num_ranks * num_max_dispatch_tokens_per_rank * num_bytes_per_msg + rank * num_max_dispatch_tokens_per_rank * num_bytes_per_msg + - slot_idx * num_bytes_per_msg; + chunk_start_slot * num_bytes_per_msg; auto const dst_p2p_ptr = ipc_rdma_base_ptrs ? uccl::get_ipc_p2p_ptr(dst_ptr, ipc_rdma_base_ptrs, rank, dst_rank, max_nvl_peers, 0) : 0; + auto const chunk_bytes = + static_cast(chunk_len_slots) * num_bytes_per_msg; if (dst_p2p_ptr == 0) { __threadfence_system(); + // if (lane_id == 0) { + // printf("[dispatch] IBGDA PUT dst_rank: %d, dst_expert_idx:" + // " %d, slot_idx: %d, chunk_bytes: %llu\n", dst_rank, dst_expert_idx, slot_idx, (unsigned long long)chunk_bytes); + // } uccl::nvshmemi_ibgda_put_nbi_warp( dst_ptr - reinterpret_cast(rdma_buffer_ptr), src_ptr - reinterpret_cast(rdma_buffer_ptr), - num_bytes_per_msg, dst_rank, + chunk_bytes, dst_rank, /*warp_id=*/dst_expert_local_idx, // NOTE(Yang): for selecting // rb. lane_id, slot_idx, d2h_channel_addrs, num_d2h_channel_addrs, - false, low_latency_buffer_idx); + false, low_latency_buffer_idx, 0, chunk_len_slots); } else { // Intra-node: use direct memory copy via IPC auto const* src_int4_ptr = reinterpret_cast(src_ptr); auto* dst_int4_ptr = reinterpret_cast(dst_p2p_ptr); - UNROLLED_WARP_COPY(8, lane_id, num_int4_per_msg, dst_int4_ptr, - src_int4_ptr, ld_nc_global, st_na_global); + UNROLLED_WARP_COPY(8, lane_id, static_cast(num_int4_per_msg * chunk_len_slots), dst_int4_ptr, src_int4_ptr, ld_nc_global, st_na_global); + } } - // Increase counter after finishing __syncwarp(); - lane_id == 0 ? atomic_add_release_global( - atomic_finish_counter_per_expert + dst_expert_idx, 1) - : 0; + + if (chunk_len_slots > 0) { + // if (lane_id == 0) { + // printf("update atomic_finish_counter_per_expert for dst_expert_idx: %d, chunk_len_slots: %d\n", + // dst_expert_idx, chunk_len_slots); + // } + lane_id == 0 ? atomic_add_release_global( + atomic_finish_counter_per_expert + dst_expert_idx, chunk_len_slots) + : 0; + } } } } else if (warp_id == num_warps - 1) { @@ -274,6 +372,9 @@ __global__ __launch_bounds__(1024, 1) void dispatch( atomic_add_release_global(atomic_finish_counter_per_expert + i, FINISHED_SUM_TAG - sum); } + // if (lane_id == 0) + // printf("sm_id: %d, atomic_finish_counter_per_expert subtract for expert_idx: %d, sum: %d\n", sm_id, + // i, sum); } } __syncthreads(); @@ -303,6 +404,10 @@ __global__ __launch_bounds__(1024, 1) void dispatch( : 0; if (dst_p2p_ptr == 0) { // Inter-node or no IPC: use IBGDA atomic + // if (lane_id == 0) { + // printf("[dispatch] IBGDA AMO dst_rank: %d, dst_expert_idx: %d, num_tokens_sent: %d\n", + // dst_rank, responsible_expert_idx, num_tokens_sent); + // } uccl::nvshmemi_ibgda_amo_nonfetch_add( dst_ptr_internode, reinterpret_cast(atomic_buffer_ptr), -num_tokens_sent - 1, dst_rank, @@ -520,8 +625,11 @@ void dispatch(void* packed_recv_x, void* packed_recv_x_scales, // Workspace checks auto atomic_counter_per_expert = static_cast(workspace); auto atomic_finish_counter_per_expert = - atomic_counter_per_expert + num_experts; - EP_HOST_ASSERT(num_experts * sizeof(int) * 2 <= NUM_WORKSPACE_BYTES); + atomic_counter_per_expert + num_experts; + auto token_ready_flags = atomic_finish_counter_per_expert + num_experts; + EP_HOST_ASSERT(static_cast(num_experts) * sizeof(int) * + static_cast(2 + num_max_dispatch_tokens_per_rank) <= + NUM_WORKSPACE_BYTES); // FP8 checks if (use_ue8m0) @@ -538,7 +646,7 @@ void dispatch(void* packed_recv_x, void* packed_recv_x_scales, packed_recv_src_info, packed_recv_layout_range, packed_recv_count, \ cumulative_local_expert_recv_stats, dispatch_wait_recv_cost_stats, \ rdma_recv_x, rdma_recv_count, rdma_x, x, topk_idx, \ - atomic_counter_per_expert, atomic_finish_counter_per_expert, \ + atomic_counter_per_expert, atomic_finish_counter_per_expert, token_ready_flags, \ next_clean, next_clean_second, num_next_clean_int, num_tokens, \ num_max_dispatch_tokens_per_rank, num_topk, num_experts, rank, \ num_ranks, num_warp_groups, num_warps_per_group, round_scale, phases, \ @@ -567,9 +675,9 @@ __global__ __launch_bounds__(1024, 1) void combine( void* combined_x, void* rdma_recv_x, int* rdma_recv_flag, void* rdma_send_x, void const* x, int64_t const* topk_idx, float const* topk_weights, int const* src_info, int64_t const* layout_range, - int64_t* combine_wait_recv_cost_stats, int* next_clean, - int* next_clean_second, int num_next_clean_int, int* atomic_clean_flag, - int num_combined_tokens, int hidden, int num_topk, + int64_t* combine_wait_recv_cost_stats, int* next_clean, + int* next_clean_second, int num_next_clean_int, int* atomic_clean_flag, + int* token_ready_flags, int num_combined_tokens, int hidden, int num_topk, int num_max_dispatch_tokens_per_rank, int num_experts, int rank, int num_ranks, int num_warp_groups, int num_warps_per_group, int phases, bool zero_copy, uint64_t const* d2h_channel_addrs, @@ -600,6 +708,8 @@ __global__ __launch_bounds__(1024, 1) void combine( constexpr size_t num_bytes_per_slot = kHidden * sizeof(nv_bfloat16); EP_STATIC_ASSERT(num_bytes_per_slot % sizeof(int4) == 0, "Invalid vectorization"); + constexpr int kPrevSlotsToCheck = 4; + EP_STATIC_ASSERT(kPrevSlotsToCheck > 0, "Chunk size must be positive"); // Sending phase if ((phases & LOW_LATENCY_SEND_PHASE) == 0) goto LOW_LATENCY_COMBINE_RECV; @@ -634,6 +744,9 @@ __global__ __launch_bounds__(1024, 1) void combine( local_expert_idx * num_ranks * num_max_dispatch_tokens_per_rank * num_bytes_per_slot; + auto ready_flags_per_expert = token_ready_flags + + responsible_expert_idx * + num_max_dispatch_tokens_per_rank; // Unpack layout int offset, num_tokens_to_send; @@ -712,6 +825,12 @@ __global__ __launch_bounds__(1024, 1) void combine( ? uccl::get_ipc_p2p_ptr(dst_ptr, ipc_rdma_base_ptrs, rank, dst_rank, max_nvl_peers, 0) : 0; + auto const local_slot = token_idx - offset; + + if (dst_p2p_ptr == 0 && lane_id == 0) { + EP_DEVICE_ASSERT(local_slot < num_max_dispatch_tokens_per_rank); + st_release_sys_global(ready_flags_per_expert + local_slot, 0); + } if (not zero_copy or dst_p2p_ptr != 0) { // Read from `cpy_src_int4_ptr` and copy into `cpy_dst_int4_ptr` @@ -848,16 +967,102 @@ __global__ __launch_bounds__(1024, 1) void combine( // NOTES: for zero-copy mode, we assume the data is already in the send // buffer if (dst_p2p_ptr == 0) { - __threadfence_system(); - nvshmemi_ibgda_put_nbi_warp( - dst_ptr - reinterpret_cast(rdma_buffer_ptr), - buf_ptr - reinterpret_cast(rdma_buffer_ptr), - hidden * sizeof(nv_bfloat16), dst_rank, - /*warp_id=*/global_expert_idx, // NOTE(Yang): for selecting rb. - // NOTE(Ziming): this is global_expert_idx because destination is - // indexed by global_expert_idx - lane_id, token_idx - offset, d2h_channel_addrs, - num_d2h_channel_addrs, true, low_latency_buffer_idx); + if (lane_id == 0) + st_release_sys_global(ready_flags_per_expert + local_slot, 1); + __syncwarp(); + + bool issue_chunk_send = false; + int chunk_start_slot = local_slot; + int chunk_len_slots = 0; + int chunk_start_token_idx = 0; + + if (lane_id == 0) { + issue_chunk_send = ((local_slot + 1) % kPrevSlotsToCheck) == 0 || + (local_slot == num_tokens_to_send - 1); + if (issue_chunk_send) { + chunk_len_slots = + (local_slot == num_tokens_to_send - 1) + ? ((local_slot % kPrevSlotsToCheck) + 1) + : kPrevSlotsToCheck; + chunk_start_slot = local_slot - (chunk_len_slots - 1); + chunk_start_token_idx = offset + chunk_start_slot; + for (int i = 0; i < chunk_len_slots; ++i) { + auto const* flag_ptr = + ready_flags_per_expert + chunk_start_slot + i; + while (ld_acquire_sys_global(flag_ptr) == 0) {} + } + } + } + + issue_chunk_send = __shfl_sync(WARP_MASK, issue_chunk_send, 0); + chunk_start_slot = __shfl_sync(WARP_MASK, chunk_start_slot, 0); + chunk_len_slots = __shfl_sync(WARP_MASK, chunk_len_slots, 0); + chunk_start_token_idx = __shfl_sync(WARP_MASK, chunk_start_token_idx, 0); + + if (issue_chunk_send) { + int remaining_slots = chunk_len_slots; + int segment_token_idx = chunk_start_token_idx; + int segment_slot = chunk_start_slot; + while (remaining_slots > 0) { + int segment_src_idx = 0; + int segment_len_slots = 1; + if (lane_id == 0) { + segment_src_idx = __ldg(local_src_info + segment_token_idx); + for (int i = 1; i < remaining_slots; ++i) { + auto next_src_idx = + __ldg(local_src_info + segment_token_idx + i); + if (next_src_idx != segment_src_idx + i) break; + ++segment_len_slots; + } + } + segment_src_idx = __shfl_sync(WARP_MASK, segment_src_idx, 0); + segment_len_slots = + __shfl_sync(WARP_MASK, segment_len_slots, 0); + + auto const segment_dst_ptr = + reinterpret_cast(rdma_recv_x) + + (global_expert_idx * num_max_dispatch_tokens_per_rank + + segment_src_idx) * + num_bytes_per_slot; + auto const segment_src_ptr = + reinterpret_cast(rdma_send_x_vec) + + static_cast(segment_token_idx) * + num_bytes_per_slot; + auto const segment_bytes = + static_cast(segment_len_slots) * num_bytes_per_slot; + + __threadfence_system(); + // if (lane_id == 0) + // if (segment_len_slots > 1) + // printf("[combine] IBGDA put_nbi_warp: expert %d, dst_rank " + // "%d, slot %d, len %d\n", + // global_expert_idx, dst_rank, segment_slot, + // segment_len_slots); + uccl::nvshmemi_ibgda_put_nbi_warp( + segment_dst_ptr - + reinterpret_cast(rdma_buffer_ptr), + segment_src_ptr - + reinterpret_cast(rdma_buffer_ptr), + segment_bytes, dst_rank, + /*expert_idx=*/global_expert_idx, lane_id, + segment_slot + segment_len_slots - 1, d2h_channel_addrs, + num_d2h_channel_addrs, true, low_latency_buffer_idx, 0, + segment_len_slots); + + + if (lane_id == 0) { + for (int i = 0; i < segment_len_slots; ++i) + st_release_sys_global(ready_flags_per_expert + + segment_slot + i, + 0); + } + __syncwarp(); + + remaining_slots -= segment_len_slots; + segment_token_idx += segment_len_slots; + segment_slot += segment_len_slots; + } + } } } @@ -1035,7 +1240,12 @@ void combine(void* combined_x, void* rdma_recv_x, int* rdma_recv_flag, // Check workspace auto atomic_clean_flag = static_cast(workspace); - EP_HOST_ASSERT(sizeof(int) <= NUM_WORKSPACE_BYTES); + auto token_ready_flags = atomic_clean_flag + 1; + auto const required_workspace_ints = + static_cast(1) + + static_cast(num_experts) * + static_cast(num_max_dispatch_tokens_per_rank); + EP_HOST_ASSERT(required_workspace_ints * sizeof(int) <= NUM_WORKSPACE_BYTES); EP_HOST_ASSERT(num_topk <= kNumMaxTopk); // Online cast cannot use zero-copy @@ -1050,11 +1260,12 @@ void combine(void* combined_x, void* rdma_recv_x, int* rdma_recv_flag, auto combine_func = use_logfmt ? combine \ : combine; \ SET_SHARED_MEMORY_FOR_TMA(combine_func); \ - LAUNCH_KERNEL( \ - &cfg, combine_func, combined_x, rdma_recv_x, rdma_recv_flag, \ - rdma_send_x, x, topk_idx, topk_weights, src_info, layout_range, \ - combine_wait_recv_cost_stats, next_clean, next_clean_second, \ - num_next_clean_int, atomic_clean_flag, num_combined_tokens, hidden, \ + LAUNCH_KERNEL( \ + &cfg, combine_func, combined_x, rdma_recv_x, rdma_recv_flag, \ + rdma_send_x, x, topk_idx, topk_weights, src_info, layout_range, \ + combine_wait_recv_cost_stats, next_clean, next_clean_second, \ + num_next_clean_int, atomic_clean_flag, token_ready_flags, \ + num_combined_tokens, hidden, \ num_topk, num_max_dispatch_tokens_per_rank, num_experts, rank, \ num_ranks, num_warp_groups, num_warps_per_group, phases, zero_copy, \ d2h_channel_addrs, num_d2h_channel_addrs, max_nvl_peers, \ diff --git a/ep/src/rdma.cpp b/ep/src/rdma.cpp index 9027c11bf..6295edf29 100644 --- a/ep/src/rdma.cpp +++ b/ep/src/rdma.cpp @@ -1065,7 +1065,7 @@ static void post_rdma_async_batched_fast_mode( #ifdef USE_RECEIVER_BARRIER uint32_t imm = WriteImm::Pack(get_is_combine(cmd.cmd_type), get_low_latency(cmd.cmd_type), - cmd.expert_idx, 1, my_rank) + cmd.expert_idx, cmd.atomic_val == 0 ? 1 : cmd.atomic_val, my_rank) .GetImmData(); ibv_wr_rdma_write_imm(qpx, ctx->remote_rkey, remote_addr, htonl(imm)); #else