Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 8 additions & 10 deletions pykokkos/core/visitors/workunit_visitor.py
Original file line number Diff line number Diff line change
Expand Up @@ -295,28 +295,26 @@ def visit_Call(self, node: ast.Call) -> cppast.CallExpr:
else:
return cppast.CallExpr(function, [args[0], f"pk_id_{work_unit}"])

atomic_fetch_op: re.Pattern = re.compile("atomic_fetch_*")
is_atomic_fetch_op: bool = atomic_fetch_op.match(name)
atomic_op: re.Pattern = re.compile("atomic_.*")
is_atomic_increment: bool = name == "atomic_increment"
is_atomic_compare_exchange: bool = name == "atomic_compare_exchange"
is_atomic_op: bool = atomic_op.match(name) and not (
is_atomic_increment or is_atomic_compare_exchange
)

if is_atomic_fetch_op or is_atomic_compare_exchange or is_atomic_increment:
if is_atomic_op or is_atomic_compare_exchange or is_atomic_increment:
if is_atomic_increment and len(args) != 2:
self.error(node, "is_atomic_increment takes exactly 2 arguments")
if not is_atomic_increment and is_atomic_fetch_op and len(args) != 3:
self.error(node, "atomic_fetch_op functions take exactly 3 arguments")
if not is_atomic_increment and is_atomic_op and len(args) != 3:
self.error(node, "atomic_op functions take exactly 3 arguments")
if is_atomic_compare_exchange and len(args) != 4:
self.error(node, "atomic_compare_exchange takes exactly 4 arguments")

# convert indices
args[0] = cppast.CallExpr(args[0], args[1].exprs)
del args[1]

# if not isinstance(args[0], cppast.CallExpr):
# self.error(
# node, "atomic_fetch_op functions only support views")

# atomic_fetch_* operations need to have an address as
# atomic_* & atomic_fetch_* operations need to have an address as
# their first argument
args[0] = cppast.UnaryOperator(args[0], cppast.BinaryOperatorKind.AddrOf)
return cppast.CallExpr(function, args)
Expand Down
184 changes: 172 additions & 12 deletions tests/test_atomics.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ def __init__(self, threads: int, i_1: int, i_2: int, f_1: float, f_2: float):
self.view1D_rshift: pk.View1D[pk.int32] = pk.View([1], pk.int32)
self.view1D_sub: pk.View1D[pk.double] = pk.View([1], pk.double)
self.view1D_xor: pk.View1D[pk.int32] = pk.View([1], pk.int32)
self.view1D_ace: pk.View1D[pk.int32] = pk.View([1], pk.int32)

self.view1D_add[0] = f_1
self.view1D_and[0] = i_1
Expand All @@ -37,54 +38,109 @@ def __init__(self, threads: int, i_1: int, i_2: int, f_1: float, f_2: float):
self.view1D_rshift[0] = i_1
self.view1D_sub[0] = f_1
self.view1D_xor[0] = i_1
self.view1D_ace[0] = i_1

@pk.workunit
def atomic_add(self, tid: int) -> None:
pk.atomic_fetch_add(self.view1D_add, [0], self.f_2)
pk.atomic_add(self.view1D_add, [0], self.f_2)

@pk.workunit
def atomic_and(self, tid: int) -> None:
pk.atomic_fetch_and(self.view1D_and, [0], self.i_2)
pk.atomic_and(self.view1D_and, [0], self.i_2)

@pk.workunit
def atomic_div(self, tid: int) -> None:
pk.atomic_fetch_div(self.view1D_div, [0], self.f_2)
pk.atomic_div(self.view1D_div, [0], self.f_2)

@pk.workunit
def atomic_lshift(self, tid: int) -> None:
pk.atomic_fetch_lshift(self.view1D_lshift, [0], self.i_2)
pk.atomic_lshift(self.view1D_lshift, [0], self.i_2)

