diff --git a/.github/workflows/array_api.yml b/.github/workflows/array_api.yml index 2cb6a855..8d5eddda 100644 --- a/.github/workflows/array_api.yml +++ b/.github/workflows/array_api.yml @@ -12,7 +12,7 @@ jobs: test_array_api: strategy: matrix: - os: [ubuntu-latest, ubuntu-22.04, ubuntu-22.04-arm] + os: [ubuntu-latest, ubuntu-22.04-arm] python-version: [3.11, 3.12, 3.13] runs-on: ${{ matrix.os }} defaults: diff --git a/.github/workflows/base-cancelling.yml b/.github/workflows/base-cancelling.yml index a89293bc..eb29e59b 100644 --- a/.github/workflows/base-cancelling.yml +++ b/.github/workflows/base-cancelling.yml @@ -9,7 +9,7 @@ jobs: name: "Cancel duplicate workflow runs" strategy: matrix: - os: [ubuntu-latest, ubuntu-22.04, ubuntu-22.04-arm] + os: [ubuntu-latest, ubuntu-22.04-arm] runs-on: ${{matrix.os}} steps: diff --git a/.github/workflows/base-linux-ci.yml b/.github/workflows/base-linux-ci.yml index 091bcd9e..ff4a48d0 100644 --- a/.github/workflows/base-linux-ci.yml +++ b/.github/workflows/base-linux-ci.yml @@ -18,7 +18,7 @@ jobs: matrix: python-version: [3.11, 3.12, 3.13] kokkos-branch: ['4.7.01'] - os: [ubuntu-latest, ubuntu-22.04, ubuntu-22.04-arm] + os: [ubuntu-latest, ubuntu-22.04-arm] runs-on: ${{matrix.os}} defaults: diff --git a/.github/workflows/base-python-package.yml b/.github/workflows/base-python-package.yml index 26bd0a3d..ce2c4583 100644 --- a/.github/workflows/base-python-package.yml +++ b/.github/workflows/base-python-package.yml @@ -17,7 +17,7 @@ jobs: strategy: matrix: python-version: [3.11, 3.12, 3.13] - os: [ubuntu-latest, ubuntu-22.04, ubuntu-22.04-arm] + os: [ubuntu-latest, ubuntu-22.04-arm] runs-on: ${{matrix.os}} defaults: @@ -48,7 +48,7 @@ jobs: strategy: matrix: python-version: [3.13] - os: [ubuntu-latest, ubuntu-22.04, ubuntu-22.04-arm] + os: [ubuntu-latest, ubuntu-22.04-arm] runs-on: ${{matrix.os}} defaults: diff --git a/.github/workflows/formatting.yml b/.github/workflows/formatting.yml index 722afc25..8a41f807 100644 --- a/.github/workflows/formatting.yml +++ b/.github/workflows/formatting.yml @@ -13,7 +13,7 @@ jobs: strategy: matrix: python-version: [3.13] - os: [ubuntu-latest, ubuntu-22.04, ubuntu-22.04-arm] + os: [ubuntu-latest] runs-on: ${{matrix.os}} steps: diff --git a/.github/workflows/main_ci.yml b/.github/workflows/main_ci.yml index 93acbf16..e03b7e9f 100644 --- a/.github/workflows/main_ci.yml +++ b/.github/workflows/main_ci.yml @@ -12,7 +12,7 @@ jobs: test_pykokkos: strategy: matrix: - os: [ubuntu-latest, ubuntu-22.04, ubuntu-22.04-arm] + os: [ubuntu-latest, ubuntu-22.04-arm] python-version: ["3.11", "3.12", "3.13"] runs-on: ${{ matrix.os }} steps: @@ -24,7 +24,7 @@ jobs: - name: Install dependencies run: | python -m pip install --upgrade pip - python -m pip install --upgrade numpy mypy==1.0.1 cmake pytest pybind11 scikit-build patchelf + python -m pip install --upgrade numpy mypy==1.0.1 cmake pytest pybind11 scikit-build patchelf scipy - name: Install pykokkos-base working-directory: base run: | diff --git a/.github/workflows/nightly.yml b/.github/workflows/nightly.yml index ca803a07..ddfe47f5 100644 --- a/.github/workflows/nightly.yml +++ b/.github/workflows/nightly.yml @@ -10,7 +10,7 @@ jobs: test_vs_kokkos_develop: strategy: matrix: - os: [ubuntu-latest, ubuntu-22.04, ubuntu-22.04-arm] + os: [ubuntu-latest, ubuntu-22.04-arm] python-version: ["3.13"] runs-on: ${{ matrix.os }} steps: @@ -24,7 +24,7 @@ jobs: - name: Install dependencies run: | python -m pip install --upgrade pip - python -m pip install --upgrade numpy mypy==0.981 cmake pytest pybind11 scikit-build patchelf ninja + python -m pip install --upgrade numpy mypy==0.981 cmake pytest pybind11 scikit-build patchelf ninja scipy - name: Build Kokkos develop branch run: | cd /tmp diff --git a/pykokkos/core/translators/static.py b/pykokkos/core/translators/static.py index dced1afe..3f2b29f9 100644 --- a/pykokkos/core/translators/static.py +++ b/pykokkos/core/translators/static.py @@ -184,7 +184,7 @@ def check_symbols(self, classtypes: List[PyKokkosEntity], path: str) -> None: for error in error_messages: print(error) - sys.exit() + raise Exception("PyKokkos translation failed") def translate_classtypes( self, classtypes: List[PyKokkosEntity], restrict_views: Set[str] @@ -430,9 +430,9 @@ def translate_workunits( workunit: cppast.MethodDecl = workunits[n][1] self.add_rand_pool_state(workunit) node_visitor.has_rand_call = False - except: + except Exception as e: print(f"Translation of {w} {w.name} failed") - sys.exit(1) + raise e return workunits, has_rand_call diff --git a/pykokkos/core/translators/symbols_pass.py b/pykokkos/core/translators/symbols_pass.py index 142050d2..6d425a39 100644 --- a/pykokkos/core/translators/symbols_pass.py +++ b/pykokkos/core/translators/symbols_pass.py @@ -7,6 +7,7 @@ from pykokkos.core.visitors.visitors_util import ( math_constants, math_functions, + math_special_functions, allowed_types, view_dtypes, ) @@ -62,6 +63,7 @@ def __init__(self, members: PyKokkosMembers, pk_import: str, path: str): self.global_symbols.update(math_constants) self.global_symbols.update(math_functions) + self.global_symbols.update(math_special_functions) self.global_symbols.update(allowed_types) self.global_symbols.update(view_dtypes) self.global_symbols.update( diff --git a/pykokkos/core/visitors/pykokkos_visitor.py b/pykokkos/core/visitors/pykokkos_visitor.py index e953f5ad..96340b76 100644 --- a/pykokkos/core/visitors/pykokkos_visitor.py +++ b/pykokkos/core/visitors/pykokkos_visitor.py @@ -465,6 +465,10 @@ def visit_Call(self, node: ast.Call) -> cppast.CallExpr: ]: return cppast.CallExpr(function, args) + if name in visitors_util.math_special_functions: + function = cppast.DeclRefExpr("Kokkos::Experimental::" + name) + return cppast.CallExpr(function, args) + if function in self.kokkos_functions: if "PK_RESTRICT" in os.environ: return adjust_kokkos_function_call( diff --git a/pykokkos/core/visitors/visitors_util.py b/pykokkos/core/visitors/visitors_util.py index ecb10bcc..54e506bc 100644 --- a/pykokkos/core/visitors/visitors_util.py +++ b/pykokkos/core/visitors/visitors_util.py @@ -75,6 +75,25 @@ def pretty_print(node): # ast.NotIn: "not in", } +# see Kokkos_MathematicalSpecialFunctions.hpp +math_special_functions: Set = { + "expint1", + "erf", + "erfcx", + "cyl_bessel_y0", + "cyl_bessel_j0", + "cyl_bessel_y1", + "cyl_bessel_j1", + "cyl_bessel_i0", + "cyl_bessel_k0", + "cyl_bessel_i1", + "cyl_bessel_k1", + "cyl_bessel_h10", + "cyl_bessel_h11", + "cyl_bessel_h20", + "cyl_bessel_h21", +} + # TODO: provide mapping to cmath versions math_functions: Set = { "acos", @@ -89,7 +108,6 @@ def pretty_print(node): "cos", "cosh", "degrees", - "erf", "erfc", "exp", "expm1", diff --git a/pykokkos/core/visitors/workunit_visitor.py b/pykokkos/core/visitors/workunit_visitor.py index c0bec889..8533fb22 100644 --- a/pykokkos/core/visitors/workunit_visitor.py +++ b/pykokkos/core/visitors/workunit_visitor.py @@ -343,16 +343,24 @@ def visit_Call(self, node: ast.Call) -> cppast.CallExpr: return rand_call - if name in {"cyl_bessel_j0", "cyl_bessel_j1"}: + if name.startswith("cyl_bessel_"): if len(args) != 1: - self.error(node, "pk.cyl_bessel_j0/j1 accepts only one argument") + self.error(node, "bessel functions accepts only one argument") + # Instantiate bessel functions with an explicit complex type derived + # from the argument, but make sure we strip references so we don't + # end up with invalid nested types like + # Kokkos::complex&>. s = cppast.Serializer() arg_str = s.serialize(args[0]) + cmplx_type = f"std::remove_reference_t" + if name.endswith("h", 0, -2): + # special case of Hankel functions + arg_types = f"Kokkos::Experimental::{name}<{cmplx_type}>" + else: + arg_types = f"Kokkos::Experimental::{name}<{cmplx_type}, double, int>" math_call = cppast.CallExpr( - cppast.DeclRefExpr( - f"Kokkos::Experimental::{name}, double, int>" - ), + cppast.DeclRefExpr(arg_types), args, ) real_number_call = cppast.MemberCallExpr( diff --git a/tests/test_special_functions.py b/tests/test_special_functions.py new file mode 100644 index 00000000..12b9cb94 --- /dev/null +++ b/tests/test_special_functions.py @@ -0,0 +1,105 @@ +""" +Test the Kokkos Special Functions defined in +Kokkos_MathematicalSpecialFunctions.hpp + + - expint1 + - erf + - erfcx + - cyl_bessel_j0 + - cyl_bessel_y0 + - cyl_bessel_j1 + - cyl_bessel_y1 + - cyl_bessel_i0 + - cyl_bessel_k0 + - cyl_bessel_i1 + - cyl_bessel_k1 + - cyl_bessel_h10 + - cyl_bessel_h11 + - cyl_bessel_h20 + - cyl_bessel_h21 + +For np.float32, np.float64, np.complex64 and np.complex128 dtypes +""" + +import pytest +import numpy as np +import pykokkos as pk +import scipy.special as spsp +from functools import partial + + +@pk.workunit +def special_function_workunit(tid, out, arr, flag): + if flag == 0: + out[tid] = expint1(arr[tid]) + elif flag == 1: + out[tid] = erf(arr[tid]) + elif flag == 2: + out[tid] = erfcx(arr[tid]) + elif flag == 3: + out[tid] = cyl_bessel_j0(arr[tid]) + elif flag == 4: + out[tid] = cyl_bessel_y0(arr[tid]) + elif flag == 5: + out[tid] = cyl_bessel_j1(arr[tid]) + elif flag == 6: + out[tid] = cyl_bessel_y1(arr[tid]) + elif flag == 7: + out[tid] = cyl_bessel_i0(arr[tid]) + elif flag == 8: + out[tid] = cyl_bessel_k0(arr[tid]) + elif flag == 9: + out[tid] = cyl_bessel_i1(arr[tid]) + elif flag == 10: + out[tid] = cyl_bessel_k1(arr[tid]) + elif flag == 11: + out[tid] = cyl_bessel_h10(arr[tid]) + elif flag == 12: + out[tid] = cyl_bessel_h11(arr[tid]) + elif flag == 13: + out[tid] = cyl_bessel_h20(arr[tid]) + elif flag == 14: + out[tid] = cyl_bessel_h21(arr[tid]) + + +@pytest.mark.parametrize( + "flag, sp_func", + [ + (0, spsp.exp1), + (1, spsp.erf), + (2, spsp.erfcx), + (3, partial(spsp.jv, v=0)), + (4, partial(spsp.yv, v=0)), + (5, partial(spsp.jv, v=1)), + (6, partial(spsp.yv, v=1)), + (7, partial(spsp.iv, v=0)), + (8, partial(spsp.kv, v=0)), + (9, partial(spsp.iv, v=1)), + (10, partial(spsp.kv, v=1)), + (11, partial(spsp.hankel1, v=0)), + (12, partial(spsp.hankel1, v=1)), + (13, partial(spsp.hankel2, v=0)), + (14, partial(spsp.hankel2, v=1)), + ], +) +@pytest.mark.parametrize("dtype", [np.complex128]) # TODO: add more types +def test_kokkos_special_functions(flag, sp_func, dtype): + # generate random numpy data + N = 400 + rng = np.random.default_rng() + if dtype == np.complex64: + real = rng.random(size=N, dtype=np.float32) * N - (N // 2) + imag = rng.random(size=N, dtype=np.float32) * N - (N // 2) + arr = real + 1j * imag + elif dtype == np.complex128: + real = rng.random(size=N, dtype=np.float32) * N - (N // 2) + imag = rng.random(size=N, dtype=np.float32) * N - (N // 2) + arr = real + 1j * imag + else: + arr = rng.random(size=N, dtype=dtype) * N - (N // 2) + + expected = np.empty_like(arr) + pk.parallel_for(N, special_function_workunit, out=expected, arr=arr, flag=flag) + actual = pk_function(arr) + + np.testing.assert_allequal(actual, expected)