Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
21 commits
Select commit Hold shift + click to select a range
0dae1c5
add inclusive scan translation along with Kokkos `begin()` and `end()`
IvanGrigorik Dec 15, 2025
8704ef5
add inclusive scan example inside of the kernel
IvanGrigorik Dec 15, 2025
99ac006
add init file
IvanGrigorik Dec 15, 2025
dad8aae
add lower and upper bound functions, using binary search
IvanGrigorik Dec 15, 2025
e4acc6f
remove `inclusive_scan` from other branch
IvanGrigorik Dec 15, 2025
d666913
fix mypy issues
IvanGrigorik Dec 15, 2025
5b32aba
rename `test` to `example`
IvanGrigorik Dec 15, 2025
c389845
remove unnecessary comments and annotations
IvanGrigorik Dec 18, 2025
a908a0f
update examples, remove annotations and comments
IvanGrigorik Dec 18, 2025
08177c7
add support of multiple ranges
IvanGrigorik Dec 18, 2025
92ad949
Merge branch 'grigorik/upper_bound' of https://github.com/kokkos/pyko…
IvanGrigorik Dec 18, 2025
7c8f17a
add ignore flags to runtests
IvanGrigorik Dec 18, 2025
5fe1bb6
Revert "add support of multiple ranges"
IvanGrigorik Dec 19, 2025
866a0a5
Revert "add ignore flags to runtests"
IvanGrigorik Dec 19, 2025
a122cd3
trying to figure out the issue with CIs
IvanGrigorik Dec 19, 2025
312d86e
Revert "Revert "add ignore flags to runtests""
IvanGrigorik Dec 19, 2025
c2a9225
Revert "Revert "add support of multiple ranges""
IvanGrigorik Dec 19, 2025
e50875c
Merge branch 'main' into grigorik/upper_bound
IvanGrigorik Dec 23, 2025
297a4d9
Merge branch 'main' into grigorik/upper_bound
IvanGrigorik Dec 25, 2025
e142d10
try to fix test CI
IvanGrigorik Dec 25, 2025
c88da9b
add class method annotation
IvanGrigorik Dec 25, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
97 changes: 97 additions & 0 deletions examples/kokkos/lower_bound_example.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,97 @@
import pykokkos as pk


@pk.workunit
def init_data(i, view):
view[i] = i + 1


# Test lower_bound with scratch memory
@pk.workunit
def team_lower_bound(team_member, view, result_view):
team_size: int = team_member.team_size()
offset: int = team_member.league_rank() * team_size
localIdx: int = team_member.team_rank()
globalIdx: int = offset + localIdx
team_rank: int = team_member.team_rank()

scratch: pk.ScratchView1D[int] = pk.ScratchView1D(
team_member.team_scratch(0), team_size
)

scratch[team_rank] = view[globalIdx]
team_member.team_barrier()
search_value: int = team_rank * 2 # Search for a value
Comment thread
IvanGrigorik marked this conversation as resolved.
bound_idx: int = pk.lower_bound(scratch, team_size, search_value)
result_view[globalIdx] = bound_idx


# Test lower_bound with regular view
# Find lower bound for value i in the first 10 elements
@pk.workunit
def lower_bound_view(i, view, result_view):
search_value: int = i
bound_idx: int = pk.lower_bound(view, 10, search_value)
result_view[i] = bound_idx


def main():
N = 64
team_size = 32
num_teams = (N + team_size - 1) // team_size

view: pk.View1D[int] = pk.View([N], int)
result_view: pk.View1D[int] = pk.View([N], int)

# Expected results
expected_scratch = pk.View([64], int)
expected_scratch_data = [
0, 1, 3, 5, 7, 9, 11, 13, 15, 17, 19, 21, 23, 25, 27, 29, 31,
32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 1, 3, 5, 7, 9, 11, 13, 15, 17, 19, 21, 23, 25, 27, 29
]
for i in range(64):
expected_scratch[i] = expected_scratch_data[i]

expected_view = pk.View([64], int)
expected_view_data = [
0, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9,
10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10,
10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10,
10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10,
10, 10, 10, 10, 10, 10, 10, 10
]
for i in range(64):
expected_view[i] = expected_view_data[i]

p_init = pk.RangePolicy(pk.ExecutionSpace.OpenMP, 0, N)
pk.parallel_for(p_init, init_data, view=view)

print(f"Total elements: {N}, Team size: {team_size}, Number of teams: {num_teams}")
print(f"Initial view: {view}")

