From 6508869c169baac5d0933007bc848dccb54ab367 Mon Sep 17 00:00:00 2001 From: Dallas Jacobsen Date: Fri, 1 May 2026 13:54:00 -0700 Subject: [PATCH] Add CPU support in fbgemm for FloatToFP8RowwiseQuantized and FP8RowwiseQuantizedToFloat (#5644) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Summary: X-link: https://github.com/facebookresearch/FBGEMM/pull/2592 Add fp8 support on CPU for fbgemm::FloatToFP8RowwiseQuantized GLOBE eager accuracy test - `output_columns = ncols - 2 * sizeof(float)` equals `ncols_aligned` (the full aligned width, NOT the original K). This matches GPU kernel behavior (`quantize_fp8_rowwise.cu:170`). - `std::abs` + `std::max` reduction is equivalent to `MAX(max_elem, -min_elem)` from MTIA ref kernel line 74 - `at::empty` for output (not `at::zeros`) — padding bytes `[K, K_aligned)` left uninitialized, matching GPU (`quantize_fp8_rowwise.cu:223`) and MTIA kernel behavior - Empty tensor early-return with `at::zeros` matches GPU (line 217-221) - Scale zero-pad initialized to 0.0f for PT2 compliance (matches GPU line 52) Differential Revision: D100724285 --- .../src/quantize_ops/quantize_ops_cpu.cpp | 123 ++++++++++++++++-- 1 file changed, 115 insertions(+), 8 deletions(-) diff --git a/fbgemm_gpu/src/quantize_ops/quantize_ops_cpu.cpp b/fbgemm_gpu/src/quantize_ops/quantize_ops_cpu.cpp index be47106d43..c804dedc51 100644 --- a/fbgemm_gpu/src/quantize_ops/quantize_ops_cpu.cpp +++ b/fbgemm_gpu/src/quantize_ops/quantize_ops_cpu.cpp @@ -417,22 +417,129 @@ Tensor fused8bitrowwise_to_float_or_half_cpu( return output; } -// dummy cpu code for gpu fp8_rowwise conversions /// @ingroup quantize-data-cpu /// -Tensor float_to_FP8rowwise_cpu(const Tensor& input, bool /*forward*/) { - TORCH_CHECK(false, "fp8 is not supported by CPU"); - return input; +Tensor float_to_FP8rowwise_cpu(const Tensor& input, bool forward) { + TENSOR_ON_CPU(input); + TORCH_CHECK( + input.dim() >= 2, + "Tensor 'input' must have >= 2 dimension(s). Found ", + input.ndimension()); + + const auto input_sizes = input.sizes(); + const auto last_dim = input_sizes.size() - 1; + const int64_t nrows = c10::size_to_dim_(last_dim, input_sizes); + const int64_t ncols = static_cast(input_sizes[last_dim]); + + const int ebit = forward ? 4 : 5; + const int bias = forward ? 15 : 31; + const float max_pos = forward ? 0.9375f : 0.875f; + constexpr float kEpsilon = 1e-20f; + + const int64_t ncols_aligned = (ncols + 3) / 4 * 4; + const int32_t output_columns = ncols_aligned + 2 * sizeof(float); + + auto output_dims = input_sizes.vec(); + output_dims[last_dim] = output_columns; + + if (nrows == 0 || ncols == 0) { + return at::zeros(output_dims, input.options().dtype(at::kByte)); + } + + auto output = at::empty(output_dims, input.options().dtype(at::kByte)); + + const float* input_data = input.const_data_ptr(); + uint8_t* output_data = output.mutable_data_ptr(); + + for (auto row : c10::irange(nrows)) { + const float* input_row = input_data + row * ncols; + uint8_t* output_row = output_data + row * output_columns; + float* output_row_scale_bias = + reinterpret_cast(output_row + ncols_aligned); + + float maximum_element = kEpsilon; + for (auto col : c10::irange(ncols)) { + maximum_element = std::max(maximum_element, std::abs(input_row[col])); + } + const float scale = max_pos / (kEpsilon + maximum_element); + output_row_scale_bias[0] = scale; + // Initialize padding to make output deterministic for PT2 compliance + output_row_scale_bias[1] = 0.0f; + + for (auto col : c10::irange(ncols)) { + output_row[col] = + float_to_hfp8(input_row[col] * scale, ebit, bias, max_pos); + } + } + + return output; } /// @ingroup quantize-data-cpu /// Tensor FP8rowwise_to_float_cpu( const Tensor& input, - bool /*forward*/, - const int64_t /*output_dtype*/) { - TORCH_CHECK(false, "fp8 is not supported by CPU"); - return input; + bool forward, + const int64_t output_dtype) { + TENSOR_ON_CPU(input); + TORCH_CHECK( + input.dim() >= 2, + "Tensor 'input' must have >= 2 dimension(s). Found ", + input.ndimension()); + + const auto input_sizes = input.sizes(); + const auto last_dim = input_sizes.size() - 1; + const int64_t nrows = c10::size_to_dim_(last_dim, input_sizes); + const int32_t ncols = input_sizes[last_dim]; + + const int ebit = forward ? 4 : 5; + const int bias = forward ? 15 : 31; + const auto output_sparse_dtype = static_cast(output_dtype); + TORCH_CHECK( + output_sparse_dtype == SparseType::FP32 || + output_sparse_dtype == SparseType::FP16 || + output_sparse_dtype == SparseType::BF16, + "Unsupported output dtype: ", + output_dtype); + const auto out_dtype = output_sparse_dtype == SparseType::FP16 + ? at::kHalf + : (output_sparse_dtype == SparseType::BF16 ? at::kBFloat16 : at::kFloat); + + if (nrows == 0 || ncols == 0) { + auto output_dims = input_sizes.vec(); + output_dims[last_dim] = 0; + return at::zeros(output_dims, input.options().dtype(out_dtype)); + } + + const int32_t output_columns = ncols - 2 * sizeof(float); + auto output_dims = input_sizes.vec(); + output_dims[last_dim] = output_columns; + + // Always dequantize into float32, then convert to target dtype at the end. + auto output = at::empty(output_dims, input.options().dtype(at::kFloat)); + + const uint8_t* input_data = input.const_data_ptr(); + float* output_data = output.mutable_data_ptr(); + + for (auto row : c10::irange(nrows)) { + const uint8_t* input_row = input_data + row * ncols; + float* output_row = output_data + row * output_columns; + const float* input_row_scale_bias = + reinterpret_cast(input_row + output_columns); + const float scale = input_row_scale_bias[0]; + + for (auto col : c10::irange(output_columns)) { + output_row[col] = hfp8_to_float(input_row[col], ebit, bias) / scale; + } + } + + if (output_sparse_dtype == SparseType::FP16) { + return output.to(at::kHalf); + } + if (output_sparse_dtype == SparseType::BF16) { + return output.to(at::kBFloat16); + } + return output; } /// @ingroup quantize-data-cpu