diff --git a/ep/src/uccl_ep.cc b/ep/src/uccl_ep.cc index ccdcf6bb2..c873ce0d7 100644 --- a/ep/src/uccl_ep.cc +++ b/ep/src/uccl_ep.cc @@ -518,10 +518,6 @@ class Buffer { std::uintptr_t is_token_in_rank_ptr, std::optional& previous_event, bool async, bool allocate_on_comm_stream, std::uintptr_t compute_stream_ptr) { - EP_HOST_ASSERT(topk_idx_ptr != 0); - EP_HOST_ASSERT(num_tokens_per_rank_ptr != 0); - EP_HOST_ASSERT(num_tokens_per_expert_ptr != 0); - EP_HOST_ASSERT(is_token_in_rank_ptr != 0); EP_HOST_ASSERT(num_experts > 0); auto compute_stream = reinterpret_cast(compute_stream_ptr); @@ -536,6 +532,27 @@ class Buffer { stream_wait(comm_stream, compute_stream); } + // Early return for 0-token case (rank has no tokens to dispatch) + if (num_tokens <= 0 || topk_idx_ptr == 0) { + // Zero-fill output arrays so callers see "nothing to send" + if (num_tokens_per_rank_ptr != 0) + CUDA_CHECK(cudaMemsetAsync(reinterpret_cast(num_tokens_per_rank_ptr), + 0, num_ranks * sizeof(int), comm_stream)); + if (num_tokens_per_rdma_rank_ptr != 0) + CUDA_CHECK(cudaMemsetAsync(reinterpret_cast(num_tokens_per_rdma_rank_ptr), + 0, get_num_rdma_ranks() * sizeof(int), comm_stream)); + if (num_tokens_per_expert_ptr != 0) + CUDA_CHECK(cudaMemsetAsync(reinterpret_cast(num_tokens_per_expert_ptr), + 0, num_experts * sizeof(int), comm_stream)); + std::optional event; + if (async) { + event = EventHandle(comm_stream); + } else { + stream_wait(compute_stream, comm_stream); + } + return event; + } + auto* topk_idx = reinterpret_cast(topk_idx_ptr); auto* num_tokens_per_rank = reinterpret_cast(num_tokens_per_rank_ptr); auto* num_tokens_per_rdma_rank = @@ -637,13 +654,7 @@ class Buffer { std::optional& previous_event, bool async, bool allocate_on_comm_stream, std::uintptr_t compute_stream_ptr) { - EP_HOST_ASSERT(num_tokens > 0); EP_HOST_ASSERT(num_experts > 0); - EP_HOST_ASSERT(num_tokens_per_rank_ptr != 0); - EP_HOST_ASSERT(is_token_in_rank_ptr != 0); - EP_HOST_ASSERT(num_tokens_per_expert_ptr != 0); - EP_HOST_ASSERT(rank_prefix_matrix_ptr != 0); - EP_HOST_ASSERT(channel_prefix_matrix_ptr != 0); EP_HOST_ASSERT(config.num_sms % 2 == 0); int num_channels = config.num_sms / 2; @@ -658,8 +669,17 @@ class Buffer { stream_wait(comm_stream, compute_stream); } + // No early return for 0-token case — must participate in collective notify + EP_HOST_ASSERT(num_tokens_per_rank_ptr != 0); + EP_HOST_ASSERT(num_tokens_per_expert_ptr != 0); + EP_HOST_ASSERT(rank_prefix_matrix_ptr != 0); + EP_HOST_ASSERT(channel_prefix_matrix_ptr != 0); + int* num_tokens_per_rank = reinterpret_cast(num_tokens_per_rank_ptr); - bool* is_token_in_rank = reinterpret_cast(is_token_in_rank_ptr); + // is_token_in_rank may be 0 when num_tokens == 0 (empty tensor), use workspace as safe dummy + bool* is_token_in_rank = is_token_in_rank_ptr != 0 + ? reinterpret_cast(is_token_in_rank_ptr) + : nullptr; int* num_tokens_per_expert = reinterpret_cast(num_tokens_per_expert_ptr); int* rank_prefix_matrix = reinterpret_cast(rank_prefix_matrix_ptr); @@ -725,11 +745,8 @@ class Buffer { std::uintptr_t recv_src_idx_ptr, std::uintptr_t send_head_ptr, std::optional& previous_event, bool async, bool allocate_on_comm_stream, std::uintptr_t compute_stream_ptr) { - EP_HOST_ASSERT(x_ptr != 0 && is_token_in_rank_ptr != 0); - EP_HOST_ASSERT(channel_prefix_matrix_ptr != 0); - EP_HOST_ASSERT(recv_x_ptr != 0 && recv_channel_prefix_matrix_ptr != 0); - EP_HOST_ASSERT(recv_src_idx_ptr != 0 && send_head_ptr != 0); - EP_HOST_ASSERT(num_tokens > 0 && hidden > 0 && num_recv_tokens >= 0); + // x_ptr and is_token_in_rank_ptr may be 0 when num_tokens == 0 + EP_HOST_ASSERT(hidden > 0); EP_HOST_ASSERT((hidden * x_element_size) % static_cast(sizeof(int4)) == 0); @@ -742,16 +759,20 @@ class Buffer { } else { stream_wait(comm_stream, compute_stream); } + + // No early return — must participate in collective even with 0 tokens + if (cached_mode) { - EP_HOST_ASSERT(rank_prefix_matrix_ptr != 0); - int num_memset_int = num_channels * num_ranks * 4; - uccl::intranode::cached_notify_dispatch( - reinterpret_cast(rank_prefix_matrix_ptr), num_memset_int, - buffer_ptrs_gpu, barrier_signal_ptrs_gpu, rank, num_ranks, - comm_stream); + if (rank_prefix_matrix_ptr != 0) { + int num_memset_int = num_channels * num_ranks * 4; + uccl::intranode::cached_notify_dispatch( + reinterpret_cast(rank_prefix_matrix_ptr), num_memset_int, + buffer_ptrs_gpu, barrier_signal_ptrs_gpu, rank, num_ranks, + comm_stream); + } } - auto* x = reinterpret_cast(x_ptr); + auto* x = x_ptr == 0 ? nullptr : reinterpret_cast(x_ptr); auto* x_scales = x_scales_ptr == 0 ? nullptr : reinterpret_cast(x_scales_ptr); auto* topk_idx = @@ -761,19 +782,22 @@ class Buffer { : reinterpret_cast(topk_weights_ptr); uccl::intranode::dispatch( - reinterpret_cast(recv_x_ptr), + recv_x_ptr == 0 ? nullptr : reinterpret_cast(recv_x_ptr), recv_x_scales_ptr == 0 ? nullptr : reinterpret_cast(recv_x_scales_ptr), - reinterpret_cast(recv_src_idx_ptr), + recv_src_idx_ptr == 0 ? nullptr : reinterpret_cast(recv_src_idx_ptr), recv_topk_idx_ptr == 0 ? nullptr : reinterpret_cast(recv_topk_idx_ptr), recv_topk_weights_ptr == 0 ? nullptr : reinterpret_cast(recv_topk_weights_ptr), - reinterpret_cast(recv_channel_prefix_matrix_ptr), - reinterpret_cast(send_head_ptr), x, x_scales, topk_idx, - topk_weights, reinterpret_cast(is_token_in_rank_ptr), - reinterpret_cast(channel_prefix_matrix_ptr), num_tokens, + recv_channel_prefix_matrix_ptr == 0 ? reinterpret_cast(workspace) : reinterpret_cast(recv_channel_prefix_matrix_ptr), + send_head_ptr == 0 ? nullptr : reinterpret_cast(send_head_ptr), + x, x_scales, topk_idx, + topk_weights, + is_token_in_rank_ptr == 0 ? nullptr : reinterpret_cast(is_token_in_rank_ptr), + channel_prefix_matrix_ptr == 0 ? reinterpret_cast(workspace) : reinterpret_cast(channel_prefix_matrix_ptr), + num_tokens, num_worst_tokens, static_cast(hidden * x_element_size / sizeof(int4)), num_topk, num_experts, num_scales, scale_token_stride, scale_hidden_stride, @@ -801,10 +825,6 @@ class Buffer { std::uintptr_t recv_topk_weights_ptr, std::optional& previous_event, bool async, bool allocate_on_comm_stream, std::uintptr_t compute_stream_ptr) { - EP_HOST_ASSERT(x_ptr != 0 && src_idx_ptr != 0 && - rank_prefix_matrix_ptr != 0); - EP_HOST_ASSERT(channel_prefix_matrix_ptr != 0 && send_head_ptr != 0); - EP_HOST_ASSERT(recv_x_ptr != 0); EP_HOST_ASSERT((hidden * x_element_size) % static_cast(sizeof(int4)) == 0); @@ -819,10 +839,14 @@ class Buffer { stream_wait(comm_stream, compute_stream); } + // No early return — must participate in collective even with 0 tokens + EP_HOST_ASSERT(num_channels * num_ranks * sizeof(int) * 2 <= static_cast(num_nvl_bytes)); uccl::intranode::cached_notify_combine( - buffer_ptrs_gpu, reinterpret_cast(send_head_ptr), num_channels, + buffer_ptrs_gpu, + send_head_ptr == 0 ? nullptr : reinterpret_cast(send_head_ptr), + num_channels, num_recv_tokens, num_channels * num_ranks * 2, barrier_signal_ptrs_gpu, rank, num_ranks, comm_stream); @@ -831,17 +855,20 @@ class Buffer { bias_1_ptr == 0 ? nullptr : reinterpret_cast(bias_1_ptr), }; uccl::intranode::combine( - cuda_dtype_from_code(x_dtype_code), reinterpret_cast(recv_x_ptr), + cuda_dtype_from_code(x_dtype_code), + recv_x_ptr == 0 ? nullptr : reinterpret_cast(recv_x_ptr), recv_topk_weights_ptr == 0 ? nullptr : reinterpret_cast(recv_topk_weights_ptr), - reinterpret_cast(x_ptr), + x_ptr == 0 ? nullptr : reinterpret_cast(x_ptr), topk_weights_ptr == 0 ? nullptr : reinterpret_cast(topk_weights_ptr), - bias_ptrs[0], bias_ptrs[1], reinterpret_cast(src_idx_ptr), - reinterpret_cast(rank_prefix_matrix_ptr), - reinterpret_cast(channel_prefix_matrix_ptr), - reinterpret_cast(send_head_ptr), num_tokens, num_recv_tokens, + bias_ptrs[0], bias_ptrs[1], + src_idx_ptr == 0 ? nullptr : reinterpret_cast(src_idx_ptr), + rank_prefix_matrix_ptr == 0 ? nullptr : reinterpret_cast(rank_prefix_matrix_ptr), + channel_prefix_matrix_ptr == 0 ? nullptr : reinterpret_cast(channel_prefix_matrix_ptr), + send_head_ptr == 0 ? nullptr : reinterpret_cast(send_head_ptr), + num_tokens, num_recv_tokens, hidden, num_topk, buffer_ptrs_gpu, rank, num_ranks, comm_stream, config.num_sms, config.num_max_nvl_chunked_send_tokens, config.num_max_nvl_chunked_recv_tokens); @@ -871,15 +898,7 @@ class Buffer { bool allocate_on_comm_stream, std::uintptr_t compute_stream_ptr) { nb::gil_scoped_release release; - EP_HOST_ASSERT(num_tokens_per_rank_ptr != 0); - EP_HOST_ASSERT(num_tokens_per_rdma_rank_ptr != 0); - EP_HOST_ASSERT(num_tokens_per_expert_ptr != 0); - EP_HOST_ASSERT(is_token_in_rank_ptr != 0); - EP_HOST_ASSERT(rdma_channel_prefix_matrix_ptr != 0); - EP_HOST_ASSERT(recv_rdma_rank_prefix_sum_ptr != 0); - EP_HOST_ASSERT(gbl_channel_prefix_matrix_ptr != 0); - EP_HOST_ASSERT(recv_gbl_rank_prefix_sum_ptr != 0); - EP_HOST_ASSERT(num_tokens > 0 && hidden > 0 && num_experts > 0); + EP_HOST_ASSERT(hidden > 0 && num_experts > 0); EP_HOST_ASSERT(config.num_sms % 2 == 0); EP_HOST_ASSERT(0 < get_num_rdma_ranks() && get_num_rdma_ranks() <= NUM_MAX_RDMA_PEERS); @@ -898,6 +917,16 @@ class Buffer { stream_wait(comm_stream, compute_stream); } + // No early return — must call notify_dispatch so other ranks see our token counts + EP_HOST_ASSERT(num_tokens_per_rank_ptr != 0); + EP_HOST_ASSERT(num_tokens_per_rdma_rank_ptr != 0); + EP_HOST_ASSERT(num_tokens_per_expert_ptr != 0); + // is_token_in_rank_ptr may be 0 when num_tokens == 0 (empty tensor) + EP_HOST_ASSERT(rdma_channel_prefix_matrix_ptr != 0); + EP_HOST_ASSERT(recv_rdma_rank_prefix_sum_ptr != 0); + EP_HOST_ASSERT(gbl_channel_prefix_matrix_ptr != 0); + EP_HOST_ASSERT(recv_gbl_rank_prefix_sum_ptr != 0); + *moe_recv_counter = -1; *moe_recv_rdma_counter = -1; for (int i = 0; i < num_local_experts; ++i) moe_recv_expert_counter[i] = -1; @@ -909,7 +938,11 @@ class Buffer { moe_recv_rdma_counter_mapped, reinterpret_cast(num_tokens_per_expert_ptr), moe_recv_expert_counter_mapped, num_experts, - reinterpret_cast(is_token_in_rank_ptr), num_tokens, + // is_token_in_rank may be null for 0-token case, use workspace as safe dummy + is_token_in_rank_ptr != 0 + ? reinterpret_cast(is_token_in_rank_ptr) + : nullptr, + num_tokens, num_worst_tokens, num_channels, hidden_int4, num_scales, num_topk, expert_alignment, reinterpret_cast(rdma_channel_prefix_matrix_ptr), @@ -977,12 +1010,10 @@ class Buffer { std::uintptr_t send_rdma_head_ptr, std::uintptr_t send_nvl_head_ptr, std::optional& previous_event, bool async, bool allocate_on_comm_stream, std::uintptr_t compute_stream_ptr) { - EP_HOST_ASSERT(x_ptr != 0 && recv_x_ptr != 0 && is_token_in_rank_ptr != 0); - EP_HOST_ASSERT(rdma_channel_prefix_matrix_ptr != 0); - EP_HOST_ASSERT(recv_rdma_rank_prefix_sum_ptr != 0); - EP_HOST_ASSERT(gbl_channel_prefix_matrix_ptr != 0); - EP_HOST_ASSERT(recv_gbl_rank_prefix_sum_ptr != 0); - EP_HOST_ASSERT(num_tokens > 0 && hidden > 0); + // recv pointers may be 0 when num_recv_tokens or num_rdma_recv_tokens == 0 + // (torch.empty((0, ...)).data_ptr() returns 0 in PyTorch) + // x_ptr and is_token_in_rank_ptr may be 0 when num_tokens == 0 + EP_HOST_ASSERT(hidden > 0); EP_HOST_ASSERT(config.num_sms % 2 == 0); int const num_channels = config.num_sms / 2; @@ -996,6 +1027,8 @@ class Buffer { stream_wait(comm_stream, compute_stream); } + // No early return — must participate in collective dispatch even with 0 tokens + if (cached_mode) { uccl::internode::cached_notify( hidden_int4, num_scales, num_topk, num_topk, num_ranks, num_channels, @@ -1009,16 +1042,13 @@ class Buffer { num_ranks), num_nvl_bytes, true, low_latency_mode, d_handles, num_d2h_channel_addrs, atomic_buffer_ptr); - } else { - EP_HOST_ASSERT(recv_src_meta_ptr != 0); - EP_HOST_ASSERT(send_rdma_head_ptr != 0); - EP_HOST_ASSERT(send_nvl_head_ptr != 0); - EP_HOST_ASSERT(recv_rdma_channel_prefix_matrix_ptr != 0); - EP_HOST_ASSERT(recv_gbl_channel_prefix_matrix_ptr != 0); } + // No asserts on recv_src_meta_ptr, send_rdma_head_ptr, send_nvl_head_ptr, + // recv_rdma/gbl_channel_prefix_matrix_ptr — they may be 0 when + // num_recv_tokens or num_rdma_recv_tokens == 0. uccl::internode::dispatch( - reinterpret_cast(recv_x_ptr), + recv_x_ptr == 0 ? nullptr : reinterpret_cast(recv_x_ptr), recv_x_scales_ptr == 0 ? nullptr : reinterpret_cast(recv_x_scales_ptr), recv_topk_idx_ptr == 0 ? nullptr @@ -1026,8 +1056,8 @@ class Buffer { recv_topk_weights_ptr == 0 ? nullptr : reinterpret_cast(recv_topk_weights_ptr), - cached_mode ? nullptr : reinterpret_cast(recv_src_meta_ptr), - reinterpret_cast(x_ptr), + recv_src_meta_ptr == 0 ? nullptr : reinterpret_cast(recv_src_meta_ptr), + x_ptr == 0 ? nullptr : reinterpret_cast(x_ptr), x_scales_ptr == 0 ? nullptr : reinterpret_cast(x_scales_ptr), topk_idx_ptr == 0 ? nullptr @@ -1035,19 +1065,22 @@ class Buffer { topk_weights_ptr == 0 ? nullptr : reinterpret_cast(topk_weights_ptr), - cached_mode ? nullptr : reinterpret_cast(send_rdma_head_ptr), - cached_mode ? nullptr : reinterpret_cast(send_nvl_head_ptr), - cached_mode + send_rdma_head_ptr == 0 ? nullptr : reinterpret_cast(send_rdma_head_ptr), + send_nvl_head_ptr == 0 ? nullptr : reinterpret_cast(send_nvl_head_ptr), + recv_rdma_channel_prefix_matrix_ptr == 0 ? nullptr : reinterpret_cast(recv_rdma_channel_prefix_matrix_ptr), - cached_mode + recv_gbl_channel_prefix_matrix_ptr == 0 ? nullptr : reinterpret_cast(recv_gbl_channel_prefix_matrix_ptr), reinterpret_cast(rdma_channel_prefix_matrix_ptr), reinterpret_cast(recv_rdma_rank_prefix_sum_ptr), reinterpret_cast(gbl_channel_prefix_matrix_ptr), reinterpret_cast(recv_gbl_rank_prefix_sum_ptr), - reinterpret_cast(is_token_in_rank_ptr), num_tokens, + is_token_in_rank_ptr == 0 + ? nullptr + : reinterpret_cast(is_token_in_rank_ptr), + num_tokens, num_worst_tokens, hidden_int4, num_scales, num_topk, num_experts, scale_token_stride, scale_hidden_stride, rdma_buffer_ptr, config.num_max_rdma_chunked_send_tokens, @@ -1080,13 +1113,6 @@ class Buffer { std::uintptr_t combined_x_ptr, std::uintptr_t combined_topk_weights_ptr, std::optional& previous_event, bool async, bool allocate_on_comm_stream, std::uintptr_t compute_stream_ptr) { - EP_HOST_ASSERT(x_ptr != 0 && src_meta_ptr != 0 && combined_x_ptr != 0); - EP_HOST_ASSERT(is_combined_token_in_rank_ptr != 0); - EP_HOST_ASSERT(rdma_channel_prefix_matrix_ptr != 0); - EP_HOST_ASSERT(rdma_rank_prefix_sum_ptr != 0); - EP_HOST_ASSERT(gbl_channel_prefix_matrix_ptr != 0); - EP_HOST_ASSERT(combined_rdma_head_ptr != 0); - EP_HOST_ASSERT(combined_nvl_head_ptr != 0); EP_HOST_ASSERT(config.num_sms % 2 == 0); int const num_channels = config.num_sms / 2; @@ -1100,12 +1126,16 @@ class Buffer { stream_wait(comm_stream, compute_stream); } + // No early return — must participate in collective combine even with 0 tokens + uccl::internode::cached_notify( hidden_int4, 0, 0, num_topk, num_ranks, num_channels, - num_combined_tokens, reinterpret_cast(combined_rdma_head_ptr), - reinterpret_cast(rdma_channel_prefix_matrix_ptr), - reinterpret_cast(rdma_rank_prefix_sum_ptr), - reinterpret_cast(combined_nvl_head_ptr), rdma_buffer_ptr, + num_combined_tokens, + combined_rdma_head_ptr == 0 ? nullptr : reinterpret_cast(combined_rdma_head_ptr), + rdma_channel_prefix_matrix_ptr == 0 ? nullptr : reinterpret_cast(rdma_channel_prefix_matrix_ptr), + rdma_rank_prefix_sum_ptr == 0 ? nullptr : reinterpret_cast(rdma_rank_prefix_sum_ptr), + combined_nvl_head_ptr == 0 ? nullptr : reinterpret_cast(combined_nvl_head_ptr), + rdma_buffer_ptr, config.num_max_rdma_chunked_recv_tokens, buffer_ptrs_gpu, config.num_max_nvl_chunked_recv_tokens, barrier_signal_ptrs_gpu, rank, comm_stream, @@ -1119,22 +1149,22 @@ class Buffer { }; uccl::internode::combine( cuda_dtype_from_code(x_dtype_code), - reinterpret_cast(combined_x_ptr), + combined_x_ptr == 0 ? nullptr : reinterpret_cast(combined_x_ptr), combined_topk_weights_ptr == 0 ? nullptr : reinterpret_cast(combined_topk_weights_ptr), - reinterpret_cast(is_combined_token_in_rank_ptr), - reinterpret_cast(x_ptr), + is_combined_token_in_rank_ptr == 0 ? nullptr : reinterpret_cast(is_combined_token_in_rank_ptr), + x_ptr == 0 ? nullptr : reinterpret_cast(x_ptr), topk_weights_ptr == 0 ? nullptr : reinterpret_cast(topk_weights_ptr), bias_ptrs[0], bias_ptrs[1], - reinterpret_cast(combined_rdma_head_ptr), - reinterpret_cast(combined_nvl_head_ptr), - reinterpret_cast(src_meta_ptr), - reinterpret_cast(rdma_channel_prefix_matrix_ptr), - reinterpret_cast(rdma_rank_prefix_sum_ptr), - reinterpret_cast(gbl_channel_prefix_matrix_ptr), num_tokens, + combined_rdma_head_ptr == 0 ? nullptr : reinterpret_cast(combined_rdma_head_ptr), + combined_nvl_head_ptr == 0 ? nullptr : reinterpret_cast(combined_nvl_head_ptr), + src_meta_ptr == 0 ? nullptr : reinterpret_cast(src_meta_ptr), + rdma_channel_prefix_matrix_ptr == 0 ? nullptr : reinterpret_cast(rdma_channel_prefix_matrix_ptr), + rdma_rank_prefix_sum_ptr == 0 ? nullptr : reinterpret_cast(rdma_rank_prefix_sum_ptr), + gbl_channel_prefix_matrix_ptr == 0 ? nullptr : reinterpret_cast(gbl_channel_prefix_matrix_ptr), num_tokens, num_combined_tokens, hidden, num_topk, rdma_buffer_ptr, config.num_max_rdma_chunked_send_tokens, config.num_max_rdma_chunked_recv_tokens, buffer_ptrs_gpu,