diff --git a/fbgemm_gpu/src/quantize_ops/quantize_padded_fp8_rowwise.cu b/fbgemm_gpu/src/quantize_ops/quantize_padded_fp8_rowwise.cu index 6df66042aa..f3af1a2497 100644 --- a/fbgemm_gpu/src/quantize_ops/quantize_padded_fp8_rowwise.cu +++ b/fbgemm_gpu/src/quantize_ops/quantize_padded_fp8_rowwise.cu @@ -35,7 +35,8 @@ __global__ inline void _float_to_paddedFP8rowwise_cuda_kernel( const int output_columns = ncols_aligned + (ncols + row_dim - 1) / row_dim * 8; - const int64_t row = (int)blockIdx.x * blockDim.x + threadIdx.x; + const int64_t row = + static_cast(blockIdx.x) * blockDim.x + threadIdx.x; // for 1D case, unsqueezing needed if (nrows == 1) { const auto threads = (ncols + row_dim - 1) / row_dim; @@ -96,10 +97,11 @@ __global__ inline void _get_padding_value_kernel( const int row_dim, const std::uint8_t* const __restrict__ input, int* const __restrict__ offsets) { - const int64_t row = (int)blockIdx.x * blockDim.x + threadIdx.x; + const int64_t row = + static_cast(blockIdx.x) * blockDim.x + threadIdx.x; const int row_ext = row_dim + 8; const auto threads = (ncols + row_ext - 1) / row_ext; - if (row > threads) + if (row >= threads) return; const std::uint8_t* const input_row = input + row * row_ext; int pad = *reinterpret_cast(input_row + row_dim + 4); @@ -114,7 +116,8 @@ __global__ inline void _single_thread_sum_padding_kernel( int* __restrict__ total_pad) { // this is to count the sum of padding in the first row of 2D input // in one kernel launch to remove multiple H to D Syncs. - const auto tid = (int)blockIdx.x * blockDim.x + threadIdx.x; + const auto tid = + static_cast(blockIdx.x) * blockDim.x + threadIdx.x; if (tid != 0) { return; } @@ -125,7 +128,12 @@ __global__ inline void _single_thread_sum_padding_kernel( while (offset + 4 <= ncols) { pad = *reinterpret_cast(input + offset); if (pad < 0) { - offset += -pad * row_ext; + // Widen before negating so pad == INT_MIN doesn't overflow. + const int64_t step = -static_cast(pad) * row_ext; + if (step > ncols - offset) { + break; + } + offset += static_cast(step); } else { total_pad[0] += pad; offset += row_ext; @@ -153,7 +161,7 @@ __global__ inline void _PaddedFP8rowwise_to_float_1d_cuda_kernel( reinterpret_cast(input_row + row_dim); const auto scale = input_row_scale[0]; int pad = *reinterpret_cast(&input_row_scale[1]); - pad = (pad > 0) ? pad : 0; + pad = ::max(0, ::min(pad, row_dim)); const auto pad_offset = offsets[row]; output_t* output_row = output + row * row_dim - pad_offset; for (auto col = threadIdx.x; col < row_dim - pad; col += blockDim.x) { @@ -176,7 +184,8 @@ __global__ inline void _PaddedFP8rowwise_to_float_2d_cuda_kernel( const int ebit = forward ? 4 : 5; const int bias = forward ? 15 : 31; - const int64_t row = (int)blockIdx.x * blockDim.x + threadIdx.x; + const int64_t row = + static_cast(blockIdx.x) * blockDim.x + threadIdx.x; if (row >= nrows) { return; } @@ -189,7 +198,7 @@ __global__ inline void _PaddedFP8rowwise_to_float_2d_cuda_kernel( int pad = *reinterpret_cast(&input_row_scale[1]); // if pad is negative it's used to indidate indices of the next padded // bucket - pad = (pad > 0) ? pad : 0; + pad = ::max(0, ::min(pad, row_dim)); for (int bi = 0; bi < row_dim - pad; ++bi) { const auto output_ = hfp8_to_float(input_row[col + bi], ebit, bias) / input_row_scale[0]; @@ -208,6 +217,11 @@ Tensor _float_to_paddedFP8rowwise_gpu_t( const int64_t row_dim) { TENSOR_CONTIGUOUS_AND_ON_CUDA_GPU(input); CUDA_DEVICE_GUARD(input); + TORCH_CHECK( + row_dim > 0 && row_dim % 4 == 0, + "row_dim (", + row_dim, + ") must be a positive multiple of 4 to keep the appended scale/pad words 4-byte aligned"); const auto input_sizes = input.sizes(); const auto last_dim = input_sizes.size() - 1; @@ -264,13 +278,26 @@ Tensor _paddedFP8rowwise_to_float_gpu_t( TENSOR_ON_CUDA_GPU(input); TORCH_CHECK(input.is_contiguous(), "input must be contiguous"); CUDA_DEVICE_GUARD(input); + TORCH_CHECK( + row_dim > 0 && row_dim % 4 == 0, + "row_dim (", + row_dim, + ") must be a positive multiple of 4 to keep the appended scale/pad words 4-byte aligned"); const auto input_sizes = input.sizes(); const auto last_dim = input_sizes.size() - 1; const int nrows = c10::size_to_dim_(last_dim, input_sizes); const int ncols = input_sizes[last_dim]; const int row_ext = row_dim + 8; - int output_columns = ncols - (ncols + row_ext - 1) / row_ext * 8; + TORCH_CHECK( + ncols % row_ext == 0, + "ncols (", + ncols, + ") must be a multiple of row_ext (", + row_ext, + ")"); + const int num_buckets = ncols / row_ext; + int output_columns = ncols - num_buckets * 8; // Global memory instructions support reading or writing words of size equal // to 1, 2, 4, 8, or 16 bytes. Any access (via a variable or a pointer) to // data residing in global memory compiles to a single global memory @@ -281,12 +308,9 @@ Tensor _paddedFP8rowwise_to_float_gpu_t( constexpr int threads_per_block = 256; const auto num_blocks = cuda_calc_xblock_count( - (nrows == 1) ? (ncols + row_ext - 1) / row_ext + 1 : nrows, - threads_per_block); + (nrows == 1) ? num_buckets : nrows, threads_per_block); Tensor offsets = at::empty( - (nrows == 1) ? num_blocks * threads_per_block + 1 - : 0, // 4 = sizeof(float) - input.options().dtype(at::kInt)); + (nrows == 1) ? num_buckets : 0, input.options().dtype(at::kInt)); int total_pad = 0; if (nrows == 1) { FBGEMM_LAUNCH_KERNEL( @@ -316,7 +340,7 @@ Tensor _paddedFP8rowwise_to_float_gpu_t( total_pad_tensor.data_ptr()); total_pad = total_pad_tensor[0].item(); } else { - total_pad = offsets[((ncols + row_ext - 1) / row_ext)].item(); + total_pad = offsets[num_buckets].item(); } output_columns -= total_pad; } else { @@ -340,13 +364,6 @@ Tensor _paddedFP8rowwise_to_float_gpu_t( if (nrows == 1) { // Use one thread block to work on 1 row for nrows == 1 - TORCH_CHECK( - ncols % row_ext == 0, - "ncols (", - ncols, - ") must be multiple of ", - row_ext) - const int num_rows = ncols / row_ext; const int ebit = forward ? 4 : 5; const int bias = forward ? 15 : 31; constexpr int kMaxThreads = 1024; @@ -356,7 +373,7 @@ Tensor _paddedFP8rowwise_to_float_gpu_t( output.scalar_type(), "PaddedFP8rowwise_to_float_1d_cuda_kernel", [&] { FBGEMM_LAUNCH_KERNEL( (_PaddedFP8rowwise_to_float_1d_cuda_kernel), - num_rows, + num_buckets, threads_per_block, 0, at::cuda::getCurrentCUDAStream(), diff --git a/fbgemm_gpu/test/quantize/fp8_rowwise_test.py b/fbgemm_gpu/test/quantize/fp8_rowwise_test.py index ebf40e379c..d7dd90f2ce 100644 --- a/fbgemm_gpu/test/quantize/fp8_rowwise_test.py +++ b/fbgemm_gpu/test/quantize/fp8_rowwise_test.py @@ -227,6 +227,47 @@ def test_quantize_and_dequantize_op_padded_fp8_rowwise( torch.testing.assert_allclose(dqcat, qref, rtol=0.1, atol=0.05) + @unittest.skipIf(*gpu_unavailable) + def test_padded_fp8_rowwise_input_validation(self) -> None: + fp32 = SparseType.FP32.as_int() + x = torch.rand(2, 32, device="cuda") + # row_dim must be a positive multiple of 4. + for bad in (0, -4, 3, 6): + with self.assertRaises(RuntimeError): + torch.ops.fbgemm.FloatToPaddedFP8RowwiseQuantized( + x, forward=True, row_dim=bad + ) + with self.assertRaises(RuntimeError): + torch.ops.fbgemm.PaddedFP8RowwiseQuantizedToFloat( + torch.zeros(2, 24, device="cuda", dtype=torch.uint8), + forward=True, + row_dim=bad, + output_dtype=fp32, + ) + # Dequant ncols must be a multiple of row_dim + 8. + with self.assertRaises(RuntimeError): + torch.ops.fbgemm.PaddedFP8RowwiseQuantizedToFloat( + torch.zeros(2, 25, device="cuda", dtype=torch.uint8), + forward=True, + row_dim=16, + output_dtype=fp32, + ) + + @unittest.skipIf(*gpu_unavailable) + def test_padded_fp8_rowwise_1d_roundtrip(self) -> None: + # Exercises the nrows == 1 path where _get_padding_value_kernel used + # to read past the offsets buffer at the boundary thread. + fp32 = SparseType.FP32.as_int() + for row_dim, num_buckets in [(4, 1), (16, 7), (256, 33)]: + x = torch.rand(row_dim * num_buckets, device="cuda") + q = torch.ops.fbgemm.FloatToPaddedFP8RowwiseQuantized( + x, forward=True, row_dim=row_dim + ) + dq = torch.ops.fbgemm.PaddedFP8RowwiseQuantizedToFloat( + q, forward=True, row_dim=row_dim, output_dtype=fp32 + ) + torch.testing.assert_close(dq.cpu(), x.cpu(), rtol=0.1, atol=0.05) + if __name__ == "__main__": unittest.main()