Skip to content
Open
Show file tree
Hide file tree
Changes from 7 commits
Commits
Show all changes
21 commits
Select commit Hold shift + click to select a range
0dae1c5
add inclusive scan translation along with Kokkos `begin()` and `end()`
IvanGrigorik Dec 15, 2025
8704ef5
add inclusive scan example inside of the kernel
IvanGrigorik Dec 15, 2025
99ac006
add init file
IvanGrigorik Dec 15, 2025
dad8aae
add lower and upper bound functions, using binary search
IvanGrigorik Dec 15, 2025
e4acc6f
remove `inclusive_scan` from other branch
IvanGrigorik Dec 15, 2025
d666913
fix mypy issues
IvanGrigorik Dec 15, 2025
5b32aba
rename `test` to `example`
IvanGrigorik Dec 15, 2025
c389845
remove unnecessary comments and annotations
IvanGrigorik Dec 18, 2025
a908a0f
update examples, remove annotations and comments
IvanGrigorik Dec 18, 2025
08177c7
add support of multiple ranges
IvanGrigorik Dec 18, 2025
92ad949
Merge branch 'grigorik/upper_bound' of https://github.com/kokkos/pyko…
IvanGrigorik Dec 18, 2025
7c8f17a
add ignore flags to runtests
IvanGrigorik Dec 18, 2025
5fe1bb6
Revert "add support of multiple ranges"
IvanGrigorik Dec 19, 2025
866a0a5
Revert "add ignore flags to runtests"
IvanGrigorik Dec 19, 2025
a122cd3
trying to figure out the issue with CIs
IvanGrigorik Dec 19, 2025
312d86e
Revert "Revert "add ignore flags to runtests""
IvanGrigorik Dec 19, 2025
c2a9225
Revert "Revert "add support of multiple ranges""
IvanGrigorik Dec 19, 2025
e50875c
Merge branch 'main' into grigorik/upper_bound
IvanGrigorik Dec 23, 2025
297a4d9
Merge branch 'main' into grigorik/upper_bound
IvanGrigorik Dec 25, 2025
e142d10
try to fix test CI
IvanGrigorik Dec 25, 2025
c88da9b
add class method annotation
IvanGrigorik Dec 25, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
75 changes: 75 additions & 0 deletions examples/kokkos/lower_bound_example.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,75 @@
import pykokkos as pk


@pk.workunit
def init_data(i: int, view: pk.View1D[int]):
view[i] = i + 1


# Test lower_bound with scratch memory
@pk.workunit
def team_lower_bound(team_member: pk.TeamMember, view: pk.View1D[int], result_view: pk.View1D[int]):
Comment thread
IvanGrigorik marked this conversation as resolved.
Outdated
team_size: int = team_member.team_size()
offset: int = team_member.league_rank() * team_size
localIdx: int = team_member.team_rank()
globalIdx: int = offset + localIdx
team_rank: int = team_member.team_rank()

# Allocate scratch memory for sorted data
scratch: pk.ScratchView1D[int] = pk.ScratchView1D(
team_member.team_scratch(0), team_size
)

# Copy data to scratch and make it sorted within the team
scratch[team_rank] = view[globalIdx]
team_member.team_barrier()

# Now use lower_bound to find position in scratch
# For example, find lower bound for the value at team_rank position
search_value: int = team_rank * 2 # Search for a value
Comment thread
IvanGrigorik marked this conversation as resolved.

# Find lower bound in scratch memory
bound_idx: int = pk.lower_bound(scratch, team_size, search_value)

# Store result
result_view[globalIdx] = bound_idx


# Test lower_bound with regular view
@pk.workunit
def lower_bound_view(i: int, view: pk.View1D[int], result_view: pk.View1D[int]):
Comment thread
IvanGrigorik marked this conversation as resolved.
Outdated
# Find lower bound for value i in the first 10 elements
search_value: int = i
bound_idx: int = pk.lower_bound(view, 10, search_value)
result_view[i] = bound_idx


def main():
N = 64
team_size = 32
num_teams = (N + team_size - 1) // team_size

view: pk.View1D[int] = pk.View([N], int)
result_view: pk.View1D[int] = pk.View([N], int)

