diff --git a/cpp/daal/src/algorithms/service_kernel_math.h b/cpp/daal/src/algorithms/service_kernel_math.h index a3b8b4ace67..4171af7b051 100644 --- a/cpp/daal/src/algorithms/service_kernel_math.h +++ b/cpp/daal/src/algorithms/service_kernel_math.h @@ -24,6 +24,7 @@ #ifndef __SERVICE_KERNEL_MATH_H__ #define __SERVICE_KERNEL_MATH_H__ +#include #include #include "services/daal_defines.h" @@ -46,6 +47,7 @@ #include "src/externals/service_math.h" #include "src/services/service_profiler.h" +#include #if defined(DAAL_INTEL_CPP_COMPILER) #include "immintrin.h" #endif @@ -122,6 +124,94 @@ class PairwiseDistances virtual services::Status finalize(const size_t n, FPType * a) = 0; }; +// --------------------------------------------------------------------------- +// BF16GemmDispatcher +// +// Encapsulates the decision of whether to use AMX-BF16 for a float32 GEMM, +// and performs the conversion + dispatch when eligible. +// +// Design principles: +// - Hardware availability is already folded into daal_get_float32_matmul_precision() +// at initialisation time (see env_detect_precision.cpp). Call-sites only +// query the effective precision; no hardware checks here. +// - The effective precision is cached as a translation-unit-level static bool +// so the hot path is a single branch on a compile-time-constant-like value. +// - Matrix-size threshold (kMinDim) is owned by the dispatcher, not the caller. +// - Primary template handles all non-float types with zero overhead. +// --------------------------------------------------------------------------- + +#ifndef DAAL_REF + +/// Minimum GEMM dimension for which AMX-BF16 is profitable. +/// BF16 conversion overhead outweighs the GEMM benefit for small tiles. +static constexpr DAAL_INT kBF16MinDim = 64; + +/// Whether AMX-BF16 GEMM is usable in this process. +/// Combines hardware capability (checked once at init) with the user's precision hint. +/// This is the only place both conditions are AND-ed; call-sites check nothing else. +static const bool g_use_bf16_gemm = + daal_has_amx_bf16() && (daal_get_float32_matmul_precision() == daal::internal::Float32MatmulPrecision::allow_bf16); + +/// Primary template: use standard BLAS for all non-float types. +template +struct BF16GemmDispatcher +{ + /// Computes out = B * A^T (col-major, M=nRowsB, N=nRowsA, K=nColsA). + static void compute(const FPType * a, const FPType * b, DAAL_INT m, DAAL_INT n, DAAL_INT k, FPType * out) + { + const char transa = 't', transb = 'n'; + const FPType alpha = FPType(1), beta = FPType(0); + BlasInst::xxgemm(&transa, &transb, &m, &n, &k, &alpha, b, &k, a, &k, &beta, out, &m); + } +}; + +/// float specialization: use AMX-BF16 GEMM when precision hint allows it +/// and all dimensions are >= kBF16MinDim; fall back to sgemm otherwise. +/// BF16 conversion uses memcpy to avoid strict-aliasing UB (C++17 [basic.lval]). +template +struct BF16GemmDispatcher +{ + static void compute(const float * a, const float * b, DAAL_INT m, DAAL_INT n, DAAL_INT k, float * out) + { + if (g_use_bf16_gemm && m >= kBF16MinDim && n >= kBF16MinDim && k >= kBF16MinDim) + { + const size_t szA = static_cast(n) * static_cast(k); + const size_t szB = static_cast(m) * static_cast(k); + MKL_BF16 * a16 = static_cast(mkl_malloc(szA * sizeof(MKL_BF16), 64)); + MKL_BF16 * b16 = static_cast(mkl_malloc(szB * sizeof(MKL_BF16), 64)); + if (a16 && b16) + { + // BF16 = upper 16 bits of IEEE-754 float32; use memcpy to read bits safely. + for (size_t i = 0; i < szA; ++i) + { + uint32_t bits; + memcpy(&bits, &a[i], sizeof(bits)); + a16[i] = static_cast(bits >> 16); + } + for (size_t i = 0; i < szB; ++i) + { + uint32_t bits; + memcpy(&bits, &b[i], sizeof(bits)); + b16[i] = static_cast(bits >> 16); + } + // col-major: C = B16 * A16^T => cblas NoTrans B, Trans A + cblas_gemm_bf16bf16f32(CblasColMajor, CblasNoTrans, CblasTrans, m, n, k, 1.0f, b16, m, a16, n, 0.0f, out, m); + mkl_free(a16); + mkl_free(b16); + return; + } + if (a16) mkl_free(a16); + if (b16) mkl_free(b16); + } + // Fallback: standard sgemm + const char transa = 't', transb = 'n'; + const float alpha = 1.0f, beta = 0.0f; + BlasInst::xxgemm(&transa, &transb, &m, &n, &k, &alpha, b, &k, a, &k, &beta, out, &m); + } +}; + +#endif // !DAAL_REF + // compute: sum(A^2, 2) + sum(B^2, 2) -2*A*B' template class EuclideanDistances : public PairwiseDistances @@ -131,7 +221,7 @@ class EuclideanDistances : public PairwiseDistances : _a(a), _b(b), _squared(squared), _isSqrtNorm(isSqrtNorm) {} - ~EuclideanDistances() override {} + virtual ~EuclideanDistances() override {} PairwiseDistanceType getType() override { return PairwiseDistanceType::euclidean; } @@ -292,22 +382,20 @@ class EuclideanDistances : public PairwiseDistances return safeStat.detach(); } - - // compute (A x B') + // compute (A x B') -- EuclideanDistances inner GEMM + // GEMM call: out = B * A^T (col-major: M=nRowsB, N=nRowsA, K=nColsA) + // Dispatch to AMX-BF16 or standard BLAS is handled by BF16GemmDispatcher. void computeABt(const FPType * const a, const FPType * const b, const size_t nRowsA, const size_t nColsA, const size_t nRowsB, FPType * const out) { - const char transa = 't'; - const char transb = 'n'; - const DAAL_INT _m = nRowsB; - const DAAL_INT _n = nRowsA; - const DAAL_INT _k = nColsA; - const FPType alpha = 1.0; - const DAAL_INT lda = nColsA; - const DAAL_INT ldy = nColsA; - const FPType beta = 0.0; - const DAAL_INT ldaty = nRowsB; - - BlasInst::xxgemm(&transa, &transb, &_m, &_n, &_k, &alpha, b, &lda, a, &ldy, &beta, out, &ldaty); +#ifndef DAAL_REF + BF16GemmDispatcher::compute(a, b, static_cast(nRowsB), static_cast(nRowsA), static_cast(nColsA), + out); +#else + const char transa = 't', transb = 'n'; + const DAAL_INT m = nRowsB, n = nRowsA, k = nColsA; + const FPType alpha = FPType(1), beta = FPType(0); + BlasInst::xxgemm(&transa, &transb, &m, &n, &k, &alpha, b, &k, a, &k, &beta, out, &m); +#endif } const NumericTable & _a; @@ -329,7 +417,7 @@ class CosineDistances : public EuclideanDistances public: CosineDistances(const NumericTable & a, const NumericTable & b) : super(a, b, true, true) {} - ~CosineDistances() override {} + virtual ~CosineDistances() override {} PairwiseDistanceType getType() override { return PairwiseDistanceType::cosine; } @@ -372,7 +460,7 @@ class MinkowskiDistances : public PairwiseDistances : _a(a), _b(b), _powered(powered), _p(p) {} - ~MinkowskiDistances() override {} + virtual ~MinkowskiDistances() override {} PairwiseDistanceType getType() override { return PairwiseDistanceType::minkowski; } @@ -471,7 +559,7 @@ class ChebyshevDistances : public PairwiseDistances public: ChebyshevDistances(const NumericTable & a, const NumericTable & b) : _a(a), _b(b) {} - ~ChebyshevDistances() override {} + virtual ~ChebyshevDistances() override {} PairwiseDistanceType getType() override { return PairwiseDistanceType::chebyshev; } diff --git a/cpp/daal/src/services/compiler/generic/env_detect_features.cpp b/cpp/daal/src/services/compiler/generic/env_detect_features.cpp index 39ad9f27cb1..9ff4bd96744 100644 --- a/cpp/daal/src/services/compiler/generic/env_detect_features.cpp +++ b/cpp/daal/src/services/compiler/generic/env_detect_features.cpp @@ -300,6 +300,18 @@ DAAL_EXPORT int daal_enabled_cpu_detect() result |= feature; \ } +/// Check if AMX-BF16 is available: CPUID(7,0).EDX[22] + XCR0[17:18] tile state enabled by OS. +static int check_amx_bf16_features() +{ + /* CPUID.(EAX=07H, ECX=0H):EDX.AMX-BF16[bit 22]==1 */ + if (!check_cpuid(7, 0, 3, (1 << 22))) + { + return 0; + } + /* XCR0[17:18] - XTILECFG and XTILEDATA must be set */ + return check_xgetbv_xcr0_ymm(0x60000); +} + DAAL_UINT64 __daal_internal_serv_cpu_feature_detect() { DAAL_UINT64 result = daal::internal::CpuFeature::unknown; @@ -313,6 +325,10 @@ DAAL_UINT64 __daal_internal_serv_cpu_feature_detect() DAAL_TEST_CPU_FEATURE(result, 7, 1, 0, 5, daal::internal::CpuFeature::avx512_bf16); DAAL_TEST_CPU_FEATURE(result, 7, 0, 2, 11, daal::internal::CpuFeature::avx512_vnni); } + if (check_amx_bf16_features()) + { + result |= daal::internal::CpuFeature::amx_bf16; + } DAAL_TEST_CPU_FEATURE(result, 1, 0, 2, 7, daal::internal::CpuFeature::sstep); DAAL_TEST_CPU_FEATURE(result, 6, 0, 0, 1, daal::internal::CpuFeature::tb); DAAL_TEST_CPU_FEATURE(result, 6, 0, 0, 14, daal::internal::CpuFeature::tb3); diff --git a/cpp/daal/src/services/compiler/generic/env_detect_precision.cpp b/cpp/daal/src/services/compiler/generic/env_detect_precision.cpp new file mode 100644 index 00000000000..3c2317234a0 --- /dev/null +++ b/cpp/daal/src/services/compiler/generic/env_detect_precision.cpp @@ -0,0 +1,108 @@ +/* file: env_detect_precision.cpp */ +/******************************************************************************* +* Copyright contributors to the oneDAL project +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +/* +//++ +// Float32 matmul precision control. +// +// Inspired by torch.set_float32_matmul_precision / jax_default_matmul_precision. +// +// Precision levels: +// strict (default) -- always compute in float32. Full IEEE-754, bit-exact. +// allow_bf16 -- allow BF16 kernels where oneDAL determines the accuracy +// impact is acceptable. The library, not the user, decides +// which operations are eligible. Currently: float32 +// Euclidean GEMM, dims >= 64, hardware with AMX-BF16. +// +// Hardware availability is checked ONCE at process initialisation (g_hw_bf16). +// The setter stores the user's requested level as-is; BF16GemmDispatcher +// gates on both g_hw_bf16 and the stored precision level. +// This way the getter is a plain atomic load with no HW queries. +//-- +*/ + +#include "src/services/cpu_type.h" +#include "src/services/service_defines.h" + +#include +#include /* std::getenv */ +#include /* std::strcmp */ + +namespace +{ + +/// Hardware capability flag — evaluated ONCE at process start. +/// True iff AMX-BF16 is present and OS-enabled. +/// AMX-BF16 is x86-only; always false on ARM/RISCV64. +/// daal_serv_cpu_feature_detect() returns a DAAL_UINT64 bitmask; +/// amx_bf16 corresponds to bit 5 (see daal::internal::CpuFeature::amx_bf16). +static bool detect_hw_amx_bf16() +{ +#if defined(TARGET_X86_64) || defined(__x86_64__) || defined(_M_X64) + // AMX-BF16 bit = (1ULL << 5) in daal::internal::CpuFeature (cpu_type.h). + // Use the numeric value directly to avoid dependency on TARGET_X86_64 guard + // in cpu_type.h when the macro is not set by the build system. + static const DAAL_UINT64 amx_bf16_mask = (1ULL << 5); + return (daal_serv_cpu_feature_detect() & amx_bf16_mask) != 0; +#else + return false; +#endif +} +static const bool g_hw_amx_bf16 = detect_hw_amx_bf16(); + +/// Parse ONEDAL_FLOAT32_MATMUL_PRECISION env var. +/// Recognised values: "ALLOW_BF16" / "allow_bf16". Everything else → strict. +static daal::internal::Float32MatmulPrecision parse_env_precision() +{ + const char * val = std::getenv("ONEDAL_FLOAT32_MATMUL_PRECISION"); + if (val && (std::strcmp(val, "ALLOW_BF16") == 0 || std::strcmp(val, "allow_bf16") == 0)) + { + return daal::internal::Float32MatmulPrecision::allow_bf16; + } + return daal::internal::Float32MatmulPrecision::strict; +} + +/// Stored precision level (what the user requested). +/// Default: read from env at startup; can be overridden at runtime. +/// BF16GemmDispatcher combines this with g_hw_amx_bf16 to make the +/// final dispatch decision. +static std::atomic g_float32_matmul_precision { static_cast(parse_env_precision()) }; + +} // anonymous namespace + +/// Return whether AMX-BF16 hardware is available in this process. +/// Result is constant for the lifetime of the process. +DAAL_EXPORT bool daal_has_amx_bf16() +{ + return g_hw_amx_bf16; +} + +/// Return the currently requested float32 matmul precision. +/// This is a plain atomic load — no hardware queries. +DAAL_EXPORT daal::internal::Float32MatmulPrecision daal_get_float32_matmul_precision() +{ + return static_cast(g_float32_matmul_precision.load(std::memory_order_relaxed)); +} + +/// Set the float32 matmul precision hint. +/// Stores the requested level directly; does NOT re-check hardware. +/// If 'high' is requested but hardware lacks AMX-BF16, BF16GemmDispatcher +/// will fall through to sgemm transparently via its own g_use_bf16_gemm gate. +DAAL_EXPORT void daal_set_float32_matmul_precision(daal::internal::Float32MatmulPrecision p) +{ + g_float32_matmul_precision.store(static_cast(p), std::memory_order_relaxed); +} diff --git a/cpp/daal/src/services/cpu_type.h b/cpp/daal/src/services/cpu_type.h index 13563e21eba..bd4de595e3a 100644 --- a/cpp/daal/src/services/cpu_type.h +++ b/cpp/daal/src/services/cpu_type.h @@ -62,8 +62,10 @@ enum CpuFeature avx512_bf16 = (1ULL << 2), /*!< AVX-512 bfloat16 */ avx512_vnni = (1ULL << 3), /*!< AVX-512 Vector Neural Network Instructions (VNNI) */ tb3 = (1ULL << 4), /*!< Intel(R) Turbo Boost Max 3.0 */ + amx_bf16 = (1ULL << 5), /*!< Intel(R) Advanced Matrix Extensions bfloat16 (AMX-BF16) */ #endif }; + } // namespace internal } // namespace daal #endif diff --git a/cpp/daal/src/services/service_defines.h b/cpp/daal/src/services/service_defines.h index ce93bcf5377..2a001ad3f48 100644 --- a/cpp/daal/src/services/service_defines.h +++ b/cpp/daal/src/services/service_defines.h @@ -33,6 +33,53 @@ DAAL_EXPORT int __daal_serv_cpu_detect(int); DAAL_EXPORT int daal_enabled_cpu_detect(); DAAL_EXPORT DAAL_UINT64 daal_serv_cpu_feature_detect(); +/* --------------------------------------------------------------------------- + * Float32MatmulPrecision — INTERNAL, experimental API. + * + * Precision hint for internal float32 matrix multiplications. + * This is a *hint*, not a hard override: the library decides which operations + * are eligible for reduced precision. Modelled after + * torch.set_float32_matmul_precision / jax_default_matmul_precision. + * + * Levels: + * + * strict (default) + * Always compute in the input dtype (float32 → sgemm, float64 → dgemm). + * Full IEEE-754, bit-exact and reproducible across hardware generations. + * + * allow_bf16 + * Allow oneDAL to use BF16 kernels in operations where the library has + * determined the accuracy impact is bounded and acceptable for the algorithm. + * Currently: float32 Euclidean GEMM on AMX-BF16 hardware, dims >= 64. + * Output dtype remains float32; only internal GEMM accumulation uses BF16. + * Falls back to sgemm silently on hardware without AMX-BF16. + * + * Activation: + * env var: ONEDAL_FLOAT32_MATMUL_PRECISION=ALLOW_BF16 (before loading oneDAL) + * programmatic: daal_set_float32_matmul_precision(daal::internal::Float32MatmulPrecision::allow_bf16) + * + * NOTE: This API is internal and subject to change without notice. + * Do not expose in public headers until the design is finalised. + * ----------------------------------------------------------------------- */ +namespace daal +{ +namespace internal +{ +enum Float32MatmulPrecision +{ + strict = 0, /*!< Full float32, IEEE-754, bit-exact (default). */ + allow_bf16 = 1, /*!< Allow BF16 kernels where oneDAL deems accuracy impact acceptable. + * Output dtype is unchanged; only internal compute uses BF16. */ +}; +} // namespace internal +} // namespace daal + +/* Hardware capability — constant for the process lifetime. */ +DAAL_EXPORT bool daal_has_amx_bf16(); +/* Precision getter/setter — no HW queries; operate on the stored hint only. */ +DAAL_EXPORT daal::internal::Float32MatmulPrecision daal_get_float32_matmul_precision(); +DAAL_EXPORT void daal_set_float32_matmul_precision(daal::internal::Float32MatmulPrecision p); + void run_cpuid(uint32_t eax, uint32_t ecx, uint32_t * abcd); DAAL_EXPORT bool daal_check_is_intel_cpu();