Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
272 changes: 154 additions & 118 deletions hwy/contrib/math/fast_math-inl.h
Original file line number Diff line number Diff line change
Expand Up @@ -509,8 +509,8 @@ HWY_INLINE V FastAtan2(const D d, V y, V x) {
* Fast approximation of tanh(x).
*
* Valid Lane Types: float32, float64
* Max Relative Error : 0.02% for float32, 0.02% for float64
* Average Relative Error : 0.0001% for float32, 0.000001% for float64
* Max Relative Error : 0.006% for float32, 0.006% for float64
* Average Relative Error : 0.00003% for float32, 0.000004% for float64
* Max Relative Error for [-0.01, 0.01] : 0.003%
* Average Relative Error for [-0.01, 0.01] : 0.000002%
* Valid Range: float32: [-1e35, +1e35]
Expand All @@ -526,7 +526,7 @@ HWY_INLINE V FastTanh(D d, V val) {
auto y = Abs(val);

constexpr size_t kLanes = HWY_MAX_LANES_D(D);
V a, b, c, d_val;
V a, b, c, d_val, e;

if constexpr ((kLanes >= 4 && !HWY_HAVE_SCALABLE) ||
(HWY_HAVE_SCALABLE && sizeof(T) == 4 && detail::IsFull(d))) {
Expand All @@ -552,41 +552,50 @@ HWY_INLINE V FastTanh(D d, V val) {
idx_i = Min(idx_i, Set(DI(), 7));

HWY_ALIGN static constexpr T arr_a[8] = {
static_cast<T>(-0.32124064137467889),
static_cast<T>(-0.25063809221086503),
static_cast<T>(-0.12743143099276930),
static_cast<T>(0.00879257493380024),
static_cast<T>(0.09774602019349406),
static_cast<T>(0.09746926160817335),
static_cast<T>(0.03461152231207073),
static_cast<T>(0.00190152088495461)};
static_cast<T>(0.055972201057460494),
static_cast<T>(0.141142221987527),
static_cast<T>(0.168455721126857),
static_cast<T>(0.122316741409427),
static_cast<T>(0.0363144967431118),
static_cast<T>(-0.0207194964499277),
static_cast<T>(-0.0148971040423995),
static_cast<T>(-0.00075065644748894059)};
HWY_ALIGN static constexpr T arr_b[8] = {
static_cast<T>(-0.00191824037528361),
static_cast<T>(-0.04124816646249752),
static_cast<T>(-0.17298439557734449),
static_cast<T>(-0.40057378897277868),
static_cast<T>(-0.61249080708786208),
static_cast<T>(-0.60213708038791991),
static_cast<T>(-0.27823168655690367),
static_cast<T>(-0.02476274996694735)};
static_cast<T>(-0.34169330090059025),
static_cast<T>(-0.397428676355329),
static_cast<T>(-0.431042547951791),
static_cast<T>(-0.323126397208966),
static_cast<T>(-0.0427346425890441),
static_cast<T>(0.212762301109212),
static_cast<T>(0.161869495723915),
static_cast<T>(0.013314738818911605)};
HWY_ALIGN static constexpr T arr_c[8] = {
static_cast<T>(1.00009167744117367),
static_cast<T>(1.00761071731437957),
static_cast<T>(1.05510014553372966),
static_cast<T>(1.18269984996630884),
static_cast<T>(1.35178582776199496),
static_cast<T>(1.32900602936237289),
static_cast<T>(0.76904955465038061),
static_cast<T>(0.10755808660951535)};
static_cast<T>(0.00051593307922874086),
static_cast<T>(0.0149076074949383),
static_cast<T>(0.0304366282196960),
static_cast<T>(-0.0647369325706488),
static_cast<T>(-0.409688801213951),
static_cast<T>(-0.841374109580751),
static_cast<T>(-0.682231542238376),
static_cast<T>(-0.088759895679724005)};
HWY_ALIGN static constexpr T arr_d[8] = {
static_cast<T>(-0.00000105381180317),
static_cast<T>(-0.00049073110164177),
static_cast<T>(-0.00625688495915542),
static_cast<T>(-0.03025960590948565),
static_cast<T>(-0.07544878909500170),
static_cast<T>(-0.06251270311746010),
static_cast<T>(0.26202315539906595),
static_cast<T>(0.84371089018138146)};
static_cast<T>(0.99998859362000814),
static_cast<T>(0.998256258266747),
static_cast<T>(0.995064857371438),
static_cast<T>(1.03255144697396),
static_cast<T>(1.22230367094384),
static_cast<T>(1.54838530714448),
static_cast<T>(1.333832467345998),
static_cast<T>(0.26430484966813278)};
HWY_ALIGN static constexpr T arr_e[8] = {
static_cast<T>(4.614802448903236e-08),
static_cast<T>(8.13616326949790e-05),
static_cast<T>(0.000327414093187497),
static_cast<T>(-0.00523293563447093),
static_cast<T>(-0.0446002224255692),
static_cast<T>(-0.137517925537649),
static_cast<T>(-0.0313056210132493),
static_cast<T>(0.70225333959194747)};