@pk.workunit
def atomic_max(self, tid: int) -> None:
pk.atomic_fetch_max(self.view1D_max, [0], self.f_2)
pk.atomic_max(self.view1D_max, [0], self.f_2)

@pk.workunit
def atomic_min(self, tid: int) -> None:
pk.atomic_fetch_min(self.view1D_min, [0], self.f_2)
pk.atomic_min(self.view1D_min, [0], self.f_2)

@pk.workunit
def atomic_mod(self, tid: int) -> None:
pk.atomic_fetch_mod(self.view1D_mod, [0], self.i_2)
pk.atomic_mod(self.view1D_mod, [0], self.i_2)

@pk.workunit
def atomic_mul(self, tid: int) -> None:
pk.atomic_fetch_mul(self.view1D_mul, [0], self.f_2)
pk.atomic_mul(self.view1D_mul, [0], self.f_2)

@pk.workunit
def atomic_or(self, tid: int) -> None:
pk.atomic_fetch_or(self.view1D_or, [0], self.i_2)
pk.atomic_or(self.view1D_or, [0], self.i_2)

@pk.workunit
def atomic_rshift(self, tid: int) -> None:
pk.atomic_fetch_rshift(self.view1D_rshift, [0], self.i_2)
pk.atomic_rshift(self.view1D_rshift, [0], self.i_2)

@pk.workunit
def atomic_sub(self, tid: int) -> None:
pk.atomic_fetch_sub(self.view1D_sub, [0], self.f_2)
pk.atomic_sub(self.view1D_sub, [0], self.f_2)

@pk.workunit
def atomic_xor(self, tid: int) -> None:
pk.atomic_fetch_xor(self.view1D_xor, [0], self.i_2)
pk.atomic_xor(self.view1D_xor, [0], self.i_2)

@pk.workunit
def atomic_fetch_add(self, tid: int) -> None:
old_value: pk.double = pk.atomic_fetch_add(self.view1D_add, [0], self.f_2)

@pk.workunit
def atomic_fetch_and(self, tid: int) -> None:
old_value: pk.double = pk.atomic_fetch_and(self.view1D_and, [0], self.i_2)

@pk.workunit
def atomic_fetch_div(self, tid: int) -> None:
old_value: pk.double = pk.atomic_fetch_div(self.view1D_div, [0], self.f_2)

@pk.workunit
def atomic_fetch_lshift(self, tid: int) -> None:
old_value: pk.double = pk.atomic_fetch_lshift(self.view1D_lshift, [0], self.i_2)

@pk.workunit
def atomic_fetch_max(self, tid: int) -> None:
old_value: pk.double = pk.atomic_fetch_max(self.view1D_max, [0], self.f_2)

@pk.workunit
def atomic_fetch_min(self, tid: int) -> None:
old_value: pk.double = pk.atomic_fetch_min(self.view1D_min, [0], self.f_2)

@pk.workunit
def atomic_fetch_mod(self, tid: int) -> None:
old_value: pk.double = pk.atomic_fetch_mod(self.view1D_mod, [0], self.i_2)

@pk.workunit
def atomic_fetch_mul(self, tid: int) -> None:
old_value: pk.double = pk.atomic_fetch_mul(self.view1D_mul, [0], self.f_2)

@pk.workunit
def atomic_fetch_or(self, tid: int) -> None:
old_value: pk.double = pk.atomic_fetch_or(self.view1D_or, [0], self.i_2)

@pk.workunit
def atomic_fetch_rshift(self, tid: int) -> None:
old_value: pk.double = pk.atomic_fetch_rshift(self.view1D_rshift, [0], self.i_2)

@pk.workunit
def atomic_fetch_sub(self, tid: int) -> None:
old_value: pk.double = pk.atomic_fetch_sub(self.view1D_sub, [0], self.f_2)

@pk.workunit
def atomic_fetch_xor(self, tid: int) -> None:
old_value: pk.double = pk.atomic_fetch_xor(self.view1D_xor, [0], self.i_2)

