diff --git a/fbgemm_gpu/codegen/utils/embedding_bounds_check_v1.cu b/fbgemm_gpu/codegen/utils/embedding_bounds_check_v1.cu index 6dbfc24632..6bc8fa5409 100644 --- a/fbgemm_gpu/codegen/utils/embedding_bounds_check_v1.cu +++ b/fbgemm_gpu/codegen/utils/embedding_bounds_check_v1.cu @@ -21,7 +21,8 @@ __global__ __launch_bounds__(kMaxThreads) void bounds_check_indices_kernel_v1( FixedDivisor fd, TORCH_DSA_KERNEL_ARGS) { int32_t T = rows_per_table.size(0); - auto b_t = blockIdx.x * blockDim.y + threadIdx.y; + const auto warp_idx = blockIdx.x * blockDim.y + threadIdx.y; + auto b_t = warp_idx; int32_t b; int32_t t; int32_t B = 0; @@ -55,46 +56,116 @@ __global__ __launch_bounds__(kMaxThreads) void bounds_check_indices_kernel_v1( auto indices_end = offsets[b_t + 1]; const index_t num_indices = indices.size(0); - if (bounds_check_mode == BoundsCheckMode::FATAL) { - CUDA_KERNEL_ASSERT( - indices_start >= 0 && "indices_start must be non-negative"); - CUDA_KERNEL_ASSERT( - indices_start <= indices_end && - "indices_start must not exceed indices_end"); - CUDA_KERNEL_ASSERT( - indices_end <= num_indices && - "indices_end must not exceed num_indices"); - } else if (bounds_check_mode == BoundsCheckMode::WARNING) { - if (indices_start < 0 || indices_start > indices_end || - indices_end > num_indices) { - if (gpuAtomicIncrement(&warning[0]) == 0) { - printf( - "EmbeddingBoundsCheck (VBE %s): (at least one) Out of bounds access for " - "batch: %d, table: %d, indices_start: %lld, indices_end: %lld," - " num_indices: %lld. Setting indices_start and indices_end within " - "the range.\n", - vbe ? "true" : "false", - b, - t, - static_cast(indices_start), - static_cast(indices_end), - static_cast(num_indices)); + // Only lane 0 checks and corrects offsets. All 32 lanes in the warp share + // the same (b, t) pair, so having every lane write the same correction to + // offsets[b_t] / offsets[b_t+1] is redundant and constitutes a data race + // in the CUDA memory model (even though the values are identical). + if (threadIdx.x == 0) { + // Correct the last element first so that the per-b_t correction below + // (when b_t + 1 == total_B) reads the already-fixed offsets[total_B]. + // Gate on b_t_orig == 0 so only one warp in the grid runs this check, + // avoiding a multi-warp race on offsets[total_B] and warning inflation. + if (warp_idx == 0) { + if (bounds_check_mode == BoundsCheckMode::FATAL) { + CUDA_KERNEL_ASSERT( + num_indices == offsets[total_B] && + "num_indices must match the last element in offsets"); + } else if (bounds_check_mode == BoundsCheckMode::WARNING) { + if (num_indices != offsets[total_B]) { + if (gpuAtomicIncrement(&warning[0]) == 0) { + printf( + "EmbeddingBoundsCheck (VBE %s): the last element in offsets is incorrect for " + "total batch size %s: %d, total table num T: %d, " + " last element in offsets: %lld, indices size: %lld. " + " Setting the last element in offsets to be indices size.\n", + vbe ? "true" : "false", + vbe ? "total_B" : "B", + vbe ? total_B : B, + T, + static_cast(offsets[total_B]), + static_cast(num_indices)); + } + offsets[total_B] = num_indices; + } + } else if (bounds_check_mode == BoundsCheckMode::IGNORE) { + if (num_indices != offsets[total_B]) { + offsets[total_B] = num_indices; + } + } + } + + // Re-read indices_end in case offsets[b_t + 1] was the last element and + // was just corrected above. + indices_end = offsets[b_t + 1]; + + // Per-b_t offset correction. + if (bounds_check_mode == BoundsCheckMode::FATAL) { + CUDA_KERNEL_ASSERT( + indices_start >= 0 && "indices_start must be non-negative"); + CUDA_KERNEL_ASSERT( + indices_start <= indices_end && + "indices_start must not exceed indices_end"); + CUDA_KERNEL_ASSERT( + indices_end <= num_indices && + "indices_end must not exceed num_indices"); + } else if (bounds_check_mode == BoundsCheckMode::WARNING) { + if (indices_start < 0 || indices_start > indices_end || + indices_end > num_indices) { + if (gpuAtomicIncrement(&warning[0]) == 0) { + printf( + "EmbeddingBoundsCheck (VBE %s): (at least one) Out of bounds access for " + "batch: %d, table: %d, indices_start: %lld, indices_end: %lld," + " num_indices: %lld. Setting indices_start and indices_end within " + "the range.\n", + vbe ? "true" : "false", + b, + t, + static_cast(indices_start), + static_cast(indices_end), + static_cast(num_indices)); + } + adjust_offset_kernel( + indices_start, + indices_end, + num_indices, + &offsets[b_t], + &offsets[b_t + 1]); + } + } else if (bounds_check_mode == BoundsCheckMode::IGNORE) { + if (indices_start < 0 || indices_start > indices_end || + indices_end > num_indices) { + adjust_offset_kernel( + indices_start, + indices_end, + num_indices, + &offsets[b_t], + &offsets[b_t + 1]); } - adjust_offset_kernel( - indices_start, - indices_end, - num_indices, - &offsets[b_t], - &offsets[b_t + 1]); } - } else if (bounds_check_mode == BoundsCheckMode::IGNORE) { - adjust_offset_kernel( - indices_start, - indices_end, - num_indices, - &offsets[b_t], - &offsets[b_t + 1]); } + // Broadcast corrected indices_start/indices_end from lane 0 to all lanes. + indices_start = shfl_sync(indices_start, 0); + indices_end = shfl_sync(indices_end, 0); + + // Assert post-broadcast invariants on every lane so that no thread enters + // the index loop with out-of-range offsets (guards against missed + // corrections and silent OOB memory accesses). + CUDA_KERNEL_ASSERT( + indices_start >= 0 && + "indices_start must be non-negative after correction"); + CUDA_KERNEL_ASSERT( + indices_start <= indices_end && + "indices_start must not exceed indices_end after correction"); + CUDA_KERNEL_ASSERT( + indices_end <= num_indices && + "indices_end must not exceed num_indices after correction"); + // Best-effort backward monotonicity check: indices_start should not be + // less than the previous segment's start. This read of offsets[b_t - 1] + // may race with the previous warp's correction, but catches the common + // case of non-monotonic offsets that per-pair checks miss. + CUDA_KERNEL_ASSERT( + (b_t == 0 || indices_start >= offsets[b_t - 1]) && + "offsets are not monotonically non-decreasing after correction"); const auto L = indices_end - indices_start; for (index_t i = static_cast(threadIdx.x); i < L; @@ -133,33 +204,6 @@ __global__ __launch_bounds__(kMaxThreads) void bounds_check_indices_kernel_v1( } } } - - if (bounds_check_mode == BoundsCheckMode::FATAL) { - CUDA_KERNEL_ASSERT( - num_indices == offsets[total_B] && - "num_indices must match the last element in offsets"); - } else if (bounds_check_mode == BoundsCheckMode::WARNING) { - if (num_indices != offsets[total_B]) { - if (gpuAtomicIncrement(&warning[0]) == 0) { - printf( - "EmbeddingBoundsCheck (VBE %s): the last element in offsets is incorrect for " - "total batch size %s: %d, total table num T: %d, " - " last element in offsets: %lld, indices size: %lld. " - " Setting the last element in offsets to be indices size.\n", - vbe ? "true" : "false", - vbe ? "total_B" : "B", - vbe ? total_B : B, - T, - static_cast(offsets[total_B]), - static_cast(num_indices)); - } - offsets[total_B] = num_indices; - } - } else if (bounds_check_mode == BoundsCheckMode::IGNORE) { - if (num_indices != offsets[total_B]) { - offsets[total_B] = num_indices; - } - } } void _bounds_check_indices_cuda_v1( diff --git a/fbgemm_gpu/codegen/utils/embedding_bounds_check_v2.cu b/fbgemm_gpu/codegen/utils/embedding_bounds_check_v2.cu index ed53391860..654bdbf622 100644 --- a/fbgemm_gpu/codegen/utils/embedding_bounds_check_v2.cu +++ b/fbgemm_gpu/codegen/utils/embedding_bounds_check_v2.cu @@ -85,46 +85,78 @@ __global__ __launch_bounds__(kMaxThreads) void bounds_check_indices_kernel_v2( auto indices_start = offsets[b_t]; auto indices_end = offsets[b_t + 1]; - if (bounds_check_mode == BoundsCheckMode::FATAL) { - CUDA_KERNEL_ASSERT( - indices_start >= 0 && "indices_start must be non-negative"); - CUDA_KERNEL_ASSERT( - indices_start <= indices_end && - "indices_start must not exceed indices_end"); - CUDA_KERNEL_ASSERT( - indices_end <= num_indices && - "indices_end must not exceed num_indices"); - } else if (bounds_check_mode == BoundsCheckMode::WARNING) { - if (indices_start < 0 || indices_start > indices_end || - indices_end > num_indices) { - if (threadIdx.x == 0 && gpuAtomicIncrement(&warning[0]) == 0) { - printf( - "EmbeddingBoundsCheck (VBE %s): (at least one) Out of bounds access for " - "batch: %d, table: %d, indices_start: %lld, indices_end: %lld," - " num_indices: %lld. Setting indices_start and indices_end within " - "the range.\n", - vbe ? "true" : "false", - b, - t, - static_cast(indices_start), - static_cast(indices_end), - static_cast(num_indices)); + // Only lane 0 checks and corrects offsets. All 32 lanes in the warp share + // the same (b, t) pair, so having every lane write the same correction to + // offsets[b_t] / offsets[b_t+1] is redundant and constitutes a data race + // in the CUDA memory model (even though the values are identical). + if (threadIdx.x == 0) { + if (bounds_check_mode == BoundsCheckMode::FATAL) { + CUDA_KERNEL_ASSERT( + indices_start >= 0 && "indices_start must be non-negative"); + CUDA_KERNEL_ASSERT( + indices_start <= indices_end && + "indices_start must not exceed indices_end"); + CUDA_KERNEL_ASSERT( + indices_end <= num_indices && + "indices_end must not exceed num_indices"); + } else if (bounds_check_mode == BoundsCheckMode::WARNING) { + if (indices_start < 0 || indices_start > indices_end || + indices_end > num_indices) { + if (gpuAtomicIncrement(&warning[0]) == 0) { + printf( + "EmbeddingBoundsCheck (VBE %s): (at least one) Out of bounds access for " + "batch: %d, table: %d, indices_start: %lld, indices_end: %lld," + " num_indices: %lld. Setting indices_start and indices_end within " + "the range.\n", + vbe ? "true" : "false", + b, + t, + static_cast(indices_start), + static_cast(indices_end), + static_cast(num_indices)); + } + adjust_offset_kernel( + indices_start, + indices_end, + num_indices, + &offsets[b_t], + &offsets[b_t + 1]); + } + } else if (bounds_check_mode == BoundsCheckMode::IGNORE) { + if (indices_start < 0 || indices_start > indices_end || + indices_end > num_indices) { + adjust_offset_kernel( + indices_start, + indices_end, + num_indices, + &offsets[b_t], + &offsets[b_t + 1]); } - adjust_offset_kernel( - indices_start, - indices_end, - num_indices, - &offsets[b_t], - &offsets[b_t + 1]); } - } else if (bounds_check_mode == BoundsCheckMode::IGNORE) { - adjust_offset_kernel( - indices_start, - indices_end, - num_indices, - &offsets[b_t], - &offsets[b_t + 1]); } + // Broadcast corrected indices_start/indices_end from lane 0 to all lanes. + indices_start = shfl_sync(indices_start, 0); + indices_end = shfl_sync(indices_end, 0); + + // Assert post-broadcast invariants on every lane so that no thread enters + // the index loop with out-of-range offsets (guards against missed + // corrections and silent OOB memory accesses). + CUDA_KERNEL_ASSERT( + indices_start >= 0 && + "indices_start must be non-negative after correction"); + CUDA_KERNEL_ASSERT( + indices_start <= indices_end && + "indices_start must not exceed indices_end after correction"); + CUDA_KERNEL_ASSERT( + indices_end <= num_indices && + "indices_end must not exceed num_indices after correction"); + // Best-effort backward monotonicity check: indices_start should not be + // less than the previous segment's start. This read of offsets[b_t - 1] + // may race with the previous warp's correction, but catches the common + // case of non-monotonic offsets that per-pair checks miss. + CUDA_KERNEL_ASSERT( + (b_t == 0 || indices_start >= offsets[b_t - 1]) && + "offsets are not monotonically non-decreasing after correction"); const auto L = indices_end - indices_start; for (index_t i = static_cast(threadIdx.x); i < L;