p_init = pk.RangePolicy(pk.ExecutionSpace.OpenMP, 0, N)
pk.parallel_for(p_init, init_data, view=view)

print(f"Total elements: {N}, Team size: {team_size}, Number of teams: {num_teams}")
print(f"Initial view: {view}")

# Test with TeamPolicy (scratch memory)
team_policy = pk.TeamPolicy(pk.ExecutionSpace.OpenMP, num_teams, team_size)

print("\nRunning lower_bound with scratch memory...")
Comment thread
IvanGrigorik marked this conversation as resolved.
Outdated
pk.parallel_for(team_policy, team_lower_bound, view=view, result_view=result_view)
print(f"Result (scratch lower_bound): {result_view}")

# Test with RangePolicy (regular view)
print("\nRunning lower_bound with regular view...")
pk.parallel_for(p_init, lower_bound_view, view=view, result_view=result_view)
print(f"Result (view lower_bound): {result_view}")
Comment thread
IvanGrigorik marked this conversation as resolved.


if __name__ == "__main__":
main()
75 changes: 75 additions & 0 deletions examples/kokkos/upper_bound_example.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,75 @@
import pykokkos as pk


@pk.workunit
def init_data(i: int, view: pk.View1D[int]):
view[i] = i + 1


# Test upper_bound with scratch memory
@pk.workunit
def team_upper_bound(team_member: pk.TeamMember, view: pk.View1D[int], result_view: pk.View1D[int]):
Comment thread
IvanGrigorik marked this conversation as resolved.
Outdated
team_size: int = team_member.team_size()
offset: int = team_member.league_rank() * team_size
localIdx: int = team_member.team_rank()
globalIdx: int = offset + localIdx
team_rank: int = team_member.team_rank()

# Allocate scratch memory for sorted data
scratch: pk.ScratchView1D[int] = pk.ScratchView1D(
team_member.team_scratch(0), team_size
)

# Copy data to scratch and make it sorted within the team
Comment thread
IvanGrigorik marked this conversation as resolved.
Outdated
scratch[team_rank] = view[globalIdx]
team_member.team_barrier()

# Now use upper_bound to find position in scratch
# For example, find upper bound for the value at team_rank position
search_value: int = team_rank * 2 # Search for a value
Comment thread
IvanGrigorik marked this conversation as resolved.

# Find upper bound in scratch memory
Comment thread
IvanGrigorik marked this conversation as resolved.
Outdated
bound_idx: int = pk.upper_bound(scratch, team_size, search_value)

# Store result
Comment thread
IvanGrigorik marked this conversation as resolved.
Outdated
result_view[globalIdx] = bound_idx


# Test upper_bound with regular view
@pk.workunit
def upper_bound_view(i: int, view: pk.View1D[int], result_view: pk.View1D[int]):
Comment thread
IvanGrigorik marked this conversation as resolved.
Outdated
# Find upper bound for value i in the first 10 elements
Comment thread
IvanGrigorik marked this conversation as resolved.
Outdated
search_value: int = i
bound_idx: int = pk.upper_bound(view, 10, search_value)
result_view[i] = bound_idx


def main():
N = 64
team_size = 32
num_teams = (N + team_size - 1) // team_size

view: pk.View1D[int] = pk.View([N], int)
result_view: pk.View1D[int] = pk.View([N], int)

p_init = pk.RangePolicy(pk.ExecutionSpace.Cuda, 0, N)
pk.parallel_for(p_init, init_data, view=view)

print(f"Total elements: {N}, Team size: {team_size}, Number of teams: {num_teams}")
print(f"Initial view: {view}")

# Test with TeamPolicy (scratch memory)
Comment thread
IvanGrigorik marked this conversation as resolved.
team_policy = pk.TeamPolicy(pk.ExecutionSpace.Cuda, num_teams, team_size)

print("\nRunning upper_bound with scratch memory...")
pk.parallel_for(team_policy, team_upper_bound, view=view, result_view=result_view)
print(f"Result (scratch upper_bound): {result_view}")
Comment thread
IvanGrigorik marked this conversation as resolved.

