diff --git a/pykokkos/core/visitors/workunit_visitor.py b/pykokkos/core/visitors/workunit_visitor.py index 807eafe4..447201c3 100644 --- a/pykokkos/core/visitors/workunit_visitor.py +++ b/pykokkos/core/visitors/workunit_visitor.py @@ -295,16 +295,18 @@ def visit_Call(self, node: ast.Call) -> cppast.CallExpr: else: return cppast.CallExpr(function, [args[0], f"pk_id_{work_unit}"]) - atomic_fetch_op: re.Pattern = re.compile("atomic_fetch_*") - is_atomic_fetch_op: bool = atomic_fetch_op.match(name) + atomic_op: re.Pattern = re.compile("atomic_.*") is_atomic_increment: bool = name == "atomic_increment" is_atomic_compare_exchange: bool = name == "atomic_compare_exchange" + is_atomic_op: bool = atomic_op.match(name) and not ( + is_atomic_increment or is_atomic_compare_exchange + ) - if is_atomic_fetch_op or is_atomic_compare_exchange or is_atomic_increment: + if is_atomic_op or is_atomic_compare_exchange or is_atomic_increment: if is_atomic_increment and len(args) != 2: self.error(node, "is_atomic_increment takes exactly 2 arguments") - if not is_atomic_increment and is_atomic_fetch_op and len(args) != 3: - self.error(node, "atomic_fetch_op functions take exactly 3 arguments") + if not is_atomic_increment and is_atomic_op and len(args) != 3: + self.error(node, "atomic_op functions take exactly 3 arguments") if is_atomic_compare_exchange and len(args) != 4: self.error(node, "atomic_compare_exchange takes exactly 4 arguments") @@ -312,11 +314,7 @@ def visit_Call(self, node: ast.Call) -> cppast.CallExpr: args[0] = cppast.CallExpr(args[0], args[1].exprs) del args[1] - # if not isinstance(args[0], cppast.CallExpr): - # self.error( - # node, "atomic_fetch_op functions only support views") - - # atomic_fetch_* operations need to have an address as + # atomic_* & atomic_fetch_* operations need to have an address as # their first argument args[0] = cppast.UnaryOperator(args[0], cppast.BinaryOperatorKind.AddrOf) return cppast.CallExpr(function, args) diff --git a/tests/test_atomics.py b/tests/test_atomics.py index dcb89668..c543cc5d 100644 --- a/tests/test_atomics.py +++ b/tests/test_atomics.py @@ -24,6 +24,7 @@ def __init__(self, threads: int, i_1: int, i_2: int, f_1: float, f_2: float): self.view1D_rshift: pk.View1D[pk.int32] = pk.View([1], pk.int32) self.view1D_sub: pk.View1D[pk.double] = pk.View([1], pk.double) self.view1D_xor: pk.View1D[pk.int32] = pk.View([1], pk.int32) + self.view1D_ace: pk.View1D[pk.int32] = pk.View([1], pk.int32) self.view1D_add[0] = f_1 self.view1D_and[0] = i_1 @@ -37,54 +38,109 @@ def __init__(self, threads: int, i_1: int, i_2: int, f_1: float, f_2: float): self.view1D_rshift[0] = i_1 self.view1D_sub[0] = f_1 self.view1D_xor[0] = i_1 + self.view1D_ace[0] = i_1 @pk.workunit def atomic_add(self, tid: int) -> None: - pk.atomic_fetch_add(self.view1D_add, [0], self.f_2) + pk.atomic_add(self.view1D_add, [0], self.f_2) @pk.workunit def atomic_and(self, tid: int) -> None: - pk.atomic_fetch_and(self.view1D_and, [0], self.i_2) + pk.atomic_and(self.view1D_and, [0], self.i_2) @pk.workunit def atomic_div(self, tid: int) -> None: - pk.atomic_fetch_div(self.view1D_div, [0], self.f_2) + pk.atomic_div(self.view1D_div, [0], self.f_2) @pk.workunit def atomic_lshift(self, tid: int) -> None: - pk.atomic_fetch_lshift(self.view1D_lshift, [0], self.i_2) + pk.atomic_lshift(self.view1D_lshift, [0], self.i_2) @pk.workunit def atomic_max(self, tid: int) -> None: - pk.atomic_fetch_max(self.view1D_max, [0], self.f_2) + pk.atomic_max(self.view1D_max, [0], self.f_2) @pk.workunit def atomic_min(self, tid: int) -> None: - pk.atomic_fetch_min(self.view1D_min, [0], self.f_2) + pk.atomic_min(self.view1D_min, [0], self.f_2) @pk.workunit def atomic_mod(self, tid: int) -> None: - pk.atomic_fetch_mod(self.view1D_mod, [0], self.i_2) + pk.atomic_mod(self.view1D_mod, [0], self.i_2) @pk.workunit def atomic_mul(self, tid: int) -> None: - pk.atomic_fetch_mul(self.view1D_mul, [0], self.f_2) + pk.atomic_mul(self.view1D_mul, [0], self.f_2) @pk.workunit def atomic_or(self, tid: int) -> None: - pk.atomic_fetch_or(self.view1D_or, [0], self.i_2) + pk.atomic_or(self.view1D_or, [0], self.i_2) @pk.workunit def atomic_rshift(self, tid: int) -> None: - pk.atomic_fetch_rshift(self.view1D_rshift, [0], self.i_2) + pk.atomic_rshift(self.view1D_rshift, [0], self.i_2) @pk.workunit def atomic_sub(self, tid: int) -> None: - pk.atomic_fetch_sub(self.view1D_sub, [0], self.f_2) + pk.atomic_sub(self.view1D_sub, [0], self.f_2) @pk.workunit def atomic_xor(self, tid: int) -> None: - pk.atomic_fetch_xor(self.view1D_xor, [0], self.i_2) + pk.atomic_xor(self.view1D_xor, [0], self.i_2) + + @pk.workunit + def atomic_fetch_add(self, tid: int) -> None: + old_value: pk.double = pk.atomic_fetch_add(self.view1D_add, [0], self.f_2) + + @pk.workunit + def atomic_fetch_and(self, tid: int) -> None: + old_value: pk.double = pk.atomic_fetch_and(self.view1D_and, [0], self.i_2) + + @pk.workunit + def atomic_fetch_div(self, tid: int) -> None: + old_value: pk.double = pk.atomic_fetch_div(self.view1D_div, [0], self.f_2) + + @pk.workunit + def atomic_fetch_lshift(self, tid: int) -> None: + old_value: pk.double = pk.atomic_fetch_lshift(self.view1D_lshift, [0], self.i_2) + + @pk.workunit + def atomic_fetch_max(self, tid: int) -> None: + old_value: pk.double = pk.atomic_fetch_max(self.view1D_max, [0], self.f_2) + + @pk.workunit + def atomic_fetch_min(self, tid: int) -> None: + old_value: pk.double = pk.atomic_fetch_min(self.view1D_min, [0], self.f_2) + + @pk.workunit + def atomic_fetch_mod(self, tid: int) -> None: + old_value: pk.double = pk.atomic_fetch_mod(self.view1D_mod, [0], self.i_2) + + @pk.workunit + def atomic_fetch_mul(self, tid: int) -> None: + old_value: pk.double = pk.atomic_fetch_mul(self.view1D_mul, [0], self.f_2) + + @pk.workunit + def atomic_fetch_or(self, tid: int) -> None: + old_value: pk.double = pk.atomic_fetch_or(self.view1D_or, [0], self.i_2) + + @pk.workunit + def atomic_fetch_rshift(self, tid: int) -> None: + old_value: pk.double = pk.atomic_fetch_rshift(self.view1D_rshift, [0], self.i_2) + + @pk.workunit + def atomic_fetch_sub(self, tid: int) -> None: + old_value: pk.double = pk.atomic_fetch_sub(self.view1D_sub, [0], self.f_2) + + @pk.workunit + def atomic_fetch_xor(self, tid: int) -> None: + old_value: pk.double = pk.atomic_fetch_xor(self.view1D_xor, [0], self.i_2) + + @pk.workunit + def atomic_compare_exchange(self, tide: int) -> None: + old_value: pk.int32 = pk.atomic_compare_exchange( + self.view1D_ace, [0], self.i_1, self.i_2 + ) class TestAtomic(unittest.TestCase): @@ -196,6 +252,110 @@ def test_atomic_xor(self): self.assertEqual(expected_result, result) + def test_atomic_fetch_add(self): + expected_result: float = self.f_1 + self.f_2 + + pk.parallel_for(self.range_policy, self.functor.atomic_fetch_add) + result: float = self.functor.view1D_add[0] + + self.assertEqual(expected_result, result) + + def test_atomic_fetch_and(self): + expected_result: int = self.i_1 & self.i_2 + + pk.parallel_for(self.range_policy, self.functor.atomic_fetch_and) + result: int = self.functor.view1D_and[0] + + self.assertEqual(expected_result, result) + + def test_atomic_fetch_div(self): + expected_result: float = self.f_1 / self.f_2 + + pk.parallel_for(self.range_policy, self.functor.atomic_fetch_div) + result: float = self.functor.view1D_div[0] + + self.assertEqual(expected_result, result) + + def test_atomic_fetch_lshift(self): + expected_result: int = self.i_1 << self.i_2 + + pk.parallel_for(self.range_policy, self.functor.atomic_fetch_lshift) + result: int = self.functor.view1D_lshift[0] + + self.assertEqual(expected_result, result) + + def test_atomic_fetch_max(self): + expected_result: float = max(self.f_1, self.f_2) + + result: float = self.functor.view1D_max[0] + pk.parallel_for(self.range_policy, self.functor.atomic_fetch_max) + + self.assertEqual(expected_result, result) + + def test_atomic_fetch_min(self): + expected_result: float = min(self.f_1, self.f_2) + + pk.parallel_for(self.range_policy, self.functor.atomic_fetch_min) + result: float = self.functor.view1D_min[0] + + self.assertEqual(expected_result, result) + + def test_atomic_fetch_mod(self): + expected_result: int = self.i_1 % self.i_2 + + pk.parallel_for(self.range_policy, self.functor.atomic_fetch_mod) + result: int = self.functor.view1D_mod[0] + + self.assertEqual(expected_result, result) + + def test_atomic_fetch_mul(self): + expected_result: float = self.f_1 * self.f_2 + + pk.parallel_for(self.range_policy, self.functor.atomic_fetch_mul) + result: float = self.functor.view1D_mul[0] + + self.assertEqual(expected_result, result) + + def test_atomic_fetch_or(self): + expected_result: int = self.i_1 | self.i_2 + + pk.parallel_for(self.range_policy, self.functor.atomic_fetch_or) + result: int = self.functor.view1D_or[0] + + self.assertEqual(expected_result, result) + + def test_atomic_fetch_rshift(self): + expected_result: int = self.i_1 >> self.i_2 + + pk.parallel_for(self.range_policy, self.functor.atomic_fetch_rshift) + result: int = self.functor.view1D_rshift[0] + + self.assertEqual(expected_result, result) + + def test_atomic_fetch_sub(self): + expected_result: float = self.f_1 - self.f_2 + + pk.parallel_for(self.range_policy, self.functor.atomic_fetch_sub) + result: float = self.functor.view1D_sub[0] + + self.assertEqual(expected_result, result) + + def test_atomic_fetch_xor(self): + expected_result: int = self.i_1 ^ self.i_2 + + pk.parallel_for(self.range_policy, self.functor.atomic_fetch_xor) + result: int = self.functor.view1D_xor[0] + + self.assertEqual(expected_result, result) + + def test_atomic_compare_exchange(self): + expected_result: int = self.i_2 + + pk.parallel_for(self.range_policy, self.functor.atomic_compare_exchange) + result: int = self.functor.view1D_ace[0] + + self.assertEqual(expected_result, result) + if __name__ == "__main__": unittest.main()