Skip to content

Commit 3c73db9

Browse files
jeetkanjani7facebook-github-bot
authored andcommitted
bf16 scale/bias for INT4 (#5595)
Summary: Add bf16 scale/bias support for INT4/INT2 fused N-bit rowwise quantization in FBGEMM and SilverTorch. Previously, fused N-bit rowwise quantization only supported fp16 scale/bias - storing scale/bias in bf16 avoids precision loss from fp16 truncation during quantization round-trips. X-link: facebookresearch/FBGEMM#2551 Reviewed By: zhaozhul Differential Revision: D95859348
1 parent d83f7c8 commit 3c73db9

4 files changed

Lines changed: 680 additions & 31 deletions

File tree

fbgemm_gpu/src/quantize_ops/quantize_fused_nbit_rowwise.cu

Lines changed: 195 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -14,8 +14,46 @@ namespace fbgemm_gpu {
1414

1515
namespace {
1616

17+
// Helper to convert float to scale_bias type (round-trip for quantization)
18+
template <typename scale_bias_t>
19+
__device__ __forceinline__ float float_to_scale_bias(float val);
20+
21+
template <>
22+
__device__ __forceinline__ float float_to_scale_bias<__half>(float val) {
23+
return __half2float(__float2half(val));
24+
}
25+
26+
template <>
27+
__device__ __forceinline__ float float_to_scale_bias<__nv_bfloat16>(float val) {
28+
__nv_bfloat16 bf = __float2bfloat16(val);
29+
#ifdef USE_ROCM
30+
return float(bf);
31+
#else
32+
return __bfloat162float(bf);
33+
#endif
34+
}
35+
36+
// Helper to convert scale_bias type to float
37+
template <typename scale_bias_t>
38+
__device__ __forceinline__ float scale_bias_to_float(scale_bias_t val);
39+
40+
template <>
41+
__device__ __forceinline__ float scale_bias_to_float<__half>(__half val) {
42+
return __half2float(val);
43+
}
44+
45+
template <>
46+
__device__ __forceinline__ float scale_bias_to_float<__nv_bfloat16>(
47+
__nv_bfloat16 val) {
48+
#ifdef USE_ROCM
49+
return float(val);
50+
#else
51+
return __bfloat162float(val);
52+
#endif
53+
}
54+
1755
// FP32/FP16 -> Fused 4/2-bit rowwise kernel
18-
template <typename input_t>
56+
template <typename input_t, typename scale_bias_t = __half>
1957
__global__ inline void _float_to_fusednbitrowwise_cuda_kernel(
2058
const int bit_rate,
2159
const input_t* __restrict__ input,
@@ -24,23 +62,24 @@ __global__ inline void _float_to_fusednbitrowwise_cuda_kernel(
2462
std::uint8_t* __restrict__ output) {
2563
const int num_elem_per_byte = 8 / bit_rate;
2664
const int output_columns =
27-
(ncols + num_elem_per_byte - 1) / num_elem_per_byte + 2 * sizeof(__half);
65+
(ncols + num_elem_per_byte - 1) / num_elem_per_byte +
66+
2 * sizeof(scale_bias_t);
2867

2968
int row = (int)blockIdx.x * blockDim.x + threadIdx.x;
3069
const auto row_incre = blockDim.x * gridDim.x;
3170
for (/*row*/; row < nrows; row += row_incre) {
3271
const input_t* input_row = input + row * ncols;
3372
std::uint8_t* output_row = output + row * output_columns;
34-
__half* output_row_scale_bias = reinterpret_cast<__half*>(
73+
scale_bias_t* output_row_scale_bias = reinterpret_cast<scale_bias_t*>(
3574
output_row + (ncols + num_elem_per_byte - 1) / num_elem_per_byte);
3675

3776
float minimum_element = fbgemm_gpu::min(input_row, input_row + ncols);
3877
float maximum_element = fbgemm_gpu::max(input_row, input_row + ncols);
39-
minimum_element = __half2float(__float2half(minimum_element));
78+
minimum_element = float_to_scale_bias<scale_bias_t>(minimum_element);
4079
const float range = maximum_element - minimum_element;
4180

42-
float scale = __half2float(
43-
__float2half(range == 0 ? 1.0f : range / ((1 << bit_rate) - 1)));
81+
float scale = float_to_scale_bias<scale_bias_t>(
82+
range == 0 ? 1.0f : range / ((1 << bit_rate) - 1));
4483
if (scale == 0) {
4584
// Corner case handling when maximum_element == minimum_element
4685
// Any scale would work because X - minimum_element will be 0 for all X
@@ -52,8 +91,13 @@ __global__ inline void _float_to_fusednbitrowwise_cuda_kernel(
5291
inverse_scale = 1.0f;
5392
}
5493

55-
output_row_scale_bias[0] = __float2half(scale);
56-
output_row_scale_bias[1] = __float2half(minimum_element);
94+
if constexpr (std::is_same_v<scale_bias_t, __half>) {
95+
output_row_scale_bias[0] = __float2half(scale);
96+
output_row_scale_bias[1] = __float2half(minimum_element);
97+
} else {
98+
output_row_scale_bias[0] = __float2bfloat16(scale);
99+
output_row_scale_bias[1] = __float2bfloat16(minimum_element);
100+
}
57101
for (std::size_t col = 0; col < ncols; ++col) {
58102
const float X = input_row[col];
59103

@@ -74,30 +118,32 @@ __global__ inline void _float_to_fusednbitrowwise_cuda_kernel(
74118
}
75119

76120
// Fused 4/2-bit rowwise -> FP32/FP16 kernel
77-
template <typename output_t, bool scale_bias_last>
121+
template <typename output_t, typename scale_bias_t, bool scale_bias_last>
78122
__global__ inline void _fusednbitrowwise_to_float_cuda_kernel(
79123
const int bit_rate,
80124
const std::uint8_t* input,
81125
const int nrows,
82126
const int ncols,
83127
output_t* const output) {
84128
const int num_elem_per_byte = 8 / bit_rate;
85-
const int output_columns = (ncols - 2 * sizeof(__half)) * num_elem_per_byte;
129+
const int output_columns =
130+
(ncols - 2 * sizeof(scale_bias_t)) * num_elem_per_byte;
86131
int row = (int)blockIdx.y * blockDim.y + threadIdx.y;
87132
const int col = (int)blockIdx.x * blockDim.x + threadIdx.x;
88133
const auto row_incre = blockDim.y * gridDim.y;
89134
for (/*row*/; row < nrows; row += row_incre) {
90135
if (row < nrows && col < output_columns) {
91136
const std::uint8_t* input_row = input + row * ncols;
92-
const __half* input_row_scale_bias = reinterpret_cast<const __half*>(
93-
input_row +
94-
(!scale_bias_last
95-
? 0
96-
: (output_columns + num_elem_per_byte - 1) / num_elem_per_byte));
97-
float scale = __half2float(input_row_scale_bias[0]);
98-
float bias = __half2float(input_row_scale_bias[1]);
137+
const scale_bias_t* input_row_scale_bias =
138+
reinterpret_cast<const scale_bias_t*>(
139+
input_row +
140+
(!scale_bias_last ? 0
141+
: (output_columns + num_elem_per_byte - 1) /
142+
num_elem_per_byte));
143+
float scale = scale_bias_to_float(input_row_scale_bias[0]);
144+
float bias = scale_bias_to_float(input_row_scale_bias[1]);
99145
if constexpr (!scale_bias_last) {
100-
input_row += 2 * sizeof(__half);
146+
input_row += 2 * sizeof(scale_bias_t);
101147
}
102148
output_t* output_row = output + row * output_columns;
103149

@@ -216,7 +262,71 @@ DLL_PUBLIC Tensor _single_or_half_precision_to_fusednbitrowwise_gpu(
216262
return output;
217263
}
218264

219-
template <typename output_t>
265+
template <typename input_t, typename scale_bias_t = __half>
266+
Tensor _float_to_fusednbitrowwise_bf16sb_gpu_t(
267+
const Tensor& input,
268+
const int64_t bit_rate) {
269+
TENSOR_ON_CUDA_GPU(input);
270+
TENSOR_NDIM_EQUALS(input, 2);
271+
CUDA_DEVICE_GUARD(input);
272+
273+
const int nrows = input.size(0);
274+
const int ncols = input.size(1);
275+
const int num_elem_per_byte = 8 / bit_rate;
276+
TORCH_CHECK(
277+
ncols % (2 * num_elem_per_byte) == 0,
278+
"ncols needs to be multiple of 2 Bytes (bfloat16 type size) to make the address aligned");
279+
const int output_columns =
280+
(ncols + num_elem_per_byte - 1) / num_elem_per_byte +
281+
2 * sizeof(scale_bias_t);
282+
283+
auto output =
284+
at::empty({nrows, output_columns}, input.options().dtype(at::kByte));
285+
286+
if (nrows == 0 || ncols == 0) {
287+
return output;
288+
}
289+
290+
constexpr auto threads_per_block = 256;
291+
const auto num_blocks = cuda_calc_xblock_count(nrows, threads_per_block);
292+
293+
FBGEMM_DISPATCH_FLOATING_TYPES(
294+
input.scalar_type(),
295+
"_float_to_fusednbitrowwise_bf16sb_cuda_kernel",
296+
[&] {
297+
FBGEMM_LAUNCH_KERNEL(
298+
(_float_to_fusednbitrowwise_cuda_kernel<scalar_t, scale_bias_t>),
299+
num_blocks,
300+
threads_per_block,
301+
0,
302+
at::cuda::getCurrentCUDAStream(),
303+
bit_rate,
304+
input.data_ptr<scalar_t>(),
305+
nrows,
306+
ncols,
307+
output.data_ptr<std::uint8_t>());
308+
});
309+
310+
return output;
311+
}
312+
313+
/// @ingroup quantize-ops-cuda
314+
/// Converts a tensor of `float` or `at::Half` values into a tensor of fused
315+
/// N-bit rowwise values with bf16 scale/bias.
316+
///
317+
/// @param input A tensor of `float` or `at::Half` values
318+
/// @param bit_rate
319+
///
320+
/// @return A new tensor with values from the input tensor converted to
321+
/// fused N-bit rowwise with bf16 scale/bias.
322+
DLL_PUBLIC Tensor _single_or_half_precision_to_fusednbitrowwise_bf16sb_gpu(
323+
const Tensor& input,
324+
const int64_t bit_rate) {
325+
return _float_to_fusednbitrowwise_bf16sb_gpu_t<float, __nv_bfloat16>(
326+
input, bit_rate);
327+
}
328+
329+
template <typename output_t, typename scale_bias_t = __half>
220330
Tensor _fusednbitrowwise_to_float_gpu_t(
221331
const Tensor& input,
222332
const int64_t bit_rate,
@@ -228,7 +338,8 @@ Tensor _fusednbitrowwise_to_float_gpu_t(
228338
const int nrows = input.size(0);
229339
const int ncols = input.size(1);
230340
const int num_elem_per_byte = 8 / bit_rate;
231-
const int output_columns = (ncols - 2 * sizeof(at::Half)) * num_elem_per_byte;
341+
const int output_columns =
342+
(ncols - 2 * sizeof(scale_bias_t)) * num_elem_per_byte;
232343

233344
// Global memory instructions support reading or writing words of size equal
234345
// to 1, 2, 4, 8, or 16 bytes. Any access (via a variable or a pointer) to
@@ -267,17 +378,20 @@ Tensor _fusednbitrowwise_to_float_gpu_t(
267378
const auto gridDim_y = cuda_calc_block_count(nrows, blockDim.y);
268379
const dim3 gridDim(gridDim_x, gridDim_y);
269380

270-
#define DEQUANT_LAUNCH_NBIT(scale_bias_last) \
271-
FBGEMM_LAUNCH_KERNEL( \
272-
(_fusednbitrowwise_to_float_cuda_kernel<scalar_t, scale_bias_last>), \
273-
gridDim, \
274-
blockDim, \
275-
0, \
276-
at::cuda::getCurrentCUDAStream(), \
277-
bit_rate, \
278-
input.data_ptr<std::uint8_t>(), \
279-
nrows, \
280-
ncols, \
381+
#define DEQUANT_LAUNCH_NBIT(scale_bias_last) \
382+
FBGEMM_LAUNCH_KERNEL( \
383+
(_fusednbitrowwise_to_float_cuda_kernel< \
384+
scalar_t, \
385+
scale_bias_t, \
386+
scale_bias_last>), \
387+
gridDim, \
388+
blockDim, \
389+
0, \
390+
at::cuda::getCurrentCUDAStream(), \
391+
bit_rate, \
392+
input.data_ptr<std::uint8_t>(), \
393+
nrows, \
394+
ncols, \
281395
output.mutable_data_ptr<scalar_t>())
282396

283397
FBGEMM_DISPATCH_FLOATING_TYPES(
@@ -365,6 +479,48 @@ DLL_PUBLIC at::Tensor _fusednbitrowwise_to_single_or_half_precision_gpu(
365479
return output;
366480
}
367481

482+
/// @ingroup quantize-ops-cuda
483+
/// Converts a tensor of fused N-bit rowwise values with bf16 scale/bias into
484+
/// a tensor of `float` or `at::Half` or `at::Bf16` values.
485+
///
486+
/// @param input A tensor of fused N-bit rowwise values with bf16 scale/bias
487+
/// @param bit_rate
488+
/// @param output_dtype The target floating point type, specified as integer
489+
/// representation of `SparseType` enum
490+
///
491+
/// @return A new tensor with values from the input tensor converted to `float`
492+
/// or `at::Half` or `at::Bf16`, depending on `output_dtype`.
493+
///
494+
/// @throw c10::Error if `output_dtype` is not one of (`SparseType::FP32` or
495+
/// `SparseType::FP16` or `SparseType::BF16`).
496+
DLL_PUBLIC at::Tensor _fusednbitrowwise_bf16sb_to_single_or_half_precision_gpu(
497+
const at::Tensor& input,
498+
const int64_t bit_rate,
499+
const int64_t output_dtype,
500+
const bool scale_bias_last) {
501+
Tensor output;
502+
503+
SparseType output_sparse_dtype = static_cast<SparseType>(output_dtype);
504+
switch (output_sparse_dtype) {
505+
case SparseType::FP32:
506+
output = _fusednbitrowwise_to_float_gpu_t<float, __nv_bfloat16>(
507+
input, bit_rate, scale_bias_last);
508+
break;
509+
case SparseType::FP16:
510+
output = _fusednbitrowwise_to_float_gpu_t<at::Half, __nv_bfloat16>(
511+
input, bit_rate, scale_bias_last);
512+
break;
513+
case SparseType::BF16:
514+
output = _fusednbitrowwise_to_float_gpu_t<at::BFloat16, __nv_bfloat16>(
515+
input, bit_rate, scale_bias_last);
516+
break;
517+
default:
518+
TORCH_CHECK(false);
519+
}
520+
521+
return output;
522+
}
523+
368524
} // namespace fbgemm_gpu
369525

370526
FBGEMM_OP_DISPATCH(
@@ -379,6 +535,10 @@ FBGEMM_OP_DISPATCH(
379535
CUDA,
380536
"FloatOrHalfToFusedNBitRowwiseQuantizedSBHalf",
381537
fbgemm_gpu::_single_or_half_precision_to_fusednbitrowwise_gpu);
538+
FBGEMM_OP_DISPATCH(
539+
CUDA,
540+
"FloatOrHalfToFusedNBitRowwiseQuantizedSBBFloat16",
541+
fbgemm_gpu::_single_or_half_precision_to_fusednbitrowwise_bf16sb_gpu);
382542
FBGEMM_OP_DISPATCH(
383543
CUDA,
384544
"FusedNBitRowwiseQuantizedSBHalfToFloat",
@@ -391,3 +551,7 @@ FBGEMM_OP_DISPATCH(
391551
CUDA,
392552
"FusedNBitRowwiseQuantizedSBHalfToFloatOrHalf",
393553
fbgemm_gpu::_fusednbitrowwise_to_single_or_half_precision_gpu);
554+
FBGEMM_OP_DISPATCH(
555+
CUDA,
556+
"FusedNBitRowwiseQuantizedSBBFloat16ToFloatOrHalf",
557+
fbgemm_gpu::_fusednbitrowwise_bf16sb_to_single_or_half_precision_gpu);

0 commit comments

Comments
 (0)