@@ -14,8 +14,46 @@ namespace fbgemm_gpu {
1414
1515namespace {
1616
17+ // Helper to convert float to scale_bias type (round-trip for quantization)
18+ template <typename scale_bias_t >
19+ __device__ __forceinline__ float float_to_scale_bias (float val);
20+
21+ template <>
22+ __device__ __forceinline__ float float_to_scale_bias<__half>(float val) {
23+ return __half2float (__float2half (val));
24+ }
25+
26+ template <>
27+ __device__ __forceinline__ float float_to_scale_bias<__nv_bfloat16>(float val) {
28+ __nv_bfloat16 bf = __float2bfloat16 (val);
29+ #ifdef USE_ROCM
30+ return float (bf);
31+ #else
32+ return __bfloat162float (bf);
33+ #endif
34+ }
35+
36+ // Helper to convert scale_bias type to float
37+ template <typename scale_bias_t >
38+ __device__ __forceinline__ float scale_bias_to_float (scale_bias_t val);
39+
40+ template <>
41+ __device__ __forceinline__ float scale_bias_to_float<__half>(__half val) {
42+ return __half2float (val);
43+ }
44+
45+ template <>
46+ __device__ __forceinline__ float scale_bias_to_float<__nv_bfloat16>(
47+ __nv_bfloat16 val) {
48+ #ifdef USE_ROCM
49+ return float (val);
50+ #else
51+ return __bfloat162float (val);
52+ #endif
53+ }
54+
1755// FP32/FP16 -> Fused 4/2-bit rowwise kernel
18- template <typename input_t >
56+ template <typename input_t , typename scale_bias_t = __half >
1957__global__ inline void _float_to_fusednbitrowwise_cuda_kernel (
2058 const int bit_rate,
2159 const input_t * __restrict__ input,
@@ -24,23 +62,24 @@ __global__ inline void _float_to_fusednbitrowwise_cuda_kernel(
2462 std::uint8_t * __restrict__ output) {
2563 const int num_elem_per_byte = 8 / bit_rate;
2664 const int output_columns =
27- (ncols + num_elem_per_byte - 1 ) / num_elem_per_byte + 2 * sizeof (__half);
65+ (ncols + num_elem_per_byte - 1 ) / num_elem_per_byte +
66+ 2 * sizeof (scale_bias_t );
2867
2968 int row = (int )blockIdx .x * blockDim .x + threadIdx .x ;
3069 const auto row_incre = blockDim .x * gridDim .x ;
3170 for (/* row*/ ; row < nrows; row += row_incre) {
3271 const input_t * input_row = input + row * ncols;
3372 std::uint8_t * output_row = output + row * output_columns;
34- __half * output_row_scale_bias = reinterpret_cast <__half *>(
73+ scale_bias_t * output_row_scale_bias = reinterpret_cast <scale_bias_t *>(
3574 output_row + (ncols + num_elem_per_byte - 1 ) / num_elem_per_byte);
3675
3776 float minimum_element = fbgemm_gpu::min (input_row, input_row + ncols);
3877 float maximum_element = fbgemm_gpu::max (input_row, input_row + ncols);
39- minimum_element = __half2float ( __float2half ( minimum_element) );
78+ minimum_element = float_to_scale_bias< scale_bias_t >( minimum_element);
4079 const float range = maximum_element - minimum_element;
4180
42- float scale = __half2float (
43- __float2half ( range == 0 ? 1 .0f : range / ((1 << bit_rate) - 1 ) ));
81+ float scale = float_to_scale_bias< scale_bias_t > (
82+ range == 0 ? 1 .0f : range / ((1 << bit_rate) - 1 ));
4483 if (scale == 0 ) {
4584 // Corner case handling when maximum_element == minimum_element
4685 // Any scale would work because X - minimum_element will be 0 for all X
@@ -52,8 +91,13 @@ __global__ inline void _float_to_fusednbitrowwise_cuda_kernel(
5291 inverse_scale = 1 .0f ;
5392 }
5493
55- output_row_scale_bias[0 ] = __float2half (scale);
56- output_row_scale_bias[1 ] = __float2half (minimum_element);
94+ if constexpr (std::is_same_v<scale_bias_t , __half>) {
95+ output_row_scale_bias[0 ] = __float2half (scale);
96+ output_row_scale_bias[1 ] = __float2half (minimum_element);
97+ } else {
98+ output_row_scale_bias[0 ] = __float2bfloat16 (scale);
99+ output_row_scale_bias[1 ] = __float2bfloat16 (minimum_element);
100+ }
57101 for (std::size_t col = 0 ; col < ncols; ++col) {
58102 const float X = input_row[col];
59103
@@ -74,30 +118,32 @@ __global__ inline void _float_to_fusednbitrowwise_cuda_kernel(
74118}
75119
76120// Fused 4/2-bit rowwise -> FP32/FP16 kernel
77- template <typename output_t , bool scale_bias_last>
121+ template <typename output_t , typename scale_bias_t , bool scale_bias_last>
78122__global__ inline void _fusednbitrowwise_to_float_cuda_kernel (
79123 const int bit_rate,
80124 const std::uint8_t * input,
81125 const int nrows,
82126 const int ncols,
83127 output_t * const output) {
84128 const int num_elem_per_byte = 8 / bit_rate;
85- const int output_columns = (ncols - 2 * sizeof (__half)) * num_elem_per_byte;
129+ const int output_columns =
130+ (ncols - 2 * sizeof (scale_bias_t )) * num_elem_per_byte;
86131 int row = (int )blockIdx .y * blockDim .y + threadIdx .y ;
87132 const int col = (int )blockIdx .x * blockDim .x + threadIdx .x ;
88133 const auto row_incre = blockDim .y * gridDim .y ;
89134 for (/* row*/ ; row < nrows; row += row_incre) {
90135 if (row < nrows && col < output_columns) {
91136 const std::uint8_t * input_row = input + row * ncols;
92- const __half* input_row_scale_bias = reinterpret_cast <const __half*>(
93- input_row +
94- (!scale_bias_last
95- ? 0
96- : (output_columns + num_elem_per_byte - 1 ) / num_elem_per_byte));
97- float scale = __half2float (input_row_scale_bias[0 ]);
98- float bias = __half2float (input_row_scale_bias[1 ]);
137+ const scale_bias_t * input_row_scale_bias =
138+ reinterpret_cast <const scale_bias_t *>(
139+ input_row +
140+ (!scale_bias_last ? 0
141+ : (output_columns + num_elem_per_byte - 1 ) /
142+ num_elem_per_byte));
143+ float scale = scale_bias_to_float (input_row_scale_bias[0 ]);
144+ float bias = scale_bias_to_float (input_row_scale_bias[1 ]);
99145 if constexpr (!scale_bias_last) {
100- input_row += 2 * sizeof (__half );
146+ input_row += 2 * sizeof (scale_bias_t );
101147 }
102148 output_t * output_row = output + row * output_columns;
103149
@@ -216,7 +262,71 @@ DLL_PUBLIC Tensor _single_or_half_precision_to_fusednbitrowwise_gpu(
216262 return output;
217263}
218264
219- template <typename output_t >
265+ template <typename input_t , typename scale_bias_t = __half>
266+ Tensor _float_to_fusednbitrowwise_bf16sb_gpu_t (
267+ const Tensor& input,
268+ const int64_t bit_rate) {
269+ TENSOR_ON_CUDA_GPU (input);
270+ TENSOR_NDIM_EQUALS (input, 2 );
271+ CUDA_DEVICE_GUARD (input);
272+
273+ const int nrows = input.size (0 );
274+ const int ncols = input.size (1 );
275+ const int num_elem_per_byte = 8 / bit_rate;
276+ TORCH_CHECK (
277+ ncols % (2 * num_elem_per_byte) == 0 ,
278+ " ncols needs to be multiple of 2 Bytes (bfloat16 type size) to make the address aligned" );
279+ const int output_columns =
280+ (ncols + num_elem_per_byte - 1 ) / num_elem_per_byte +
281+ 2 * sizeof (scale_bias_t );
282+
283+ auto output =
284+ at::empty ({nrows, output_columns}, input.options ().dtype (at::kByte ));
285+
286+ if (nrows == 0 || ncols == 0 ) {
287+ return output;
288+ }
289+
290+ constexpr auto threads_per_block = 256 ;
291+ const auto num_blocks = cuda_calc_xblock_count (nrows, threads_per_block);
292+
293+ FBGEMM_DISPATCH_FLOATING_TYPES (
294+ input.scalar_type (),
295+ " _float_to_fusednbitrowwise_bf16sb_cuda_kernel" ,
296+ [&] {
297+ FBGEMM_LAUNCH_KERNEL (
298+ (_float_to_fusednbitrowwise_cuda_kernel<scalar_t , scale_bias_t >),
299+ num_blocks,
300+ threads_per_block,
301+ 0 ,
302+ at::cuda::getCurrentCUDAStream (),
303+ bit_rate,
304+ input.data_ptr <scalar_t >(),
305+ nrows,
306+ ncols,
307+ output.data_ptr <std::uint8_t >());
308+ });
309+
310+ return output;
311+ }
312+
313+ // / @ingroup quantize-ops-cuda
314+ // / Converts a tensor of `float` or `at::Half` values into a tensor of fused
315+ // / N-bit rowwise values with bf16 scale/bias.
316+ // /
317+ // / @param input A tensor of `float` or `at::Half` values
318+ // / @param bit_rate
319+ // /
320+ // / @return A new tensor with values from the input tensor converted to
321+ // / fused N-bit rowwise with bf16 scale/bias.
322+ DLL_PUBLIC Tensor _single_or_half_precision_to_fusednbitrowwise_bf16sb_gpu (
323+ const Tensor& input,
324+ const int64_t bit_rate) {
325+ return _float_to_fusednbitrowwise_bf16sb_gpu_t <float , __nv_bfloat16>(
326+ input, bit_rate);
327+ }
328+
329+ template <typename output_t , typename scale_bias_t = __half>
220330Tensor _fusednbitrowwise_to_float_gpu_t (
221331 const Tensor& input,
222332 const int64_t bit_rate,
@@ -228,7 +338,8 @@ Tensor _fusednbitrowwise_to_float_gpu_t(
228338 const int nrows = input.size (0 );
229339 const int ncols = input.size (1 );
230340 const int num_elem_per_byte = 8 / bit_rate;
231- const int output_columns = (ncols - 2 * sizeof (at::Half)) * num_elem_per_byte;
341+ const int output_columns =
342+ (ncols - 2 * sizeof (scale_bias_t )) * num_elem_per_byte;
232343
233344 // Global memory instructions support reading or writing words of size equal
234345 // to 1, 2, 4, 8, or 16 bytes. Any access (via a variable or a pointer) to
@@ -267,17 +378,20 @@ Tensor _fusednbitrowwise_to_float_gpu_t(
267378 const auto gridDim_y = cuda_calc_block_count (nrows, blockDim .y );
268379 const dim3 gridDim (gridDim_x, gridDim_y);
269380
270- #define DEQUANT_LAUNCH_NBIT (scale_bias_last ) \
271- FBGEMM_LAUNCH_KERNEL ( \
272- (_fusednbitrowwise_to_float_cuda_kernel<scalar_t , scale_bias_last>), \
273- gridDim , \
274- blockDim , \
275- 0 , \
276- at::cuda::getCurrentCUDAStream (), \
277- bit_rate, \
278- input.data_ptr <std::uint8_t >(), \
279- nrows, \
280- ncols, \
381+ #define DEQUANT_LAUNCH_NBIT (scale_bias_last ) \
382+ FBGEMM_LAUNCH_KERNEL ( \
383+ (_fusednbitrowwise_to_float_cuda_kernel< \
384+ scalar_t , \
385+ scale_bias_t , \
386+ scale_bias_last>), \
387+ gridDim , \
388+ blockDim , \
389+ 0 , \
390+ at::cuda::getCurrentCUDAStream (), \
391+ bit_rate, \
392+ input.data_ptr <std::uint8_t >(), \
393+ nrows, \
394+ ncols, \
281395 output.mutable_data_ptr <scalar_t >())
282396
283397 FBGEMM_DISPATCH_FLOATING_TYPES (
@@ -365,6 +479,48 @@ DLL_PUBLIC at::Tensor _fusednbitrowwise_to_single_or_half_precision_gpu(
365479 return output;
366480}
367481
482+ // / @ingroup quantize-ops-cuda
483+ // / Converts a tensor of fused N-bit rowwise values with bf16 scale/bias into
484+ // / a tensor of `float` or `at::Half` or `at::Bf16` values.
485+ // /
486+ // / @param input A tensor of fused N-bit rowwise values with bf16 scale/bias
487+ // / @param bit_rate
488+ // / @param output_dtype The target floating point type, specified as integer
489+ // / representation of `SparseType` enum
490+ // /
491+ // / @return A new tensor with values from the input tensor converted to `float`
492+ // / or `at::Half` or `at::Bf16`, depending on `output_dtype`.
493+ // /
494+ // / @throw c10::Error if `output_dtype` is not one of (`SparseType::FP32` or
495+ // / `SparseType::FP16` or `SparseType::BF16`).
496+ DLL_PUBLIC at::Tensor _fusednbitrowwise_bf16sb_to_single_or_half_precision_gpu (
497+ const at::Tensor& input,
498+ const int64_t bit_rate,
499+ const int64_t output_dtype,
500+ const bool scale_bias_last) {
501+ Tensor output;
502+
503+ SparseType output_sparse_dtype = static_cast <SparseType>(output_dtype);
504+ switch (output_sparse_dtype) {
505+ case SparseType::FP32:
506+ output = _fusednbitrowwise_to_float_gpu_t <float , __nv_bfloat16>(
507+ input, bit_rate, scale_bias_last);
508+ break ;
509+ case SparseType::FP16:
510+ output = _fusednbitrowwise_to_float_gpu_t <at::Half, __nv_bfloat16>(
511+ input, bit_rate, scale_bias_last);
512+ break ;
513+ case SparseType::BF16:
514+ output = _fusednbitrowwise_to_float_gpu_t <at::BFloat16, __nv_bfloat16>(
515+ input, bit_rate, scale_bias_last);
516+ break ;
517+ default :
518+ TORCH_CHECK (false );
519+ }
520+
521+ return output;
522+ }
523+
368524} // namespace fbgemm_gpu
369525
370526FBGEMM_OP_DISPATCH (
@@ -379,6 +535,10 @@ FBGEMM_OP_DISPATCH(
379535 CUDA,
380536 " FloatOrHalfToFusedNBitRowwiseQuantizedSBHalf" ,
381537 fbgemm_gpu::_single_or_half_precision_to_fusednbitrowwise_gpu);
538+ FBGEMM_OP_DISPATCH (
539+ CUDA,
540+ " FloatOrHalfToFusedNBitRowwiseQuantizedSBBFloat16" ,
541+ fbgemm_gpu::_single_or_half_precision_to_fusednbitrowwise_bf16sb_gpu);
382542FBGEMM_OP_DISPATCH (
383543 CUDA,
384544 " FusedNBitRowwiseQuantizedSBHalfToFloat" ,
@@ -391,3 +551,7 @@ FBGEMM_OP_DISPATCH(
391551 CUDA,
392552 " FusedNBitRowwiseQuantizedSBHalfToFloatOrHalf" ,
393553 fbgemm_gpu::_fusednbitrowwise_to_single_or_half_precision_gpu);
554+ FBGEMM_OP_DISPATCH (
555+ CUDA,
556+ " FusedNBitRowwiseQuantizedSBBFloat16ToFloatOrHalf" ,
557+ fbgemm_gpu::_fusednbitrowwise_bf16sb_to_single_or_half_precision_gpu);
0 commit comments