Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 3 additions & 2 deletions g3doc/op_wishlist.md
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ _mm512_getmant (f32/f64)

High-precision! Consider copying from SLEEF. See #1650.

cosh, erf, fmod, ilogb, lgamma, logb, modf, nextafter, nexttoward, pow,
cosh, fmod, ilogb, lgamma, logb, modf, nextafter, nexttoward, pow,
scalbn, tan, tgamma

### Remaining STL functions for hwy/contrib/algo
Expand Down Expand Up @@ -184,4 +184,5 @@ For SVE (svld1sb_u32)+WASM? Compiler can probably already fuse.
* ~~exp2~~ - by johnplatts
* ~~Min/MaxValue~~
* ~~Count(If) (https://en.algorithmica.org/hpc/simd/masking/)~~
* ~~cbrt~~
* ~~cbrt~~
* ~~erf~~
86 changes: 86 additions & 0 deletions hwy/contrib/math/math-inl.h
Original file line number Diff line number Diff line change
Expand Up @@ -171,6 +171,21 @@ HWY_NOINLINE V CallCos(const D d, VecArg<V> x) {
return Cos(d, x);
}

/**
* Highway SIMD version of std::erf(x).
*
* Valid Lane Types: float32, float64
* Max Error: ULP = 4
* Valid Range: [-FLT_MAX, +FLT_MAX]
* @return error function of 'x'
*/
template <class D, class V>
HWY_INLINE V Erf(D d, V x);
template <class D, class V>
HWY_NOINLINE V CallErf(const D d, VecArg<V> x) {
return Erf(d, x);
}

/**
* Highway SIMD version of std::exp(x).
*
Expand Down Expand Up @@ -1640,6 +1655,77 @@ HWY_INLINE V Cos(const D d, V x) {
d, Xor(impl.CosReduce(d, y, q), impl.CosSignFromQuadrant(d, q)));
}

// Erf
// Based on Cephes erff/erf by Stephen Moshier (public domain, 1989)
// See https://www.netlib.org/cephes/ — single/ndtrf.c
template <class D, class V>
HWY_INLINE V Erf(const D d, V x) {
using T = TFromD<D>;

// |x| < 1, 2/sqrt(pi)
const V kT0 = Set(d, static_cast<T>(1.128379165726710));
const V kT1 = Set(d, static_cast<T>(-3.761262582423300E-1));
const V kT2 = Set(d, static_cast<T>(1.128358514861418E-1));
const V kT3 = Set(d, static_cast<T>(-2.685381193529856E-2));
const V kT4 = Set(d, static_cast<T>(5.188327685732524E-3));
const V kT5 = Set(d, static_cast<T>(-8.010193625184903E-4));
const V kT6 = Set(d, static_cast<T>(7.853861353153693E-5));

// |x| in [1, 2), 1/sqrt(pi)
const V kP0 = Set(d, static_cast<T>(5.638259427386472E-1));
const V kP1 = Set(d, static_cast<T>(-2.741127028184656E-1));
const V kP2 = Set(d, static_cast<T>(3.404879937665872E-1));
const V kP3 = Set(d, static_cast<T>(-4.944515323274145E-1));
const V kP4 = Set(d, static_cast<T>(6.210004621745983E-1));
const V kP5 = Set(d, static_cast<T>(-5.824733027278666E-1));
const V kP6 = Set(d, static_cast<T>(3.687424674597105E-1));
const V kP7 = Set(d, static_cast<T>(-1.387039388740657E-1));
const V kP8 = Set(d, static_cast<T>(2.326819970068386E-2));

// |x| in [2, 14), 1/sqrt(pi)
const V kR0 = Set(d, static_cast<T>(5.641895067754075E-1));
const V kR1 = Set(d, static_cast<T>(-2.820767439740514E-1));
const V kR2 = Set(d, static_cast<T>(4.218463358204948E-1));
const V kR3 = Set(d, static_cast<T>(-1.015265279202700E+0));
const V kR4 = Set(d, static_cast<T>(2.921019019210786E+0));
const V kR5 = Set(d, static_cast<T>(-7.495518717768503E+0));
const V kR6 = Set(d, static_cast<T>(1.297719955372516E+1));
const V kR7 = Set(d, static_cast<T>(-1.047766399936249E+1));

const V kOne = Set(d, static_cast<T>(1));

const V sign = And(SignBit(d), x);
x = Xor(x, sign);
const V z = Mul(x, x);

// |x| < 1: erf(x) = x * T(x * x)
const V poly_T = impl::Estrin(z, kT0, kT1, kT2, kT3, kT4, kT5, kT6);
const V erf_small = Mul(x, poly_T);

// |x| in [1, 14): erf(x) = 1 - exp(-x * x) * poly(1 / (x * x)) / x
const V exp_neg_z = Exp(d, Neg(z));
const V inv_x = Div(kOne, x);
const V w = Mul(inv_x, inv_x);
const V poly_P = impl::Estrin(w, kP0, kP1, kP2, kP3, kP4, kP5, kP6, kP7, kP8);
const V poly_R = impl::Estrin(w, kR0, kR1, kR2, kR3, kR4, kR5, kR6, kR7);
const auto is_mid = Lt(x, Set(d, static_cast<T>(2)));
const V poly_pick = IfThenElse(is_mid, poly_P, poly_R);

const V scaled_poly = Mul(inv_x, poly_pick);
const V erf_large = NegMulAdd(exp_neg_z, scaled_poly, kOne);

const auto is_below_saturation = Lt(x, Set(d, static_cast<T>(14)));
const auto is_small = Lt(x, kOne);

V result = kOne;
result = IfThenElse(is_below_saturation, erf_large, result);
result = IfThenElse(is_small, erf_small, result);

result = IfThenElse(IsNaN(x), x, result);

return Or(result, sign);
}

template <class D, class V>
HWY_INLINE V Exp(const D d, V x) {
using T = TFromD<D>;
Expand Down
5 changes: 5 additions & 0 deletions hwy/contrib/math/math_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -173,6 +173,10 @@ HWY_NOINLINE void TestMath(const char* name, T (*fx1)(T),
DEFINE_MATH_TEST_FUNC(NAME)

// clang-format off
DEFINE_MATH_TEST(Erf,
std::erf, CallErf, -FLT_MAX, +FLT_MAX, 4,
// temp rigging this, updating once f32 is passing and get a better idea of approach
std::erf, CallErf, -DBL_MAX, +DBL_MAX, 100000000)
DEFINE_MATH_TEST(Exp,
std::exp, CallExp, -FLT_MAX, +104.0f, 1,
std::exp, CallExp, -DBL_MAX, +104.0, 1)
Expand Down Expand Up @@ -623,6 +627,7 @@ HWY_AFTER_NAMESPACE();
namespace hwy {
namespace {
HWY_BEFORE_TEST(HwyMathTest);
HWY_EXPORT_AND_TEST_P(HwyMathTest, TestAllErf);
HWY_EXPORT_AND_TEST_P(HwyMathTest, TestAllExp);
HWY_EXPORT_AND_TEST_P(HwyMathTest, TestAllExp2);
HWY_EXPORT_AND_TEST_P(HwyMathTest, TestAllExpm1);
Expand Down
Loading