diff --git a/include/fbgemm/QuantUtils.h b/include/fbgemm/QuantUtils.h index 08e1935c71..7300cd21be 100644 --- a/include/fbgemm/QuantUtils.h +++ b/include/fbgemm/QuantUtils.h @@ -290,7 +290,8 @@ FBGEMM_API void FloatOrHalfToFusedNBitRowwiseQuantizedSBHalf( * * @param bit_rate can be 2, 4, or 8 */ -template +template + requires(!is_uint16_t_of_type_bf16 || !std::is_same_v) FBGEMM_API void FusedNBitRowwiseQuantizedSBHalfToFloatOrHalf( int bit_rate, const uint8_t* input, @@ -325,6 +326,7 @@ FBGEMM_API void FloatOrHalfToFused8BitRowwiseQuantizedSBFloat( * the corresponding quantize version only supports 8-bit. */ template + requires(!is_uint16_t_of_type_bf16 || !std::is_same_v) FBGEMM_API void Fused8BitRowwiseQuantizedSBFloatToFloatOrHalf( const uint8_t* input, size_t input_rows, @@ -361,6 +363,7 @@ FBGEMM_API void FloatOrHalfToFused8BitRowwiseQuantizedSBFloatRef( * This should not be called directly except in testing. */ template + requires(!is_uint16_t_of_type_bf16 || !std::is_same_v) FBGEMM_API void FusedNBitRowwiseQuantizedSBHalfToFloatOrHalfRef( int bit_rate, const uint8_t* input, @@ -374,6 +377,7 @@ FBGEMM_API void FusedNBitRowwiseQuantizedSBHalfToFloatOrHalfRef( * This should not be called directly except in testing. */ template + requires(!is_uint16_t_of_type_bf16 || !std::is_same_v) FBGEMM_API void Fused8BitRowwiseQuantizedSBFloatToFloatOrHalfRef( const uint8_t* input, size_t input_rows, diff --git a/include/fbgemm/QuantUtilsAvx2.h b/include/fbgemm/QuantUtilsAvx2.h index 0e95649c7c..8e56189065 100644 --- a/include/fbgemm/QuantUtilsAvx2.h +++ b/include/fbgemm/QuantUtilsAvx2.h @@ -165,7 +165,8 @@ void FloatOrHalfToFused8BitRowwiseQuantizedSBFloatAvx2( std::uint8_t* output, const InputType* rowwise_min_max = nullptr); -template +template + requires(!is_bf16 || !std::is_same_v) void FusedNBitRowwiseQuantizedSBHalfToFloatOrHalfAvx2( const std::uint8_t* input, size_t input_rows, @@ -175,7 +176,9 @@ void FusedNBitRowwiseQuantizedSBHalfToFloatOrHalfAvx2( template < typename OutputType, bool scale_bias_last = true, - bool quant_padding_float_type = true> + bool quant_padding_float_type = true, + bool is_bf16 = false> + requires(!is_bf16 || !std::is_same_v) void Fused8BitRowwiseQuantizedSBFloatToFloatOrHalfAvx2( const std::uint8_t* input, size_t input_rows, diff --git a/include/fbgemm/QuantUtilsNeon.h b/include/fbgemm/QuantUtilsNeon.h index ab55403f38..add246929c 100644 --- a/include/fbgemm/QuantUtilsNeon.h +++ b/include/fbgemm/QuantUtilsNeon.h @@ -43,7 +43,7 @@ void FloatOrHalfToFusedNBitRowwiseQuantizedSBHalfNeon( int input_columns, std::uint8_t* output); -template +template void FusedNBitRowwiseQuantizedSBHalfToFloatOrHalfNeon( const std::uint8_t* input, size_t input_rows, diff --git a/src/Bf16ConvertAvx2.h b/src/Bf16ConvertAvx2.h new file mode 100644 index 0000000000..3c2c13ddcb --- /dev/null +++ b/src/Bf16ConvertAvx2.h @@ -0,0 +1,38 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#pragma once + +#ifdef __AVX2__ +#include + +namespace fbgemm::internal { + +// Round-nearest fp32→bf16: val + 0x8000, take high 16 bits. +inline __m256i cvt_fp32_to_bf16x8(__m256i val) { + return _mm256_srli_epi32( + _mm256_add_epi32(val, _mm256_set1_epi32(0x8000)), 16); +} + +// Convert 2x8 fp32 to 16 packed bf16 +inline __m256i cvt_fp32x16_bf16x16(__m256 a, __m256 b) { + __m256i y0 = cvt_fp32_to_bf16x8(_mm256_castps_si256(a)); + __m256i y1 = cvt_fp32_to_bf16x8(_mm256_castps_si256(b)); + return _mm256_permute4x64_epi64(_mm256_packus_epi32(y0, y1), 0xd8); +} + +// Convert 8 fp32 to 8 packed bf16 (128-bit result) +inline __m128i cvt_fp32x8_bf16x8(__m256 src) { + __m256i r = cvt_fp32_to_bf16x8(_mm256_castps_si256(src)); + return _mm_packus_epi32( + _mm256_castsi256_si128(r), _mm256_extracti128_si256(r, 1)); +} + +} // namespace fbgemm::internal + +#endif diff --git a/src/FbgemmBfloat16ConvertAvx2.cc b/src/FbgemmBfloat16ConvertAvx2.cc index b044dd4460..7b34bc9d1b 100644 --- a/src/FbgemmBfloat16ConvertAvx2.cc +++ b/src/FbgemmBfloat16ConvertAvx2.cc @@ -11,30 +11,18 @@ #include #endif #define FBGEMM_EXPORTS +#include "./Bf16ConvertAvx2.h" #include "fbgemm/FbgemmConvert.h" namespace fbgemm { namespace { -inline __m256i QuantizeBfloat16Avx2(const __m256& x0, const __m256& x1) { - // Add 2^15 and right shift 16 to do round-nearest - __m256i y0 = _mm256_srli_epi32( - _mm256_add_epi32(_mm256_castps_si256(x0), _mm256_set1_epi32(1 << 15)), - 16); - __m256i y1 = _mm256_srli_epi32( - _mm256_add_epi32(_mm256_castps_si256(x1), _mm256_set1_epi32(1 << 15)), - 16); - // AVX2 doesn't have _mm256_cvtepi32_epi16 so we need this instruction - // sequence. - return _mm256_permute4x64_epi64(_mm256_packus_epi32(y0, y1), 0xd8); -} - inline void FloatToBfloat16KernelAvx2(const float* src, bfloat16* dst) { // Two float m256i -> One bfloat16 m256i const __m256 src_reg0 = _mm256_loadu_ps(src); const __m256 src_reg1 = _mm256_loadu_ps(src + 8); - __m256i dst_reg = QuantizeBfloat16Avx2(src_reg0, src_reg1); + __m256i dst_reg = internal::cvt_fp32x16_bf16x16(src_reg0, src_reg1); _mm256_storeu_si256(reinterpret_cast<__m256i*>(dst), dst_reg); } diff --git a/src/QuantUtils.cc b/src/QuantUtils.cc index a1a640b2a8..b1e5519916 100644 --- a/src/QuantUtils.cc +++ b/src/QuantUtils.cc @@ -745,6 +745,7 @@ void FloatOrHalfToFused8BitRowwiseQuantizedSBFloat( } template + requires(!is_uint16_t_of_type_bf16 || !std::is_same_v) void FusedNBitRowwiseQuantizedSBHalfToFloatOrHalfRef( int bit_rate, const uint8_t* input, @@ -792,7 +793,8 @@ void FusedNBitRowwiseQuantizedSBHalfToFloatOrHalfRef( } } -template +template + requires(!is_uint16_t_of_type_bf16 || !std::is_same_v) void FusedNBitRowwiseQuantizedSBHalfToFloatOrHalf( int bit_rate, const uint8_t* input, @@ -800,54 +802,64 @@ void FusedNBitRowwiseQuantizedSBHalfToFloatOrHalf( int input_columns, OutputType* output, bool scale_bias_last [[maybe_unused]]) { + auto ref_fallback = [&]() { + FusedNBitRowwiseQuantizedSBHalfToFloatOrHalfRef< + OutputType, + is_uint16_t_of_type_bf16>( + bit_rate, input, input_rows, input_columns, output); + }; + #if HAVE_SVE switch (bit_rate) { case 2: - FusedNBitRowwiseQuantizedSBHalfToFloatOrHalfNeon( + FusedNBitRowwiseQuantizedSBHalfToFloatOrHalfNeon< + OutputType, 2, is_uint16_t_of_type_bf16>( input, input_rows, input_columns, output); break; case 4: - FusedNBitRowwiseQuantizedSBHalfToFloatOrHalfNeon( + FusedNBitRowwiseQuantizedSBHalfToFloatOrHalfNeon< + OutputType, 4, is_uint16_t_of_type_bf16>( input, input_rows, input_columns, output); break; case 8: - FusedNBitRowwiseQuantizedSBHalfToFloatOrHalfNeon( + FusedNBitRowwiseQuantizedSBHalfToFloatOrHalfNeon< + OutputType, 8, is_uint16_t_of_type_bf16>( input, input_rows, input_columns, output); break; default: - FusedNBitRowwiseQuantizedSBHalfToFloatOrHalfRef( - bit_rate, input, input_rows, input_columns, output); + ref_fallback(); } #else - if (cpuinfo_initialize() && fbgemmHasAvx2Support()) { #if CPUINFO_ARCH_X86 || CPUINFO_ARCH_X86_64 switch (bit_rate) { case 2: - FusedNBitRowwiseQuantizedSBHalfToFloatOrHalfAvx2( + FusedNBitRowwiseQuantizedSBHalfToFloatOrHalfAvx2< + OutputType, 2, is_uint16_t_of_type_bf16>( input, input_rows, input_columns, output); break; case 4: - FusedNBitRowwiseQuantizedSBHalfToFloatOrHalfAvx2( + FusedNBitRowwiseQuantizedSBHalfToFloatOrHalfAvx2< + OutputType, 4, is_uint16_t_of_type_bf16>( input, input_rows, input_columns, output); break; case 8: - FusedNBitRowwiseQuantizedSBHalfToFloatOrHalfAvx2( + FusedNBitRowwiseQuantizedSBHalfToFloatOrHalfAvx2< + OutputType, 8, is_uint16_t_of_type_bf16>( input, input_rows, input_columns, output); break; default: - FusedNBitRowwiseQuantizedSBHalfToFloatOrHalfRef( - bit_rate, input, input_rows, input_columns, output); + ref_fallback(); } + return; #endif - } else { - FusedNBitRowwiseQuantizedSBHalfToFloatOrHalfRef( - bit_rate, input, input_rows, input_columns, output); } + ref_fallback(); #endif } template + requires(!is_uint16_t_of_type_bf16 || !std::is_same_v) void Fused8BitRowwiseQuantizedSBFloatToFloatOrHalfRef( const std::uint8_t* input, size_t input_rows, @@ -897,6 +909,7 @@ void Fused8BitRowwiseQuantizedSBFloatToFloatOrHalfRef( } template + requires(!is_uint16_t_of_type_bf16 || !std::is_same_v) void Fused8BitRowwiseQuantizedSBFloatToFloatOrHalf( const std::uint8_t* input, size_t input_rows, @@ -911,7 +924,6 @@ void Fused8BitRowwiseQuantizedSBFloatToFloatOrHalf( if (cpuinfo_initialize() && fbgemmHasAvx2Support()) { #if CPUINFO_ARCH_X86 || CPUINFO_ARCH_X86_64 if (is_uint16_t_of_type_bf16 && fbgemmHasAvx512Bf16Support()) { -#ifdef FBGEMM_FBCODE #define DEQUANT_LAUNCH_AVX512_BF16(scale_bias_last, quant_padding_float_type) \ Fused8BitRowwiseQuantizedSBFloatToBfloat16Avx512< \ scale_bias_last, \ @@ -933,32 +945,31 @@ void Fused8BitRowwiseQuantizedSBFloatToFloatOrHalf( } #undef DEQUANT_LAUNCH_AVX512_BF16 return; -#endif - } else if (!is_uint16_t_of_type_bf16) { + } -#define DEQUANT_LAUNCH_AVX2(scale_bias_last, quant_padding_float_type) \ - Fused8BitRowwiseQuantizedSBFloatToFloatOrHalfAvx2< \ - OutputType, \ - scale_bias_last, \ - quant_padding_float_type>(input, input_rows, input_columns, output); +#define DEQUANT_LAUNCH_AVX2(scale_bias_last, quant_padding_float_type) \ + Fused8BitRowwiseQuantizedSBFloatToFloatOrHalfAvx2< \ + OutputType, \ + scale_bias_last, \ + quant_padding_float_type, \ + is_uint16_t_of_type_bf16>(input, input_rows, input_columns, output); - if (scale_bias_last) { - if (quant_padding_float_type) { - DEQUANT_LAUNCH_AVX2(true, true); - } else { - DEQUANT_LAUNCH_AVX2(true, false); - } + if (scale_bias_last) { + if (quant_padding_float_type) { + DEQUANT_LAUNCH_AVX2(true, true); } else { - if (quant_padding_float_type) { - DEQUANT_LAUNCH_AVX2(false, true); - } else { - DEQUANT_LAUNCH_AVX2(false, false); - } + DEQUANT_LAUNCH_AVX2(true, false); + } + } else { + if (quant_padding_float_type) { + DEQUANT_LAUNCH_AVX2(false, true); + } else { + DEQUANT_LAUNCH_AVX2(false, false); } + } #undef DEQUANT_LAUNCH_AVX2 - return; - } + return; #endif } Fused8BitRowwiseQuantizedSBFloatToFloatOrHalfRef< @@ -997,14 +1008,7 @@ void Fused8BitRowwiseQuantizedSBFloatToFloatOrHalf( type* output, \ bool scale_bias_last); \ template FBGEMM_API void \ - FusedNBitRowwiseQuantizedSBHalfToFloatOrHalfRef( \ - int bit_rate, \ - const uint8_t* input, \ - size_t input_rows, \ - int input_columns, \ - type* output, \ - bool scale_bias_last); \ - template FBGEMM_API void FusedNBitRowwiseQuantizedSBHalfToFloatOrHalf( \ + FusedNBitRowwiseQuantizedSBHalfToFloatOrHalf( \ int bit_rate, \ const uint8_t* input, \ size_t input_rows, \ @@ -1033,23 +1037,7 @@ void Fused8BitRowwiseQuantizedSBFloatToFloatOrHalf( const bool scale_bias_last, \ const bool quant_padding_float_type); \ template FBGEMM_API void \ - Fused8BitRowwiseQuantizedSBFloatToFloatOrHalfRef( \ - const uint8_t* input, \ - size_t input_rows, \ - int input_columns, \ - type* output, \ - const bool scale_bias_last, \ - const bool quant_padding_float_type); \ - template FBGEMM_API void \ Fused8BitRowwiseQuantizedSBFloatToFloatOrHalf( \ - const uint8_t* input, \ - size_t input_rows, \ - int input_columns, \ - type* output, \ - const bool scale_bias_last, \ - const bool quant_padding_float_type); \ - template FBGEMM_API void \ - Fused8BitRowwiseQuantizedSBFloatToFloatOrHalf( \ const uint8_t* input, \ size_t input_rows, \ int input_columns, \ @@ -1064,4 +1052,38 @@ INSTANTIATE_QuantizationFunctions(float16) #undef INSTANTIATE_QuantizationFunctions +// bf16 variants only apply to float16 (uint16_t) output type +template FBGEMM_API void +FusedNBitRowwiseQuantizedSBHalfToFloatOrHalfRef( + int bit_rate, + const uint8_t* input, + size_t input_rows, + int input_columns, + float16* output, + bool scale_bias_last); +template FBGEMM_API void +FusedNBitRowwiseQuantizedSBHalfToFloatOrHalf( + int bit_rate, + const uint8_t* input, + size_t input_rows, + int input_columns, + float16* output, + bool scale_bias_last); +template FBGEMM_API void +Fused8BitRowwiseQuantizedSBFloatToFloatOrHalfRef( + const uint8_t* input, + size_t input_rows, + int input_columns, + float16* output, + const bool scale_bias_last, + const bool quant_padding_float_type); +template FBGEMM_API void +Fused8BitRowwiseQuantizedSBFloatToFloatOrHalf( + const uint8_t* input, + size_t input_rows, + int input_columns, + float16* output, + const bool scale_bias_last, + const bool quant_padding_float_type); + } // namespace fbgemm diff --git a/src/QuantUtilsAvx2.cc b/src/QuantUtilsAvx2.cc index 07b25b9311..702d49e048 100644 --- a/src/QuantUtilsAvx2.cc +++ b/src/QuantUtilsAvx2.cc @@ -23,10 +23,12 @@ #include "fbgemm/FloatConversion.h" #include "fbgemm/Types.h" #include "fbgemm/UtilsAvx2.h" +#include "./Bf16ConvertAvx2.h" // @manual namespace fbgemm { using namespace std; + //////////////////////////////////////////////////////////////////////////////// // Utility functions @@ -1949,7 +1951,8 @@ void FloatOrHalfToFused8BitRowwiseQuantizedSBFloatAvx2( } // for each row } -template +template + requires(!is_bf16 || !std::is_same_v) void FusedNBitRowwiseQuantizedSBHalfToFloatOrHalfAvx2( const std::uint8_t* input, size_t input_rows, @@ -2083,6 +2086,14 @@ void FusedNBitRowwiseQuantizedSBHalfToFloatOrHalfAvx2( _mm256_storeu_ps(output_row_float + col + VLEN, vinq1); _mm256_storeu_ps(output_row_float + col + 2 * VLEN, vinq2); _mm256_storeu_ps(output_row_float + col + 3 * VLEN, vinq3); + } else if constexpr (is_bf16) { + __m256i packed01 = internal::cvt_fp32x16_bf16x16(vinq0, vinq1); + __m256i packed23 = internal::cvt_fp32x16_bf16x16(vinq2, vinq3); + _mm256_storeu_si256( + reinterpret_cast<__m256i*>(output_row + col), packed01); + _mm256_storeu_si256( + reinterpret_cast<__m256i*>(output_row + col + 2 * VLEN), + packed23); } else { _mm_storeu_si128( reinterpret_cast<__m128i*>(output_row + col), @@ -2147,6 +2158,25 @@ void FusedNBitRowwiseQuantizedSBHalfToFloatOrHalfAvx2( output_row_float + col + 2 * VLEN, vmask_store2, vinq2); _mm256_maskstore_ps( output_row_float + col + 3 * VLEN, vmask_store3, vinq3); + } else if constexpr (is_bf16) { + __m256i packed01 = internal::cvt_fp32x16_bf16x16(vinq0, vinq1); + __m256i packed23 = internal::cvt_fp32x16_bf16x16(vinq2, vinq3); + _mm_maskstore_epi32( + reinterpret_cast(output_row + col), + _mm256_castsi256_si128(vmask_store0), + _mm256_castsi256_si128(packed01)); + _mm_maskstore_epi32( + reinterpret_cast(output_row + col + VLEN), + _mm256_castsi256_si128(vmask_store1), + _mm256_extracti128_si256(packed01, 1)); + _mm_maskstore_epi32( + reinterpret_cast(output_row + col + 2 * VLEN), + _mm256_castsi256_si128(vmask_store2), + _mm256_castsi256_si128(packed23)); + _mm_maskstore_epi32( + reinterpret_cast(output_row + col + 3 * VLEN), + _mm256_castsi256_si128(vmask_store3), + _mm256_extracti128_si256(packed23, 1)); } else { _mm_maskstore_epi32( reinterpret_cast(output_row + col), @@ -2179,6 +2209,8 @@ void FusedNBitRowwiseQuantizedSBHalfToFloatOrHalfAvx2( static_cast(double(scale) * quantized + double(bias)); if constexpr (std::is_same_v) { output_row[col] = output_value; + } else if constexpr (is_bf16) { + output_row[col] = cpu_float2bfloat16(output_value); } else { output_row[col] = cpu_float2half_rn(output_value); } @@ -2190,7 +2222,9 @@ void FusedNBitRowwiseQuantizedSBHalfToFloatOrHalfAvx2( template < typename OutputType, bool scale_bias_last, - bool quant_padding_float_type> + bool quant_padding_float_type, + bool is_bf16> + requires(!is_bf16 || !std::is_same_v) void Fused8BitRowwiseQuantizedSBFloatToFloatOrHalfAvx2( const std::uint8_t* input, size_t input_rows, @@ -2230,6 +2264,10 @@ void Fused8BitRowwiseQuantizedSBFloatToFloatOrHalfAvx2( if constexpr (std::is_same_v) { float* output_row_float = reinterpret_cast(output_row); _mm256_storeu_ps(output_row_float + col, dequantzed_v); + } else if constexpr (is_bf16) { + _mm_storeu_si128( + reinterpret_cast<__m128i*>(output_row + col), + internal::cvt_fp32x8_bf16x8(dequantzed_v)); } else { _mm_storeu_si128( reinterpret_cast<__m128i*>(output_row + col), @@ -2243,6 +2281,8 @@ void Fused8BitRowwiseQuantizedSBFloatToFloatOrHalfAvx2( static_cast(double(input_row[col]) * scale + double(bias)); if constexpr (std::is_same_v) { output_row[col] = output_value; + } else if constexpr (is_bf16) { + output_row[col] = cpu_float2bfloat16(output_value); } else { output_row[col] = cpu_float2half_rn(output_value); } @@ -2275,6 +2315,21 @@ INSTANTIATE_QuantizationAvx2FunctionsNBits(float16, 8) // clang-format on #undef INSTANTIATE_QuantizationAvx2FunctionsNBits +#define INSTANTIATE_DequantNBitBf16Avx2(bit_rate) \ + template void \ + FusedNBitRowwiseQuantizedSBHalfToFloatOrHalfAvx2(\ + const std::uint8_t* input, \ + size_t input_rows, \ + int input_columns, \ + float16* output); + +// clang-format off +INSTANTIATE_DequantNBitBf16Avx2(2) +INSTANTIATE_DequantNBitBf16Avx2(4) +INSTANTIATE_DequantNBitBf16Avx2(8) +// clang-format on +#undef INSTANTIATE_DequantNBitBf16Avx2 + #define INSTANTIATE_QuantizationAvx2Functions8Bits(type) \ template void FloatOrHalfToFused8BitRowwiseQuantizedSBFloatAvx2( \ const type* input, \ @@ -2289,26 +2344,31 @@ INSTANTIATE_QuantizationAvx2Functions8Bits(float16) // clang-format on #undef INSTANTIATE_QuantizationAvx2Functions8Bits -#define INSTANTIATE_Fused8BitRowwiseQuantizedSBFloatToFloatOrHalfAvx2( \ - type, scale_bias_last, quant_padding_float_type) \ +#define INSTANTIATE_Fused8BitRowwiseQuantizedSBFloatToFloatOrHalfAvx2( \ + type, scale_bias_last, quant_padding_float_type, is_bf16_out) \ template void Fused8BitRowwiseQuantizedSBFloatToFloatOrHalfAvx2< \ type, \ scale_bias_last, \ - quant_padding_float_type>( \ + quant_padding_float_type, \ + is_bf16_out>( \ const std::uint8_t* input, \ size_t input_rows, \ int input_columns, \ type* output); - // clang-format off -INSTANTIATE_Fused8BitRowwiseQuantizedSBFloatToFloatOrHalfAvx2(float, true, true) -INSTANTIATE_Fused8BitRowwiseQuantizedSBFloatToFloatOrHalfAvx2(float, true, false) -INSTANTIATE_Fused8BitRowwiseQuantizedSBFloatToFloatOrHalfAvx2(float, false, true) -INSTANTIATE_Fused8BitRowwiseQuantizedSBFloatToFloatOrHalfAvx2(float, false, false) -INSTANTIATE_Fused8BitRowwiseQuantizedSBFloatToFloatOrHalfAvx2(float16, true, true) -INSTANTIATE_Fused8BitRowwiseQuantizedSBFloatToFloatOrHalfAvx2(float16, true, false) -INSTANTIATE_Fused8BitRowwiseQuantizedSBFloatToFloatOrHalfAvx2(float16, false, true) -INSTANTIATE_Fused8BitRowwiseQuantizedSBFloatToFloatOrHalfAvx2(float16, false, false) +// clang-format off +INSTANTIATE_Fused8BitRowwiseQuantizedSBFloatToFloatOrHalfAvx2(float, true, true, false) +INSTANTIATE_Fused8BitRowwiseQuantizedSBFloatToFloatOrHalfAvx2(float, true, false, false) +INSTANTIATE_Fused8BitRowwiseQuantizedSBFloatToFloatOrHalfAvx2(float, false, true, false) +INSTANTIATE_Fused8BitRowwiseQuantizedSBFloatToFloatOrHalfAvx2(float, false, false, false) +INSTANTIATE_Fused8BitRowwiseQuantizedSBFloatToFloatOrHalfAvx2(float16, true, true, false) +INSTANTIATE_Fused8BitRowwiseQuantizedSBFloatToFloatOrHalfAvx2(float16, true, false, false) +INSTANTIATE_Fused8BitRowwiseQuantizedSBFloatToFloatOrHalfAvx2(float16, false, true, false) +INSTANTIATE_Fused8BitRowwiseQuantizedSBFloatToFloatOrHalfAvx2(float16, false, false, false) +INSTANTIATE_Fused8BitRowwiseQuantizedSBFloatToFloatOrHalfAvx2(float16, true, true, true) +INSTANTIATE_Fused8BitRowwiseQuantizedSBFloatToFloatOrHalfAvx2(float16, true, false, true) +INSTANTIATE_Fused8BitRowwiseQuantizedSBFloatToFloatOrHalfAvx2(float16, false, true, true) +INSTANTIATE_Fused8BitRowwiseQuantizedSBFloatToFloatOrHalfAvx2(float16, false, false, true) // clang-format on #undef INSTANTIATE_Fused8BitRowwiseQuantizedSBFloatToFloatOrHalfAvx2 diff --git a/src/QuantUtilsNeon.cc b/src/QuantUtilsNeon.cc index 102b628861..f4e574c4d0 100644 --- a/src/QuantUtilsNeon.cc +++ b/src/QuantUtilsNeon.cc @@ -587,12 +587,22 @@ void Fused8BitRowwiseQuantizedSBFloatToFloatOrHalfNeon( } // for each row } -template +// fp32→bf16; matches Bf16ConvertAvx2.h (val + 0x8000, take high 16 bits). +static inline uint16x4_t cvt_fp32x4_to_bf16x4(float32x4_t v) { + const uint32x4_t u = vreinterpretq_u32_f32(v); + const uint32x4_t rounded = vaddq_u32(u, vdupq_n_u32(0x8000)); + return vshrn_n_u32(rounded, 16); +} + +template void FusedNBitRowwiseQuantizedSBHalfToFloatOrHalfNeon( const std::uint8_t* input, size_t input_rows, int input_columns, OutputType* output) { + static_assert( + !IS_BF16_OUT || std::is_same_v, + "IS_BF16_OUT requires float16 output type (bf16 is stored as uint16_t)."); svbool_t allTruePred = svptrue_b8(); constexpr size_t kNumElemsPerIter = 8; constexpr size_t kNumBytesPerIter = BIT_RATE; @@ -668,6 +678,11 @@ void FusedNBitRowwiseQuantizedSBHalfToFloatOrHalfNeon( if constexpr (std::is_same_v) { vst1q_f32(output, svget_neonq(in_v_0_f)); vst1q_f32(output + 4, svget_neonq(in_v_1_f)); + } else if constexpr (IS_BF16_OUT) { + const uint16x4_t bf_lo = cvt_fp32x4_to_bf16x4(svget_neonq(in_v_0_f)); + const uint16x4_t bf_hi = cvt_fp32x4_to_bf16x4(svget_neonq(in_v_1_f)); + vst1q_u16( + reinterpret_cast(output), vcombine_u16(bf_lo, bf_hi)); } else { float16x4_t dequantzed_v_half_low = vcvt_f16_f32(svget_neonq(in_v_0_f)); float16x4_t dequantzed_v_half_high = @@ -716,6 +731,13 @@ void FusedNBitRowwiseQuantizedSBHalfToFloatOrHalfNeon( if constexpr (std::is_same_v) { svst1_f32(lastPredA, output, in_v_0_f); svst1_f32(lastPredB, output + 4, in_v_1_f); + } else if constexpr (IS_BF16_OUT) { + const uint16x4_t bf_lo = cvt_fp32x4_to_bf16x4(svget_neonq(in_v_0_f)); + const uint16x4_t bf_hi = cvt_fp32x4_to_bf16x4(svget_neonq(in_v_1_f)); + svst1_u16( + lastPredC, + reinterpret_cast(output), + svset_neonq_u16(svundef_u16(), vcombine_u16(bf_lo, bf_hi))); } else { float16x4_t dequantzed_v_half_low_low = vcvt_f16_f32(svget_neonq(in_v_0_f)); @@ -776,6 +798,21 @@ INSTANTIATE_QuantizationNeonFunctionsNBits(float16, 8) // clang-format on #undef INSTANTIATE_QuantizationNeonFunctionsNBits +#define INSTANTIATE_DequantNBitBf16Neon(bit_rate) \ + template void \ + FusedNBitRowwiseQuantizedSBHalfToFloatOrHalfNeon(\ + const std::uint8_t* input, \ + size_t input_rows, \ + int input_columns, \ + float16* output); + +// clang-format off +INSTANTIATE_DequantNBitBf16Neon(2) +INSTANTIATE_DequantNBitBf16Neon(4) +INSTANTIATE_DequantNBitBf16Neon(8) +// clang-format on +#undef INSTANTIATE_DequantNBitBf16Neon + #endif // HAVE_SVE } // namespace fbgemm diff --git a/test/QuantUtilsTest.cc b/test/QuantUtilsTest.cc index f7feecaf2d..82346924f4 100644 --- a/test/QuantUtilsTest.cc +++ b/test/QuantUtilsTest.cc @@ -751,6 +751,19 @@ TEST_P(EmbeddingQuantizeTest, embeddingHalfTest) { FusedNBitRowwiseQuantizedSBHalfToFloatOrHalf( bit_rate, outVecRef.data(), rows, out_cols, dequantOutHalfTest.data()); EXPECT_EQ(dequantOutHalfRef, dequantOutHalfTest); + + // ref (double) vs SIMD (fp32 FMA) can differ by ~1 fp32 ULP; allow ~2 bf16 ULPs. + vector dequantBf16Ref(rows * cols), dequantBf16Test(rows * cols); + FusedNBitRowwiseQuantizedSBHalfToFloatOrHalfRef( + bit_rate, outVecRef.data(), rows, out_cols, dequantBf16Ref.data()); + FusedNBitRowwiseQuantizedSBHalfToFloatOrHalf( + bit_rate, outVecRef.data(), rows, out_cols, dequantBf16Test.data()); + for (int i = 0; i < rows * cols; ++i) { + float r = cpu_bf162float(static_cast(dequantBf16Ref[i])); + float t = cpu_bf162float(static_cast(dequantBf16Test[i])); + EXPECT_NEAR(r, t, std::max(1e-3f, 1.6e-2f * std::abs(r))); + } + EXPECT_NE(dequantBf16Ref, dequantOutHalfRef); } // Scale and bias are of type float @@ -824,6 +837,19 @@ TEST_P(EmbeddingQuantizeSBFloatTest, embeddingFloatTest) { Fused8BitRowwiseQuantizedSBFloatToFloatOrHalf( outVecRef.data(), rows, out_cols, dequantOutHalfTest.data()); EXPECT_EQ(dequantOutHalfRef, dequantOutHalfTest); + + // ref (double) vs SIMD (fp32 FMA) can differ by ~1 fp32 ULP; allow ~2 bf16 ULPs. + vector dequantBf16Ref(rows * cols), dequantBf16Test(rows * cols); + Fused8BitRowwiseQuantizedSBFloatToFloatOrHalfRef( + outVecRef.data(), rows, out_cols, dequantBf16Ref.data()); + Fused8BitRowwiseQuantizedSBFloatToFloatOrHalf( + outVecRef.data(), rows, out_cols, dequantBf16Test.data()); + for (int i = 0; i < rows * cols; ++i) { + float r = cpu_bf162float(static_cast(dequantBf16Ref[i])); + float t = cpu_bf162float(static_cast(dequantBf16Test[i])); + EXPECT_NEAR(r, t, std::max(1e-3f, 1.6e-2f * std::abs(r))); + } + EXPECT_NE(dequantBf16Ref, dequantOutHalfRef); } TEST_P(