@pk.workunit
def atomic_compare_exchange(self, tide: int) -> None:
old_value: pk.int32 = pk.atomic_compare_exchange(
self.view1D_ace, [0], self.i_1, self.i_2
)


class TestAtomic(unittest.TestCase):
Expand Down Expand Up @@ -196,6 +252,110 @@ def test_atomic_xor(self):

self.assertEqual(expected_result, result)

def test_atomic_fetch_add(self):
expected_result: float = self.f_1 + self.f_2

pk.parallel_for(self.range_policy, self.functor.atomic_fetch_add)
result: float = self.functor.view1D_add[0]

self.assertEqual(expected_result, result)

def test_atomic_fetch_and(self):
expected_result: int = self.i_1 & self.i_2

pk.parallel_for(self.range_policy, self.functor.atomic_fetch_and)
result: int = self.functor.view1D_and[0]

self.assertEqual(expected_result, result)

def test_atomic_fetch_div(self):
expected_result: float = self.f_1 / self.f_2

pk.parallel_for(self.range_policy, self.functor.atomic_fetch_div)
result: float = self.functor.view1D_div[0]

self.assertEqual(expected_result, result)

def test_atomic_fetch_lshift(self):
expected_result: int = self.i_1 << self.i_2

pk.parallel_for(self.range_policy, self.functor.atomic_fetch_lshift)
result: int = self.functor.view1D_lshift[0]

self.assertEqual(expected_result, result)

def test_atomic_fetch_max(self):
expected_result: float = max(self.f_1, self.f_2)

result: float = self.functor.view1D_max[0]
pk.parallel_for(self.range_policy, self.functor.atomic_fetch_max)

self.assertEqual(expected_result, result)

def test_atomic_fetch_min(self):
expected_result: float = min(self.f_1, self.f_2)

pk.parallel_for(self.range_policy, self.functor.atomic_fetch_min)
result: float = self.functor.view1D_min[0]

self.assertEqual(expected_result, result)

def test_atomic_fetch_mod(self):
expected_result: int = self.i_1 % self.i_2

pk.parallel_for(self.range_policy, self.functor.atomic_fetch_mod)
result: int = self.functor.view1D_mod[0]

self.assertEqual(expected_result, result)

def test_atomic_fetch_mul(self):
expected_result: float = self.f_1 * self.f_2

pk.parallel_for(self.range_policy, self.functor.atomic_fetch_mul)
result: float = self.functor.view1D_mul[0]

self.assertEqual(expected_result, result)

def test_atomic_fetch_or(self):
expected_result: int = self.i_1 | self.i_2

pk.parallel_for(self.range_policy, self.functor.atomic_fetch_or)
result: int = self.functor.view1D_or[0]

self.assertEqual(expected_result, result)

def test_atomic_fetch_rshift(self):
expected_result: int = self.i_1 >> self.i_2

pk.parallel_for(self.range_policy, self.functor.atomic_fetch_rshift)
result: int = self.functor.view1D_rshift[0]

self.assertEqual(expected_result, result)

def test_atomic_fetch_sub(self):
expected_result: float = self.f_1 - self.f_2

pk.parallel_for(self.range_policy, self.functor.atomic_fetch_sub)
result: float = self.functor.view1D_sub[0]

self.assertEqual(expected_result, result)

def test_atomic_fetch_xor(self):
expected_result: int = self.i_1 ^ self.i_2

pk.parallel_for(self.range_policy, self.functor.atomic_fetch_xor)
result: int = self.functor.view1D_xor[0]

self.assertEqual(expected_result, result)

def test_atomic_compare_exchange(self):
expected_result: int = self.i_2

pk.parallel_for(self.range_policy, self.functor.atomic_compare_exchange)
result: int = self.functor.view1D_ace[0]

self.assertEqual(expected_result, result)


if __name__ == "__main__":
unittest.main()
Loading