diff --git a/g3doc/op_wishlist.md b/g3doc/op_wishlist.md index 9056e18b27..8e2501184c 100644 --- a/g3doc/op_wishlist.md +++ b/g3doc/op_wishlist.md @@ -38,7 +38,7 @@ _mm512_getmant (f32/f64) High-precision! Consider copying from SLEEF. See #1650. -cosh, erf, fmod, ilogb, lgamma, logb, modf, nextafter, nexttoward, pow, +cosh, fmod, ilogb, lgamma, logb, modf, nextafter, nexttoward, pow, scalbn, tan, tgamma ### Remaining STL functions for hwy/contrib/algo @@ -184,4 +184,5 @@ For SVE (svld1sb_u32)+WASM? Compiler can probably already fuse. * ~~exp2~~ - by johnplatts * ~~Min/MaxValue~~ * ~~Count(If) (https://en.algorithmica.org/hpc/simd/masking/)~~ -* ~~cbrt~~ \ No newline at end of file +* ~~cbrt~~ +* ~~erf~~ \ No newline at end of file diff --git a/hwy/contrib/math/math-inl.h b/hwy/contrib/math/math-inl.h index 8ef3842a60..ca23a0ea48 100644 --- a/hwy/contrib/math/math-inl.h +++ b/hwy/contrib/math/math-inl.h @@ -171,6 +171,21 @@ HWY_NOINLINE V CallCos(const D d, VecArg x) { return Cos(d, x); } +/** + * Highway SIMD version of std::erf(x). + * + * Valid Lane Types: float32, float64 + * Max Error: ULP = 4 + * Valid Range: [-FLT_MAX, +FLT_MAX] + * @return error function of 'x' + */ +template +HWY_INLINE V Erf(D d, V x); +template +HWY_NOINLINE V CallErf(const D d, VecArg x) { + return Erf(d, x); +} + /** * Highway SIMD version of std::exp(x). * @@ -1640,6 +1655,77 @@ HWY_INLINE V Cos(const D d, V x) { d, Xor(impl.CosReduce(d, y, q), impl.CosSignFromQuadrant(d, q))); } +// Erf +// Based on Cephes erff/erf by Stephen Moshier (public domain, 1989) +// See https://www.netlib.org/cephes/ — single/ndtrf.c +template +HWY_INLINE V Erf(const D d, V x) { + using T = TFromD; + + // |x| < 1, 2/sqrt(pi) + const V kT0 = Set(d, static_cast(1.128379165726710)); + const V kT1 = Set(d, static_cast(-3.761262582423300E-1)); + const V kT2 = Set(d, static_cast(1.128358514861418E-1)); + const V kT3 = Set(d, static_cast(-2.685381193529856E-2)); + const V kT4 = Set(d, static_cast(5.188327685732524E-3)); + const V kT5 = Set(d, static_cast(-8.010193625184903E-4)); + const V kT6 = Set(d, static_cast(7.853861353153693E-5)); + + // |x| in [1, 2), 1/sqrt(pi) + const V kP0 = Set(d, static_cast(5.638259427386472E-1)); + const V kP1 = Set(d, static_cast(-2.741127028184656E-1)); + const V kP2 = Set(d, static_cast(3.404879937665872E-1)); + const V kP3 = Set(d, static_cast(-4.944515323274145E-1)); + const V kP4 = Set(d, static_cast(6.210004621745983E-1)); + const V kP5 = Set(d, static_cast(-5.824733027278666E-1)); + const V kP6 = Set(d, static_cast(3.687424674597105E-1)); + const V kP7 = Set(d, static_cast(-1.387039388740657E-1)); + const V kP8 = Set(d, static_cast(2.326819970068386E-2)); + + // |x| in [2, 14), 1/sqrt(pi) + const V kR0 = Set(d, static_cast(5.641895067754075E-1)); + const V kR1 = Set(d, static_cast(-2.820767439740514E-1)); + const V kR2 = Set(d, static_cast(4.218463358204948E-1)); + const V kR3 = Set(d, static_cast(-1.015265279202700E+0)); + const V kR4 = Set(d, static_cast(2.921019019210786E+0)); + const V kR5 = Set(d, static_cast(-7.495518717768503E+0)); + const V kR6 = Set(d, static_cast(1.297719955372516E+1)); + const V kR7 = Set(d, static_cast(-1.047766399936249E+1)); + + const V kOne = Set(d, static_cast(1)); + + const V sign = And(SignBit(d), x); + x = Xor(x, sign); + const V z = Mul(x, x); + + // |x| < 1: erf(x) = x * T(x * x) + const V poly_T = impl::Estrin(z, kT0, kT1, kT2, kT3, kT4, kT5, kT6); + const V erf_small = Mul(x, poly_T); + + // |x| in [1, 14): erf(x) = 1 - exp(-x * x) * poly(1 / (x * x)) / x + const V exp_neg_z = Exp(d, Neg(z)); + const V inv_x = Div(kOne, x); + const V w = Mul(inv_x, inv_x); + const V poly_P = impl::Estrin(w, kP0, kP1, kP2, kP3, kP4, kP5, kP6, kP7, kP8); + const V poly_R = impl::Estrin(w, kR0, kR1, kR2, kR3, kR4, kR5, kR6, kR7); + const auto is_mid = Lt(x, Set(d, static_cast(2))); + const V poly_pick = IfThenElse(is_mid, poly_P, poly_R); + + const V scaled_poly = Mul(inv_x, poly_pick); + const V erf_large = NegMulAdd(exp_neg_z, scaled_poly, kOne); + + const auto is_below_saturation = Lt(x, Set(d, static_cast(14))); + const auto is_small = Lt(x, kOne); + + V result = kOne; + result = IfThenElse(is_below_saturation, erf_large, result); + result = IfThenElse(is_small, erf_small, result); + + result = IfThenElse(IsNaN(x), x, result); + + return Or(result, sign); +} + template HWY_INLINE V Exp(const D d, V x) { using T = TFromD; diff --git a/hwy/contrib/math/math_test.cc b/hwy/contrib/math/math_test.cc index 1bdca9838d..0fd5419d5d 100644 --- a/hwy/contrib/math/math_test.cc +++ b/hwy/contrib/math/math_test.cc @@ -173,6 +173,10 @@ HWY_NOINLINE void TestMath(const char* name, T (*fx1)(T), DEFINE_MATH_TEST_FUNC(NAME) // clang-format off +DEFINE_MATH_TEST(Erf, + std::erf, CallErf, -FLT_MAX, +FLT_MAX, 4, + // temp rigging this, updating once f32 is passing and get a better idea of approach + std::erf, CallErf, -DBL_MAX, +DBL_MAX, 100000000) DEFINE_MATH_TEST(Exp, std::exp, CallExp, -FLT_MAX, +104.0f, 1, std::exp, CallExp, -DBL_MAX, +104.0, 1) @@ -623,6 +627,7 @@ HWY_AFTER_NAMESPACE(); namespace hwy { namespace { HWY_BEFORE_TEST(HwyMathTest); +HWY_EXPORT_AND_TEST_P(HwyMathTest, TestAllErf); HWY_EXPORT_AND_TEST_P(HwyMathTest, TestAllExp); HWY_EXPORT_AND_TEST_P(HwyMathTest, TestAllExp2); HWY_EXPORT_AND_TEST_P(HwyMathTest, TestAllExpm1);