diff --git a/fbgemm_gpu/src/quantize_ops/quantize_fused_nbit_rowwise.cu b/fbgemm_gpu/src/quantize_ops/quantize_fused_nbit_rowwise.cu index 8741fd79bd..9396c2effa 100644 --- a/fbgemm_gpu/src/quantize_ops/quantize_fused_nbit_rowwise.cu +++ b/fbgemm_gpu/src/quantize_ops/quantize_fused_nbit_rowwise.cu @@ -14,8 +14,46 @@ namespace fbgemm_gpu { namespace { +// Helper to convert float to scale_bias type (round-trip for quantization) +template +__device__ __forceinline__ float float_to_scale_bias(float val); + +template <> +__device__ __forceinline__ float float_to_scale_bias<__half>(float val) { + return __half2float(__float2half(val)); +} + +template <> +__device__ __forceinline__ float float_to_scale_bias<__nv_bfloat16>(float val) { + __nv_bfloat16 bf = __float2bfloat16(val); +#ifdef USE_ROCM + return float(bf); +#else + return __bfloat162float(bf); +#endif +} + +// Helper to convert scale_bias type to float +template +__device__ __forceinline__ float scale_bias_to_float(scale_bias_t val); + +template <> +__device__ __forceinline__ float scale_bias_to_float<__half>(__half val) { + return __half2float(val); +} + +template <> +__device__ __forceinline__ float scale_bias_to_float<__nv_bfloat16>( + __nv_bfloat16 val) { +#ifdef USE_ROCM + return float(val); +#else + return __bfloat162float(val); +#endif +} + // FP32/FP16 -> Fused 4/2-bit rowwise kernel -template +template __global__ inline void _float_to_fusednbitrowwise_cuda_kernel( const int bit_rate, const input_t* __restrict__ input, @@ -24,23 +62,24 @@ __global__ inline void _float_to_fusednbitrowwise_cuda_kernel( std::uint8_t* __restrict__ output) { const int num_elem_per_byte = 8 / bit_rate; const int output_columns = - (ncols + num_elem_per_byte - 1) / num_elem_per_byte + 2 * sizeof(__half); + (ncols + num_elem_per_byte - 1) / num_elem_per_byte + + 2 * sizeof(scale_bias_t); int row = (int)blockIdx.x * blockDim.x + threadIdx.x; const auto row_incre = blockDim.x * gridDim.x; for (/*row*/; row < nrows; row += row_incre) { const input_t* input_row = input + row * ncols; std::uint8_t* output_row = output + row * output_columns; - __half* output_row_scale_bias = reinterpret_cast<__half*>( + scale_bias_t* output_row_scale_bias = reinterpret_cast( output_row + (ncols + num_elem_per_byte - 1) / num_elem_per_byte); float minimum_element = fbgemm_gpu::min(input_row, input_row + ncols); float maximum_element = fbgemm_gpu::max(input_row, input_row + ncols); - minimum_element = __half2float(__float2half(minimum_element)); + minimum_element = float_to_scale_bias(minimum_element); const float range = maximum_element - minimum_element; - float scale = __half2float( - __float2half(range == 0 ? 1.0f : range / ((1 << bit_rate) - 1))); + float scale = float_to_scale_bias( + range == 0 ? 1.0f : range / ((1 << bit_rate) - 1)); if (scale == 0) { // Corner case handling when maximum_element == minimum_element // Any scale would work because X - minimum_element will be 0 for all X @@ -52,8 +91,13 @@ __global__ inline void _float_to_fusednbitrowwise_cuda_kernel( inverse_scale = 1.0f; } - output_row_scale_bias[0] = __float2half(scale); - output_row_scale_bias[1] = __float2half(minimum_element); + if constexpr (std::is_same_v) { + output_row_scale_bias[0] = __float2half(scale); + output_row_scale_bias[1] = __float2half(minimum_element); + } else { + output_row_scale_bias[0] = __float2bfloat16(scale); + output_row_scale_bias[1] = __float2bfloat16(minimum_element); + } for (std::size_t col = 0; col < ncols; ++col) { const float X = input_row[col]; @@ -74,7 +118,7 @@ __global__ inline void _float_to_fusednbitrowwise_cuda_kernel( } // Fused 4/2-bit rowwise -> FP32/FP16 kernel -template +template __global__ inline void _fusednbitrowwise_to_float_cuda_kernel( const int bit_rate, const std::uint8_t* input, @@ -82,22 +126,24 @@ __global__ inline void _fusednbitrowwise_to_float_cuda_kernel( const int ncols, output_t* const output) { const int num_elem_per_byte = 8 / bit_rate; - const int output_columns = (ncols - 2 * sizeof(__half)) * num_elem_per_byte; + const int output_columns = + (ncols - 2 * sizeof(scale_bias_t)) * num_elem_per_byte; int row = (int)blockIdx.y * blockDim.y + threadIdx.y; const int col = (int)blockIdx.x * blockDim.x + threadIdx.x; const auto row_incre = blockDim.y * gridDim.y; for (/*row*/; row < nrows; row += row_incre) { if (row < nrows && col < output_columns) { const std::uint8_t* input_row = input + row * ncols; - const __half* input_row_scale_bias = reinterpret_cast( - input_row + - (!scale_bias_last - ? 0 - : (output_columns + num_elem_per_byte - 1) / num_elem_per_byte)); - float scale = __half2float(input_row_scale_bias[0]); - float bias = __half2float(input_row_scale_bias[1]); + const scale_bias_t* input_row_scale_bias = + reinterpret_cast( + input_row + + (!scale_bias_last ? 0 + : (output_columns + num_elem_per_byte - 1) / + num_elem_per_byte)); + float scale = scale_bias_to_float(input_row_scale_bias[0]); + float bias = scale_bias_to_float(input_row_scale_bias[1]); if constexpr (!scale_bias_last) { - input_row += 2 * sizeof(__half); + input_row += 2 * sizeof(scale_bias_t); } output_t* output_row = output + row * output_columns; @@ -216,7 +262,71 @@ DLL_PUBLIC Tensor _single_or_half_precision_to_fusednbitrowwise_gpu( return output; } -template +template +Tensor _float_to_fusednbitrowwise_bf16sb_gpu_t( + const Tensor& input, + const int64_t bit_rate) { + TENSOR_ON_CUDA_GPU(input); + TENSOR_NDIM_EQUALS(input, 2); + CUDA_DEVICE_GUARD(input); + + const int nrows = input.size(0); + const int ncols = input.size(1); + const int num_elem_per_byte = 8 / bit_rate; + TORCH_CHECK( + ncols % (2 * num_elem_per_byte) == 0, + "ncols needs to be multiple of 2 Bytes (bfloat16 type size) to make the address aligned"); + const int output_columns = + (ncols + num_elem_per_byte - 1) / num_elem_per_byte + + 2 * sizeof(scale_bias_t); + + auto output = + at::empty({nrows, output_columns}, input.options().dtype(at::kByte)); + + if (nrows == 0 || ncols == 0) { + return output; + } + + constexpr auto threads_per_block = 256; + const auto num_blocks = cuda_calc_xblock_count(nrows, threads_per_block); + + FBGEMM_DISPATCH_FLOATING_TYPES( + input.scalar_type(), + "_float_to_fusednbitrowwise_bf16sb_cuda_kernel", + [&] { + FBGEMM_LAUNCH_KERNEL( + (_float_to_fusednbitrowwise_cuda_kernel), + num_blocks, + threads_per_block, + 0, + at::cuda::getCurrentCUDAStream(), + bit_rate, + input.data_ptr(), + nrows, + ncols, + output.data_ptr()); + }); + + return output; +} + +/// @ingroup quantize-ops-cuda +/// Converts a tensor of `float` or `at::Half` values into a tensor of fused +/// N-bit rowwise values with bf16 scale/bias. +/// +/// @param input A tensor of `float` or `at::Half` values +/// @param bit_rate +/// +/// @return A new tensor with values from the input tensor converted to +/// fused N-bit rowwise with bf16 scale/bias. +DLL_PUBLIC Tensor _single_or_half_precision_to_fusednbitrowwise_bf16sb_gpu( + const Tensor& input, + const int64_t bit_rate) { + return _float_to_fusednbitrowwise_bf16sb_gpu_t( + input, bit_rate); +} + +template Tensor _fusednbitrowwise_to_float_gpu_t( const Tensor& input, const int64_t bit_rate, @@ -228,7 +338,8 @@ Tensor _fusednbitrowwise_to_float_gpu_t( const int nrows = input.size(0); const int ncols = input.size(1); const int num_elem_per_byte = 8 / bit_rate; - const int output_columns = (ncols - 2 * sizeof(at::Half)) * num_elem_per_byte; + const int output_columns = + (ncols - 2 * sizeof(scale_bias_t)) * num_elem_per_byte; // Global memory instructions support reading or writing words of size equal // to 1, 2, 4, 8, or 16 bytes. Any access (via a variable or a pointer) to @@ -267,17 +378,20 @@ Tensor _fusednbitrowwise_to_float_gpu_t( const auto gridDim_y = cuda_calc_block_count(nrows, blockDim.y); const dim3 gridDim(gridDim_x, gridDim_y); -#define DEQUANT_LAUNCH_NBIT(scale_bias_last) \ - FBGEMM_LAUNCH_KERNEL( \ - (_fusednbitrowwise_to_float_cuda_kernel), \ - gridDim, \ - blockDim, \ - 0, \ - at::cuda::getCurrentCUDAStream(), \ - bit_rate, \ - input.data_ptr(), \ - nrows, \ - ncols, \ +#define DEQUANT_LAUNCH_NBIT(scale_bias_last) \ + FBGEMM_LAUNCH_KERNEL( \ + (_fusednbitrowwise_to_float_cuda_kernel< \ + scalar_t, \ + scale_bias_t, \ + scale_bias_last>), \ + gridDim, \ + blockDim, \ + 0, \ + at::cuda::getCurrentCUDAStream(), \ + bit_rate, \ + input.data_ptr(), \ + nrows, \ + ncols, \ output.mutable_data_ptr()) FBGEMM_DISPATCH_FLOATING_TYPES( @@ -365,6 +479,48 @@ DLL_PUBLIC at::Tensor _fusednbitrowwise_to_single_or_half_precision_gpu( return output; } +/// @ingroup quantize-ops-cuda +/// Converts a tensor of fused N-bit rowwise values with bf16 scale/bias into +/// a tensor of `float` or `at::Half` or `at::Bf16` values. +/// +/// @param input A tensor of fused N-bit rowwise values with bf16 scale/bias +/// @param bit_rate +/// @param output_dtype The target floating point type, specified as integer +/// representation of `SparseType` enum +/// +/// @return A new tensor with values from the input tensor converted to `float` +/// or `at::Half` or `at::Bf16`, depending on `output_dtype`. +/// +/// @throw c10::Error if `output_dtype` is not one of (`SparseType::FP32` or +/// `SparseType::FP16` or `SparseType::BF16`). +DLL_PUBLIC at::Tensor _fusednbitrowwise_bf16sb_to_single_or_half_precision_gpu( + const at::Tensor& input, + const int64_t bit_rate, + const int64_t output_dtype, + const bool scale_bias_last) { + Tensor output; + + SparseType output_sparse_dtype = static_cast(output_dtype); + switch (output_sparse_dtype) { + case SparseType::FP32: + output = _fusednbitrowwise_to_float_gpu_t( + input, bit_rate, scale_bias_last); + break; + case SparseType::FP16: + output = _fusednbitrowwise_to_float_gpu_t( + input, bit_rate, scale_bias_last); + break; + case SparseType::BF16: + output = _fusednbitrowwise_to_float_gpu_t( + input, bit_rate, scale_bias_last); + break; + default: + TORCH_CHECK(false); + } + + return output; +} + } // namespace fbgemm_gpu FBGEMM_OP_DISPATCH( @@ -379,6 +535,10 @@ FBGEMM_OP_DISPATCH( CUDA, "FloatOrHalfToFusedNBitRowwiseQuantizedSBHalf", fbgemm_gpu::_single_or_half_precision_to_fusednbitrowwise_gpu); +FBGEMM_OP_DISPATCH( + CUDA, + "FloatOrHalfToFusedNBitRowwiseQuantizedSBBFloat16", + fbgemm_gpu::_single_or_half_precision_to_fusednbitrowwise_bf16sb_gpu); FBGEMM_OP_DISPATCH( CUDA, "FusedNBitRowwiseQuantizedSBHalfToFloat", @@ -391,3 +551,7 @@ FBGEMM_OP_DISPATCH( CUDA, "FusedNBitRowwiseQuantizedSBHalfToFloatOrHalf", fbgemm_gpu::_fusednbitrowwise_to_single_or_half_precision_gpu); +FBGEMM_OP_DISPATCH( + CUDA, + "FusedNBitRowwiseQuantizedSBBFloat16ToFloatOrHalf", + fbgemm_gpu::_fusednbitrowwise_bf16sb_to_single_or_half_precision_gpu); diff --git a/fbgemm_gpu/src/quantize_ops/quantize_ops_cpu.cpp b/fbgemm_gpu/src/quantize_ops/quantize_ops_cpu.cpp index 8c75fa1753..be47106d43 100644 --- a/fbgemm_gpu/src/quantize_ops/quantize_ops_cpu.cpp +++ b/fbgemm_gpu/src/quantize_ops/quantize_ops_cpu.cpp @@ -227,6 +227,66 @@ Tensor _fusednbitrowwise_sbfront_to_float_or_half_cpu( return output; } +template +Tensor _fusednbitrowwise_bf16sb_to_float_or_half_cpu( + const Tensor& input, + const int64_t bit_rate, + const bool scale_bias_last) { + TENSOR_ON_CPU(input); + TENSOR_NDIM_EQUALS(input, 2); + + const auto input_sizes = input.sizes(); + const int64_t nrows = input_sizes[0]; + const int32_t ncols = nbit_elems_to_bytes(input); + const int32_t num_elem_per_byte = 8 / bit_rate; + const int32_t output_columns = + (ncols - 2 * sizeof(at::BFloat16)) * num_elem_per_byte; + + Tensor output; + if constexpr (std::is_same_v) { + output = + at::empty({nrows, output_columns}, input.options().dtype(at::kFloat)); + } else if constexpr (std::is_same_v) { + output = + at::empty({nrows, output_columns}, input.options().dtype(at::kHalf)); + } else if constexpr (std::is_same_v) { + output = at::empty( + {nrows, output_columns}, input.options().dtype(at::kBFloat16)); + } else { + TORCH_CHECK( + false, + "Unsupported output dtype for _fusednbitrowwise_bf16sb_to_float_or_half_cpu"); + } + + const uint8_t* input_data = input.const_data_ptr(); + output_t* output_data = output.mutable_data_ptr(); + + for (int64_t row = 0; row < nrows; ++row) { + const uint8_t* input_row = input_data + row * ncols; + const at::BFloat16* input_row_scale_bias = + reinterpret_cast( + input_row + + (scale_bias_last + ? (output_columns + num_elem_per_byte - 1) / num_elem_per_byte + : 0)); + const float scale = static_cast(input_row_scale_bias[0]); + const float bias = static_cast(input_row_scale_bias[1]); + const uint8_t* nums = + scale_bias_last ? input_row : input_row + 2 * sizeof(at::BFloat16); + output_t* output_row = output_data + row * output_columns; + + for (int32_t col = 0; col < output_columns; ++col) { + uint8_t quantized = nums[col / num_elem_per_byte]; + quantized >>= (col % num_elem_per_byte) * bit_rate; + quantized &= (1 << bit_rate) - 1; + const float output_value = scale * quantized + bias; + output_row[col] = static_cast(output_value); + } + } + + return output; +} + /// @ingroup quantize-data-cpu /// Tensor& _fused8bitrowwise_to_float_cpu_out( @@ -451,6 +511,32 @@ Tensor fusednbitrowwise_to_float_or_half_cpu( return output; } +Tensor fusednbitrowwise_bf16sb_to_float_or_half_cpu( + const Tensor& input, + const int64_t bit_rate, + const int64_t output_dtype, + const bool scale_bias_last) { + Tensor output; + SparseType output_sparse_dtype = static_cast(output_dtype); + switch (output_sparse_dtype) { + case SparseType::FP32: + output = _fusednbitrowwise_bf16sb_to_float_or_half_cpu( + input, bit_rate, scale_bias_last); + break; + case SparseType::FP16: + output = _fusednbitrowwise_bf16sb_to_float_or_half_cpu( + input, bit_rate, scale_bias_last); + break; + case SparseType::BF16: + output = _fusednbitrowwise_bf16sb_to_float_or_half_cpu( + input, bit_rate, scale_bias_last); + break; + default: + TORCH_CHECK(false); + } + return output; +} + Tensor float_to_fusednbitrowwise_cpu( const Tensor& input, const int64_t bit_rate) { @@ -643,6 +729,8 @@ TORCH_LIBRARY_FRAGMENT(fbgemm, m) { "FloatOrHalfToFusedNBitRowwiseQuantizedSBHalf(Tensor input, int bit_rate) -> Tensor"); m.def( "FloatOrHalfToFusedNBitRowwiseQuantizedSBHalfWithRowwiseMinMax(Tensor input, int bit_rate, Tensor rowwise_min_max) -> Tensor"); + m.def( + "FloatOrHalfToFusedNBitRowwiseQuantizedSBBFloat16(Tensor input, int bit_rate) -> Tensor"); m.def( "FusedNBitRowwiseQuantizedSBHalfToFloat(Tensor input, int bit_rate) -> Tensor"); m.def( @@ -651,6 +739,8 @@ TORCH_LIBRARY_FRAGMENT(fbgemm, m) { "FusedNBitRowwiseQuantizedSBHalfToHalf(Tensor input, int bit_rate) -> Tensor"); m.def( "FusedNBitRowwiseQuantizedSBHalfToFloatOrHalf(Tensor input, int bit_rate, int output_dtype=0, bool scale_bias_last=True) -> Tensor"); + m.def( + "FusedNBitRowwiseQuantizedSBBFloat16ToFloatOrHalf(Tensor input, int bit_rate, int output_dtype=0, bool scale_bias_last=True) -> Tensor"); m.def( "FloatToHFP8Quantized(Tensor input, int ebits, int exponent_bias, float max_pos) -> Tensor"); m.def( @@ -726,6 +816,9 @@ TORCH_LIBRARY_IMPL(fbgemm, CPU, m) { DISPATCH_TO_CPU( "FusedNBitRowwiseQuantizedSBHalfToFloatOrHalf", fbgemm_gpu::fusednbitrowwise_to_float_or_half_cpu); + DISPATCH_TO_CPU( + "FusedNBitRowwiseQuantizedSBBFloat16ToFloatOrHalf", + fbgemm_gpu::fusednbitrowwise_bf16sb_to_float_or_half_cpu); DISPATCH_TO_CPU("FloatToHFP8Quantized", fbgemm_gpu::_float_to_hfp8_cpu); DISPATCH_TO_CPU("HFP8QuantizedToFloat", fbgemm_gpu::_hfp8_to_float_cpu); } diff --git a/fbgemm_gpu/src/quantize_ops/quantize_ops_meta.cpp b/fbgemm_gpu/src/quantize_ops/quantize_ops_meta.cpp index 46e1a75530..4a07c50c2e 100644 --- a/fbgemm_gpu/src/quantize_ops/quantize_ops_meta.cpp +++ b/fbgemm_gpu/src/quantize_ops/quantize_ops_meta.cpp @@ -95,6 +95,36 @@ Tensor fusednbitrowwise_to_float_or_half_meta( } } +/// @ingroup quantize-data-meta +/// +Tensor fusednbitrowwise_bf16sb_to_float_or_half_meta( + const Tensor& input, + const int64_t bit_rate, + const int64_t output_dtype, + const bool scale_bias_last [[maybe_unused]]) { + const at::SymIntArrayRef input_sizes = input.sym_sizes(); + const at::SymInt& nrows = input_sizes[0]; + const at::SymInt ncols = nbit_elems_to_bytes_meta(input); + const at::SymInt num_elem_per_byte = 8 / bit_rate; + const at::SymInt output_columns = + (ncols - 2 * sizeof(at::BFloat16)) * num_elem_per_byte; + + SparseType output_sparse_dtype = static_cast(output_dtype); + switch (output_sparse_dtype) { + case SparseType::FP32: + return at::empty_symint( + {nrows, output_columns}, input.options().dtype(at::kFloat)); + case SparseType::FP16: + return at::empty_symint( + {nrows, output_columns}, input.options().dtype(at::kHalf)); + case SparseType::BF16: + return at::empty_symint( + {nrows, output_columns}, input.options().dtype(at::kBFloat16)); + default: + TORCH_CHECK(false, "Unsupported output dtype "); + } +} + } // namespace fbgemm_gpu TORCH_LIBRARY_IMPL(fbgemm, Meta, m) { @@ -107,4 +137,7 @@ TORCH_LIBRARY_IMPL(fbgemm, Meta, m) { m.impl( "FusedNBitRowwiseQuantizedSBHalfToFloatOrHalf", TORCH_FN(fbgemm_gpu::fusednbitrowwise_to_float_or_half_meta)); + m.impl( + "FusedNBitRowwiseQuantizedSBBFloat16ToFloatOrHalf", + TORCH_FN(fbgemm_gpu::fusednbitrowwise_bf16sb_to_float_or_half_meta)); } diff --git a/fbgemm_gpu/test/quantize/fused_nbit_rowwise_test.py b/fbgemm_gpu/test/quantize/fused_nbit_rowwise_test.py index 19158f539f..0a69097a0e 100644 --- a/fbgemm_gpu/test/quantize/fused_nbit_rowwise_test.py +++ b/fbgemm_gpu/test/quantize/fused_nbit_rowwise_test.py @@ -6,7 +6,9 @@ # pyre-strict +import os import unittest +from typing import Callable import hypothesis.strategies as st import numpy as np @@ -413,5 +415,362 @@ def test_quantize_and_dequantize_op_cpu_and_cuda( torch.testing.assert_close(dequantized_data.cpu(), reference) +def _bf16_round_trip(t: torch.Tensor) -> torch.Tensor: + return t.to(torch.bfloat16).to(torch.float32) + + +def _quantize_dequantize_bf16_sb_reference( + data: torch.Tensor, bit_rate: int +) -> torch.Tensor: + """Roundtrip reference matching the bf16 scale/bias fbgemm kernel. + + Mirrors the fp16 SB reference but rounds scale/bias through bfloat16 + instead of float16, so the expected value matches what the kernel produces. + """ + data_f32 = data.float() + qmax = (1 << bit_rate) - 1 + minimum = _bf16_round_trip(data_f32.min(dim=1, keepdim=True).values) + maximum = data_f32.max(dim=1, keepdim=True).values + span = maximum - minimum + scale = _bf16_round_trip(torch.where(span == 0, torch.ones_like(span), span / qmax)) + inverse_scale = 1.0 / scale + quantized = torch.clamp( + torch.round((data_f32 - minimum) * inverse_scale), 0.0, float(qmax) + ) + return scale * quantized + minimum + + +class TestFusedNBitRowwiseQuantizationConversionBF16SB(unittest.TestCase): + """Tests for FusedNBitRowwiseQuantizedSBBFloat16 ops (bf16 scale/bias).""" + + @unittest.skipUnless(gpu_available, "CUDA required for bf16 SB quantize op") + # pyre-ignore [56] + @given( + nrows=st.integers(min_value=1, max_value=64), + ncols=st.integers(min_value=8, max_value=128), + bit_rate=st.sampled_from([2, 4]), + output_dtype=st.sampled_from( + [SparseType.FP32, SparseType.FP16, SparseType.BF16] + ), + ) + @settings(deadline=10000, suppress_health_check=[HealthCheck.filter_too_much]) + def test_bf16_sb_quantize_dequantize_roundtrip_cuda( + self, + nrows: int, + ncols: int, + bit_rate: int, + output_dtype: SparseType, + ) -> None: + num_elem_per_byte = 8 // bit_rate + assume(ncols % (2 * num_elem_per_byte) == 0) + + torch.manual_seed(0) + input_data = torch.rand(nrows, ncols).float().cuda() + + quantized_gpu = ( + torch.ops.fbgemm.FloatOrHalfToFusedNBitRowwiseQuantizedSBBFloat16( + input_data, bit_rate + ) + ) + dequantized_gpu = ( + torch.ops.fbgemm.FusedNBitRowwiseQuantizedSBBFloat16ToFloatOrHalf( + quantized_gpu, bit_rate, output_dtype.as_int() + ) + ) + + reference = _quantize_dequantize_bf16_sb_reference(input_data.cpu(), bit_rate) + if output_dtype == SparseType.FP16: + reference = reference.half() + elif output_dtype == SparseType.BF16: + reference = reference.bfloat16() + torch.testing.assert_close(dequantized_gpu.cpu(), reference) + + @unittest.skipUnless( + gpu_available, "CUDA needed to produce quantized bytes for CPU dequant test" + ) + # pyre-ignore [56] + @given( + nrows=st.integers(min_value=1, max_value=64), + ncols=st.integers(min_value=8, max_value=128), + bit_rate=st.sampled_from([2, 4]), + output_dtype=st.sampled_from( + [SparseType.FP32, SparseType.FP16, SparseType.BF16] + ), + scale_bias_last=st.booleans(), + ) + @settings(deadline=10000, suppress_health_check=[HealthCheck.filter_too_much]) + def test_bf16_sb_dequantize_cpu_matches_cuda( + self, + nrows: int, + ncols: int, + bit_rate: int, + output_dtype: SparseType, + scale_bias_last: bool, + ) -> None: + """CPU dequant produces the same result as CUDA dequant for bf16 SB bytes.""" + num_elem_per_byte = 8 // bit_rate + assume(ncols % (2 * num_elem_per_byte) == 0) + + torch.manual_seed(0) + input_data = torch.rand(nrows, ncols).float().cuda() + # Quantize on CUDA (only backend that has bf16 SB quantize); dequant on CPU. + # Quantize op always emits scale_bias_last layout, so for the + # scale_bias_last=False case we synthesize a front-layout tensor. + quantized_last = ( + torch.ops.fbgemm.FloatOrHalfToFusedNBitRowwiseQuantizedSBBFloat16( + input_data, bit_rate + ) + ).cpu() + + if scale_bias_last: + quantized = quantized_last + else: + packed_dim = quantized_last.size(1) - 2 * 2 # 2 * sizeof(bf16) + quantized = torch.cat( + [quantized_last[:, packed_dim:], quantized_last[:, :packed_dim]], + dim=1, + ).contiguous() + + dequantized_cpu = ( + torch.ops.fbgemm.FusedNBitRowwiseQuantizedSBBFloat16ToFloatOrHalf( + quantized, bit_rate, output_dtype.as_int(), scale_bias_last + ) + ) + dequantized_cuda = ( + torch.ops.fbgemm.FusedNBitRowwiseQuantizedSBBFloat16ToFloatOrHalf( + quantized.cuda(), bit_rate, output_dtype.as_int(), scale_bias_last + ) + ).cpu() + torch.testing.assert_close(dequantized_cpu, dequantized_cuda) + + def test_bf16_sb_dequantize_meta_dispatch(self) -> None: + """Meta dispatch returns correct shape/dtype without materializing data.""" + bit_rate = 4 + num_elem_per_byte = 8 // bit_rate + nrows, ncols = 8, 64 + # Quantized layout: ncols/num_elem_per_byte packed bytes + 2 * sizeof(bf16). + packed_cols = ncols // num_elem_per_byte + 2 * 2 + meta_input = torch.empty((nrows, packed_cols), dtype=torch.uint8, device="meta") + + for output_dtype, expected_dtype in [ + (SparseType.FP32, torch.float32), + (SparseType.FP16, torch.float16), + (SparseType.BF16, torch.bfloat16), + ]: + out = torch.ops.fbgemm.FusedNBitRowwiseQuantizedSBBFloat16ToFloatOrHalf( + meta_input, bit_rate, output_dtype.as_int() + ) + self.assertEqual(out.device, torch.device("meta")) + self.assertEqual(out.shape, (nrows, ncols)) + self.assertEqual(out.dtype, expected_dtype) + + @unittest.skipUnless( + gpu_available, "CUDA required for fp16 vs bf16 loss comparison" + ) + def test_fp16_vs_bf16_sb_quantization_loss_comparable(self) -> None: + """fp16 and bf16 scale/bias produce comparable quantization loss across + a range of input distributions. The 4-bit quantization grid dominates + the error; SB precision should change loss by at most ~2x.""" + torch.manual_seed(0) + bit_rate = 4 + nrows, ncols = 256, 64 + # 4-bit grid has 15 levels; absolute mean error is bounded by span / 30 + # for uniform inputs. We use a looser ceiling of span / 15 as a safety + # check and require bf16 to stay within 2x of fp16. + ranges = [ + ("unit_uniform", -1.0, 1.0), + ("wide_dynamic", -100.0, 100.0), + ("small_positive", 1e-3, 1.0), + ("tiny_symmetric", -1e-4, 1e-4), + ] + for name, lo, hi in ranges: + input_data = torch.empty(nrows, ncols).uniform_(lo, hi).float().cuda() + span = hi - lo + + q_fp16 = torch.ops.fbgemm.FloatOrHalfToFusedNBitRowwiseQuantizedSBHalf( + input_data, bit_rate + ) + d_fp16 = torch.ops.fbgemm.FusedNBitRowwiseQuantizedSBHalfToFloatOrHalf( + q_fp16, bit_rate, SparseType.FP32.as_int() + ) + loss_fp16 = (input_data - d_fp16).abs().mean().item() + + q_bf16 = torch.ops.fbgemm.FloatOrHalfToFusedNBitRowwiseQuantizedSBBFloat16( + input_data, bit_rate + ) + d_bf16 = torch.ops.fbgemm.FusedNBitRowwiseQuantizedSBBFloat16ToFloatOrHalf( + q_bf16, bit_rate, SparseType.FP32.as_int() + ) + loss_bf16 = (input_data - d_bf16).abs().mean().item() + + ceiling = span / 15.0 + self.assertLess( + loss_fp16, + ceiling, + f"[{name}] fp16 loss {loss_fp16} exceeds ceiling {ceiling}", + ) + self.assertLess( + loss_bf16, + ceiling, + f"[{name}] bf16 loss {loss_bf16} exceeds ceiling {ceiling}", + ) + self.assertLess( + loss_bf16, + 2.0 * loss_fp16 + 1e-7, + f"[{name}] bf16 loss {loss_bf16} > 2x fp16 loss {loss_fp16}", + ) + + @unittest.skipUnless(gpu_available, "CUDA required for kernel perf benchmark") + def test_fp16_vs_bf16_sb_dequantize_perf_no_regression(self) -> None: + """bf16 SB dequant kernel must not regress vs fp16 SB by more than 50%. + + Times the GPU dequant kernel for a representative cache shape across + several iterations after warmup, using CUDA events for accurate timing. + """ + bit_rate = 4 + nrows, ncols = 65536, 128 + warmup_iters = 5 + bench_iters = 50 + + torch.manual_seed(0) + input_data = torch.rand(nrows, ncols).float().cuda() + + q_fp16 = torch.ops.fbgemm.FloatOrHalfToFusedNBitRowwiseQuantizedSBHalf( + input_data, bit_rate + ) + q_bf16 = torch.ops.fbgemm.FloatOrHalfToFusedNBitRowwiseQuantizedSBBFloat16( + input_data, bit_rate + ) + + def bench(op: Callable[[], torch.Tensor]) -> float: + for _ in range(warmup_iters): + op() + torch.cuda.synchronize() + start = torch.cuda.Event(enable_timing=True) + end = torch.cuda.Event(enable_timing=True) + start.record() + for _ in range(bench_iters): + op() + end.record() + torch.cuda.synchronize() + return start.elapsed_time(end) / bench_iters + + out_dtype_int = SparseType.FP32.as_int() + fp16_us = ( + bench( + lambda: torch.ops.fbgemm.FusedNBitRowwiseQuantizedSBHalfToFloatOrHalf( + q_fp16, bit_rate, out_dtype_int + ) + ) + * 1000.0 + ) + bf16_us = ( + bench( + lambda: torch.ops.fbgemm.FusedNBitRowwiseQuantizedSBBFloat16ToFloatOrHalf( + q_bf16, bit_rate, out_dtype_int + ) + ) + * 1000.0 + ) + print( + f"[dequant perf] shape=({nrows}x{ncols}) bit_rate={bit_rate} " + f"fp16_sb={fp16_us:.2f}us bf16_sb={bf16_us:.2f}us " + f"ratio={bf16_us / fp16_us:.3f}x" + ) + self.assertLess( + bf16_us, + 1.5 * fp16_us, + f"bf16 SB dequant ({bf16_us:.2f}us) regresses >50% vs fp16 SB " + f"({fp16_us:.2f}us)", + ) + + @unittest.skipUnless( + gpu_available and os.environ.get("FBGEMM_BENCH_DEQUANT_SWEEP", "0") == "1", + "Set FBGEMM_BENCH_DEQUANT_SWEEP=1 with CUDA to run perf sweep", + ) + def test_fp16_vs_bf16_sb_dequantize_perf_sweep(self) -> None: + """Wide sweep of dequant kernel perf across cache-sized shapes, + bit rates and output dtypes. Gated behind FBGEMM_BENCH_DEQUANT_SWEEP=1 + to keep CI cheap. Prints a markdown table; asserts no >50% regression.""" + warmup_iters = 10 + bench_iters = 100 + out_dtype_int = SparseType.FP32.as_int() + + shapes = [ + (65536, 64), + (65536, 128), + (65536, 256), + (262144, 64), + (262144, 128), + (262144, 256), + (1048576, 64), + (1048576, 128), + (1048576, 256), + (4194304, 64), + (4194304, 128), + ] + bit_rates = [2, 4] + + def bench(op: Callable[[], torch.Tensor]) -> float: + for _ in range(warmup_iters): + op() + torch.cuda.synchronize() + start = torch.cuda.Event(enable_timing=True) + end = torch.cuda.Event(enable_timing=True) + start.record() + for _ in range(bench_iters): + op() + end.record() + torch.cuda.synchronize() + return start.elapsed_time(end) / bench_iters + + torch.manual_seed(0) + print("\n| shape | bit | bytes_in | fp16_sb_us | bf16_sb_us | ratio |") + print("|---------------|-----|----------|-----------:|-----------:|------:|") + max_ratio = 0.0 + for nrows, ncols in shapes: + input_data = torch.rand(nrows, ncols).float().cuda() + for bit_rate in bit_rates: + q_fp16 = torch.ops.fbgemm.FloatOrHalfToFusedNBitRowwiseQuantizedSBHalf( + input_data, bit_rate + ) + q_bf16 = ( + torch.ops.fbgemm.FloatOrHalfToFusedNBitRowwiseQuantizedSBBFloat16( + input_data, bit_rate + ) + ) + bytes_in = q_fp16.numel() + fp16_us = ( + bench( + lambda q=q_fp16, b=bit_rate: torch.ops.fbgemm.FusedNBitRowwiseQuantizedSBHalfToFloatOrHalf( + q, b, out_dtype_int + ) + ) + * 1000.0 + ) + bf16_us = ( + bench( + lambda q=q_bf16, b=bit_rate: torch.ops.fbgemm.FusedNBitRowwiseQuantizedSBBFloat16ToFloatOrHalf( + q, b, out_dtype_int + ) + ) + * 1000.0 + ) + ratio = bf16_us / fp16_us + max_ratio = max(max_ratio, ratio) + print( + f"| {nrows:>7}x{ncols:<3} | {bit_rate:>3} | " + f"{bytes_in:>8} | {fp16_us:>10.2f} | {bf16_us:>10.2f} | " + f"{ratio:>5.3f}x |" + ) + print(f"\nWorst ratio (bf16/fp16): {max_ratio:.3f}x") + self.assertLess( + max_ratio, + 1.5, + f"bf16 SB dequant kernel regresses >50% vs fp16 SB " + f"(worst ratio {max_ratio:.3f}x)", + ) + + if __name__ == "__main__": unittest.main()