Skip to content

perf: AMX BF16 dispatch in EuclideanDistances::computeABt#3564

Draft
napetrov wants to merge 15 commits intouxlfoundation:mainfrom
napetrov:feature/amx-bf16-knn-euclidean
Draft

perf: AMX BF16 dispatch in EuclideanDistances::computeABt#3564
napetrov wants to merge 15 commits intouxlfoundation:mainfrom
napetrov:feature/amx-bf16-knn-euclidean

Conversation

@napetrov
Copy link
Copy Markdown
Contributor

@napetrov napetrov commented Mar 17, 2026

AMX BF16 fast path for float32 Euclidean distance GEMM

Summary

Adds an opt-in AMX-BF16 fast path for the float32 Euclidean distance computation
used by KNN (brute-force) and DBSCAN. Internally, the A×B' matrix multiply in
EuclideanDistances::computeABt converts float32 inputs to BF16 and dispatches
to Intel AMX tile instructions. Output remains float32.

Activation (opt-in, default is unchanged strict / full float32):

ONEDAL_FLOAT32_MATMUL_PRECISION=ALLOW_BF16

Or programmatically:

daal_set_float32_matmul_precision(daal::internal::Float32MatmulPrecision::allow_bf16);

Design

Float32MatmulPrecision API (internal, daal::internal namespace, not installed in public headers):

Level Behaviour
strict (default) Always float32. IEEE-754, bit-exact.
allow_bf16 Permit BF16 kernels where oneDAL determines accuracy impact is bounded.

Hardware check — once at process init (g_hw_amx_bf16). No HW queries in getter/setter/hot path.

BF16GemmDispatcher<float, cpu> — partial specialization separates the dispatch struct from EuclideanDistances. The TU-init g_use_bf16_gemm bool combines HW capability and user hint; the hot path branches on a single static bool (branch predictor-friendly).

Eligibility gate: both GEMM dimensions (rows of A and rows of B) must be ≥ 64. Below this threshold, BF16 tiling overhead outweighs the benefit and accuracy impact grows.

Benchmark results

Hardware: Intel Xeon 6975P-C (Granite Rapids), KVM, AMX-BF16 enabled.
Method: scikit-learn_bench (sklbench), n_jobs set via --parameters.

DBSCAN euclidean (feat ≥ 64):

config speedup range
8 vCPU, n_jobs 1–8 1.30–1.35x
16 vCPU, best case (f512, n_jobs=8) 1.60x
48 vCPU, n_jobs 1–8 1.21–1.25x

KNN classification inference (feat ≥ 128):

config speedup range
8 vCPU 1.31–1.44x
16 vCPU 1.11–1.22x
48 vCPU 1.08–1.17x

No regressions observed: KNN training, KNN regressor (feat=10), DBSCAN (feat=3) all show ~1.00x.

Files changed

File Change
cpp/daal/src/services/service_defines.h Float32MatmulPrecision enum + API declarations (internal)
cpp/daal/src/services/compiler/generic/env_detect_precision.cpp HW check, env parse, getter/setter impl
cpp/daal/src/algorithms/service_kernel_math.h BF16GemmDispatcher, refactored computeABt
cpp/daal/src/services/cpu_type.h CpuFeature::amx_bf16 flag

@napetrov napetrov force-pushed the feature/amx-bf16-knn-euclidean branch 2 times, most recently from a4d8723 to 5cb316e Compare March 17, 2026 04:24
@napetrov
Copy link
Copy Markdown
Contributor Author

/intelci: run

// compute (A x B')
#ifndef DAAL_REF
// AMX-BF16 capability check (MKL builds only; cross-platform)
static bool knn_has_amx_bf16()
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It would be better to extend the function we already have for CPU features detection, if needed:
https://github.com/uxlfoundation/oneDAL/blob/main/cpp/daal/src/services/compiler/generic/env_detect_features.cpp#L303

@Vika-F
Copy link
Copy Markdown
Contributor

Vika-F commented Mar 17, 2026

  1. This kind of PRs are rather meaningless without performance & accuracy data.
    Need to ask an agent to run sklearn_bench on this.
    1.a) Probably the fp32->bf16 conversion done on every call to distances will "eat" all the potential performance.
  2. The conversion code can be improved with _mm512_cvtneps_pbh, but this won't help much with the performance.
  3. Using of unions in C++ can be harmful due to a strict aliasing rule (https://gist.github.com/shafik/848ae25ee209f698763cffee272a58f8). Need a union-free version of the code.
  4. In DAAL we typically handle the case of specialized implementations with [partial] template specializations and not with the std::is_same. Template specializations allow to reduce the amount of branching and decrease the code complexity.

Add AMX BF16 fast path for kNN brute-force and DBSCAN distance
computation. When AMX BF16 hardware is available and all GEMM
dimensions >= 64, converts float32 operands to BF16 and uses
cblas_gemm_bf16bf16f32 instead of sgemm.

Benchmarks on Xeon 6975P-C (sklearnex, float32, cpu):
- KNeighborsClassifier: 2.7-3.4x speedup
- KNeighborsRegressor:  2.9-3.3x speedup
- DBSCAN:               3.5-4.0x speedup

Accuracy impact vs float32 baseline:
- Classifier: delta < 1.2% (within dataset noise)
- Regressor R2: delta < 0.001

Runtime detection via CPUID (AMX_BF16 bit) + XGETBV XCR0[18:17].
Falls back to sgemm if AMX unavailable or dims < 64.

Signed-off-by: Nikolay Petrov <[email protected]>
- Add daal::CpuFeature::amx_bf16 to the CpuFeature enum in cpu_type.h
- Add check_amx_bf16_features() to env_detect_features.cpp using the
  existing DAAL CPUID/XGETBV infrastructure (CPUID(7,0).EDX[22] +
  XCR0[17:18]); remove standalone knn_has_amx_bf16() from
  service_kernel_math.h
- Replace std::is_same dispatch with template specialization: add
  computeABtImpl primary template (sgemm/dgemm fallback) and a float
  specialization (AMX-BF16 path, guarded by DAAL_REF)
- Replace union-based float→BF16 conversion with memcpy to avoid
  strict-aliasing UB (C++17 3.10/[basic.lval])

Signed-off-by: Nikolay Petrov <[email protected]>
Introduces a precision-hint API modelled after torch.set_float32_matmul_precision
/ jax_default_matmul_precision so users can opt in to reduced-precision
arithmetic where the library determines it is safe.

  enum daal::Float32MatmulPrecision { highest (default), high }

  daal_get_float32_matmul_precision()   -- query effective precision
  daal_set_float32_matmul_precision(p)  -- set at runtime

  Env var:  ONEDAL_FLOAT32_MATMUL_PRECISION=HIGH

* Hardware availability (AMX-BF16) is folded into the effective precision
  at initialisation time in env_detect_precision.cpp.  Call-sites never
  query hardware themselves; they only check daal_get_float32_matmul_precision().

* BF16GemmDispatcher<FPType, cpu> encapsulates all dispatch logic:
  - primary template: standard sgemm/dgemm, zero overhead for non-float types
  - float specialization: AMX-BF16 path when precision=='high' and all dims >=64
  - matrix-size threshold (kBF16MinDim=64) owned by the dispatcher, not callers

* g_use_bf16_gemm is a TU-level static bool evaluated once at startup,
  so the hot path is a single branch on a constant-like value.

* computeABt() in EuclideanDistances is now a one-liner delegating to
  BF16GemmDispatcher; no template ambiguity, no DAAL_REF ifdefs in caller.

* Default is Float32MatmulPrecision::highest (no precision loss without
  explicit opt-in); consistent with PyTorch/JAX defaults.

Signed-off-by: Nikolay Petrov <[email protected]>
Previously compute_effective_precision() was called in both the setter
and the getter, folding HW availability into the stored value.

New design:
  g_hw_amx_bf16   -- bool, computed ONCE at process init via
                     daal_serv_cpu_feature_detect(); never recomputed.
  daal_has_amx_bf16()  -- returns g_hw_amx_bf16 (constant for process lifetime)
  g_float32_matmul_precision -- stores the user's requested level as-is;
                                setter writes directly, no HW re-check.
  daal_get_float32_matmul_precision() -- plain atomic load, zero HW queries.

BF16GemmDispatcher gates on both:
  g_use_bf16_gemm = daal_has_amx_bf16() && (get_precision() == high)
This is evaluated once at TU init; the hot path sees a single bool.

Setter behaviour: daal_set_float32_matmul_precision(high) on a machine
without AMX-BF16 stores 'high' but g_use_bf16_gemm remains false, so
BF16GemmDispatcher silently falls through to sgemm. No HW check in setter.

Signed-off-by: Nikolay Petrov <[email protected]>
Add comprehensive doc comment for the Float32MatmulPrecision enum:
- Explains 'highest' vs 'high' semantics with concrete examples
- Documents current eligible operations for 'high' (Euclidean GEMM, AMX-BF16)
- Notes fallback behaviour on non-AMX hardware
- Documents env var + programmatic API
- Reserves 'medium' for future 2xBF16 pass

Signed-off-by: Nikolay Petrov <[email protected]>
The API is not yet stable; keep it internal until the design is finalised
and reviewed. Changes:

- Remove Float32MatmulPrecision from public cpu_type.h
- Define daal::internal::Float32MatmulPrecision in service_defines.h
  (internal header, not installed) with a clear 'INTERNAL, experimental'
  notice
- Update env_detect_precision.cpp and service_kernel_math.h to use
  daal::internal::Float32MatmulPrecision
- daal_has_amx_bf16 / daal_get_float32_matmul_precision /
  daal_set_float32_matmul_precision remain DAAL_EXPORT so internal
  oneDAL code and sklearnex can reach them, but they are not part of
  the public API surface

Signed-off-by: Nikolay Petrov <[email protected]>
Correct the ordering: highest > high > medium by decreasing precision,
matching torch.set_float32_matmul_precision semantics exactly.
Add 'medium' as a reserved enum value (not yet implemented, falls back
to 'high') for forward compatibility.

Signed-off-by: Nikolay Petrov <[email protected]>
Replace PyTorch-style highest/high/medium naming with explicit names:
  strict     -- always float32, IEEE-754, bit-exact (was: highest)
  allow_bf16 -- permit BF16 kernels where oneDAL deems safe (was: high)

Rationale: 'high'/'highest' are ambiguous (performance vs precision).
'strict' and 'allow_bf16' express intent directly with no ambiguity.
Extensible: allow_fp16, allow_tf32 follow the same pattern naturally.

Env var updated: ONEDAL_FLOAT32_MATMUL_PRECISION=ALLOW_BF16

Signed-off-by: Nikolay Petrov <[email protected]>
…on.cpp

sed substitution previously broke daal::CpuFeature::amx_bf16 → daal::amx_bf16.
amx_bf16 is a member of enum daal::CpuFeature, not daal namespace directly.

Signed-off-by: Nikolay Petrov <[email protected]>
@napetrov napetrov force-pushed the feature/amx-bf16-knn-euclidean branch from bef38d1 to 0a6c6c2 Compare March 28, 2026 16:20
napetrov added 6 commits April 1, 2026 09:38
The enum was defined in both cpu_type.h (with highest/high values) and
service_defines.h (with strict/allow_bf16 values). When both headers were
included in the same translation unit the compiler emitted:

  error: redefinition of 'Float32MatmulPrecision'

Float32MatmulPrecision is an internal API and belongs exclusively in
service_defines.h (internal header, not installed). Remove the definition
and its associated doc comment from cpu_type.h entirely.

Signed-off-by: Nikolay Petrov <[email protected]>
- env_detect_precision.cpp: fix include path from 'services/cpu_type.h'
  to 'src/services/cpu_type.h' (internal header, not public include path)
- Apply clang-format to service_kernel_math.h and env_detect_precision.cpp
  to pass FormatterChecks CI

Signed-off-by: Nikolay Petrov <[email protected]>
…on.cpp

CpuFeature is defined in daal::internal namespace (cpu_type.h), not
in daal namespace directly. Use daal::internal::CpuFeature::amx_bf16.

Signed-off-by: Nikolay Petrov <[email protected]>
AMX-BF16 is x86-only; CpuFeature::amx_bf16 is only defined when
TARGET_X86_64 is set. Add #if defined(TARGET_X86_64) guard around
the hardware detection, falling back to false on ARM/RISCV64.

Signed-off-by: Nikolay Petrov <[email protected]>
…uard

CpuFeature::amx_bf16 is defined inside #if defined(TARGET_X86_64) in
cpu_type.h. Without that macro set (e.g. non-Intel CI runners, make
builds without explicit TARGET define), the enum member does not exist.

Replace the raw static initialiser with detect_hw_amx_bf16() wrapper
that guards the CpuFeature reference with #if defined(TARGET_X86_64)
and returns false on non-x86 targets.

Signed-off-by: Nikolay Petrov <[email protected]>
TARGET_X86_64 is a oneDAL build-system macro not set in all build
configurations (e.g. plain make without explicit TARGET define).
Using CpuFeature::amx_bf16 directly fails when TARGET_X86_64 is not
defined, since the enum member is gated behind that macro in cpu_type.h.

Two changes:
1. Guard the x86 detection path with the portable compiler-provided
   __x86_64__ / _M_X64 in addition to TARGET_X86_64.
2. Replace CpuFeature::amx_bf16 with its numeric value (1ULL<<5) to
   avoid the header-guard dependency entirely. The comment documents
   the correspondence.

Signed-off-by: Nikolay Petrov <[email protected]>
@napetrov napetrov requested a review from a team April 2, 2026 02:33
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants