perf: AMX BF16 dispatch in EuclideanDistances::computeABt#3564
Draft
napetrov wants to merge 15 commits intouxlfoundation:mainfrom
Draft
perf: AMX BF16 dispatch in EuclideanDistances::computeABt#3564napetrov wants to merge 15 commits intouxlfoundation:mainfrom
napetrov wants to merge 15 commits intouxlfoundation:mainfrom
Conversation
a4d8723 to
5cb316e
Compare
Contributor
Author
|
/intelci: run |
Vika-F
reviewed
Mar 17, 2026
| // compute (A x B') | ||
| #ifndef DAAL_REF | ||
| // AMX-BF16 capability check (MKL builds only; cross-platform) | ||
| static bool knn_has_amx_bf16() |
Contributor
There was a problem hiding this comment.
It would be better to extend the function we already have for CPU features detection, if needed:
https://github.com/uxlfoundation/oneDAL/blob/main/cpp/daal/src/services/compiler/generic/env_detect_features.cpp#L303
Contributor
|
Add AMX BF16 fast path for kNN brute-force and DBSCAN distance computation. When AMX BF16 hardware is available and all GEMM dimensions >= 64, converts float32 operands to BF16 and uses cblas_gemm_bf16bf16f32 instead of sgemm. Benchmarks on Xeon 6975P-C (sklearnex, float32, cpu): - KNeighborsClassifier: 2.7-3.4x speedup - KNeighborsRegressor: 2.9-3.3x speedup - DBSCAN: 3.5-4.0x speedup Accuracy impact vs float32 baseline: - Classifier: delta < 1.2% (within dataset noise) - Regressor R2: delta < 0.001 Runtime detection via CPUID (AMX_BF16 bit) + XGETBV XCR0[18:17]. Falls back to sgemm if AMX unavailable or dims < 64. Signed-off-by: Nikolay Petrov <[email protected]>
- Add daal::CpuFeature::amx_bf16 to the CpuFeature enum in cpu_type.h - Add check_amx_bf16_features() to env_detect_features.cpp using the existing DAAL CPUID/XGETBV infrastructure (CPUID(7,0).EDX[22] + XCR0[17:18]); remove standalone knn_has_amx_bf16() from service_kernel_math.h - Replace std::is_same dispatch with template specialization: add computeABtImpl primary template (sgemm/dgemm fallback) and a float specialization (AMX-BF16 path, guarded by DAAL_REF) - Replace union-based float→BF16 conversion with memcpy to avoid strict-aliasing UB (C++17 3.10/[basic.lval]) Signed-off-by: Nikolay Petrov <[email protected]>
Introduces a precision-hint API modelled after torch.set_float32_matmul_precision
/ jax_default_matmul_precision so users can opt in to reduced-precision
arithmetic where the library determines it is safe.
enum daal::Float32MatmulPrecision { highest (default), high }
daal_get_float32_matmul_precision() -- query effective precision
daal_set_float32_matmul_precision(p) -- set at runtime
Env var: ONEDAL_FLOAT32_MATMUL_PRECISION=HIGH
* Hardware availability (AMX-BF16) is folded into the effective precision
at initialisation time in env_detect_precision.cpp. Call-sites never
query hardware themselves; they only check daal_get_float32_matmul_precision().
* BF16GemmDispatcher<FPType, cpu> encapsulates all dispatch logic:
- primary template: standard sgemm/dgemm, zero overhead for non-float types
- float specialization: AMX-BF16 path when precision=='high' and all dims >=64
- matrix-size threshold (kBF16MinDim=64) owned by the dispatcher, not callers
* g_use_bf16_gemm is a TU-level static bool evaluated once at startup,
so the hot path is a single branch on a constant-like value.
* computeABt() in EuclideanDistances is now a one-liner delegating to
BF16GemmDispatcher; no template ambiguity, no DAAL_REF ifdefs in caller.
* Default is Float32MatmulPrecision::highest (no precision loss without
explicit opt-in); consistent with PyTorch/JAX defaults.
Signed-off-by: Nikolay Petrov <[email protected]>
Previously compute_effective_precision() was called in both the setter
and the getter, folding HW availability into the stored value.
New design:
g_hw_amx_bf16 -- bool, computed ONCE at process init via
daal_serv_cpu_feature_detect(); never recomputed.
daal_has_amx_bf16() -- returns g_hw_amx_bf16 (constant for process lifetime)
g_float32_matmul_precision -- stores the user's requested level as-is;
setter writes directly, no HW re-check.
daal_get_float32_matmul_precision() -- plain atomic load, zero HW queries.
BF16GemmDispatcher gates on both:
g_use_bf16_gemm = daal_has_amx_bf16() && (get_precision() == high)
This is evaluated once at TU init; the hot path sees a single bool.
Setter behaviour: daal_set_float32_matmul_precision(high) on a machine
without AMX-BF16 stores 'high' but g_use_bf16_gemm remains false, so
BF16GemmDispatcher silently falls through to sgemm. No HW check in setter.
Signed-off-by: Nikolay Petrov <[email protected]>
Add comprehensive doc comment for the Float32MatmulPrecision enum: - Explains 'highest' vs 'high' semantics with concrete examples - Documents current eligible operations for 'high' (Euclidean GEMM, AMX-BF16) - Notes fallback behaviour on non-AMX hardware - Documents env var + programmatic API - Reserves 'medium' for future 2xBF16 pass Signed-off-by: Nikolay Petrov <[email protected]>
The API is not yet stable; keep it internal until the design is finalised and reviewed. Changes: - Remove Float32MatmulPrecision from public cpu_type.h - Define daal::internal::Float32MatmulPrecision in service_defines.h (internal header, not installed) with a clear 'INTERNAL, experimental' notice - Update env_detect_precision.cpp and service_kernel_math.h to use daal::internal::Float32MatmulPrecision - daal_has_amx_bf16 / daal_get_float32_matmul_precision / daal_set_float32_matmul_precision remain DAAL_EXPORT so internal oneDAL code and sklearnex can reach them, but they are not part of the public API surface Signed-off-by: Nikolay Petrov <[email protected]>
Correct the ordering: highest > high > medium by decreasing precision, matching torch.set_float32_matmul_precision semantics exactly. Add 'medium' as a reserved enum value (not yet implemented, falls back to 'high') for forward compatibility. Signed-off-by: Nikolay Petrov <[email protected]>
Replace PyTorch-style highest/high/medium naming with explicit names: strict -- always float32, IEEE-754, bit-exact (was: highest) allow_bf16 -- permit BF16 kernels where oneDAL deems safe (was: high) Rationale: 'high'/'highest' are ambiguous (performance vs precision). 'strict' and 'allow_bf16' express intent directly with no ambiguity. Extensible: allow_fp16, allow_tf32 follow the same pattern naturally. Env var updated: ONEDAL_FLOAT32_MATMUL_PRECISION=ALLOW_BF16 Signed-off-by: Nikolay Petrov <[email protected]>
…on.cpp sed substitution previously broke daal::CpuFeature::amx_bf16 → daal::amx_bf16. amx_bf16 is a member of enum daal::CpuFeature, not daal namespace directly. Signed-off-by: Nikolay Petrov <[email protected]>
bef38d1 to
0a6c6c2
Compare
The enum was defined in both cpu_type.h (with highest/high values) and service_defines.h (with strict/allow_bf16 values). When both headers were included in the same translation unit the compiler emitted: error: redefinition of 'Float32MatmulPrecision' Float32MatmulPrecision is an internal API and belongs exclusively in service_defines.h (internal header, not installed). Remove the definition and its associated doc comment from cpu_type.h entirely. Signed-off-by: Nikolay Petrov <[email protected]>
- env_detect_precision.cpp: fix include path from 'services/cpu_type.h' to 'src/services/cpu_type.h' (internal header, not public include path) - Apply clang-format to service_kernel_math.h and env_detect_precision.cpp to pass FormatterChecks CI Signed-off-by: Nikolay Petrov <[email protected]>
…on.cpp CpuFeature is defined in daal::internal namespace (cpu_type.h), not in daal namespace directly. Use daal::internal::CpuFeature::amx_bf16. Signed-off-by: Nikolay Petrov <[email protected]>
AMX-BF16 is x86-only; CpuFeature::amx_bf16 is only defined when TARGET_X86_64 is set. Add #if defined(TARGET_X86_64) guard around the hardware detection, falling back to false on ARM/RISCV64. Signed-off-by: Nikolay Petrov <[email protected]>
…uard CpuFeature::amx_bf16 is defined inside #if defined(TARGET_X86_64) in cpu_type.h. Without that macro set (e.g. non-Intel CI runners, make builds without explicit TARGET define), the enum member does not exist. Replace the raw static initialiser with detect_hw_amx_bf16() wrapper that guards the CpuFeature reference with #if defined(TARGET_X86_64) and returns false on non-x86 targets. Signed-off-by: Nikolay Petrov <[email protected]>
TARGET_X86_64 is a oneDAL build-system macro not set in all build configurations (e.g. plain make without explicit TARGET define). Using CpuFeature::amx_bf16 directly fails when TARGET_X86_64 is not defined, since the enum member is gated behind that macro in cpu_type.h. Two changes: 1. Guard the x86 detection path with the portable compiler-provided __x86_64__ / _M_X64 in addition to TARGET_X86_64. 2. Replace CpuFeature::amx_bf16 with its numeric value (1ULL<<5) to avoid the header-guard dependency entirely. The comment documents the correspondence. Signed-off-by: Nikolay Petrov <[email protected]>
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
AMX BF16 fast path for float32 Euclidean distance GEMM
Summary
Adds an opt-in AMX-BF16 fast path for the float32 Euclidean distance computation
used by KNN (brute-force) and DBSCAN. Internally, the A×B' matrix multiply in
EuclideanDistances::computeABt converts float32 inputs to BF16 and dispatches
to Intel AMX tile instructions. Output remains float32.
Activation (opt-in, default is unchanged strict / full float32):
Or programmatically:
daal_set_float32_matmul_precision(daal::internal::Float32MatmulPrecision::allow_bf16);Design
Float32MatmulPrecision API (internal,
daal::internalnamespace, not installed in public headers):strict(default)allow_bf16Hardware check — once at process init (
g_hw_amx_bf16). No HW queries in getter/setter/hot path.BF16GemmDispatcher<float, cpu> — partial specialization separates the dispatch struct from EuclideanDistances. The TU-init
g_use_bf16_gemmbool combines HW capability and user hint; the hot path branches on a single static bool (branch predictor-friendly).Eligibility gate: both GEMM dimensions (rows of A and rows of B) must be ≥ 64. Below this threshold, BF16 tiling overhead outweighs the benefit and accuracy impact grows.
Benchmark results
Hardware: Intel Xeon 6975P-C (Granite Rapids), KVM, AMX-BF16 enabled.
Method: scikit-learn_bench (sklbench), n_jobs set via
--parameters.DBSCAN euclidean (feat ≥ 64):
KNN classification inference (feat ≥ 128):
No regressions observed: KNN training, KNN regressor (feat=10), DBSCAN (feat=3) all show ~1.00x.
Files changed
cpp/daal/src/services/service_defines.hcpp/daal/src/services/compiler/generic/env_detect_precision.cppcpp/daal/src/algorithms/service_kernel_math.hcpp/daal/src/services/cpu_type.h