# Test with TeamPolicy (scratch memory)
team_policy = pk.TeamPolicy(pk.ExecutionSpace.OpenMP, num_teams, team_size)

pk.parallel_for(team_policy, team_lower_bound, view=view, result_view=result_view)
print(f"Result (scratch lower_bound): {result_view}")

# Assert scratch lower_bound results
for i in range(N):
assert (
result_view[i] == expected_scratch[i]
), f"Mismatch at index {i}: got {result_view[i]}, expected {expected_scratch[i]}"

# Test with RangePolicy (regular view)
pk.parallel_for(p_init, lower_bound_view, view=view, result_view=result_view)
print(f"Result (view lower_bound): {result_view}")
Comment thread
IvanGrigorik marked this conversation as resolved.

for i in range(N):
assert (
result_view[i] == expected_view[i]
), f"Mismatch at index {i}: got {result_view[i]}, expected {expected_view[i]}"


if __name__ == "__main__":
main()
100 changes: 100 additions & 0 deletions examples/kokkos/upper_bound_example.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,100 @@
import pykokkos as pk


@pk.workunit
def init_data(i, view):
view[i] = i + 1


# Test upper_bound with scratch memory
@pk.workunit
def team_upper_bound(team_member, view, result_view):
team_size: int = team_member.team_size()
offset: int = team_member.league_rank() * team_size
localIdx: int = team_member.team_rank()
globalIdx: int = offset + localIdx
team_rank: int = team_member.team_rank()

scratch: pk.ScratchView1D[int] = pk.ScratchView1D(
team_member.team_scratch(0), team_size
)

scratch[team_rank] = view[globalIdx]
team_member.team_barrier()
search_value: int = team_rank * 2 # Search for a value
Comment thread
IvanGrigorik marked this conversation as resolved.
bound_idx: int = pk.upper_bound(scratch, team_size, search_value)
result_view[globalIdx] = bound_idx


# Test upper_bound with regular view
# Find upper bound for value i in the first 10 elements
@pk.workunit
def upper_bound_view(i, view, result_view):
search_value: int = i
bound_idx: int = pk.upper_bound(view, 10, search_value)
result_view[i] = bound_idx


def main():
N = 64
team_size = 32
num_teams = (N + team_size - 1) // team_size

view: pk.View1D[int] = pk.View([N], int)
result_view: pk.View1D[int] = pk.View([N], int)

# Expected results
expected_scratch = pk.View([64], int)
expected_scratch_data = [
0, 2, 4, 6, 8, 10, 12, 14, 16, 18, 20, 22, 24, 26, 28, 30, 32,
32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 2, 4, 6, 8, 10, 12, 14, 16, 18, 20, 22, 24, 26, 28, 30
]
for i in range(64):
expected_scratch[i] = expected_scratch_data[i]

expected_view = pk.View([64], int)
expected_view_data = [
0, 1, 2, 3, 4, 5, 6, 7, 8, 9,
10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10,
10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10,
10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10,
10, 10, 10, 10, 10, 10, 10, 10
]
for i in range(64):
expected_view[i] = expected_view_data[i]

p_init = pk.RangePolicy(pk.ExecutionSpace.Cuda, 0, N)
pk.parallel_for(p_init, init_data, view=view)

print(f"Total elements: {N}, Team size: {team_size}, Number of teams: {num_teams}")
print(f"Initial view: {view}")

# Test with TeamPolicy (scratch memory)
Comment thread
IvanGrigorik marked this conversation as resolved.
team_policy = pk.TeamPolicy(pk.ExecutionSpace.Cuda, num_teams, team_size)

pk.parallel_for(team_policy, team_upper_bound, view=view, result_view=result_view)
print(f"Result (scratch upper_bound): {result_view}")
Comment thread
IvanGrigorik marked this conversation as resolved.

# Assert scratch upper_bound results
for i in range(N):
assert (
result_view[i] == expected_scratch[i]
), f"Mismatch at index {i}: got {result_view[i]}, expected {expected_scratch[i]}"
print("Scratch upper_bound test passed")

# Test with RangePolicy (regular view)
Comment thread
IvanGrigorik marked this conversation as resolved.
pk.parallel_for(p_init, upper_bound_view, view=view, result_view=result_view)
print(f"Result (view upper_bound): {result_view}")
Comment thread
IvanGrigorik marked this conversation as resolved.

# Assert view upper_bound results
for i in range(N):
assert (
result_view[i] == expected_view[i]
), f"Mismatch at index {i}: got {result_view[i]}, expected {expected_view[i]}"
print("View upper_bound test passed")


