Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 5 additions & 1 deletion include/fbgemm/QuantUtils.h
Original file line number Diff line number Diff line change
Expand Up @@ -290,7 +290,8 @@ FBGEMM_API void FloatOrHalfToFusedNBitRowwiseQuantizedSBHalf(
*
* @param bit_rate can be 2, 4, or 8
*/
template <typename OutputType>
template <typename OutputType, bool is_uint16_t_of_type_bf16 = false>
requires(!is_uint16_t_of_type_bf16 || !std::is_same_v<OutputType, float>)
FBGEMM_API void FusedNBitRowwiseQuantizedSBHalfToFloatOrHalf(
int bit_rate,
const uint8_t* input,
Expand Down Expand Up @@ -325,6 +326,7 @@ FBGEMM_API void FloatOrHalfToFused8BitRowwiseQuantizedSBFloat(
* the corresponding quantize version only supports 8-bit.
*/
template <typename OutputType, bool is_uint16_t_of_type_bf16 = false>
requires(!is_uint16_t_of_type_bf16 || !std::is_same_v<OutputType, float>)
FBGEMM_API void Fused8BitRowwiseQuantizedSBFloatToFloatOrHalf(
const uint8_t* input,
size_t input_rows,
Expand Down Expand Up @@ -361,6 +363,7 @@ FBGEMM_API void FloatOrHalfToFused8BitRowwiseQuantizedSBFloatRef(
* This should not be called directly except in testing.
*/
template <typename OutputType, bool is_uint16_t_of_type_bf16 = false>
requires(!is_uint16_t_of_type_bf16 || !std::is_same_v<OutputType, float>)
FBGEMM_API void FusedNBitRowwiseQuantizedSBHalfToFloatOrHalfRef(
int bit_rate,
const uint8_t* input,
Expand All @@ -374,6 +377,7 @@ FBGEMM_API void FusedNBitRowwiseQuantizedSBHalfToFloatOrHalfRef(
* This should not be called directly except in testing.
*/
template <typename OutputType, bool is_uint16_t_of_type_bf16 = false>
requires(!is_uint16_t_of_type_bf16 || !std::is_same_v<OutputType, float>)
FBGEMM_API void Fused8BitRowwiseQuantizedSBFloatToFloatOrHalfRef(
const uint8_t* input,
size_t input_rows,
Expand Down
7 changes: 5 additions & 2 deletions include/fbgemm/QuantUtilsAvx2.h
Original file line number Diff line number Diff line change
Expand Up @@ -165,7 +165,8 @@ void FloatOrHalfToFused8BitRowwiseQuantizedSBFloatAvx2(
std::uint8_t* output,
const InputType* rowwise_min_max = nullptr);

template <typename OutputType, int BIT_RATE>
template <typename OutputType, int BIT_RATE, bool is_bf16 = false>
requires(!is_bf16 || !std::is_same_v<OutputType, float>)
void FusedNBitRowwiseQuantizedSBHalfToFloatOrHalfAvx2(
const std::uint8_t* input,
size_t input_rows,
Expand All @@ -175,7 +176,9 @@ void FusedNBitRowwiseQuantizedSBHalfToFloatOrHalfAvx2(
template <
typename OutputType,
bool scale_bias_last = true,
bool quant_padding_float_type = true>
bool quant_padding_float_type = true,
bool is_bf16 = false>
requires(!is_bf16 || !std::is_same_v<OutputType, float>)
void Fused8BitRowwiseQuantizedSBFloatToFloatOrHalfAvx2(
const std::uint8_t* input,
size_t input_rows,
Expand Down
2 changes: 1 addition & 1 deletion include/fbgemm/QuantUtilsNeon.h
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ void FloatOrHalfToFusedNBitRowwiseQuantizedSBHalfNeon(
int input_columns,
std::uint8_t* output);

template <typename OutputType, int BIT_RATE>
template <typename OutputType, int BIT_RATE, bool IS_BF16_OUT = false>
void FusedNBitRowwiseQuantizedSBHalfToFloatOrHalfNeon(
const std::uint8_t* input,
size_t input_rows,
Expand Down
38 changes: 38 additions & 0 deletions src/Bf16ConvertAvx2.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
/*
* Copyright (c) Meta Platforms, Inc. and affiliates.
* All rights reserved.
*
* This source code is licensed under the BSD-style license found in the
* LICENSE file in the root directory of this source tree.
*/

#pragma once

#ifdef __AVX2__
#include <immintrin.h>

namespace fbgemm::internal {

// Round-nearest fp32→bf16: val + 0x8000, take high 16 bits.
inline __m256i cvt_fp32_to_bf16x8(__m256i val) {
return _mm256_srli_epi32(
_mm256_add_epi32(val, _mm256_set1_epi32(0x8000)), 16);
}

// Convert 2x8 fp32 to 16 packed bf16
inline __m256i cvt_fp32x16_bf16x16(__m256 a, __m256 b) {
__m256i y0 = cvt_fp32_to_bf16x8(_mm256_castps_si256(a));
__m256i y1 = cvt_fp32_to_bf16x8(_mm256_castps_si256(b));
return _mm256_permute4x64_epi64(_mm256_packus_epi32(y0, y1), 0xd8);
}

// Convert 8 fp32 to 8 packed bf16 (128-bit result)
inline __m128i cvt_fp32x8_bf16x8(__m256 src) {
__m256i r = cvt_fp32_to_bf16x8(_mm256_castps_si256(src));
return _mm_packus_epi32(
_mm256_castsi256_si128(r), _mm256_extracti128_si256(r, 1));
}

} // namespace fbgemm::internal

#endif
16 changes: 2 additions & 14 deletions src/FbgemmBfloat16ConvertAvx2.cc
Original file line number Diff line number Diff line change
Expand Up @@ -11,30 +11,18 @@
#include <immintrin.h>
#endif
#define FBGEMM_EXPORTS
#include "./Bf16ConvertAvx2.h"
#include "fbgemm/FbgemmConvert.h"

namespace fbgemm {

namespace {

inline __m256i QuantizeBfloat16Avx2(const __m256& x0, const __m256& x1) {
// Add 2^15 and right shift 16 to do round-nearest
__m256i y0 = _mm256_srli_epi32(
_mm256_add_epi32(_mm256_castps_si256(x0), _mm256_set1_epi32(1 << 15)),
16);
__m256i y1 = _mm256_srli_epi32(
_mm256_add_epi32(_mm256_castps_si256(x1), _mm256_set1_epi32(1 << 15)),
16);
// AVX2 doesn't have _mm256_cvtepi32_epi16 so we need this instruction
// sequence.
return _mm256_permute4x64_epi64(_mm256_packus_epi32(y0, y1), 0xd8);
}

inline void FloatToBfloat16KernelAvx2(const float* src, bfloat16* dst) {
// Two float m256i -> One bfloat16 m256i
const __m256 src_reg0 = _mm256_loadu_ps(src);
const __m256 src_reg1 = _mm256_loadu_ps(src + 8);
__m256i dst_reg = QuantizeBfloat16Avx2(src_reg0, src_reg1);
__m256i dst_reg = internal::cvt_fp32x16_bf16x16(src_reg0, src_reg1);
_mm256_storeu_si256(reinterpret_cast<__m256i*>(dst), dst_reg);
}

Expand Down
142 changes: 82 additions & 60 deletions src/QuantUtils.cc
Original file line number Diff line number Diff line change
Expand Up @@ -745,6 +745,7 @@ void FloatOrHalfToFused8BitRowwiseQuantizedSBFloat(
}

template <typename OutputType, bool is_uint16_t_of_type_bf16>
requires(!is_uint16_t_of_type_bf16 || !std::is_same_v<OutputType, float>)
void FusedNBitRowwiseQuantizedSBHalfToFloatOrHalfRef(
int bit_rate,
const uint8_t* input,
Expand Down Expand Up @@ -792,62 +793,73 @@ void FusedNBitRowwiseQuantizedSBHalfToFloatOrHalfRef(
}
}

template <typename OutputType>
template <typename OutputType, bool is_uint16_t_of_type_bf16>
requires(!is_uint16_t_of_type_bf16 || !std::is_same_v<OutputType, float>)
void FusedNBitRowwiseQuantizedSBHalfToFloatOrHalf(
int bit_rate,
const uint8_t* input,
size_t input_rows,
int input_columns,
OutputType* output,
bool scale_bias_last [[maybe_unused]]) {
auto ref_fallback = [&]() {
FusedNBitRowwiseQuantizedSBHalfToFloatOrHalfRef<
Comment thread
cyyever marked this conversation as resolved.
OutputType,
is_uint16_t_of_type_bf16>(
bit_rate, input, input_rows, input_columns, output);
};

#if HAVE_SVE
switch (bit_rate) {
case 2:
FusedNBitRowwiseQuantizedSBHalfToFloatOrHalfNeon<OutputType, 2>(
FusedNBitRowwiseQuantizedSBHalfToFloatOrHalfNeon<
OutputType, 2, is_uint16_t_of_type_bf16>(
input, input_rows, input_columns, output);
break;
case 4:
FusedNBitRowwiseQuantizedSBHalfToFloatOrHalfNeon<OutputType, 4>(
FusedNBitRowwiseQuantizedSBHalfToFloatOrHalfNeon<
OutputType, 4, is_uint16_t_of_type_bf16>(
input, input_rows, input_columns, output);
break;
case 8:
FusedNBitRowwiseQuantizedSBHalfToFloatOrHalfNeon<OutputType, 8>(
FusedNBitRowwiseQuantizedSBHalfToFloatOrHalfNeon<
OutputType, 8, is_uint16_t_of_type_bf16>(
input, input_rows, input_columns, output);
break;
default:
FusedNBitRowwiseQuantizedSBHalfToFloatOrHalfRef<OutputType>(
bit_rate, input, input_rows, input_columns, output);
ref_fallback();
}
#else

if (cpuinfo_initialize() && fbgemmHasAvx2Support()) {
#if CPUINFO_ARCH_X86 || CPUINFO_ARCH_X86_64
switch (bit_rate) {
case 2:
FusedNBitRowwiseQuantizedSBHalfToFloatOrHalfAvx2<OutputType, 2>(
FusedNBitRowwiseQuantizedSBHalfToFloatOrHalfAvx2<
OutputType, 2, is_uint16_t_of_type_bf16>(
input, input_rows, input_columns, output);
break;
case 4:
FusedNBitRowwiseQuantizedSBHalfToFloatOrHalfAvx2<OutputType, 4>(
FusedNBitRowwiseQuantizedSBHalfToFloatOrHalfAvx2<
OutputType, 4, is_uint16_t_of_type_bf16>(
input, input_rows, input_columns, output);
break;
case 8:
FusedNBitRowwiseQuantizedSBHalfToFloatOrHalfAvx2<OutputType, 8>(
FusedNBitRowwiseQuantizedSBHalfToFloatOrHalfAvx2<
OutputType, 8, is_uint16_t_of_type_bf16>(
input, input_rows, input_columns, output);
break;
default:
FusedNBitRowwiseQuantizedSBHalfToFloatOrHalfRef<OutputType>(
bit_rate, input, input_rows, input_columns, output);
ref_fallback();
}
return;
#endif
} else {
FusedNBitRowwiseQuantizedSBHalfToFloatOrHalfRef<OutputType>(
bit_rate, input, input_rows, input_columns, output);
}
ref_fallback();
#endif
}

template <typename OutputType, bool is_uint16_t_of_type_bf16>
requires(!is_uint16_t_of_type_bf16 || !std::is_same_v<OutputType, float>)
void Fused8BitRowwiseQuantizedSBFloatToFloatOrHalfRef(
const std::uint8_t* input,
size_t input_rows,
Expand Down Expand Up @@ -897,6 +909,7 @@ void Fused8BitRowwiseQuantizedSBFloatToFloatOrHalfRef(
}

template <typename OutputType, bool is_uint16_t_of_type_bf16>
requires(!is_uint16_t_of_type_bf16 || !std::is_same_v<OutputType, float>)
void Fused8BitRowwiseQuantizedSBFloatToFloatOrHalf(
const std::uint8_t* input,
size_t input_rows,
Expand All @@ -911,7 +924,6 @@ void Fused8BitRowwiseQuantizedSBFloatToFloatOrHalf(
if (cpuinfo_initialize() && fbgemmHasAvx2Support()) {
#if CPUINFO_ARCH_X86 || CPUINFO_ARCH_X86_64
if (is_uint16_t_of_type_bf16 && fbgemmHasAvx512Bf16Support()) {
#ifdef FBGEMM_FBCODE
#define DEQUANT_LAUNCH_AVX512_BF16(scale_bias_last, quant_padding_float_type) \
Fused8BitRowwiseQuantizedSBFloatToBfloat16Avx512< \
scale_bias_last, \
Expand All @@ -933,32 +945,31 @@ void Fused8BitRowwiseQuantizedSBFloatToFloatOrHalf(
}
#undef DEQUANT_LAUNCH_AVX512_BF16
return;
#endif
} else if (!is_uint16_t_of_type_bf16) {
}

#define DEQUANT_LAUNCH_AVX2(scale_bias_last, quant_padding_float_type) \
Fused8BitRowwiseQuantizedSBFloatToFloatOrHalfAvx2< \
OutputType, \
scale_bias_last, \
quant_padding_float_type>(input, input_rows, input_columns, output);
#define DEQUANT_LAUNCH_AVX2(scale_bias_last, quant_padding_float_type) \
Fused8BitRowwiseQuantizedSBFloatToFloatOrHalfAvx2< \
OutputType, \
scale_bias_last, \
quant_padding_float_type, \
is_uint16_t_of_type_bf16>(input, input_rows, input_columns, output);

if (scale_bias_last) {
if (quant_padding_float_type) {
DEQUANT_LAUNCH_AVX2(true, true);
} else {
DEQUANT_LAUNCH_AVX2(true, false);
}
if (scale_bias_last) {
if (quant_padding_float_type) {
DEQUANT_LAUNCH_AVX2(true, true);
} else {
if (quant_padding_float_type) {
DEQUANT_LAUNCH_AVX2(false, true);
} else {
DEQUANT_LAUNCH_AVX2(false, false);
}
DEQUANT_LAUNCH_AVX2(true, false);
}
} else {
if (quant_padding_float_type) {
DEQUANT_LAUNCH_AVX2(false, true);
} else {
DEQUANT_LAUNCH_AVX2(false, false);
}
}
#undef DEQUANT_LAUNCH_AVX2

return;
}
return;
#endif
}
Fused8BitRowwiseQuantizedSBFloatToFloatOrHalfRef<
Expand Down Expand Up @@ -997,14 +1008,7 @@ void Fused8BitRowwiseQuantizedSBFloatToFloatOrHalf(
type* output, \
bool scale_bias_last); \
template FBGEMM_API void \
FusedNBitRowwiseQuantizedSBHalfToFloatOrHalfRef<type, true>( \
int bit_rate, \
const uint8_t* input, \
size_t input_rows, \
int input_columns, \
type* output, \
bool scale_bias_last); \
template FBGEMM_API void FusedNBitRowwiseQuantizedSBHalfToFloatOrHalf<type>( \
FusedNBitRowwiseQuantizedSBHalfToFloatOrHalf<type, false>( \
int bit_rate, \
const uint8_t* input, \
size_t input_rows, \
Expand Down Expand Up @@ -1033,23 +1037,7 @@ void Fused8BitRowwiseQuantizedSBFloatToFloatOrHalf(
const bool scale_bias_last, \
const bool quant_padding_float_type); \
template FBGEMM_API void \
Fused8BitRowwiseQuantizedSBFloatToFloatOrHalfRef<type, true>( \
const uint8_t* input, \
size_t input_rows, \
int input_columns, \
type* output, \
const bool scale_bias_last, \
const bool quant_padding_float_type); \
template FBGEMM_API void \
Fused8BitRowwiseQuantizedSBFloatToFloatOrHalf<type, false>( \
const uint8_t* input, \
size_t input_rows, \
int input_columns, \
type* output, \
const bool scale_bias_last, \
const bool quant_padding_float_type); \
template FBGEMM_API void \
Fused8BitRowwiseQuantizedSBFloatToFloatOrHalf<type, true>( \
const uint8_t* input, \
size_t input_rows, \
int input_columns, \
Expand All @@ -1064,4 +1052,38 @@ INSTANTIATE_QuantizationFunctions(float16)

#undef INSTANTIATE_QuantizationFunctions

// bf16 variants only apply to float16 (uint16_t) output type
template FBGEMM_API void
FusedNBitRowwiseQuantizedSBHalfToFloatOrHalfRef<float16, true>(
int bit_rate,
const uint8_t* input,
size_t input_rows,
int input_columns,
float16* output,
bool scale_bias_last);
template FBGEMM_API void
FusedNBitRowwiseQuantizedSBHalfToFloatOrHalf<float16, true>(
int bit_rate,
const uint8_t* input,
size_t input_rows,
int input_columns,
float16* output,
bool scale_bias_last);
template FBGEMM_API void
Fused8BitRowwiseQuantizedSBFloatToFloatOrHalfRef<float16, true>(
const uint8_t* input,
size_t input_rows,
int input_columns,
float16* output,
const bool scale_bias_last,
const bool quant_padding_float_type);
template FBGEMM_API void
Fused8BitRowwiseQuantizedSBFloatToFloatOrHalf<float16, true>(
const uint8_t* input,
size_t input_rows,
int input_columns,
float16* output,
const bool scale_bias_last,
const bool quant_padding_float_type);

} // namespace fbgemm
Loading
Loading