// Since Lookup8 is available for HWY_MIN_BYTES / sizeof(T) >= 4, this
// condition covers all cases we encounter inside the top level if block
Expand All @@ -595,6 +604,7 @@ HWY_INLINE V FastTanh(D d, V val) {
b = Lookup8(d, arr_b, idx_i);
c = Lookup8(d, arr_c, idx_i);
d_val = Lookup8(d, arr_d, idx_i);
e = Lookup8(d, arr_e, idx_i);
} else {
// --- FALLBACK PATH: Blend Chain ---
// Thresholds for intervals
Expand All @@ -610,162 +620,188 @@ HWY_INLINE V FastTanh(D d, V val) {
// Split into two parallel chains to reduce dependency latency.

// -- Chain 1: Indices 0 to 3 (Evaluated starting from t3 down to t0)
auto a_low = Set(d, static_cast<T>(0.00879257493380024)); // idx 3
auto b_low = Set(d, static_cast<T>(-0.40057378897277868));
auto c_low = Set(d, static_cast<T>(1.18269984996630884));
auto d_low = Set(d, static_cast<T>(-0.03025960590948565));
auto a_low = Set(d, static_cast<T>(0.122316741409427)); // idx 3
auto b_low = Set(d, static_cast<T>(-0.323126397208966));
auto c_low = Set(d, static_cast<T>(-0.0647369325706488));
auto d_low = Set(d, static_cast<T>(1.03255144697396));
auto e_low = Set(d, static_cast<T>(-0.00523293563447093));

auto mask = Lt(y, t2);
a_low =
IfThenElse(mask, Set(d, static_cast<T>(-0.12743143099276930)), a_low);
IfThenElse(mask, Set(d, static_cast<T>(0.168455721126857)), a_low);
b_low =
IfThenElse(mask, Set(d, static_cast<T>(-0.17298439557734449)), b_low);
IfThenElse(mask, Set(d, static_cast<T>(-0.431042547951791)), b_low);
c_low =
IfThenElse(mask, Set(d, static_cast<T>(1.05510014553372966)), c_low);
IfThenElse(mask, Set(d, static_cast<T>(0.0304366282196960)), c_low);
d_low =
IfThenElse(mask, Set(d, static_cast<T>(-0.00625688495915542)), d_low);
IfThenElse(mask, Set(d, static_cast<T>(0.995064857371438)), d_low);
e_low =
IfThenElse(mask, Set(d, static_cast<T>(0.000327414093187497)), e_low);

mask = Lt(y, t1);
a_low =
IfThenElse(mask, Set(d, static_cast<T>(-0.25063809221086503)), a_low);
IfThenElse(mask, Set(d, static_cast<T>(0.141142221987527)), a_low);
b_low =
IfThenElse(mask, Set(d, static_cast<T>(-0.04124816646249752)), b_low);
IfThenElse(mask, Set(d, static_cast<T>(-0.397428676355329)), b_low);
c_low =
IfThenElse(mask, Set(d, static_cast<T>(1.00761071731437957)), c_low);
IfThenElse(mask, Set(d, static_cast<T>(0.0149076074949383)), c_low);
d_low =
IfThenElse(mask, Set(d, static_cast<T>(-0.00049073110164177)), d_low);
IfThenElse(mask, Set(d, static_cast<T>(0.998256258266747)), d_low);
e_low =
IfThenElse(mask, Set(d, static_cast<T>(8.13616326949790e-05)), e_low);

mask = Lt(y, t0);
a_low =
IfThenElse(mask, Set(d, static_cast<T>(-0.32124064137467889)), a_low);
IfThenElse(mask, Set(d, static_cast<T>(0.055972201057460494)), a_low);
b_low =
IfThenElse(mask, Set(d, static_cast<T>(-0.00191824037528361)), b_low);
c_low =
IfThenElse(mask, Set(d, static_cast<T>(1.00009167744117367)), c_low);
IfThenElse(mask, Set(d, static_cast<T>(-0.34169330090059025)), b_low);
c_low = IfThenElse(mask, Set(d, static_cast<T>(0.00051593307922874086)),
c_low);
d_low =
IfThenElse(mask, Set(d, static_cast<T>(-0.00000105381180317)), d_low);
IfThenElse(mask, Set(d, static_cast<T>(0.99998859362000814)), d_low);
e_low = IfThenElse(mask, Set(d, static_cast<T>(4.614802448903236e-08)),
e_low);

// -- Chain 2: Indices 4 to 7 (Evaluated starting from t6 down to t4)
auto a_high = Set(d, static_cast<T>(0.00190152088495461)); // idx 7
auto b_high = Set(d, static_cast<T>(-0.02476274996694735));
auto c_high = Set(d, static_cast<T>(0.10755808660951535));
auto d_high = Set(d, static_cast<T>(0.84371089018138146));
auto a_high = Set(d, static_cast<T>(-0.00075065644748894059)); // idx 7
auto b_high = Set(d, static_cast<T>(0.013314738818911605));
auto c_high = Set(d, static_cast<T>(-0.088759895679724005));
auto d_high = Set(d, static_cast<T>(0.26430484966813278));
auto e_high = Set(d, static_cast<T>(0.70225333959194747));

mask = Lt(y, t6);
a_high =
IfThenElse(mask, Set(d, static_cast<T>(0.03461152231207073)), a_high);
b_high = IfThenElse(mask, Set(d, static_cast<T>(-0.27823168655690367)),
b_high);
IfThenElse(mask, Set(d, static_cast<T>(-0.0148971040423995)), a_high);
b_high =
IfThenElse(mask, Set(d, static_cast<T>(0.161869495723915)), b_high);
c_high =
IfThenElse(mask, Set(d, static_cast<T>(0.76904955465038061)), c_high);
IfThenElse(mask, Set(d, static_cast<T>(-0.682231542238376)), c_high);
d_high =
IfThenElse(mask, Set(d, static_cast<T>(0.26202315539906595)), d_high);
IfThenElse(mask, Set(d, static_cast<T>(1.333832467345998)), d_high);
e_high =
IfThenElse(mask, Set(d, static_cast<T>(-0.0313056210132493)), e_high);

mask = Lt(y, t5);
a_high =
IfThenElse(mask, Set(d, static_cast<T>(0.09746926160817335)), a_high);
b_high = IfThenElse(mask, Set(d, static_cast<T>(-0.60213708038791991)),
b_high);
IfThenElse(mask, Set(d, static_cast<T>(-0.0207194964499277)), a_high);
b_high =
IfThenElse(mask, Set(d, static_cast<T>(0.212762301109212)), b_high);
c_high =
IfThenElse(mask, Set(d, static_cast<T>(1.32900602936237289)), c_high);
d_high = IfThenElse(mask, Set(d, static_cast<T>(-0.06251270311746010)),
d_high);
IfThenElse(mask, Set(d, static_cast<T>(-0.841374109580751)), c_high);
d_high =
IfThenElse(mask, Set(d, static_cast<T>(1.54838530714448)), d_high);
e_high =
IfThenElse(mask, Set(d, static_cast<T>(-0.137517925537649)), e_high);

mask = Lt(y, t4);
a_high =
IfThenElse(mask, Set(d, static_cast<T>(0.09774602019349406)), a_high);
b_high = IfThenElse(mask, Set(d, static_cast<T>(-0.61249080708786208)),
b_high);
IfThenElse(mask, Set(d, static_cast<T>(0.0363144967431118)), a_high);
b_high =
IfThenElse(mask, Set(d, static_cast<T>(-0.0427346425890441)), b_high);
c_high =
IfThenElse(mask, Set(d, static_cast<T>(1.35178582776199496)), c_high);
d_high = IfThenElse(mask, Set(d, static_cast<T>(-0.07544878909500170)),
d_high);
IfThenElse(mask, Set(d, static_cast<T>(-0.409688801213951)), c_high);
d_high =
IfThenElse(mask, Set(d, static_cast<T>(1.22230367094384)), d_high);
e_high =
IfThenElse(mask, Set(d, static_cast<T>(-0.0446002224255692)), e_high);

// -- Merge the two chains
auto merge_mask = Lt(y, t3);
a = IfThenElse(merge_mask, a_low, a_high);
b = IfThenElse(merge_mask, b_low, b_high);
c = IfThenElse(merge_mask, c_low, c_high);
d_val = IfThenElse(merge_mask, d_low, d_high);
e = IfThenElse(merge_mask, e_low, e_high);
} else {
// Start with highest index (7)
a = Set(d, static_cast<T>(0.00190152088495461));
b = Set(d, static_cast<T>(-0.02476274996694735));
c = Set(d, static_cast<T>(0.10755808660951535));
d_val = Set(d, static_cast<T>(0.84371089018138146));
a = Set(d, static_cast<T>(-0.00075065644748894059));
b = Set(d, static_cast<T>(0.013314738818911605));
c = Set(d, static_cast<T>(-0.088759895679724005));
d_val = Set(d, static_cast<T>(0.26430484966813278));
e = Set(d, static_cast<T>(0.70225333959194747));

// If y < t6 (idx 6)
auto mask = Lt(y, t6);
a = IfThenElse(mask, Set(d, static_cast<T>(0.03461152231207073)), a);
b = IfThenElse(mask, Set(d, static_cast<T>(-0.27823168655690367)), b);
c = IfThenElse(mask, Set(d, static_cast<T>(0.76904955465038061)), c);
a = IfThenElse(mask, Set(d, static_cast<T>(-0.0148971040423995)), a);
b = IfThenElse(mask, Set(d, static_cast<T>(0.161869495723915)), b);
c = IfThenElse(mask, Set(d, static_cast<T>(-0.682231542238376)), c);
d_val =
IfThenElse(mask, Set(d, static_cast<T>(0.26202315539906595)), d_val);
IfThenElse(mask, Set(d, static_cast<T>(1.333832467345998)), d_val);
e = IfThenElse(mask, Set(d, static_cast<T>(-0.0313056210132493)), e);

// If y < t5 (idx 5)
mask = Lt(y, t5);
a = IfThenElse(mask, Set(d, static_cast<T>(0.09746926160817335)), a);
b = IfThenElse(mask, Set(d, static_cast<T>(-0.60213708038791991)), b);
c = IfThenElse(mask, Set(d, static_cast<T>(1.32900602936237289)), c);
d_val =
IfThenElse(mask, Set(d, static_cast<T>(-0.06251270311746010)), d_val);
a = IfThenElse(mask, Set(d, static_cast<T>(-0.0207194964499277)), a);
b = IfThenElse(mask, Set(d, static_cast<T>(0.212762301109212)), b);
c = IfThenElse(mask, Set(d, static_cast<T>(-0.841374109580751)), c);
d_val = IfThenElse(mask, Set(d, static_cast<T>(1.54838530714448)), d_val);
e = IfThenElse(mask, Set(d, static_cast<T>(-0.137517925537649)), e);

// If y < t4 (idx 4)
mask = Lt(y, t4);
a = IfThenElse(mask, Set(d, static_cast<T>(0.09774602019349406)), a);
b = IfThenElse(mask, Set(d, static_cast<T>(-0.61249080708786208)), b);
c = IfThenElse(mask, Set(d, static_cast<T>(1.35178582776199496)), c);
d_val =
IfThenElse(mask, Set(d, static_cast<T>(-0.07544878909500170)), d_val);
a = IfThenElse(mask, Set(d, static_cast<T>(0.0363144967431118)), a);
b = IfThenElse(mask, Set(d, static_cast<T>(-0.0427346425890441)), b);
c = IfThenElse(mask, Set(d, static_cast<T>(-0.409688801213951)), c);
d_val = IfThenElse(mask, Set(d, static_cast<T>(1.22230367094384)), d_val);
e = IfThenElse(mask, Set(d, static_cast<T>(-0.0446002224255692)), e);

// If y < t3 (idx 3)
mask = Lt(y, t3);
a = IfThenElse(mask, Set(d, static_cast<T>(0.00879257493380024)), a);
b = IfThenElse(mask, Set(d, static_cast<T>(-0.40057378897277868)), b);
c = IfThenElse(mask, Set(d, static_cast<T>(1.18269984996630884)), c);
d_val =
IfThenElse(mask, Set(d, static_cast<T>(-0.03025960590948565)), d_val);
a = IfThenElse(mask, Set(d, static_cast<T>(0.122316741409427)), a);
b = IfThenElse(mask, Set(d, static_cast<T>(-0.323126397208966)), b);
c = IfThenElse(mask, Set(d, static_cast<T>(-0.0647369325706488)), c);
d_val = IfThenElse(mask, Set(d, static_cast<T>(1.03255144697396)), d_val);
e = IfThenElse(mask, Set(d, static_cast<T>(-0.00523293563447093)), e);

// If y < t2 (idx 2)
mask = Lt(y, t2);
a = IfThenElse(mask, Set(d, static_cast<T>(-0.12743143099276930)), a);
b = IfThenElse(mask, Set(d, static_cast<T>(-0.17298439557734449)), b);
c = IfThenElse(mask, Set(d, static_cast<T>(1.05510014553372966)), c);
a = IfThenElse(mask, Set(d, static_cast<T>(0.168455721126857)), a);
b = IfThenElse(mask, Set(d, static_cast<T>(-0.431042547951791)), b);
c = IfThenElse(mask, Set(d, static_cast<T>(0.0304366282196960)), c);
d_val =
IfThenElse(mask, Set(d, static_cast<T>(-0.00625688495915542)), d_val);
IfThenElse(mask, Set(d, static_cast<T>(0.995064857371438)), d_val);
e = IfThenElse(mask, Set(d, static_cast<T>(0.000327414093187497)), e);

// If y < t1 (idx 1)
mask = Lt(y, t1);
a = IfThenElse(mask, Set(d, static_cast<T>(-0.25063809221086503)), a);
b = IfThenElse(mask, Set(d, static_cast<T>(-0.04124816646249752)), b);
c = IfThenElse(mask, Set(d, static_cast<T>(1.00761071731437957)), c);
a = IfThenElse(mask, Set(d, static_cast<T>(0.141142221987527)), a);
b = IfThenElse(mask, Set(d, static_cast<T>(-0.397428676355329)), b);
c = IfThenElse(mask, Set(d, static_cast<T>(0.0149076074949383)), c);
d_val =
IfThenElse(mask, Set(d, static_cast<T>(-0.00049073110164177)), d_val);
IfThenElse(mask, Set(d, static_cast<T>(0.998256258266747)), d_val);
e = IfThenElse(mask, Set(d, static_cast<T>(8.13616326949790e-05)), e);

// If y < t0 (idx 0)
mask = Lt(y, t0);
a = IfThenElse(mask, Set(d, static_cast<T>(-0.32124064137467889)), a);
b = IfThenElse(mask, Set(d, static_cast<T>(-0.00191824037528361)), b);
c = IfThenElse(mask, Set(d, static_cast<T>(1.00009167744117367)), c);
a = IfThenElse(mask, Set(d, static_cast<T>(0.055972201057460494)), a);
b = IfThenElse(mask, Set(d, static_cast<T>(-0.34169330090059025)), b);
c = IfThenElse(mask, Set(d, static_cast<T>(0.00051593307922874086)), c);
d_val =
IfThenElse(mask, Set(d, static_cast<T>(-0.00000105381180317)), d_val);
IfThenElse(mask, Set(d, static_cast<T>(0.99998859362000814)), d_val);
e = IfThenElse(mask, Set(d, static_cast<T>(4.614802448903236e-08)), e);
}
}

