Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
15 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
124 changes: 106 additions & 18 deletions cpp/daal/src/algorithms/service_kernel_math.h
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
#ifndef __SERVICE_KERNEL_MATH_H__
#define __SERVICE_KERNEL_MATH_H__

#include <cstring>
#include <type_traits>

#include "services/daal_defines.h"
Expand All @@ -46,6 +47,7 @@
#include "src/externals/service_math.h"
#include "src/services/service_profiler.h"

#include <mutex>
#if defined(DAAL_INTEL_CPP_COMPILER)
#include "immintrin.h"
#endif
Expand Down Expand Up @@ -122,6 +124,94 @@
virtual services::Status finalize(const size_t n, FPType * a) = 0;
};

// ---------------------------------------------------------------------------
// BF16GemmDispatcher
//
// Encapsulates the decision of whether to use AMX-BF16 for a float32 GEMM,
// and performs the conversion + dispatch when eligible.
//
// Design principles:
// - Hardware availability is already folded into daal_get_float32_matmul_precision()
// at initialisation time (see env_detect_precision.cpp). Call-sites only
// query the effective precision; no hardware checks here.
// - The effective precision is cached as a translation-unit-level static bool
// so the hot path is a single branch on a compile-time-constant-like value.
// - Matrix-size threshold (kMinDim) is owned by the dispatcher, not the caller.
// - Primary template handles all non-float types with zero overhead.
// ---------------------------------------------------------------------------

#ifndef DAAL_REF

/// Minimum GEMM dimension for which AMX-BF16 is profitable.
/// BF16 conversion overhead outweighs the GEMM benefit for small tiles.
static constexpr DAAL_INT kBF16MinDim = 64;

/// Whether AMX-BF16 GEMM is usable in this process.
/// Combines hardware capability (checked once at init) with the user's precision hint.
/// This is the only place both conditions are AND-ed; call-sites check nothing else.
static const bool g_use_bf16_gemm =
daal_has_amx_bf16() && (daal_get_float32_matmul_precision() == daal::internal::Float32MatmulPrecision::allow_bf16);

/// Primary template: use standard BLAS for all non-float types.
template <typename FPType, CpuType cpu>
struct BF16GemmDispatcher
{
/// Computes out = B * A^T (col-major, M=nRowsB, N=nRowsA, K=nColsA).
static void compute(const FPType * a, const FPType * b, DAAL_INT m, DAAL_INT n, DAAL_INT k, FPType * out)
{
const char transa = 't', transb = 'n';
const FPType alpha = FPType(1), beta = FPType(0);
BlasInst<FPType, cpu>::xxgemm(&transa, &transb, &m, &n, &k, &alpha, b, &k, a, &k, &beta, out, &m);
}
};

/// float specialization: use AMX-BF16 GEMM when precision hint allows it
/// and all dimensions are >= kBF16MinDim; fall back to sgemm otherwise.
/// BF16 conversion uses memcpy to avoid strict-aliasing UB (C++17 [basic.lval]).
template <CpuType cpu>
struct BF16GemmDispatcher<float, cpu>
{
static void compute(const float * a, const float * b, DAAL_INT m, DAAL_INT n, DAAL_INT k, float * out)
{
if (g_use_bf16_gemm && m >= kBF16MinDim && n >= kBF16MinDim && k >= kBF16MinDim)
{
const size_t szA = static_cast<size_t>(n) * static_cast<size_t>(k);
const size_t szB = static_cast<size_t>(m) * static_cast<size_t>(k);
MKL_BF16 * a16 = static_cast<MKL_BF16 *>(mkl_malloc(szA * sizeof(MKL_BF16), 64));
MKL_BF16 * b16 = static_cast<MKL_BF16 *>(mkl_malloc(szB * sizeof(MKL_BF16), 64));
if (a16 && b16)
{
// BF16 = upper 16 bits of IEEE-754 float32; use memcpy to read bits safely.
for (size_t i = 0; i < szA; ++i)
{
uint32_t bits;
memcpy(&bits, &a[i], sizeof(bits));
a16[i] = static_cast<MKL_BF16>(bits >> 16);
}
for (size_t i = 0; i < szB; ++i)
{
uint32_t bits;
memcpy(&bits, &b[i], sizeof(bits));
b16[i] = static_cast<MKL_BF16>(bits >> 16);
}
// col-major: C = B16 * A16^T => cblas NoTrans B, Trans A
cblas_gemm_bf16bf16f32(CblasColMajor, CblasNoTrans, CblasTrans, m, n, k, 1.0f, b16, m, a16, n, 0.0f, out, m);
mkl_free(a16);
mkl_free(b16);
return;
}
if (a16) mkl_free(a16);
if (b16) mkl_free(b16);
}
// Fallback: standard sgemm
const char transa = 't', transb = 'n';
const float alpha = 1.0f, beta = 0.0f;
BlasInst<float, cpu>::xxgemm(&transa, &transb, &m, &n, &k, &alpha, b, &k, a, &k, &beta, out, &m);
}
};

