Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
226 changes: 195 additions & 31 deletions fbgemm_gpu/src/quantize_ops/quantize_fused_nbit_rowwise.cu
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,46 @@ namespace fbgemm_gpu {

namespace {

// Helper to convert float to scale_bias type (round-trip for quantization)
template <typename scale_bias_t>
__device__ __forceinline__ float float_to_scale_bias(float val);

template <>
__device__ __forceinline__ float float_to_scale_bias<__half>(float val) {
return __half2float(__float2half(val));
}

template <>
__device__ __forceinline__ float float_to_scale_bias<__nv_bfloat16>(float val) {
__nv_bfloat16 bf = __float2bfloat16(val);
#ifdef USE_ROCM
return float(bf);
#else
return __bfloat162float(bf);
#endif
}

// Helper to convert scale_bias type to float
template <typename scale_bias_t>
__device__ __forceinline__ float scale_bias_to_float(scale_bias_t val);

template <>
__device__ __forceinline__ float scale_bias_to_float<__half>(__half val) {
return __half2float(val);
}

template <>
__device__ __forceinline__ float scale_bias_to_float<__nv_bfloat16>(
__nv_bfloat16 val) {
#ifdef USE_ROCM
return float(val);
#else
return __bfloat162float(val);
#endif
}

// FP32/FP16 -> Fused 4/2-bit rowwise kernel
template <typename input_t>
template <typename input_t, typename scale_bias_t = __half>
__global__ inline void _float_to_fusednbitrowwise_cuda_kernel(
const int bit_rate,
const input_t* __restrict__ input,
Expand All @@ -24,23 +62,24 @@ __global__ inline void _float_to_fusednbitrowwise_cuda_kernel(
std::uint8_t* __restrict__ output) {
const int num_elem_per_byte = 8 / bit_rate;
const int output_columns =
(ncols + num_elem_per_byte - 1) / num_elem_per_byte + 2 * sizeof(__half);
(ncols + num_elem_per_byte - 1) / num_elem_per_byte +
2 * sizeof(scale_bias_t);

int row = (int)blockIdx.x * blockDim.x + threadIdx.x;
const auto row_incre = blockDim.x * gridDim.x;
for (/*row*/; row < nrows; row += row_incre) {
const input_t* input_row = input + row * ncols;
std::uint8_t* output_row = output + row * output_columns;
__half* output_row_scale_bias = reinterpret_cast<__half*>(
scale_bias_t* output_row_scale_bias = reinterpret_cast<scale_bias_t*>(
output_row + (ncols + num_elem_per_byte - 1) / num_elem_per_byte);

float minimum_element = fbgemm_gpu::min(input_row, input_row + ncols);
float maximum_element = fbgemm_gpu::max(input_row, input_row + ncols);
minimum_element = __half2float(__float2half(minimum_element));
minimum_element = float_to_scale_bias<scale_bias_t>(minimum_element);
const float range = maximum_element - minimum_element;

float scale = __half2float(
__float2half(range == 0 ? 1.0f : range / ((1 << bit_rate) - 1)));
float scale = float_to_scale_bias<scale_bias_t>(
range == 0 ? 1.0f : range / ((1 << bit_rate) - 1));
if (scale == 0) {
// Corner case handling when maximum_element == minimum_element
// Any scale would work because X - minimum_element will be 0 for all X
Expand All @@ -52,8 +91,13 @@ __global__ inline void _float_to_fusednbitrowwise_cuda_kernel(
inverse_scale = 1.0f;
}

output_row_scale_bias[0] = __float2half(scale);
output_row_scale_bias[1] = __float2half(minimum_element);
if constexpr (std::is_same_v<scale_bias_t, __half>) {
output_row_scale_bias[0] = __float2half(scale);
output_row_scale_bias[1] = __float2half(minimum_element);
} else {
output_row_scale_bias[0] = __float2bfloat16(scale);
output_row_scale_bias[1] = __float2bfloat16(minimum_element);
}
for (std::size_t col = 0; col < ncols; ++col) {
const float X = input_row[col];

Expand All @@ -74,30 +118,32 @@ __global__ inline void _float_to_fusednbitrowwise_cuda_kernel(
}

// Fused 4/2-bit rowwise -> FP32/FP16 kernel
template <typename output_t, bool scale_bias_last>
template <typename output_t, typename scale_bias_t, bool scale_bias_last>
__global__ inline void _fusednbitrowwise_to_float_cuda_kernel(
const int bit_rate,
const std::uint8_t* input,
const int nrows,
const int ncols,
output_t* const output) {
const int num_elem_per_byte = 8 / bit_rate;
const int output_columns = (ncols - 2 * sizeof(__half)) * num_elem_per_byte;
const int output_columns =
(ncols - 2 * sizeof(scale_bias_t)) * num_elem_per_byte;
int row = (int)blockIdx.y * blockDim.y + threadIdx.y;
const int col = (int)blockIdx.x * blockDim.x + threadIdx.x;
const auto row_incre = blockDim.y * gridDim.y;
for (/*row*/; row < nrows; row += row_incre) {
if (row < nrows && col < output_columns) {
const std::uint8_t* input_row = input + row * ncols;
const __half* input_row_scale_bias = reinterpret_cast<const __half*>(
input_row +
(!scale_bias_last
? 0
: (output_columns + num_elem_per_byte - 1) / num_elem_per_byte));
float scale = __half2float(input_row_scale_bias[0]);
float bias = __half2float(input_row_scale_bias[1]);
const scale_bias_t* input_row_scale_bias =
reinterpret_cast<const scale_bias_t*>(
input_row +
(!scale_bias_last ? 0
: (output_columns + num_elem_per_byte - 1) /
num_elem_per_byte));
float scale = scale_bias_to_float(input_row_scale_bias[0]);
float bias = scale_bias_to_float(input_row_scale_bias[1]);
if constexpr (!scale_bias_last) {
input_row += 2 * sizeof(__half);
input_row += 2 * sizeof(scale_bias_t);
}
output_t* output_row = output + row * output_columns;

Expand Down Expand Up @@ -216,7 +262,71 @@ DLL_PUBLIC Tensor _single_or_half_precision_to_fusednbitrowwise_gpu(
return output;
}

template <typename output_t>
template <typename input_t, typename scale_bias_t = __half>
Tensor _float_to_fusednbitrowwise_bf16sb_gpu_t(
const Tensor& input,
const int64_t bit_rate) {
TENSOR_ON_CUDA_GPU(input);
TENSOR_NDIM_EQUALS(input, 2);
CUDA_DEVICE_GUARD(input);

const int nrows = input.size(0);
const int ncols = input.size(1);
const int num_elem_per_byte = 8 / bit_rate;
TORCH_CHECK(
ncols % (2 * num_elem_per_byte) == 0,
"ncols needs to be multiple of 2 Bytes (bfloat16 type size) to make the address aligned");
const int output_columns =
(ncols + num_elem_per_byte - 1) / num_elem_per_byte +
2 * sizeof(scale_bias_t);

auto output =
at::empty({nrows, output_columns}, input.options().dtype(at::kByte));

if (nrows == 0 || ncols == 0) {
return output;
}

constexpr auto threads_per_block = 256;
const auto num_blocks = cuda_calc_xblock_count(nrows, threads_per_block);

FBGEMM_DISPATCH_FLOATING_TYPES(
input.scalar_type(),
"_float_to_fusednbitrowwise_bf16sb_cuda_kernel",
[&] {
FBGEMM_LAUNCH_KERNEL(
(_float_to_fusednbitrowwise_cuda_kernel<scalar_t, scale_bias_t>),
num_blocks,
threads_per_block,
0,
at::cuda::getCurrentCUDAStream(),
bit_rate,
input.data_ptr<scalar_t>(),
nrows,
ncols,
output.data_ptr<std::uint8_t>());
});

return output;
}

/// @ingroup quantize-ops-cuda
/// Converts a tensor of `float` or `at::Half` values into a tensor of fused
/// N-bit rowwise values with bf16 scale/bias.
///
/// @param input A tensor of `float` or `at::Half` values
/// @param bit_rate
///
/// @return A new tensor with values from the input tensor converted to
/// fused N-bit rowwise with bf16 scale/bias.
DLL_PUBLIC Tensor _single_or_half_precision_to_fusednbitrowwise_bf16sb_gpu(
const Tensor& input,
const int64_t bit_rate) {
return _float_to_fusednbitrowwise_bf16sb_gpu_t<float, __nv_bfloat16>(
input, bit_rate);
}

template <typename output_t, typename scale_bias_t = __half>
Tensor _fusednbitrowwise_to_float_gpu_t(
const Tensor& input,
const int64_t bit_rate,
Expand All @@ -228,7 +338,8 @@ Tensor _fusednbitrowwise_to_float_gpu_t(
const int nrows = input.size(0);
const int ncols = input.size(1);
const int num_elem_per_byte = 8 / bit_rate;
const int output_columns = (ncols - 2 * sizeof(at::Half)) * num_elem_per_byte;
const int output_columns =
(ncols - 2 * sizeof(scale_bias_t)) * num_elem_per_byte;

// Global memory instructions support reading or writing words of size equal
// to 1, 2, 4, 8, or 16 bytes. Any access (via a variable or a pointer) to
Expand Down Expand Up @@ -267,17 +378,20 @@ Tensor _fusednbitrowwise_to_float_gpu_t(
const auto gridDim_y = cuda_calc_block_count(nrows, blockDim.y);
const dim3 gridDim(gridDim_x, gridDim_y);

#define DEQUANT_LAUNCH_NBIT(scale_bias_last) \
FBGEMM_LAUNCH_KERNEL( \
(_fusednbitrowwise_to_float_cuda_kernel<scalar_t, scale_bias_last>), \
gridDim, \
blockDim, \
0, \
at::cuda::getCurrentCUDAStream(), \
bit_rate, \
input.data_ptr<std::uint8_t>(), \
nrows, \
ncols, \
#define DEQUANT_LAUNCH_NBIT(scale_bias_last) \
FBGEMM_LAUNCH_KERNEL( \
(_fusednbitrowwise_to_float_cuda_kernel< \
scalar_t, \
scale_bias_t, \
scale_bias_last>), \
gridDim, \
blockDim, \
0, \
at::cuda::getCurrentCUDAStream(), \
bit_rate, \
input.data_ptr<std::uint8_t>(), \
nrows, \
ncols, \
output.mutable_data_ptr<scalar_t>())

FBGEMM_DISPATCH_FLOATING_TYPES(
Expand Down Expand Up @@ -365,6 +479,48 @@ DLL_PUBLIC at::Tensor _fusednbitrowwise_to_single_or_half_precision_gpu(
return output;
}

/// @ingroup quantize-ops-cuda
/// Converts a tensor of fused N-bit rowwise values with bf16 scale/bias into
/// a tensor of `float` or `at::Half` or `at::Bf16` values.
///
/// @param input A tensor of fused N-bit rowwise values with bf16 scale/bias
/// @param bit_rate
/// @param output_dtype The target floating point type, specified as integer
/// representation of `SparseType` enum
///
/// @return A new tensor with values from the input tensor converted to `float`
/// or `at::Half` or `at::Bf16`, depending on `output_dtype`.
///
/// @throw c10::Error if `output_dtype` is not one of (`SparseType::FP32` or
/// `SparseType::FP16` or `SparseType::BF16`).
DLL_PUBLIC at::Tensor _fusednbitrowwise_bf16sb_to_single_or_half_precision_gpu(
const at::Tensor& input,
const int64_t bit_rate,
const int64_t output_dtype,
const bool scale_bias_last) {
Tensor output;

SparseType output_sparse_dtype = static_cast<SparseType>(output_dtype);
switch (output_sparse_dtype) {
case SparseType::FP32:
output = _fusednbitrowwise_to_float_gpu_t<float, __nv_bfloat16>(
input, bit_rate, scale_bias_last);
break;
case SparseType::FP16:
output = _fusednbitrowwise_to_float_gpu_t<at::Half, __nv_bfloat16>(
input, bit_rate, scale_bias_last);
break;
case SparseType::BF16:
output = _fusednbitrowwise_to_float_gpu_t<at::BFloat16, __nv_bfloat16>(
input, bit_rate, scale_bias_last);
break;
default:
TORCH_CHECK(false);
}

return output;
}

} // namespace fbgemm_gpu

FBGEMM_OP_DISPATCH(
Expand All @@ -379,6 +535,10 @@ FBGEMM_OP_DISPATCH(
CUDA,
"FloatOrHalfToFusedNBitRowwiseQuantizedSBHalf",
fbgemm_gpu::_single_or_half_precision_to_fusednbitrowwise_gpu);
FBGEMM_OP_DISPATCH(
CUDA,
"FloatOrHalfToFusedNBitRowwiseQuantizedSBBFloat16",
fbgemm_gpu::_single_or_half_precision_to_fusednbitrowwise_bf16sb_gpu);
FBGEMM_OP_DISPATCH(
CUDA,
"FusedNBitRowwiseQuantizedSBHalfToFloat",
Expand All @@ -391,3 +551,7 @@ FBGEMM_OP_DISPATCH(
CUDA,
"FusedNBitRowwiseQuantizedSBHalfToFloatOrHalf",
fbgemm_gpu::_fusednbitrowwise_to_single_or_half_precision_gpu);
FBGEMM_OP_DISPATCH(
CUDA,
"FusedNBitRowwiseQuantizedSBBFloat16ToFloatOrHalf",
fbgemm_gpu::_fusednbitrowwise_bf16sb_to_single_or_half_precision_gpu);
Loading
Loading