// Math: f(y) = ay^3 + by^2 + cy + d
// Math: f(y) = ay^4 + by^3 + cy^2 + dy + e
// Using Estrin's scheme from FastExp
const auto y2 = Mul(y, y);
const auto pcd = MulAdd(c, y, d_val);
const auto pab = MulAdd(a, y, b);
auto result = MulAdd(y2, pab, pcd);

const auto kSmall = Set(d, static_cast<T>(0.01));
// term0 = d*y + e
const auto term0 = MulAdd(d_val, y, e);
// term1 = b*y + c
const auto term1 = MulAdd(b, y, c);
// term2 = a*y2 + term1
const auto term2 = MulAdd(a, y2, term1);
// result = term2 * y2 + term0
auto result = MulAdd(term2, y2, term0);

const auto kSmall = Set(d, static_cast<T>(0.001));
result = IfThenElse(Lt(y, kSmall), y, result);

const auto k1 = Set(d, static_cast<T>(1.0));
// We can take Min since cubic approximation for index 7 is monotonically
// increasing, so for inputs >5 the polynomial approximation will output >1.0
// allowing us to use Min() directly instead of IfThenElse()
result = Min(result, k1);
// We cannot use the Min() optimization for inputs >= 5.0 here because the
// degree 4 polynomial for index 7 is not monotonically increasing and turns
// down for large inputs (coefficient 'a' is negative).
result = IfThenElse(Lt(y, Set(d, static_cast<T>(5.2))), result, k1);

return CopySign(result, val); // Restore sign
}
Expand Down
Loading
Loading