#endif // !DAAL_REF

// compute: sum(A^2, 2) + sum(B^2, 2) -2*A*B'
template <typename FPType, CpuType cpu>
class EuclideanDistances : public PairwiseDistances<FPType, cpu>
Expand All @@ -131,11 +221,11 @@
: _a(a), _b(b), _squared(squared), _isSqrtNorm(isSqrtNorm)
{}

~EuclideanDistances() override {}
virtual ~EuclideanDistances() override {}

PairwiseDistanceType getType() override { return PairwiseDistanceType::euclidean; }

services::Status init() override

Check notice on line 228 in cpp/daal/src/algorithms/service_kernel_math.h

View check run for this annotation

codefactor.io / CodeFactor

cpp/daal/src/algorithms/service_kernel_math.h#L228

"virtual" is redundant since function is already declared as "override". (readability/inheritance)
{
services::Status s;

Expand Down Expand Up @@ -292,22 +382,20 @@

return safeStat.detach();
}

// compute (A x B')
// compute (A x B') -- EuclideanDistances inner GEMM
// GEMM call: out = B * A^T (col-major: M=nRowsB, N=nRowsA, K=nColsA)
// Dispatch to AMX-BF16 or standard BLAS is handled by BF16GemmDispatcher.
void computeABt(const FPType * const a, const FPType * const b, const size_t nRowsA, const size_t nColsA, const size_t nRowsB, FPType * const out)
{
const char transa = 't';
const char transb = 'n';
const DAAL_INT _m = nRowsB;
const DAAL_INT _n = nRowsA;
const DAAL_INT _k = nColsA;
const FPType alpha = 1.0;
const DAAL_INT lda = nColsA;
const DAAL_INT ldy = nColsA;
const FPType beta = 0.0;
const DAAL_INT ldaty = nRowsB;

BlasInst<FPType, cpu>::xxgemm(&transa, &transb, &_m, &_n, &_k, &alpha, b, &lda, a, &ldy, &beta, out, &ldaty);
#ifndef DAAL_REF
BF16GemmDispatcher<FPType, cpu>::compute(a, b, static_cast<DAAL_INT>(nRowsB), static_cast<DAAL_INT>(nRowsA), static_cast<DAAL_INT>(nColsA),
out);
#else
const char transa = 't', transb = 'n';
const DAAL_INT m = nRowsB, n = nRowsA, k = nColsA;
const FPType alpha = FPType(1), beta = FPType(0);
BlasInst<FPType, cpu>::xxgemm(&transa, &transb, &m, &n, &k, &alpha, b, &k, a, &k, &beta, out, &m);
#endif
}

const NumericTable & _a;
Expand All @@ -329,7 +417,7 @@
public:
CosineDistances(const NumericTable & a, const NumericTable & b) : super(a, b, true, true) {}

~CosineDistances() override {}
virtual ~CosineDistances() override {}

PairwiseDistanceType getType() override { return PairwiseDistanceType::cosine; }

Expand All @@ -338,7 +426,7 @@
// output: Row-major matrix of size { aSize x bSize }
services::Status computeBatch(const FPType * const a, const FPType * const b, size_t aOffset, size_t aSize, size_t bOffset, size_t bSize,
FPType * const res) override
{

Check notice on line 429 in cpp/daal/src/algorithms/service_kernel_math.h

View check run for this annotation

codefactor.io / CodeFactor

cpp/daal/src/algorithms/service_kernel_math.h#L429

"virtual" is redundant since function is already declared as "override". (readability/inheritance)
const size_t nRowsA = aSize;
const size_t nColsA = super::_a.getNumberOfColumns();
const size_t nRowsB = bSize;
Expand Down Expand Up @@ -372,7 +460,7 @@
: _a(a), _b(b), _powered(powered), _p(p)
{}

~MinkowskiDistances() override {}
virtual ~MinkowskiDistances() override {}

PairwiseDistanceType getType() override { return PairwiseDistanceType::minkowski; }

Expand All @@ -381,7 +469,7 @@
services::Status s;
return s;
}

Check notice on line 472 in cpp/daal/src/algorithms/service_kernel_math.h

View check run for this annotation

codefactor.io / CodeFactor

cpp/daal/src/algorithms/service_kernel_math.h#L472

"virtual" is redundant since function is already declared as "override". (readability/inheritance)
services::Status computeBatch(const FPType * const a, const FPType * const b, size_t aOffset, size_t aSize, size_t bOffset, size_t bSize,
FPType * const res) override
{
Expand Down Expand Up @@ -471,7 +559,7 @@
public:
ChebyshevDistances(const NumericTable & a, const NumericTable & b) : _a(a), _b(b) {}

~ChebyshevDistances() override {}
virtual ~ChebyshevDistances() override {}

PairwiseDistanceType getType() override { return PairwiseDistanceType::chebyshev; }

Expand All @@ -480,7 +568,7 @@
services::Status s;
return s;
}

Check notice on line 571 in cpp/daal/src/algorithms/service_kernel_math.h

View check run for this annotation

codefactor.io / CodeFactor

cpp/daal/src/algorithms/service_kernel_math.h#L571

"virtual" is redundant since function is already declared as "override". (readability/inheritance)
services::Status computeBatch(const FPType * const a, const FPType * const b, size_t aOffset, size_t aSize, size_t bOffset, size_t bSize,
FPType * const res) override
{
Expand Down
16 changes: 16 additions & 0 deletions cpp/daal/src/services/compiler/generic/env_detect_features.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -300,6 +300,18 @@ DAAL_EXPORT int daal_enabled_cpu_detect()
result |= feature; \
}

/// Check if AMX-BF16 is available: CPUID(7,0).EDX[22] + XCR0[17:18] tile state enabled by OS.
static int check_amx_bf16_features()
{
/* CPUID.(EAX=07H, ECX=0H):EDX.AMX-BF16[bit 22]==1 */
if (!check_cpuid(7, 0, 3, (1 << 22)))
{
return 0;
}
/* XCR0[17:18] - XTILECFG and XTILEDATA must be set */
return check_xgetbv_xcr0_ymm(0x60000);
}

DAAL_UINT64 __daal_internal_serv_cpu_feature_detect()
{
DAAL_UINT64 result = daal::internal::CpuFeature::unknown;
Expand All @@ -313,6 +325,10 @@ DAAL_UINT64 __daal_internal_serv_cpu_feature_detect()
DAAL_TEST_CPU_FEATURE(result, 7, 1, 0, 5, daal::internal::CpuFeature::avx512_bf16);
DAAL_TEST_CPU_FEATURE(result, 7, 0, 2, 11, daal::internal::CpuFeature::avx512_vnni);
}
if (check_amx_bf16_features())
{
result |= daal::internal::CpuFeature::amx_bf16;
}
DAAL_TEST_CPU_FEATURE(result, 1, 0, 2, 7, daal::internal::CpuFeature::sstep);
DAAL_TEST_CPU_FEATURE(result, 6, 0, 0, 1, daal::internal::CpuFeature::tb);
DAAL_TEST_CPU_FEATURE(result, 6, 0, 0, 14, daal::internal::CpuFeature::tb3);
Expand Down
108 changes: 108 additions & 0 deletions cpp/daal/src/services/compiler/generic/env_detect_precision.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,108 @@
/* file: env_detect_precision.cpp */
/*******************************************************************************
* Copyright contributors to the oneDAL project
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*******************************************************************************/

