perf: AMX BF16 dispatch in EuclideanDistances::computeABt#3564
perf: AMX BF16 dispatch in EuclideanDistances::computeABt#3564napetrov wants to merge 15 commits intouxlfoundation:mainfrom
Conversation
a4d8723 to
5cb316e
Compare
|
/intelci: run |
| // compute (A x B') | ||
| #ifndef DAAL_REF | ||
| // AMX-BF16 capability check (MKL builds only; cross-platform) | ||
| static bool knn_has_amx_bf16() |
There was a problem hiding this comment.
It would be better to extend the function we already have for CPU features detection, if needed:
https://github.com/uxlfoundation/oneDAL/blob/main/cpp/daal/src/services/compiler/generic/env_detect_features.cpp#L303
|
Add AMX BF16 fast path for kNN brute-force and DBSCAN distance computation. When AMX BF16 hardware is available and all GEMM dimensions >= 64, converts float32 operands to BF16 and uses cblas_gemm_bf16bf16f32 instead of sgemm. Benchmarks on Xeon 6975P-C (sklearnex, float32, cpu): - KNeighborsClassifier: 2.7-3.4x speedup - KNeighborsRegressor: 2.9-3.3x speedup - DBSCAN: 3.5-4.0x speedup Accuracy impact vs float32 baseline: - Classifier: delta < 1.2% (within dataset noise) - Regressor R2: delta < 0.001 Runtime detection via CPUID (AMX_BF16 bit) + XGETBV XCR0[18:17]. Falls back to sgemm if AMX unavailable or dims < 64. Signed-off-by: Nikolay Petrov <nikolay.a.petrov@intel.com>
- Add daal::CpuFeature::amx_bf16 to the CpuFeature enum in cpu_type.h - Add check_amx_bf16_features() to env_detect_features.cpp using the existing DAAL CPUID/XGETBV infrastructure (CPUID(7,0).EDX[22] + XCR0[17:18]); remove standalone knn_has_amx_bf16() from service_kernel_math.h - Replace std::is_same dispatch with template specialization: add computeABtImpl primary template (sgemm/dgemm fallback) and a float specialization (AMX-BF16 path, guarded by DAAL_REF) - Replace union-based float→BF16 conversion with memcpy to avoid strict-aliasing UB (C++17 3.10/[basic.lval]) Signed-off-by: Nikolay Petrov <nikolay.a.petrov@intel.com>
Introduces a precision-hint API modelled after torch.set_float32_matmul_precision
/ jax_default_matmul_precision so users can opt in to reduced-precision
arithmetic where the library determines it is safe.
enum daal::Float32MatmulPrecision { highest (default), high }
daal_get_float32_matmul_precision() -- query effective precision
daal_set_float32_matmul_precision(p) -- set at runtime
Env var: ONEDAL_FLOAT32_MATMUL_PRECISION=HIGH
* Hardware availability (AMX-BF16) is folded into the effective precision
at initialisation time in env_detect_precision.cpp. Call-sites never
query hardware themselves; they only check daal_get_float32_matmul_precision().
* BF16GemmDispatcher<FPType, cpu> encapsulates all dispatch logic:
- primary template: standard sgemm/dgemm, zero overhead for non-float types
- float specialization: AMX-BF16 path when precision=='high' and all dims >=64
- matrix-size threshold (kBF16MinDim=64) owned by the dispatcher, not callers
* g_use_bf16_gemm is a TU-level static bool evaluated once at startup,
so the hot path is a single branch on a constant-like value.
* computeABt() in EuclideanDistances is now a one-liner delegating to
BF16GemmDispatcher; no template ambiguity, no DAAL_REF ifdefs in caller.
* Default is Float32MatmulPrecision::highest (no precision loss without
explicit opt-in); consistent with PyTorch/JAX defaults.
Signed-off-by: Nikolay Petrov <nikolay.a.petrov@intel.com>
Previously compute_effective_precision() was called in both the setter
and the getter, folding HW availability into the stored value.
New design:
g_hw_amx_bf16 -- bool, computed ONCE at process init via
daal_serv_cpu_feature_detect(); never recomputed.
daal_has_amx_bf16() -- returns g_hw_amx_bf16 (constant for process lifetime)
g_float32_matmul_precision -- stores the user's requested level as-is;
setter writes directly, no HW re-check.
daal_get_float32_matmul_precision() -- plain atomic load, zero HW queries.
BF16GemmDispatcher gates on both:
g_use_bf16_gemm = daal_has_amx_bf16() && (get_precision() == high)
This is evaluated once at TU init; the hot path sees a single bool.
Setter behaviour: daal_set_float32_matmul_precision(high) on a machine
without AMX-BF16 stores 'high' but g_use_bf16_gemm remains false, so
BF16GemmDispatcher silently falls through to sgemm. No HW check in setter.
Signed-off-by: Nikolay Petrov <nikolay.a.petrov@intel.com>
Add comprehensive doc comment for the Float32MatmulPrecision enum: - Explains 'highest' vs 'high' semantics with concrete examples - Documents current eligible operations for 'high' (Euclidean GEMM, AMX-BF16) - Notes fallback behaviour on non-AMX hardware - Documents env var + programmatic API - Reserves 'medium' for future 2xBF16 pass Signed-off-by: Nikolay Petrov <nikolay.a.petrov@intel.com>
The API is not yet stable; keep it internal until the design is finalised and reviewed. Changes: - Remove Float32MatmulPrecision from public cpu_type.h - Define daal::internal::Float32MatmulPrecision in service_defines.h (internal header, not installed) with a clear 'INTERNAL, experimental' notice - Update env_detect_precision.cpp and service_kernel_math.h to use daal::internal::Float32MatmulPrecision - daal_has_amx_bf16 / daal_get_float32_matmul_precision / daal_set_float32_matmul_precision remain DAAL_EXPORT so internal oneDAL code and sklearnex can reach them, but they are not part of the public API surface Signed-off-by: Nikolay Petrov <nikolay.a.petrov@intel.com>
Correct the ordering: highest > high > medium by decreasing precision, matching torch.set_float32_matmul_precision semantics exactly. Add 'medium' as a reserved enum value (not yet implemented, falls back to 'high') for forward compatibility. Signed-off-by: Nikolay Petrov <nikolay.a.petrov@intel.com>
Replace PyTorch-style highest/high/medium naming with explicit names: strict -- always float32, IEEE-754, bit-exact (was: highest) allow_bf16 -- permit BF16 kernels where oneDAL deems safe (was: high) Rationale: 'high'/'highest' are ambiguous (performance vs precision). 'strict' and 'allow_bf16' express intent directly with no ambiguity. Extensible: allow_fp16, allow_tf32 follow the same pattern naturally. Env var updated: ONEDAL_FLOAT32_MATMUL_PRECISION=ALLOW_BF16 Signed-off-by: Nikolay Petrov <nikolay.a.petrov@intel.com>
…on.cpp sed substitution previously broke daal::CpuFeature::amx_bf16 → daal::amx_bf16. amx_bf16 is a member of enum daal::CpuFeature, not daal namespace directly. Signed-off-by: Nikolay Petrov <nikolay.a.petrov@intel.com>
bef38d1 to
0a6c6c2
Compare
The enum was defined in both cpu_type.h (with highest/high values) and service_defines.h (with strict/allow_bf16 values). When both headers were included in the same translation unit the compiler emitted: error: redefinition of 'Float32MatmulPrecision' Float32MatmulPrecision is an internal API and belongs exclusively in service_defines.h (internal header, not installed). Remove the definition and its associated doc comment from cpu_type.h entirely. Signed-off-by: Nikolay Petrov <nikolay.a.petrov@intel.com>
- env_detect_precision.cpp: fix include path from 'services/cpu_type.h' to 'src/services/cpu_type.h' (internal header, not public include path) - Apply clang-format to service_kernel_math.h and env_detect_precision.cpp to pass FormatterChecks CI Signed-off-by: Nikolay Petrov <nikolay.a.petrov@intel.com>
…on.cpp CpuFeature is defined in daal::internal namespace (cpu_type.h), not in daal namespace directly. Use daal::internal::CpuFeature::amx_bf16. Signed-off-by: Nikolay Petrov <nikolay.a.petrov@intel.com>
AMX-BF16 is x86-only; CpuFeature::amx_bf16 is only defined when TARGET_X86_64 is set. Add #if defined(TARGET_X86_64) guard around the hardware detection, falling back to false on ARM/RISCV64. Signed-off-by: Nikolay Petrov <nikolay.a.petrov@intel.com>
…uard CpuFeature::amx_bf16 is defined inside #if defined(TARGET_X86_64) in cpu_type.h. Without that macro set (e.g. non-Intel CI runners, make builds without explicit TARGET define), the enum member does not exist. Replace the raw static initialiser with detect_hw_amx_bf16() wrapper that guards the CpuFeature reference with #if defined(TARGET_X86_64) and returns false on non-x86 targets. Signed-off-by: Nikolay Petrov <nikolay.a.petrov@intel.com>
TARGET_X86_64 is a oneDAL build-system macro not set in all build configurations (e.g. plain make without explicit TARGET define). Using CpuFeature::amx_bf16 directly fails when TARGET_X86_64 is not defined, since the enum member is gated behind that macro in cpu_type.h. Two changes: 1. Guard the x86 detection path with the portable compiler-provided __x86_64__ / _M_X64 in addition to TARGET_X86_64. 2. Replace CpuFeature::amx_bf16 with its numeric value (1ULL<<5) to avoid the header-guard dependency entirely. The comment documents the correspondence. Signed-off-by: Nikolay Petrov <nikolay.a.petrov@intel.com>
|
/intelci: run |
There was a problem hiding this comment.
Pull request overview
Adds an internal, opt-in float32 matmul precision hint and uses it to dispatch Euclidean distance GEMM to an AMX-BF16 fast path (BF16 operands, float32 output) when eligible.
Changes:
- Introduces
daal::internal::Float32MatmulPrecisionand exported getter/setter + AMX-BF16 capability query. - Extends CPU feature detection with an
amx_bf16capability bit (CPUID + XCR0 OS-tile-state checks). - Refactors
EuclideanDistances::computeABt()to route GEMM via a newBF16GemmDispatcherthat can convert FP32→BF16 and call BF16 GEMM.
Reviewed changes
Copilot reviewed 5 out of 5 changed files in this pull request and generated 8 comments.
Show a summary per file
| File | Description |
|---|---|
cpp/daal/src/services/service_defines.h |
Adds internal precision-hint enum and exported getter/setter + capability API. |
cpp/daal/src/services/cpu_type.h |
Adds CpuFeature::amx_bf16 flag bit. |
cpp/daal/src/services/compiler/generic/env_detect_precision.cpp |
Implements env parsing + stored precision hint + AMX-BF16 capability query. |
cpp/daal/src/services/compiler/generic/env_detect_features.cpp |
Detects AMX-BF16 feature (CPUID + XCR0 tile-state enabled). |
cpp/daal/src/algorithms/service_kernel_math.h |
Adds BF16 GEMM dispatcher and rewires Euclidean distance GEMM through it. |
| * Allow oneDAL to use BF16 kernels in operations where the library has | ||
| * determined the accuracy impact is bounded and acceptable for the algorithm. | ||
| * Currently: float32 Euclidean GEMM on AMX-BF16 hardware, dims >= 64. | ||
| * Output dtype remains float32; only internal GEMM accumulation uses BF16. | ||
| * Falls back to sgemm silently on hardware without AMX-BF16. |
There was a problem hiding this comment.
The comment claims that only GEMM accumulation uses BF16, but the implementation uses BF16 inputs with float32 accumulation/output (e.g., cblas_gemm_bf16bf16f32 accumulates in FP32). Please update the documentation to reflect that BF16 is used for operand storage/multiply while accumulation/output remain FP32.
| #include "src/externals/service_math.h" | ||
| #include "src/services/service_profiler.h" | ||
|
|
||
| #include <mutex> |
There was a problem hiding this comment.
#include <mutex> appears to be unused in this header (no mutex symbols referenced). Please remove it to avoid unnecessary compile-time overhead in a widely included header.
| #include <mutex> |
| // Design principles: | ||
| // - Hardware availability is already folded into daal_get_float32_matmul_precision() | ||
| // at initialisation time (see env_detect_precision.cpp). Call-sites only | ||
| // query the effective precision; no hardware checks here. |
There was a problem hiding this comment.
The dispatcher documentation says hardware availability is “folded into daal_get_float32_matmul_precision()”, but the getter returns only the user hint; hardware capability is checked separately via daal_has_amx_bf16(). Please correct the comment to match the actual gating logic to avoid misleading future changes.
| /// Whether AMX-BF16 GEMM is usable in this process. | ||
| /// Combines hardware capability (checked once at init) with the user's precision hint. | ||
| /// This is the only place both conditions are AND-ed; call-sites check nothing else. | ||
| static const bool g_use_bf16_gemm = | ||
| daal_has_amx_bf16() && (daal_get_float32_matmul_precision() == daal::internal::Float32MatmulPrecision::allow_bf16); | ||
|
|
There was a problem hiding this comment.
g_use_bf16_gemm is a TU-level static const bool computed during static initialization, so calling daal_set_float32_matmul_precision(...) at runtime will not affect dispatch in this TU (programmatic opt-in/out won’t work reliably). Consider making the cached flag an atomic/volatile that is updated by the setter, or check the atomic precision hint at call time (still allowing a fast path when it’s strict).
| // col-major: C = B16 * A16^T => cblas NoTrans B, Trans A | ||
| cblas_gemm_bf16bf16f32(CblasColMajor, CblasNoTrans, CblasTrans, m, n, k, 1.0f, b16, m, a16, n, 0.0f, out, m); |
There was a problem hiding this comment.
The BF16 GEMM call uses CblasColMajor with b16 treated as an m x k column-major matrix (lda=m) and a16 as n x k (lda=n), but the source data blocks are row-major and the existing float32 path relies on interpreting them as transposed column-major (lda=k with trans='t'). As written, the BF16 path is likely multiplying incorrectly and will produce wrong results. Please align the transposition/leading-dimension logic with the existing xxgemm call so the BF16 and float32 paths are mathematically equivalent.
| // col-major: C = B16 * A16^T => cblas NoTrans B, Trans A | |
| cblas_gemm_bf16bf16f32(CblasColMajor, CblasNoTrans, CblasTrans, m, n, k, 1.0f, b16, m, a16, n, 0.0f, out, m); | |
| // Match the float32 path: | |
| // xxgemm('t', 'n', ..., b, lda=k, a, lda=k, ..., out, ldc=m) | |
| // The source blocks are row-major, so interpret b16 as a k x m | |
| // column-major matrix and transpose it, and interpret a16 as a | |
| // k x n column-major matrix without transposition. | |
| cblas_gemm_bf16bf16f32(CblasColMajor, CblasTrans, CblasNoTrans, m, n, k, 1.0f, b16, k, a16, k, 0.0f, out, m); |
There was a problem hiding this comment.
This is wrong because xxgemm uses Fortran memory alignment which is column-major, and cblas_gemm_bf16bf16f32 uses C memory alignment which is row-major.
| // BF16 = upper 16 bits of IEEE-754 float32; use memcpy to read bits safely. | ||
| for (size_t i = 0; i < szA; ++i) | ||
| { | ||
| uint32_t bits; | ||
| memcpy(&bits, &a[i], sizeof(bits)); | ||
| a16[i] = static_cast<MKL_BF16>(bits >> 16); | ||
| } | ||
| for (size_t i = 0; i < szB; ++i) | ||
| { | ||
| uint32_t bits; | ||
| memcpy(&bits, &b[i], sizeof(bits)); | ||
| b16[i] = static_cast<MKL_BF16>(bits >> 16); | ||
| } |
There was a problem hiding this comment.
The BF16 conversion truncates float32 to the upper 16 bits, which is not the standard round-to-nearest-even BF16 conversion and can introduce extra bias/error vs real BF16 hardware behavior. Consider using an MKL-provided FP32→BF16 conversion routine (or implement correct rounding) so the AMX BF16 path matches expected BF16 semantics.
| if (g_use_bf16_gemm && m >= kBF16MinDim && n >= kBF16MinDim && k >= kBF16MinDim) | ||
| { | ||
| const size_t szA = static_cast<size_t>(n) * static_cast<size_t>(k); | ||
| const size_t szB = static_cast<size_t>(m) * static_cast<size_t>(k); | ||
| MKL_BF16 * a16 = static_cast<MKL_BF16 *>(mkl_malloc(szA * sizeof(MKL_BF16), 64)); | ||
| MKL_BF16 * b16 = static_cast<MKL_BF16 *>(mkl_malloc(szB * sizeof(MKL_BF16), 64)); | ||
| if (a16 && b16) | ||
| { | ||
| // BF16 = upper 16 bits of IEEE-754 float32; use memcpy to read bits safely. | ||
| for (size_t i = 0; i < szA; ++i) | ||
| { | ||
| uint32_t bits; | ||
| memcpy(&bits, &a[i], sizeof(bits)); | ||
| a16[i] = static_cast<MKL_BF16>(bits >> 16); | ||
| } | ||
| for (size_t i = 0; i < szB; ++i) | ||
| { | ||
| uint32_t bits; | ||
| memcpy(&bits, &b[i], sizeof(bits)); | ||
| b16[i] = static_cast<MKL_BF16>(bits >> 16); | ||
| } | ||
| // col-major: C = B16 * A16^T => cblas NoTrans B, Trans A | ||
| cblas_gemm_bf16bf16f32(CblasColMajor, CblasNoTrans, CblasTrans, m, n, k, 1.0f, b16, m, a16, n, 0.0f, out, m); | ||
| mkl_free(a16); | ||
| mkl_free(b16); | ||
| return; | ||
| } | ||
| if (a16) mkl_free(a16); | ||
| if (b16) mkl_free(b16); | ||
| } |
There was a problem hiding this comment.
This path allocates and frees BF16 buffers (mkl_malloc/mkl_free) on every GEMM call and reconverts both A and B each time. In computeFull(), B is reused across many A blocks, so repeatedly converting B can dominate runtime and reduce the benefit of AMX. Consider caching/reusing converted buffers (e.g., per-thread scratch or converting B once per outer call) to avoid repeated allocation and conversion overhead.
| /// If 'high' is requested but hardware lacks AMX-BF16, BF16GemmDispatcher | ||
| /// will fall through to sgemm transparently via its own g_use_bf16_gemm gate. |
There was a problem hiding this comment.
The setter comment refers to requesting “high”, but the only non-default level in this API is allow_bf16. Please update the comment to match the actual enum values/behavior (and note that dispatch should remain correct when the precision hint is changed at runtime).
| /// If 'high' is requested but hardware lacks AMX-BF16, BF16GemmDispatcher | |
| /// will fall through to sgemm transparently via its own g_use_bf16_gemm gate. | |
| /// If allow_bf16 is requested but hardware lacks AMX-BF16, BF16GemmDispatcher | |
| /// will fall through to sgemm transparently via its own g_use_bf16_gemm gate. | |
| /// Changing this hint at runtime remains correct because dispatch combines the | |
| /// stored level with the cached hardware capability check on each use. |
david-cortes-intel
left a comment
There was a problem hiding this comment.
This would require benchmarks to assess whether it leads to improved runtimes. Less precision could potentially result in requiring more iterations for convergence.
| { | ||
| const size_t szA = static_cast<size_t>(n) * static_cast<size_t>(k); | ||
| const size_t szB = static_cast<size_t>(m) * static_cast<size_t>(k); | ||
| MKL_BF16 * a16 = static_cast<MKL_BF16 *>(mkl_malloc(szA * sizeof(MKL_BF16), 64)); |
There was a problem hiding this comment.
We have smart pointer classes to use. Would prevent potential errors from this sort of memory management.
There was a problem hiding this comment.
Yes, it's better to use TArray or other classes defined here:
https://github.com/uxlfoundation/oneDAL/blob/main/cpp/daal/src/services/service_arrays.h#L109
| { | ||
| uint32_t bits; | ||
| memcpy(&bits, &a[i], sizeof(bits)); | ||
| a16[i] = static_cast<MKL_BF16>(bits >> 16); |
There was a problem hiding this comment.
I believe MKL itself should have utilities for this sort of thing.
There was a problem hiding this comment.
It looks like no. Intrinsics like _mm512_cvtneps_pbh might be faster, but I am Ok to have it like this for now.
There was a problem hiding this comment.
Shouldn't it have #pragma omp simd then?
| static int check_amx_bf16_features() | ||
| { | ||
| /* CPUID.(EAX=07H, ECX=0H):EDX.AMX-BF16[bit 22]==1 */ | ||
| if (!check_cpuid(7, 0, 3, (1 << 22))) |
There was a problem hiding this comment.
Would this segfault or otherwise break if executed on AMD? Is there perhaps some check that tells whether it's an intel CPU?
There was a problem hiding this comment.
It is Ok, as it is checked at line 318 below.
There was a problem hiding this comment.
@Vika-F Is it possible for the call to check_cpuid to segfault?
There was a problem hiding this comment.
Actually I see now that it does check for intel CPU.
| static bool detect_hw_amx_bf16() | ||
| { | ||
| #if defined(TARGET_X86_64) || defined(__x86_64__) || defined(_M_X64) | ||
| // AMX-BF16 bit = (1ULL << 5) in daal::internal::CpuFeature (cpu_type.h). |
There was a problem hiding this comment.
Sounds like it should be moved to that file instead.
CC @Vika-F for comments.
There was a problem hiding this comment.
Yes, it's better to move detect_hw_amx_bf16() and g_hw_amx_bf16 into env_detect_features.cpp.
I am also not sure that the functions in this file really need DAAL_EXPORT. Because they are only used internally in DAAL (i.e. libonedal_core.so).
There was a problem hiding this comment.
Regarding AMX-BF16 bit = (1ULL << 5).
I think it's better not to use magic constants here, and use:
static const DAAL_UINT64 amx_bf16_mask = static_cast<DAAL_UINT64>(daal::internal::amx_bf16);
| for (size_t i = 0; i < szA; ++i) | ||
| { | ||
| uint32_t bits; | ||
| memcpy(&bits, &a[i], sizeof(bits)); |
There was a problem hiding this comment.
@Vika-F Is it a problem to have stdlib calls like this in a header?
There was a problem hiding this comment.
memcpy is deprecated due to security reasons. daal_memcpy_s should be used:
https://github.com/uxlfoundation/oneDAL/blob/main/cpp/daal/include/services/daal_memory.h#L71
| #ifndef __SERVICE_KERNEL_MATH_H__ | ||
| #define __SERVICE_KERNEL_MATH_H__ | ||
|
|
||
| #include <cstring> |
There was a problem hiding this comment.
This is including a C++ header, but then calling the functions from it without std::.
| {} | ||
|
|
||
| ~EuclideanDistances() override {} | ||
| virtual ~EuclideanDistances() override {} |
There was a problem hiding this comment.
Please remove redundant "virtual" here and in other similar places.
AMX BF16 fast path for float32 Euclidean distance GEMM
Summary
Adds an opt-in AMX-BF16 fast path for the float32 Euclidean distance computation
used by KNN (brute-force) and DBSCAN. Internally, the A×B' matrix multiply in
EuclideanDistances::computeABt converts float32 inputs to BF16 and dispatches
to Intel AMX tile instructions. Output remains float32.
Activation (opt-in, default is unchanged strict / full float32):
Or programmatically:
daal_set_float32_matmul_precision(daal::internal::Float32MatmulPrecision::allow_bf16);Design
Float32MatmulPrecision API (internal,
daal::internalnamespace, not installed in public headers):strict(default)allow_bf16Hardware check — once at process init (
g_hw_amx_bf16). No HW queries in getter/setter/hot path.BF16GemmDispatcher<float, cpu> — partial specialization separates the dispatch struct from EuclideanDistances. The TU-init
g_use_bf16_gemmbool combines HW capability and user hint; the hot path branches on a single static bool (branch predictor-friendly).Eligibility gate: both GEMM dimensions (rows of A and rows of B) must be ≥ 64. Below this threshold, BF16 tiling overhead outweighs the benefit and accuracy impact grows.
Benchmark results
Hardware: Intel Xeon 6975P-C (Granite Rapids), KVM, AMX-BF16 enabled.
Method: scikit-learn_bench (sklbench), n_jobs set via
--parameters.DBSCAN euclidean (feat ≥ 64):
KNN classification inference (feat ≥ 128):
No regressions observed: KNN training, KNN regressor (feat=10), DBSCAN (feat=3) all show ~1.00x.
Files changed
cpp/daal/src/services/service_defines.hcpp/daal/src/services/compiler/generic/env_detect_precision.cppcpp/daal/src/algorithms/service_kernel_math.hcpp/daal/src/services/cpu_type.h