# Test with RangePolicy (regular view)
Comment thread
IvanGrigorik marked this conversation as resolved.
print("\nRunning upper_bound with regular view...")
pk.parallel_for(p_init, upper_bound_view, view=view, result_view=result_view)
print(f"Result (view upper_bound): {result_view}")
Comment thread
IvanGrigorik marked this conversation as resolved.


if __name__ == "__main__":
main()
2 changes: 2 additions & 0 deletions pykokkos/core/translators/static.py
Original file line number Diff line number Diff line change
Expand Up @@ -268,6 +268,7 @@ def generate_includes(self) -> str:
"Kokkos_Core.hpp",
"Kokkos_Random.hpp",
"Kokkos_Sort.hpp",
"Kokkos_StdAlgorithms.hpp",
"fstream",
"iostream",
"cmath",
Expand All @@ -290,6 +291,7 @@ def generate_cast_includes(self) -> str:
"Kokkos_Core.hpp",
"Kokkos_Random.hpp",
"Kokkos_Sort.hpp",
"Kokkos_StdAlgorithms.hpp",
"fstream",
"iostream",
"cmath",
Expand Down
2 changes: 1 addition & 1 deletion pykokkos/core/translators/symbols_pass.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ def __init__(self, members: PyKokkosMembers, pk_import: str, path: str):
self.global_symbols.update(math_functions)
self.global_symbols.update(allowed_types)
self.global_symbols.update(view_dtypes)
self.global_symbols.update(["self", "range", "math", "List", "abs"])
self.global_symbols.update(["self", "range", "math", "List", "abs", "upper_bound", "lower_bound"])
self.global_symbols.add(pk_import)

self.global_symbols.update([field.declname for field in members.fields])
Expand Down
70 changes: 70 additions & 0 deletions pykokkos/core/visitors/workunit_visitor.py
Original file line number Diff line number Diff line change
Expand Up @@ -318,6 +318,76 @@ def visit_Call(self, node: ast.Call) -> cppast.CallExpr:

return real_number_call

# Custom `upper_bound` implementation using binary search
if name == "upper_bound":
# Check if it's called via pk.upper_bound
is_pk_call = (
isinstance(node.func, ast.Attribute)
and isinstance(node.func.value, ast.Name)
and node.func.value.id == self.pk_import
)

if not is_pk_call:
return super().visit_Call(node)

# Expected signature: pk.upper_bound(view, size, value)
if len(args) != 3:
self.error(
node,
"pk.upper_bound() takes 3 arguments: view, size, value",
)

view_expr = args[0]
size_expr = args[1]
value_expr = args[2]

# Generate binary search lambda inline
from pykokkos.interface.algorithms.upper_bound import generate_upper_bound_binary_search

# Create lambda body with binary search
lambda_body = generate_upper_bound_binary_search(view_expr, size_expr, value_expr)

# Create and invoke lambda
lambda_expr = cppast.LambdaExpr("[&]", [], lambda_body)
lambda_call = cppast.CallExpr(lambda_expr, [])

return lambda_call

# Custom `lower_bound` implementation using binary search
if name == "lower_bound":
# Check if it's called via pk.lower_bound
is_pk_call = (
isinstance(node.func, ast.Attribute)
and isinstance(node.func.value, ast.Name)
and node.func.value.id == self.pk_import
)

if not is_pk_call:
return super().visit_Call(node)

# Expected signature: pk.lower_bound(view, size, value)
if len(args) != 3:
self.error(
node,
"pk.lower_bound() takes 3 arguments: view, size, value",
)

view_expr = args[0]
size_expr = args[1]
value_expr = args[2]

# Generate binary search lambda inline
from pykokkos.interface.algorithms.lower_bound import generate_lower_bound_binary_search

# Create lambda body with binary search
lambda_body = generate_lower_bound_binary_search(view_expr, size_expr, value_expr)

# Create and invoke lambda
lambda_expr = cppast.LambdaExpr("[&]", [], lambda_body)
lambda_call = cppast.CallExpr(lambda_expr, [])

return lambda_call

return super().visit_Call(node)

