-
Notifications
You must be signed in to change notification settings - Fork 224
Expand file tree
/
Copy pathenv_detect_precision.cpp
More file actions
108 lines (97 loc) · 4.4 KB
/
env_detect_precision.cpp
File metadata and controls
108 lines (97 loc) · 4.4 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
/* file: env_detect_precision.cpp */
/*******************************************************************************
* Copyright contributors to the oneDAL project
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*******************************************************************************/
/*
//++
// Float32 matmul precision control.
//
// Inspired by torch.set_float32_matmul_precision / jax_default_matmul_precision.
//
// Precision levels:
// strict (default) -- always compute in float32. Full IEEE-754, bit-exact.
// allow_bf16 -- allow BF16 kernels where oneDAL determines the accuracy
// impact is acceptable. The library, not the user, decides
// which operations are eligible. Currently: float32
// Euclidean GEMM, dims >= 64, hardware with AMX-BF16.
//
// Hardware availability is checked ONCE at process initialisation (g_hw_bf16).
// The setter stores the user's requested level as-is; BF16GemmDispatcher
// gates on both g_hw_bf16 and the stored precision level.
// This way the getter is a plain atomic load with no HW queries.
//--
*/
#include "src/services/cpu_type.h"
#include "src/services/service_defines.h"
#include <atomic>
#include <cstdlib> /* std::getenv */
#include <cstring> /* std::strcmp */
namespace
{
/// Hardware capability flag — evaluated ONCE at process start.
/// True iff AMX-BF16 is present and OS-enabled.
/// AMX-BF16 is x86-only; always false on ARM/RISCV64.
/// daal_serv_cpu_feature_detect() returns a DAAL_UINT64 bitmask;
/// amx_bf16 corresponds to bit 5 (see daal::internal::CpuFeature::amx_bf16).
static bool detect_hw_amx_bf16()
{
#if defined(TARGET_X86_64) || defined(__x86_64__) || defined(_M_X64)
// AMX-BF16 bit = (1ULL << 5) in daal::internal::CpuFeature (cpu_type.h).
// Use the numeric value directly to avoid dependency on TARGET_X86_64 guard
// in cpu_type.h when the macro is not set by the build system.
static const DAAL_UINT64 amx_bf16_mask = (1ULL << 5);
return (daal_serv_cpu_feature_detect() & amx_bf16_mask) != 0;
#else
return false;
#endif
}
static const bool g_hw_amx_bf16 = detect_hw_amx_bf16();
/// Parse ONEDAL_FLOAT32_MATMUL_PRECISION env var.
/// Recognised values: "ALLOW_BF16" / "allow_bf16". Everything else → strict.
static daal::internal::Float32MatmulPrecision parse_env_precision()
{
const char * val = std::getenv("ONEDAL_FLOAT32_MATMUL_PRECISION");
if (val && (std::strcmp(val, "ALLOW_BF16") == 0 || std::strcmp(val, "allow_bf16") == 0))
{
return daal::internal::Float32MatmulPrecision::allow_bf16;
}
return daal::internal::Float32MatmulPrecision::strict;
}
/// Stored precision level (what the user requested).
/// Default: read from env at startup; can be overridden at runtime.
/// BF16GemmDispatcher combines this with g_hw_amx_bf16 to make the
/// final dispatch decision.
static std::atomic<int> g_float32_matmul_precision { static_cast<int>(parse_env_precision()) };
} // anonymous namespace
/// Return whether AMX-BF16 hardware is available in this process.
/// Result is constant for the lifetime of the process.
DAAL_EXPORT bool daal_has_amx_bf16()
{
return g_hw_amx_bf16;
}
/// Return the currently requested float32 matmul precision.
/// This is a plain atomic load — no hardware queries.
DAAL_EXPORT daal::internal::Float32MatmulPrecision daal_get_float32_matmul_precision()
{
return static_cast<daal::internal::Float32MatmulPrecision>(g_float32_matmul_precision.load(std::memory_order_relaxed));
}
/// Set the float32 matmul precision hint.
/// Stores the requested level directly; does NOT re-check hardware.
/// If 'high' is requested but hardware lacks AMX-BF16, BF16GemmDispatcher
/// will fall through to sgemm transparently via its own g_use_bf16_gemm gate.
DAAL_EXPORT void daal_set_float32_matmul_precision(daal::internal::Float32MatmulPrecision p)
{
g_float32_matmul_precision.store(static_cast<int>(p), std::memory_order_relaxed);
}