Fix intra-warp and inter-warp race conditions in bounds_check_indices v1 and v2 CUDA kernels#5638
Open
gchalump wants to merge 1 commit intopytorch:mainfrom
Open
Fix intra-warp and inter-warp race conditions in bounds_check_indices v1 and v2 CUDA kernels#5638gchalump wants to merge 1 commit intopytorch:mainfrom
gchalump wants to merge 1 commit intopytorch:mainfrom
Conversation
… v1 and v2 CUDA kernels Summary: Fix race conditions in `bounds_check_indices_kernel_v1` and `bounds_check_indices_kernel_v2`. Both kernels launch with `dim3(kWarpSize, kNumThreads/kWarpSize)` — all 32 lanes (`threadIdx.x`) in a warp share the same `(b, t)` pair (determined by `threadIdx.y`). **Intra-warp race fix (v1 and v2):** - Gate offset check/correction on `threadIdx.x == 0` so only lane 0 performs the check, warning print, and `adjust_offset_kernel` write. - Broadcast corrected `indices_start` and `indices_end` from lane 0 to all lanes via `__shfl_sync(0xFFFFFFFF, ..., 0)` before entering the index loop. **Inter-warp race fix — last-element correction (v1 only):** - Previously every warp ran the last-element correction (`offsets[total_B]`), causing N concurrent writes to the same address and inflated warning counts. - Gate on `warp_idx == 0` so exactly one warp in the grid performs the check, matching v2's existing `b_t_start == 0` guard. **Inter-warp race fix — IGNORE mode (v1 and v2):** - IGNORE mode previously called `adjust_offset_kernel` unconditionally, causing warp N+1 to write back its valid-but-uncorrected `offsets[N+1]` and overwrite warp N's correction of the same element. - Add the same bounds-check guard used by WARNING mode so `adjust_offset_kernel` only runs when offsets are actually out of range. **Correction ordering (v1 only):** - Reorder corrections so the last-element check (`offsets[total_B]`) runs before the per-b_t offset correction within the same lane-0 block. This ensures `indices_end` reflects the corrected `offsets[total_B]` when `b_t + 1 == total_B`. **Post-correction assertions (v1 and v2):** - Add post-broadcast `CUDA_KERNEL_ASSERT` checks on every lane to validate `indices_start >= 0`, `indices_start <= indices_end`, and `indices_end <= num_indices` after correction. - Add best-effort backward monotonicity check: assert that `indices_start >= offsets[b_t - 1]` to detect non-monotonic offsets that per-pair checks miss. This read may race with the previous warp's correction but catches the common case. The `__shfl_sync` is safe because the loop/branch conditions depend only on `threadIdx.y` (not `threadIdx.x`), so all lanes in a warp always take the same code paths and reach the shuffle together. Differential Revision: D100898565
Contributor
|
@gchalump has exported this pull request. If you are a Meta employee, you can view the originating Diff in D100898565. |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Summary:
Fix race conditions in
bounds_check_indices_kernel_v1andbounds_check_indices_kernel_v2.Both kernels launch with
dim3(kWarpSize, kNumThreads/kWarpSize)— all 32 lanes(
threadIdx.x) in a warp share the same(b, t)pair (determined bythreadIdx.y).Intra-warp race fix (v1 and v2):
threadIdx.x == 0so only lane 0 performs thecheck, warning print, and
adjust_offset_kernelwrite.indices_startandindices_endfrom lane 0 to all lanesvia
__shfl_sync(0xFFFFFFFF, ..., 0)before entering the index loop.Inter-warp race fix — last-element correction (v1 only):
offsets[total_B]),causing N concurrent writes to the same address and inflated warning counts.
warp_idx == 0so exactly one warp in the grid performs the check,matching v2's existing
b_t_start == 0guard.Inter-warp race fix — IGNORE mode (v1 and v2):
adjust_offset_kernelunconditionally, causingwarp N+1 to write back its valid-but-uncorrected
offsets[N+1]and overwritewarp N's correction of the same element.
adjust_offset_kernelonly runs when offsets are actually out of range.
Correction ordering (v1 only):
offsets[total_B]) runs beforethe per-b_t offset correction within the same lane-0 block. This ensures
indices_endreflects the correctedoffsets[total_B]whenb_t + 1 == total_B.Post-correction assertions (v1 and v2):
CUDA_KERNEL_ASSERTchecks on every lane to validateindices_start >= 0,indices_start <= indices_end, andindices_end <= num_indicesafter correction.indices_start >= offsets[b_t - 1]to detect non-monotonic offsets thatper-pair checks miss. This read may race with the previous warp's correction
but catches the common case.
The
__shfl_syncis safe because the loop/branch conditions depend only onthreadIdx.y(notthreadIdx.x), so all lanes in a warp always take the samecode paths and reach the shuffle together.
Differential Revision: D100898565