def is_nested_call(self, node: ast.FunctionDef) -> bool:
Expand Down
1 change: 1 addition & 0 deletions pykokkos/interface/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from .accumulator import Acc
from .algorithms import lower_bound, upper_bound
from .atomic.atomic_fetch_op import (
atomic_fetch_add, atomic_fetch_and, atomic_fetch_div,
atomic_fetch_lshift, atomic_fetch_max, atomic_fetch_min,
Expand Down
4 changes: 4 additions & 0 deletions pykokkos/interface/algorithms/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
from .lower_bound import lower_bound
from .upper_bound import upper_bound

__all__ = ["lower_bound", "upper_bound"]
91 changes: 91 additions & 0 deletions pykokkos/interface/algorithms/lower_bound.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,91 @@
from pykokkos.interface.views import ViewType
from pykokkos.core import cppast

def lower_bound(view: ViewType, size: int, value) -> int:
"""
Perform a lower bound search on a view

Returns the index of the first element not less than (i.e. greater or equal to) value,
similar to std::lower_bound or thrust::lower_bound.

:param view: the view to search (must be sorted)
:param size: the number of elements to search
:param value: the value to search for
:returns: the index of the first element >= value
"""
pass


def generate_lower_bound_binary_search(
view_expr: cppast.Expr, size_expr: cppast.Expr, value_expr: cppast.Expr
) -> cppast.CompoundStmt:
"""
Generate binary search implementation for lower_bound.
Returns a CompoundStmt that implements:

int left = 0;
int right = size;
int mid;
while (left < right) {
mid = left + (right - left) / 2;
if (view[mid] < value) {
left = mid + 1;
} else {
right = mid;
}
}
return left;
"""

# Variable declarations
int_type = cppast.PrimitiveType("int32_t")

# int left = 0;
left_var = cppast.DeclRefExpr("left")
left_init = cppast.IntegerLiteral(0)
left_decl = cppast.VarDecl(int_type, left_var, left_init)
left_stmt = cppast.DeclStmt(left_decl)

# int right = size;
right_var = cppast.DeclRefExpr("right")
right_decl = cppast.VarDecl(int_type, right_var, size_expr)
right_stmt = cppast.DeclStmt(right_decl)

# int mid;
mid_var = cppast.DeclRefExpr("mid")
mid_decl = cppast.VarDecl(int_type, mid_var, None)
mid_stmt = cppast.DeclStmt(mid_decl)

# while (left < right)
while_cond = cppast.BinaryOperator(left_var, right_var, cppast.BinaryOperatorKind.LT)

# mid = left + (right - left) / 2;
right_minus_left = cppast.BinaryOperator(right_var, left_var, cppast.BinaryOperatorKind.Sub)
div_expr = cppast.BinaryOperator(right_minus_left, cppast.IntegerLiteral(2), cppast.BinaryOperatorKind.Div)
mid_calc = cppast.BinaryOperator(left_var, div_expr, cppast.BinaryOperatorKind.Add)
mid_assign = cppast.AssignOperator([mid_var], mid_calc, cppast.BinaryOperatorKind.Assign)

# if (view[mid] < value)
view_ref = view_expr if isinstance(view_expr, cppast.DeclRefExpr) else cppast.DeclRefExpr("view")
view_access = cppast.ArraySubscriptExpr(view_ref, [mid_var])
if_cond = cppast.BinaryOperator(view_access, value_expr, cppast.BinaryOperatorKind.LT)

# left = mid + 1;
mid_plus_one = cppast.BinaryOperator(mid_var, cppast.IntegerLiteral(1), cppast.BinaryOperatorKind.Add)
left_assign = cppast.AssignOperator([left_var], mid_plus_one, cppast.BinaryOperatorKind.Assign)

# right = mid;
right_assign = cppast.AssignOperator([right_var], mid_var, cppast.BinaryOperatorKind.Assign)

# if-else statement
if_stmt = cppast.IfStmt(if_cond, left_assign, right_assign)

# while body
while_body = cppast.CompoundStmt([mid_assign, if_stmt])
while_stmt = cppast.WhileStmt(while_cond, while_body)

# return left;
return_stmt = cppast.ReturnStmt(left_var)

# Complete function body
return cppast.CompoundStmt([left_stmt, right_stmt, mid_stmt, while_stmt, return_stmt])
Loading