/*
//++
// Float32 matmul precision control.
//
// Inspired by torch.set_float32_matmul_precision / jax_default_matmul_precision.
//
// Precision levels:
// strict (default) -- always compute in float32. Full IEEE-754, bit-exact.
// allow_bf16 -- allow BF16 kernels where oneDAL determines the accuracy
// impact is acceptable. The library, not the user, decides
// which operations are eligible. Currently: float32
// Euclidean GEMM, dims >= 64, hardware with AMX-BF16.
//
// Hardware availability is checked ONCE at process initialisation (g_hw_bf16).
// The setter stores the user's requested level as-is; BF16GemmDispatcher
// gates on both g_hw_bf16 and the stored precision level.
// This way the getter is a plain atomic load with no HW queries.
//--
*/

#include "src/services/cpu_type.h"
#include "src/services/service_defines.h"

#include <atomic>
#include <cstdlib> /* std::getenv */
#include <cstring> /* std::strcmp */

namespace
{

/// Hardware capability flag — evaluated ONCE at process start.
/// True iff AMX-BF16 is present and OS-enabled.
/// AMX-BF16 is x86-only; always false on ARM/RISCV64.
/// daal_serv_cpu_feature_detect() returns a DAAL_UINT64 bitmask;
/// amx_bf16 corresponds to bit 5 (see daal::internal::CpuFeature::amx_bf16).
static bool detect_hw_amx_bf16()
{
#if defined(TARGET_X86_64) || defined(__x86_64__) || defined(_M_X64)
// AMX-BF16 bit = (1ULL << 5) in daal::internal::CpuFeature (cpu_type.h).
// Use the numeric value directly to avoid dependency on TARGET_X86_64 guard
// in cpu_type.h when the macro is not set by the build system.
static const DAAL_UINT64 amx_bf16_mask = (1ULL << 5);
return (daal_serv_cpu_feature_detect() & amx_bf16_mask) != 0;
#else
return false;
#endif
}
static const bool g_hw_amx_bf16 = detect_hw_amx_bf16();

/// Parse ONEDAL_FLOAT32_MATMUL_PRECISION env var.
/// Recognised values: "ALLOW_BF16" / "allow_bf16". Everything else → strict.
static daal::internal::Float32MatmulPrecision parse_env_precision()
{
const char * val = std::getenv("ONEDAL_FLOAT32_MATMUL_PRECISION");
if (val && (std::strcmp(val, "ALLOW_BF16") == 0 || std::strcmp(val, "allow_bf16") == 0))
{
return daal::internal::Float32MatmulPrecision::allow_bf16;
}
return daal::internal::Float32MatmulPrecision::strict;
}

/// Stored precision level (what the user requested).
/// Default: read from env at startup; can be overridden at runtime.
/// BF16GemmDispatcher combines this with g_hw_amx_bf16 to make the
/// final dispatch decision.
static std::atomic<int> g_float32_matmul_precision { static_cast<int>(parse_env_precision()) };

} // anonymous namespace

/// Return whether AMX-BF16 hardware is available in this process.
/// Result is constant for the lifetime of the process.
DAAL_EXPORT bool daal_has_amx_bf16()
{
return g_hw_amx_bf16;
}

/// Return the currently requested float32 matmul precision.
/// This is a plain atomic load — no hardware queries.
DAAL_EXPORT daal::internal::Float32MatmulPrecision daal_get_float32_matmul_precision()
{
return static_cast<daal::internal::Float32MatmulPrecision>(g_float32_matmul_precision.load(std::memory_order_relaxed));
}

