Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
123 changes: 115 additions & 8 deletions fbgemm_gpu/src/quantize_ops/quantize_ops_cpu.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -417,22 +417,129 @@ Tensor fused8bitrowwise_to_float_or_half_cpu(

return output;
}
// dummy cpu code for gpu fp8_rowwise conversions
/// @ingroup quantize-data-cpu
///
Tensor float_to_FP8rowwise_cpu(const Tensor& input, bool /*forward*/) {
TORCH_CHECK(false, "fp8 is not supported by CPU");
return input;
Tensor float_to_FP8rowwise_cpu(const Tensor& input, bool forward) {
TENSOR_ON_CPU(input);
TORCH_CHECK(
input.dim() >= 2,
"Tensor 'input' must have >= 2 dimension(s). Found ",
input.ndimension());

const auto input_sizes = input.sizes();
const auto last_dim = input_sizes.size() - 1;
const int64_t nrows = c10::size_to_dim_(last_dim, input_sizes);
const int64_t ncols = static_cast<int64_t>(input_sizes[last_dim]);

const int ebit = forward ? 4 : 5;
const int bias = forward ? 15 : 31;
const float max_pos = forward ? 0.9375f : 0.875f;
constexpr float kEpsilon = 1e-20f;

const int64_t ncols_aligned = (ncols + 3) / 4 * 4;
const int32_t output_columns = ncols_aligned + 2 * sizeof(float);

auto output_dims = input_sizes.vec();
output_dims[last_dim] = output_columns;

if (nrows == 0 || ncols == 0) {
return at::zeros(output_dims, input.options().dtype(at::kByte));
}

auto output = at::empty(output_dims, input.options().dtype(at::kByte));

const float* input_data = input.const_data_ptr<float>();
uint8_t* output_data = output.mutable_data_ptr<uint8_t>();

for (auto row : c10::irange(nrows)) {
const float* input_row = input_data + row * ncols;
uint8_t* output_row = output_data + row * output_columns;
float* output_row_scale_bias =
reinterpret_cast<float*>(output_row + ncols_aligned);

float maximum_element = kEpsilon;
for (auto col : c10::irange(ncols)) {
maximum_element = std::max(maximum_element, std::abs(input_row[col]));
}
const float scale = max_pos / (kEpsilon + maximum_element);
output_row_scale_bias[0] = scale;
// Initialize padding to make output deterministic for PT2 compliance
output_row_scale_bias[1] = 0.0f;

for (auto col : c10::irange(ncols)) {
output_row[col] =
float_to_hfp8(input_row[col] * scale, ebit, bias, max_pos);
}
}

return output;
}

/// @ingroup quantize-data-cpu
///
Tensor FP8rowwise_to_float_cpu(
const Tensor& input,
bool /*forward*/,
const int64_t /*output_dtype*/) {
TORCH_CHECK(false, "fp8 is not supported by CPU");
return input;
bool forward,
const int64_t output_dtype) {
TENSOR_ON_CPU(input);
TORCH_CHECK(
input.dim() >= 2,
"Tensor 'input' must have >= 2 dimension(s). Found ",
input.ndimension());

const auto input_sizes = input.sizes();
const auto last_dim = input_sizes.size() - 1;
const int64_t nrows = c10::size_to_dim_(last_dim, input_sizes);
const int32_t ncols = input_sizes[last_dim];

const int ebit = forward ? 4 : 5;
const int bias = forward ? 15 : 31;
const auto output_sparse_dtype = static_cast<SparseType>(output_dtype);
TORCH_CHECK(
output_sparse_dtype == SparseType::FP32 ||
output_sparse_dtype == SparseType::FP16 ||
output_sparse_dtype == SparseType::BF16,
"Unsupported output dtype: ",
output_dtype);
const auto out_dtype = output_sparse_dtype == SparseType::FP16
? at::kHalf
: (output_sparse_dtype == SparseType::BF16 ? at::kBFloat16 : at::kFloat);

if (nrows == 0 || ncols == 0) {
auto output_dims = input_sizes.vec();
output_dims[last_dim] = 0;
return at::zeros(output_dims, input.options().dtype(out_dtype));
}

const int32_t output_columns = ncols - 2 * sizeof(float);
auto output_dims = input_sizes.vec();
output_dims[last_dim] = output_columns;

// Always dequantize into float32, then convert to target dtype at the end.
auto output = at::empty(output_dims, input.options().dtype(at::kFloat));

const uint8_t* input_data = input.const_data_ptr<uint8_t>();
float* output_data = output.mutable_data_ptr<float>();

for (auto row : c10::irange(nrows)) {
const uint8_t* input_row = input_data + row * ncols;
float* output_row = output_data + row * output_columns;
const float* input_row_scale_bias =
reinterpret_cast<const float*>(input_row + output_columns);
const float scale = input_row_scale_bias[0];

for (auto col : c10::irange(output_columns)) {
output_row[col] = hfp8_to_float(input_row[col], ebit, bias) / scale;
}
}

if (output_sparse_dtype == SparseType::FP16) {
return output.to(at::kHalf);
}
if (output_sparse_dtype == SparseType::BF16) {
return output.to(at::kBFloat16);
}
return output;
}

/// @ingroup quantize-data-cpu
Expand Down
Loading