if __name__ == "__main__":
main()
2 changes: 2 additions & 0 deletions pykokkos/core/translators/static.py
Original file line number Diff line number Diff line change
Expand Up @@ -268,6 +268,7 @@ def generate_includes(self) -> str:
"Kokkos_Core.hpp",
"Kokkos_Random.hpp",
"Kokkos_Sort.hpp",
"Kokkos_StdAlgorithms.hpp",
"fstream",
"iostream",
"cmath",
Expand All @@ -290,6 +291,7 @@ def generate_cast_includes(self) -> str:
"Kokkos_Core.hpp",
"Kokkos_Random.hpp",
"Kokkos_Sort.hpp",
"Kokkos_StdAlgorithms.hpp",
"fstream",
"iostream",
"cmath",
Expand Down
2 changes: 1 addition & 1 deletion pykokkos/core/translators/symbols_pass.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ def __init__(self, members: PyKokkosMembers, pk_import: str, path: str):
self.global_symbols.update(math_functions)
self.global_symbols.update(allowed_types)
self.global_symbols.update(view_dtypes)
self.global_symbols.update(["self", "range", "math", "List", "abs"])
self.global_symbols.update(["self", "range", "math", "List", "abs", "upper_bound", "lower_bound"])
self.global_symbols.add(pk_import)

self.global_symbols.update([field.declname for field in members.fields])
Expand Down
70 changes: 70 additions & 0 deletions pykokkos/core/visitors/workunit_visitor.py
Original file line number Diff line number Diff line change
Expand Up @@ -318,6 +318,76 @@ def visit_Call(self, node: ast.Call) -> cppast.CallExpr:

return real_number_call

# Custom `upper_bound` implementation using binary search
if name == "upper_bound":
# Check if it's called via pk.upper_bound
is_pk_call = (
isinstance(node.func, ast.Attribute)
and isinstance(node.func.value, ast.Name)
and node.func.value.id == self.pk_import
)

if not is_pk_call:
return super().visit_Call(node)

# Expected signature: pk.upper_bound(view, size, value)
if len(args) != 3:
self.error(
node,
"pk.upper_bound() takes 3 arguments: view, size, value",
)

view_expr = args[0]
size_expr = args[1]
value_expr = args[2]

# Generate binary search lambda inline
from pykokkos.interface.algorithms.upper_bound import generate_upper_bound_binary_search

# Create lambda body with binary search
lambda_body = generate_upper_bound_binary_search(view_expr, size_expr, value_expr)

# Create and invoke lambda
lambda_expr = cppast.LambdaExpr("[&]", [], lambda_body)
lambda_call = cppast.CallExpr(lambda_expr, [])

return lambda_call

# Custom `lower_bound` implementation using binary search
if name == "lower_bound":
# Check if it's called via pk.lower_bound
is_pk_call = (
isinstance(node.func, ast.Attribute)
and isinstance(node.func.value, ast.Name)
and node.func.value.id == self.pk_import
)

if not is_pk_call:
return super().visit_Call(node)

# Expected signature: pk.lower_bound(view, size, value)
if len(args) != 3:
self.error(
node,
"pk.lower_bound() takes 3 arguments: view, size, value",
)

view_expr = args[0]
size_expr = args[1]
value_expr = args[2]

# Generate binary search lambda inline
from pykokkos.interface.algorithms.lower_bound import generate_lower_bound_binary_search

# Create lambda body with binary search
lambda_body = generate_lower_bound_binary_search(view_expr, size_expr, value_expr)

# Create and invoke lambda
lambda_expr = cppast.LambdaExpr("[&]", [], lambda_body)
lambda_call = cppast.CallExpr(lambda_expr, [])

return lambda_call

return super().visit_Call(node)

def is_nested_call(self, node: ast.FunctionDef) -> bool:
Expand Down
1 change: 1 addition & 0 deletions pykokkos/interface/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from .accumulator import Acc
from .algorithms import lower_bound, upper_bound
from .atomic.atomic_fetch_op import (
atomic_fetch_add, atomic_fetch_and, atomic_fetch_div,
atomic_fetch_lshift, atomic_fetch_max, atomic_fetch_min,
Expand Down
4 changes: 4 additions & 0 deletions pykokkos/interface/algorithms/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
from .lower_bound import lower_bound
from .upper_bound import upper_bound

__all__ = ["lower_bound", "upper_bound"]
Loading
Loading