/// Set the float32 matmul precision hint.
/// Stores the requested level directly; does NOT re-check hardware.
/// If 'high' is requested but hardware lacks AMX-BF16, BF16GemmDispatcher
/// will fall through to sgemm transparently via its own g_use_bf16_gemm gate.
DAAL_EXPORT void daal_set_float32_matmul_precision(daal::internal::Float32MatmulPrecision p)
{
g_float32_matmul_precision.store(static_cast<int>(p), std::memory_order_relaxed);
}
2 changes: 2 additions & 0 deletions cpp/daal/src/services/cpu_type.h
Original file line number Diff line number Diff line change
Expand Up @@ -62,8 +62,10 @@ enum CpuFeature
avx512_bf16 = (1ULL << 2), /*!< AVX-512 bfloat16 */
avx512_vnni = (1ULL << 3), /*!< AVX-512 Vector Neural Network Instructions (VNNI) */
tb3 = (1ULL << 4), /*!< Intel(R) Turbo Boost Max 3.0 */
amx_bf16 = (1ULL << 5), /*!< Intel(R) Advanced Matrix Extensions bfloat16 (AMX-BF16) */
#endif
};

} // namespace internal
} // namespace daal
#endif
47 changes: 47 additions & 0 deletions cpp/daal/src/services/service_defines.h
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,53 @@ DAAL_EXPORT int __daal_serv_cpu_detect(int);
DAAL_EXPORT int daal_enabled_cpu_detect();
DAAL_EXPORT DAAL_UINT64 daal_serv_cpu_feature_detect();

/* ---------------------------------------------------------------------------
* Float32MatmulPrecision — INTERNAL, experimental API.
*
* Precision hint for internal float32 matrix multiplications.
* This is a *hint*, not a hard override: the library decides which operations
* are eligible for reduced precision. Modelled after
* torch.set_float32_matmul_precision / jax_default_matmul_precision.
*
* Levels:
*
* strict (default)
* Always compute in the input dtype (float32 → sgemm, float64 → dgemm).
* Full IEEE-754, bit-exact and reproducible across hardware generations.
*
* allow_bf16
* Allow oneDAL to use BF16 kernels in operations where the library has
* determined the accuracy impact is bounded and acceptable for the algorithm.
* Currently: float32 Euclidean GEMM on AMX-BF16 hardware, dims >= 64.
* Output dtype remains float32; only internal GEMM accumulation uses BF16.
* Falls back to sgemm silently on hardware without AMX-BF16.
*
* Activation:
* env var: ONEDAL_FLOAT32_MATMUL_PRECISION=ALLOW_BF16 (before loading oneDAL)
* programmatic: daal_set_float32_matmul_precision(daal::internal::Float32MatmulPrecision::allow_bf16)
*
* NOTE: This API is internal and subject to change without notice.
* Do not expose in public headers until the design is finalised.
* ----------------------------------------------------------------------- */
namespace daal
{
namespace internal
{
enum Float32MatmulPrecision
{
strict = 0, /*!< Full float32, IEEE-754, bit-exact (default). */
allow_bf16 = 1, /*!< Allow BF16 kernels where oneDAL deems accuracy impact acceptable.
* Output dtype is unchanged; only internal compute uses BF16. */
};
} // namespace internal
} // namespace daal

/* Hardware capability — constant for the process lifetime. */
DAAL_EXPORT bool daal_has_amx_bf16();
/* Precision getter/setter — no HW queries; operate on the stored hint only. */
DAAL_EXPORT daal::internal::Float32MatmulPrecision daal_get_float32_matmul_precision();
DAAL_EXPORT void daal_set_float32_matmul_precision(daal::internal::Float32MatmulPrecision p);

void run_cpuid(uint32_t eax, uint32_t ecx, uint32_t * abcd);
DAAL_EXPORT bool daal_check_is_intel_cpu();

Expand Down
Loading