Skip to content

perf: AMX BF16 dispatch in EuclideanDistances::computeABt#3564

Open
napetrov wants to merge 15 commits intouxlfoundation:mainfrom
napetrov:feature/amx-bf16-knn-euclidean
Open

perf: AMX BF16 dispatch in EuclideanDistances::computeABt#3564
napetrov wants to merge 15 commits intouxlfoundation:mainfrom
napetrov:feature/amx-bf16-knn-euclidean

Conversation

@napetrov
Copy link
Copy Markdown
Contributor

@napetrov napetrov commented Mar 17, 2026

AMX BF16 fast path for float32 Euclidean distance GEMM

Summary

Adds an opt-in AMX-BF16 fast path for the float32 Euclidean distance computation
used by KNN (brute-force) and DBSCAN. Internally, the A×B' matrix multiply in
EuclideanDistances::computeABt converts float32 inputs to BF16 and dispatches
to Intel AMX tile instructions. Output remains float32.

Activation (opt-in, default is unchanged strict / full float32):

ONEDAL_FLOAT32_MATMUL_PRECISION=ALLOW_BF16

Or programmatically:

daal_set_float32_matmul_precision(daal::internal::Float32MatmulPrecision::allow_bf16);

Design

Float32MatmulPrecision API (internal, daal::internal namespace, not installed in public headers):

Level Behaviour
strict (default) Always float32. IEEE-754, bit-exact.
allow_bf16 Permit BF16 kernels where oneDAL determines accuracy impact is bounded.

Hardware check — once at process init (g_hw_amx_bf16). No HW queries in getter/setter/hot path.

BF16GemmDispatcher<float, cpu> — partial specialization separates the dispatch struct from EuclideanDistances. The TU-init g_use_bf16_gemm bool combines HW capability and user hint; the hot path branches on a single static bool (branch predictor-friendly).

Eligibility gate: both GEMM dimensions (rows of A and rows of B) must be ≥ 64. Below this threshold, BF16 tiling overhead outweighs the benefit and accuracy impact grows.

Benchmark results

Hardware: Intel Xeon 6975P-C (Granite Rapids), KVM, AMX-BF16 enabled.
Method: scikit-learn_bench (sklbench), n_jobs set via --parameters.

DBSCAN euclidean (feat ≥ 64):

config speedup range
8 vCPU, n_jobs 1–8 1.30–1.35x
16 vCPU, best case (f512, n_jobs=8) 1.60x
48 vCPU, n_jobs 1–8 1.21–1.25x

KNN classification inference (feat ≥ 128):

config speedup range
8 vCPU 1.31–1.44x
16 vCPU 1.11–1.22x
48 vCPU 1.08–1.17x

No regressions observed: KNN training, KNN regressor (feat=10), DBSCAN (feat=3) all show ~1.00x.

Files changed

File Change
cpp/daal/src/services/service_defines.h Float32MatmulPrecision enum + API declarations (internal)
cpp/daal/src/services/compiler/generic/env_detect_precision.cpp HW check, env parse, getter/setter impl
cpp/daal/src/algorithms/service_kernel_math.h BF16GemmDispatcher, refactored computeABt
cpp/daal/src/services/cpu_type.h CpuFeature::amx_bf16 flag

@napetrov napetrov force-pushed the feature/amx-bf16-knn-euclidean branch 2 times, most recently from a4d8723 to 5cb316e Compare March 17, 2026 04:24
@napetrov
Copy link
Copy Markdown
Contributor Author

/intelci: run

// compute (A x B')
#ifndef DAAL_REF
// AMX-BF16 capability check (MKL builds only; cross-platform)
static bool knn_has_amx_bf16()
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It would be better to extend the function we already have for CPU features detection, if needed:
https://github.com/uxlfoundation/oneDAL/blob/main/cpp/daal/src/services/compiler/generic/env_detect_features.cpp#L303

@Vika-F
Copy link
Copy Markdown
Contributor

Vika-F commented Mar 17, 2026

  1. This kind of PRs are rather meaningless without performance & accuracy data.
    Need to ask an agent to run sklearn_bench on this.
    1.a) Probably the fp32->bf16 conversion done on every call to distances will "eat" all the potential performance.
  2. The conversion code can be improved with _mm512_cvtneps_pbh, but this won't help much with the performance.
  3. Using of unions in C++ can be harmful due to a strict aliasing rule (https://gist.github.com/shafik/848ae25ee209f698763cffee272a58f8). Need a union-free version of the code.
  4. In DAAL we typically handle the case of specialized implementations with [partial] template specializations and not with the std::is_same. Template specializations allow to reduce the amount of branching and decrease the code complexity.

Add AMX BF16 fast path for kNN brute-force and DBSCAN distance
computation. When AMX BF16 hardware is available and all GEMM
dimensions >= 64, converts float32 operands to BF16 and uses
cblas_gemm_bf16bf16f32 instead of sgemm.

Benchmarks on Xeon 6975P-C (sklearnex, float32, cpu):
- KNeighborsClassifier: 2.7-3.4x speedup
- KNeighborsRegressor:  2.9-3.3x speedup
- DBSCAN:               3.5-4.0x speedup

Accuracy impact vs float32 baseline:
- Classifier: delta < 1.2% (within dataset noise)
- Regressor R2: delta < 0.001

Runtime detection via CPUID (AMX_BF16 bit) + XGETBV XCR0[18:17].
Falls back to sgemm if AMX unavailable or dims < 64.

Signed-off-by: Nikolay Petrov <nikolay.a.petrov@intel.com>
- Add daal::CpuFeature::amx_bf16 to the CpuFeature enum in cpu_type.h
- Add check_amx_bf16_features() to env_detect_features.cpp using the
  existing DAAL CPUID/XGETBV infrastructure (CPUID(7,0).EDX[22] +
  XCR0[17:18]); remove standalone knn_has_amx_bf16() from
  service_kernel_math.h
- Replace std::is_same dispatch with template specialization: add
  computeABtImpl primary template (sgemm/dgemm fallback) and a float
  specialization (AMX-BF16 path, guarded by DAAL_REF)
- Replace union-based float→BF16 conversion with memcpy to avoid
  strict-aliasing UB (C++17 3.10/[basic.lval])

Signed-off-by: Nikolay Petrov <nikolay.a.petrov@intel.com>
Introduces a precision-hint API modelled after torch.set_float32_matmul_precision
/ jax_default_matmul_precision so users can opt in to reduced-precision
arithmetic where the library determines it is safe.

  enum daal::Float32MatmulPrecision { highest (default), high }

  daal_get_float32_matmul_precision()   -- query effective precision
  daal_set_float32_matmul_precision(p)  -- set at runtime

  Env var:  ONEDAL_FLOAT32_MATMUL_PRECISION=HIGH

* Hardware availability (AMX-BF16) is folded into the effective precision
  at initialisation time in env_detect_precision.cpp.  Call-sites never
  query hardware themselves; they only check daal_get_float32_matmul_precision().

* BF16GemmDispatcher<FPType, cpu> encapsulates all dispatch logic:
  - primary template: standard sgemm/dgemm, zero overhead for non-float types
  - float specialization: AMX-BF16 path when precision=='high' and all dims >=64
  - matrix-size threshold (kBF16MinDim=64) owned by the dispatcher, not callers

* g_use_bf16_gemm is a TU-level static bool evaluated once at startup,
  so the hot path is a single branch on a constant-like value.

* computeABt() in EuclideanDistances is now a one-liner delegating to
  BF16GemmDispatcher; no template ambiguity, no DAAL_REF ifdefs in caller.

* Default is Float32MatmulPrecision::highest (no precision loss without
  explicit opt-in); consistent with PyTorch/JAX defaults.

Signed-off-by: Nikolay Petrov <nikolay.a.petrov@intel.com>
Previously compute_effective_precision() was called in both the setter
and the getter, folding HW availability into the stored value.

New design:
  g_hw_amx_bf16   -- bool, computed ONCE at process init via
                     daal_serv_cpu_feature_detect(); never recomputed.
  daal_has_amx_bf16()  -- returns g_hw_amx_bf16 (constant for process lifetime)
  g_float32_matmul_precision -- stores the user's requested level as-is;
                                setter writes directly, no HW re-check.
  daal_get_float32_matmul_precision() -- plain atomic load, zero HW queries.

BF16GemmDispatcher gates on both:
  g_use_bf16_gemm = daal_has_amx_bf16() && (get_precision() == high)
This is evaluated once at TU init; the hot path sees a single bool.

Setter behaviour: daal_set_float32_matmul_precision(high) on a machine
without AMX-BF16 stores 'high' but g_use_bf16_gemm remains false, so
BF16GemmDispatcher silently falls through to sgemm. No HW check in setter.

Signed-off-by: Nikolay Petrov <nikolay.a.petrov@intel.com>
Add comprehensive doc comment for the Float32MatmulPrecision enum:
- Explains 'highest' vs 'high' semantics with concrete examples
- Documents current eligible operations for 'high' (Euclidean GEMM, AMX-BF16)
- Notes fallback behaviour on non-AMX hardware
- Documents env var + programmatic API
- Reserves 'medium' for future 2xBF16 pass

Signed-off-by: Nikolay Petrov <nikolay.a.petrov@intel.com>
The API is not yet stable; keep it internal until the design is finalised
and reviewed. Changes:

- Remove Float32MatmulPrecision from public cpu_type.h
- Define daal::internal::Float32MatmulPrecision in service_defines.h
  (internal header, not installed) with a clear 'INTERNAL, experimental'
  notice
- Update env_detect_precision.cpp and service_kernel_math.h to use
  daal::internal::Float32MatmulPrecision
- daal_has_amx_bf16 / daal_get_float32_matmul_precision /
  daal_set_float32_matmul_precision remain DAAL_EXPORT so internal
  oneDAL code and sklearnex can reach them, but they are not part of
  the public API surface

Signed-off-by: Nikolay Petrov <nikolay.a.petrov@intel.com>
Correct the ordering: highest > high > medium by decreasing precision,
matching torch.set_float32_matmul_precision semantics exactly.
Add 'medium' as a reserved enum value (not yet implemented, falls back
to 'high') for forward compatibility.

Signed-off-by: Nikolay Petrov <nikolay.a.petrov@intel.com>
Replace PyTorch-style highest/high/medium naming with explicit names:
  strict     -- always float32, IEEE-754, bit-exact (was: highest)
  allow_bf16 -- permit BF16 kernels where oneDAL deems safe (was: high)

Rationale: 'high'/'highest' are ambiguous (performance vs precision).
'strict' and 'allow_bf16' express intent directly with no ambiguity.
Extensible: allow_fp16, allow_tf32 follow the same pattern naturally.

Env var updated: ONEDAL_FLOAT32_MATMUL_PRECISION=ALLOW_BF16

Signed-off-by: Nikolay Petrov <nikolay.a.petrov@intel.com>
…on.cpp

sed substitution previously broke daal::CpuFeature::amx_bf16 → daal::amx_bf16.
amx_bf16 is a member of enum daal::CpuFeature, not daal namespace directly.

Signed-off-by: Nikolay Petrov <nikolay.a.petrov@intel.com>
@napetrov napetrov force-pushed the feature/amx-bf16-knn-euclidean branch from bef38d1 to 0a6c6c2 Compare March 28, 2026 16:20
napetrov added 6 commits April 1, 2026 09:38
The enum was defined in both cpu_type.h (with highest/high values) and
service_defines.h (with strict/allow_bf16 values). When both headers were
included in the same translation unit the compiler emitted:

  error: redefinition of 'Float32MatmulPrecision'

Float32MatmulPrecision is an internal API and belongs exclusively in
service_defines.h (internal header, not installed). Remove the definition
and its associated doc comment from cpu_type.h entirely.

Signed-off-by: Nikolay Petrov <nikolay.a.petrov@intel.com>
- env_detect_precision.cpp: fix include path from 'services/cpu_type.h'
  to 'src/services/cpu_type.h' (internal header, not public include path)
- Apply clang-format to service_kernel_math.h and env_detect_precision.cpp
  to pass FormatterChecks CI

Signed-off-by: Nikolay Petrov <nikolay.a.petrov@intel.com>
…on.cpp

CpuFeature is defined in daal::internal namespace (cpu_type.h), not
in daal namespace directly. Use daal::internal::CpuFeature::amx_bf16.

Signed-off-by: Nikolay Petrov <nikolay.a.petrov@intel.com>
AMX-BF16 is x86-only; CpuFeature::amx_bf16 is only defined when
TARGET_X86_64 is set. Add #if defined(TARGET_X86_64) guard around
the hardware detection, falling back to false on ARM/RISCV64.

Signed-off-by: Nikolay Petrov <nikolay.a.petrov@intel.com>
…uard

CpuFeature::amx_bf16 is defined inside #if defined(TARGET_X86_64) in
cpu_type.h. Without that macro set (e.g. non-Intel CI runners, make
builds without explicit TARGET define), the enum member does not exist.

Replace the raw static initialiser with detect_hw_amx_bf16() wrapper
that guards the CpuFeature reference with #if defined(TARGET_X86_64)
and returns false on non-x86 targets.

Signed-off-by: Nikolay Petrov <nikolay.a.petrov@intel.com>
TARGET_X86_64 is a oneDAL build-system macro not set in all build
configurations (e.g. plain make without explicit TARGET define).
Using CpuFeature::amx_bf16 directly fails when TARGET_X86_64 is not
defined, since the enum member is gated behind that macro in cpu_type.h.

Two changes:
1. Guard the x86 detection path with the portable compiler-provided
   __x86_64__ / _M_X64 in addition to TARGET_X86_64.
2. Replace CpuFeature::amx_bf16 with its numeric value (1ULL<<5) to
   avoid the header-guard dependency entirely. The comment documents
   the correspondence.

Signed-off-by: Nikolay Petrov <nikolay.a.petrov@intel.com>
@napetrov napetrov requested a review from a team April 2, 2026 02:33
@napetrov
Copy link
Copy Markdown
Contributor Author

napetrov commented Apr 8, 2026

/intelci: run

@napetrov napetrov requested a review from Vika-F April 8, 2026 20:39
@napetrov napetrov marked this pull request as ready for review April 8, 2026 21:26
Copilot AI review requested due to automatic review settings April 8, 2026 21:26
Copy link
Copy Markdown
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

Adds an internal, opt-in float32 matmul precision hint and uses it to dispatch Euclidean distance GEMM to an AMX-BF16 fast path (BF16 operands, float32 output) when eligible.

Changes:

  • Introduces daal::internal::Float32MatmulPrecision and exported getter/setter + AMX-BF16 capability query.
  • Extends CPU feature detection with an amx_bf16 capability bit (CPUID + XCR0 OS-tile-state checks).
  • Refactors EuclideanDistances::computeABt() to route GEMM via a new BF16GemmDispatcher that can convert FP32→BF16 and call BF16 GEMM.

Reviewed changes

Copilot reviewed 5 out of 5 changed files in this pull request and generated 8 comments.

Show a summary per file
File Description
cpp/daal/src/services/service_defines.h Adds internal precision-hint enum and exported getter/setter + capability API.
cpp/daal/src/services/cpu_type.h Adds CpuFeature::amx_bf16 flag bit.
cpp/daal/src/services/compiler/generic/env_detect_precision.cpp Implements env parsing + stored precision hint + AMX-BF16 capability query.
cpp/daal/src/services/compiler/generic/env_detect_features.cpp Detects AMX-BF16 feature (CPUID + XCR0 tile-state enabled).
cpp/daal/src/algorithms/service_kernel_math.h Adds BF16 GEMM dispatcher and rewires Euclidean distance GEMM through it.

Comment on lines +51 to +55
* Allow oneDAL to use BF16 kernels in operations where the library has
* determined the accuracy impact is bounded and acceptable for the algorithm.
* Currently: float32 Euclidean GEMM on AMX-BF16 hardware, dims >= 64.
* Output dtype remains float32; only internal GEMM accumulation uses BF16.
* Falls back to sgemm silently on hardware without AMX-BF16.
Copy link

Copilot AI Apr 8, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The comment claims that only GEMM accumulation uses BF16, but the implementation uses BF16 inputs with float32 accumulation/output (e.g., cblas_gemm_bf16bf16f32 accumulates in FP32). Please update the documentation to reflect that BF16 is used for operand storage/multiply while accumulation/output remain FP32.

Copilot uses AI. Check for mistakes.
#include "src/externals/service_math.h"
#include "src/services/service_profiler.h"

#include <mutex>
Copy link

Copilot AI Apr 8, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

#include <mutex> appears to be unused in this header (no mutex symbols referenced). Please remove it to avoid unnecessary compile-time overhead in a widely included header.

Suggested change
#include <mutex>

Copilot uses AI. Check for mistakes.
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Agree on this.

Comment on lines +133 to +136
// Design principles:
// - Hardware availability is already folded into daal_get_float32_matmul_precision()
// at initialisation time (see env_detect_precision.cpp). Call-sites only
// query the effective precision; no hardware checks here.
Copy link

Copilot AI Apr 8, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The dispatcher documentation says hardware availability is “folded into daal_get_float32_matmul_precision()”, but the getter returns only the user hint; hardware capability is checked separately via daal_has_amx_bf16(). Please correct the comment to match the actual gating logic to avoid misleading future changes.

Copilot uses AI. Check for mistakes.
Comment on lines +149 to +154
/// Whether AMX-BF16 GEMM is usable in this process.
/// Combines hardware capability (checked once at init) with the user's precision hint.
/// This is the only place both conditions are AND-ed; call-sites check nothing else.
static const bool g_use_bf16_gemm =
daal_has_amx_bf16() && (daal_get_float32_matmul_precision() == daal::internal::Float32MatmulPrecision::allow_bf16);

Copy link

Copilot AI Apr 8, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

g_use_bf16_gemm is a TU-level static const bool computed during static initialization, so calling daal_set_float32_matmul_precision(...) at runtime will not affect dispatch in this TU (programmatic opt-in/out won’t work reliably). Consider making the cached flag an atomic/volatile that is updated by the setter, or check the atomic precision hint at call time (still allowing a fast path when it’s strict).

Copilot uses AI. Check for mistakes.
Comment on lines +197 to +198
// col-major: C = B16 * A16^T => cblas NoTrans B, Trans A
cblas_gemm_bf16bf16f32(CblasColMajor, CblasNoTrans, CblasTrans, m, n, k, 1.0f, b16, m, a16, n, 0.0f, out, m);
Copy link

Copilot AI Apr 8, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The BF16 GEMM call uses CblasColMajor with b16 treated as an m x k column-major matrix (lda=m) and a16 as n x k (lda=n), but the source data blocks are row-major and the existing float32 path relies on interpreting them as transposed column-major (lda=k with trans='t'). As written, the BF16 path is likely multiplying incorrectly and will produce wrong results. Please align the transposition/leading-dimension logic with the existing xxgemm call so the BF16 and float32 paths are mathematically equivalent.

Suggested change
// col-major: C = B16 * A16^T => cblas NoTrans B, Trans A
cblas_gemm_bf16bf16f32(CblasColMajor, CblasNoTrans, CblasTrans, m, n, k, 1.0f, b16, m, a16, n, 0.0f, out, m);
// Match the float32 path:
// xxgemm('t', 'n', ..., b, lda=k, a, lda=k, ..., out, ldc=m)
// The source blocks are row-major, so interpret b16 as a k x m
// column-major matrix and transpose it, and interpret a16 as a
// k x n column-major matrix without transposition.
cblas_gemm_bf16bf16f32(CblasColMajor, CblasTrans, CblasNoTrans, m, n, k, 1.0f, b16, k, a16, k, 0.0f, out, m);

Copilot uses AI. Check for mistakes.
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is wrong because xxgemm uses Fortran memory alignment which is column-major, and cblas_gemm_bf16bf16f32 uses C memory alignment which is row-major.

Comment on lines +184 to +196
// BF16 = upper 16 bits of IEEE-754 float32; use memcpy to read bits safely.
for (size_t i = 0; i < szA; ++i)
{
uint32_t bits;
memcpy(&bits, &a[i], sizeof(bits));
a16[i] = static_cast<MKL_BF16>(bits >> 16);
}
for (size_t i = 0; i < szB; ++i)
{
uint32_t bits;
memcpy(&bits, &b[i], sizeof(bits));
b16[i] = static_cast<MKL_BF16>(bits >> 16);
}
Copy link

Copilot AI Apr 8, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The BF16 conversion truncates float32 to the upper 16 bits, which is not the standard round-to-nearest-even BF16 conversion and can introduce extra bias/error vs real BF16 hardware behavior. Consider using an MKL-provided FP32→BF16 conversion routine (or implement correct rounding) so the AMX BF16 path matches expected BF16 semantics.

Copilot uses AI. Check for mistakes.
Comment on lines +176 to +205
if (g_use_bf16_gemm && m >= kBF16MinDim && n >= kBF16MinDim && k >= kBF16MinDim)
{
const size_t szA = static_cast<size_t>(n) * static_cast<size_t>(k);
const size_t szB = static_cast<size_t>(m) * static_cast<size_t>(k);
MKL_BF16 * a16 = static_cast<MKL_BF16 *>(mkl_malloc(szA * sizeof(MKL_BF16), 64));
MKL_BF16 * b16 = static_cast<MKL_BF16 *>(mkl_malloc(szB * sizeof(MKL_BF16), 64));
if (a16 && b16)
{
// BF16 = upper 16 bits of IEEE-754 float32; use memcpy to read bits safely.
for (size_t i = 0; i < szA; ++i)
{
uint32_t bits;
memcpy(&bits, &a[i], sizeof(bits));
a16[i] = static_cast<MKL_BF16>(bits >> 16);
}
for (size_t i = 0; i < szB; ++i)
{
uint32_t bits;
memcpy(&bits, &b[i], sizeof(bits));
b16[i] = static_cast<MKL_BF16>(bits >> 16);
}
// col-major: C = B16 * A16^T => cblas NoTrans B, Trans A
cblas_gemm_bf16bf16f32(CblasColMajor, CblasNoTrans, CblasTrans, m, n, k, 1.0f, b16, m, a16, n, 0.0f, out, m);
mkl_free(a16);
mkl_free(b16);
return;
}
if (a16) mkl_free(a16);
if (b16) mkl_free(b16);
}
Copy link

Copilot AI Apr 8, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This path allocates and frees BF16 buffers (mkl_malloc/mkl_free) on every GEMM call and reconverts both A and B each time. In computeFull(), B is reused across many A blocks, so repeatedly converting B can dominate runtime and reduce the benefit of AMX. Consider caching/reusing converted buffers (e.g., per-thread scratch or converting B once per outer call) to avoid repeated allocation and conversion overhead.

Copilot uses AI. Check for mistakes.
Comment on lines +103 to +104
/// If 'high' is requested but hardware lacks AMX-BF16, BF16GemmDispatcher
/// will fall through to sgemm transparently via its own g_use_bf16_gemm gate.
Copy link

Copilot AI Apr 8, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The setter comment refers to requesting “high”, but the only non-default level in this API is allow_bf16. Please update the comment to match the actual enum values/behavior (and note that dispatch should remain correct when the precision hint is changed at runtime).

Suggested change
/// If 'high' is requested but hardware lacks AMX-BF16, BF16GemmDispatcher
/// will fall through to sgemm transparently via its own g_use_bf16_gemm gate.
/// If allow_bf16 is requested but hardware lacks AMX-BF16, BF16GemmDispatcher
/// will fall through to sgemm transparently via its own g_use_bf16_gemm gate.
/// Changing this hint at runtime remains correct because dispatch combines the
/// stored level with the cached hardware capability check on each use.

Copilot uses AI. Check for mistakes.
Copy link
Copy Markdown
Contributor

@david-cortes-intel david-cortes-intel left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This would require benchmarks to assess whether it leads to improved runtimes. Less precision could potentially result in requiring more iterations for convergence.

{
const size_t szA = static_cast<size_t>(n) * static_cast<size_t>(k);
const size_t szB = static_cast<size_t>(m) * static_cast<size_t>(k);
MKL_BF16 * a16 = static_cast<MKL_BF16 *>(mkl_malloc(szA * sizeof(MKL_BF16), 64));
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We have smart pointer classes to use. Would prevent potential errors from this sort of memory management.

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, it's better to use TArray or other classes defined here:
https://github.com/uxlfoundation/oneDAL/blob/main/cpp/daal/src/services/service_arrays.h#L109

{
uint32_t bits;
memcpy(&bits, &a[i], sizeof(bits));
a16[i] = static_cast<MKL_BF16>(bits >> 16);
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I believe MKL itself should have utilities for this sort of thing.

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It looks like no. Intrinsics like _mm512_cvtneps_pbh might be faster, but I am Ok to have it like this for now.

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Shouldn't it have #pragma omp simd then?

static int check_amx_bf16_features()
{
/* CPUID.(EAX=07H, ECX=0H):EDX.AMX-BF16[bit 22]==1 */
if (!check_cpuid(7, 0, 3, (1 << 22)))
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would this segfault or otherwise break if executed on AMD? Is there perhaps some check that tells whether it's an intel CPU?

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It is Ok, as it is checked at line 318 below.

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@Vika-F Is it possible for the call to check_cpuid to segfault?

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actually I see now that it does check for intel CPU.

static bool detect_hw_amx_bf16()
{
#if defined(TARGET_X86_64) || defined(__x86_64__) || defined(_M_X64)
// AMX-BF16 bit = (1ULL << 5) in daal::internal::CpuFeature (cpu_type.h).
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sounds like it should be moved to that file instead.

CC @Vika-F for comments.

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, it's better to move detect_hw_amx_bf16() and g_hw_amx_bf16 into env_detect_features.cpp.

I am also not sure that the functions in this file really need DAAL_EXPORT. Because they are only used internally in DAAL (i.e. libonedal_core.so).

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Regarding AMX-BF16 bit = (1ULL << 5).
I think it's better not to use magic constants here, and use:
static const DAAL_UINT64 amx_bf16_mask = static_cast<DAAL_UINT64>(daal::internal::amx_bf16);

for (size_t i = 0; i < szA; ++i)
{
uint32_t bits;
memcpy(&bits, &a[i], sizeof(bits));
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@Vika-F Is it a problem to have stdlib calls like this in a header?

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

memcpy is deprecated due to security reasons. daal_memcpy_s should be used:
https://github.com/uxlfoundation/oneDAL/blob/main/cpp/daal/include/services/daal_memory.h#L71

#ifndef __SERVICE_KERNEL_MATH_H__
#define __SERVICE_KERNEL_MATH_H__

#include <cstring>
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is including a C++ header, but then calling the functions from it without std::.

{}

~EuclideanDistances() override {}
virtual ~EuclideanDistances() override {}
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please remove redundant "virtual" here and in other similar places.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants