diff --git a/CHANGELOG.md b/CHANGELOG.md index e75a0215a8..ad08f24119 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -37,6 +37,8 @@ and can be scoped temporarily with `wp.ScopedLogLevel`; all Python-side diagnostic output now routes through the logger ([GH-1315](https://github.com/NVIDIA/warp/issues/1315), [GH-1434](https://github.com/NVIDIA/warp/issues/1434)). +- Support passing user-defined Warp functions to `Callable` parameters in `@wp.func` functions + ([GH-1424](https://github.com/NVIDIA/warp/issues/1424)). ### Removed diff --git a/design/callable-func-parameters.md b/design/callable-func-parameters.md new file mode 100644 index 0000000000..a5e8aac421 --- /dev/null +++ b/design/callable-func-parameters.md @@ -0,0 +1,144 @@ +# Callable Function Parameters + +**Status**: Implemented + +**Issue**: [GH-1424](https://github.com/NVIDIA/warp/issues/1424) + +## Motivation + +Warp users can write higher-order Python helpers that accept a callable and +apply it to values, but the same pattern did not work inside user-defined +`@wp.func` code. A function parameter annotated as `Callable` was not matched +consistently during Warp overload resolution, especially across the two standard +library import paths: + +```python +from typing import Callable as TypingCallable +from collections.abc import Callable as AbcCallable +``` + +The aliases are illustrative; Warp treats both origins as the same type-erased +`Callable` marker. This design implements the user-defined function portion of +GH-1424 as a first pass: support direct inline calls such as +`apply(double_it, 3.0)` from a kernel or another `@wp.func`, without depending +on local function-variable assignment behavior. + +## Requirements + +| ID | Requirement | Priority | Notes | +| --- | ----------- | -------- | ----- | +| R1 | Recognize `typing.Callable` and `collections.abc.Callable` as callable annotations. | Must | Includes bare and parameterized forms. | +| R2 | Allow user-defined `@wp.func` objects to match `Callable` parameters. | Must | Callable values are type-erased. | +| R3 | Preserve the existing `"c"` type code for callable annotations. | Must | Used by module hashing and overload keys. | +| R4 | Keep annotation recognition compatible with Python 3.10 through 3.14. | Must | Avoid private `typing` internals. | +| R5 | Include callable argument and default targets in module hashes and dependencies. | Must | Prevent stale compiled modules when callable targets change. | +| R6 | Reject unsupported callable specializations explicitly. | Must | Built-ins are deferred; custom grad/replay functions are first-pass non-goals. | + +**Non-goals**: + +- Validate parameterized callable signatures such as + `Callable[[float], float]`. +- Implement runtime dispatch for arbitrary Python callables. +- Support kernel-local function variable assignment. +- Support built-in Warp functions such as `wp.sin` or `wp.add` as callable + arguments in this first pass. They are rejected until built-in callable + identity can participate safely in specialization hashing and dependency + tracking. +- Support callable-specialized functions that have custom gradient or replay + functions. + +Parameterized callable annotations are recognized as callable markers, but +their argument and return types remain unchecked. + +## Design + +### Approach + +The implementation treats `Callable` parameters as compile-time function +references. A callable parameter never becomes a runtime C++ function pointer or +kernel parameter. Instead, code generation specializes the user-defined function +for the concrete user-defined Warp function passed at each call site. + +This means: + +- `apply(double_it, x)` and `apply(triple_it, x)` produce separate specialized + native functions. +- Callable parameters are bound into the specialized function's codegen symbol + table. +- Callable parameters are omitted from the emitted forward and reverse C++ + function signatures. +- Runtime calls pass only the non-callable arguments. + +The callable annotation predicate lives in `warp/_src/types.py` as +`is_callable_annotation(annotation)`. It returns true for: + +- bare `typing.Callable`, +- bare `collections.abc.Callable`, +- `typing.Callable[...]`, +- `collections.abc.Callable[...]`. + +The helper uses `typing.get_origin()` plus the canonical +`collections.abc.Callable` object, which covers the supported Python versions +without relying on private implementation details. + +### Alternatives Considered + +One option was to normalize annotations globally during `@wp.func` registration. +That would make all `Callable` forms identical up front, but it would also +change the raw annotations stored on every function and could affect unrelated +signature handling. A small predicate keeps the behavior localized. + +Another option was to support built-in Warp functions as callable values. +GH-1424 includes built-ins, but this first pass narrows support to user-defined +targets. Built-ins require hashing and dependency behavior for callable +identity, so they are rejected explicitly to avoid partial support that could +reuse stale cached modules. + +### Key Implementation Details + +`get_type_code()` returns `"c"` for callable annotations before the generic +type branches. This keeps module hash type-code behavior stable for bare and +parameterized `Callable` forms. + +`func_match_args()` treats a `warp._src.context.Function` value as compatible +with any callable annotation. Non-function values continue to fail normal +overload resolution. + +During `Adjoint.add_call()`, default arguments are applied first. Callable +arguments and callable defaults are then collected. Built-in function values are +rejected, and user-defined function values trigger creation of a specialized +`Function` clone. The clone receives: + +- a hash-suffixed native function name, +- a fresh `Adjoint` with the original annotations, including the return + annotation, +- `callable_arg_values` used to bind callable parameter names during codegen. + +Module hashing and dependency discovery both inspect callable arguments and +callable defaults. User-defined callable targets are included in the referenced +function set for hashes and in module references/dependents for invalidation. + +Callable-specialized functions with custom gradients or custom replay functions +are rejected. Their custom functions are tied to the unspecialized native +function signature, while callable specialization removes callable runtime +parameters from the emitted signature. + +## Testing Strategy + +`warp/tests/test_func.py` covers: + +- bare `typing.Callable` and `collections.abc.Callable` runtime calls, +- parameterized `Callable[[float], float]` runtime calls, +- generic user functions that combine `Callable` with `Any`, +- default callable arguments, +- keyword callable arguments, +- nested user-defined function calls, +- callable targets affecting module hashes, +- callable targets updating cross-module dependents, +- return annotation preservation on specialized functions, +- explicit rejection for built-in callable targets, +- explicit rejection for callable-specialized functions with custom grad or + replay functions. + +Local verification should include the focused `TestFunc` suite and pre-commit +over the changed files. CI provides the full Python 3.10 through 3.14 matrix. diff --git a/docs/user_guide/basics.rst b/docs/user_guide/basics.rst index 23f9c28445..d46907e4a1 100644 --- a/docs/user_guide/basics.rst +++ b/docs/user_guide/basics.rst @@ -291,6 +291,58 @@ User functions may also be overloaded by defining multiple function signatures w def custom(x: wp.vec3): return x + wp.vec3(1.0, 0.0, 0.0) +.. _callable-parameters: + +Callable Parameters +^^^^^^^^^^^^^^^^^^^ + +User functions can accept another user-defined Warp function by annotating the +parameter as ``Callable`` from ``collections.abc`` or ``typing``. +The callable target is chosen where the user function is called and can be +invoked directly inside the user function body: + +.. code-block:: python + + from collections.abc import Callable + + import warp as wp + + @wp.func + def square(x: float): + return x * x + + + @wp.func + def cube(x: float): + return x * x * x + + + @wp.func + def apply(f: Callable, x: float): + return f(x) + + + @wp.kernel + def apply_kernel( + values: wp.array[float], + square_out: wp.array[float], + cube_out: wp.array[float], + ): + i = wp.tid() + square_out[i] = apply(square, values[i]) + cube_out[i] = apply(cube, values[i]) + +Callable parameters may also use defaults and keyword arguments: + +.. code-block:: python + + @wp.func + def apply_default(f: Callable = square, x: float = 0.0): + return f(x) + +Pass only user-defined :func:`@wp.func ` functions as callable targets. +See :doc:`limitations` for unsupported callable targets and other restrictions. + Tiles may also be passed to user functions. The function signature tile argument should include dtype and shape parameters to match the tile type intended to be used in the function. For example: diff --git a/docs/user_guide/limitations.rst b/docs/user_guide/limitations.rst index 0f68d4c1cd..e4c6814667 100644 --- a/docs/user_guide/limitations.rst +++ b/docs/user_guide/limitations.rst @@ -35,6 +35,13 @@ Kernels and User Functions (e.g., ``wp.float64(wp.PI)`` or ``wp.int64(large_value)``). * Python ``IntFlag`` values behave like raw integers in Warp kernels: bitwise negation (``~``) produces the integer negation, not a masked combination of flags as in standard Python ``IntFlag`` behavior. +* :ref:`Callable parameters ` in user functions only support direct inline calls with + user-defined :func:`@wp.func ` targets. + Arbitrary Python callables and built-in Warp functions such as ``wp.sin`` or ``wp.add`` are not supported. + Assigning callable targets to local variables inside kernels or user functions is not supported. + User functions with ``Callable`` parameters also cannot define custom gradient or replay functions. + Type parameters in annotations such as ``Callable[[float], float]`` are accepted but are not validated against + the target function signature. A limitation of Warp is that each dimension of the grid used to launch a kernel must be representable as a 32-bit signed integer. Therefore, no single dimension of a grid should exceed :math:`2^{31}-1`. diff --git a/warp/_src/codegen.py b/warp/_src/codegen.py index 2164a6cc74..398561451e 100644 --- a/warp/_src/codegen.py +++ b/warp/_src/codegen.py @@ -18,6 +18,7 @@ import threading import types from collections.abc import Callable, Mapping, Sequence +from copy import copy as shallowcopy from typing import Any, ClassVar, get_args, get_origin import warp.config @@ -881,8 +882,11 @@ def func_match_args(func, arg_types, kwarg_types): if func_arg_type is Any: continue - # handle function refs as a special case - if func_arg_type is Callable and isinstance(bound_arg_type, warp._src.context.Function): + # Callable parameters are type-erased during overload matching; the + # concrete target is bound later during call specialization. + if warp._src.types.is_callable_annotation(func_arg_type) and isinstance( + bound_arg_type, warp._src.context.Function + ): continue bound_arg_type_stripped = strip_reference(bound_arg_type) @@ -904,6 +908,151 @@ def func_match_args(func, arg_types, kwarg_types): return True +def get_callable_arg_values(func, bound_args): + """Return concrete user-defined function targets for ``Callable`` parameters. + + ``bound_args`` already includes defaults. A non-empty result means the call + needs a specialized clone where callable parameter names resolve directly to + function objects during codegen instead of runtime variables. + """ + + if func.is_builtin(): + return None + + callable_arg_values = {} + + for name, value in bound_args.items(): + if not warp._src.types.is_callable_annotation(func.input_types.get(name)): + continue + + if not isinstance(value, warp._src.context.Function): + continue + + if value.is_builtin(): + raise WarpCodegenError( + "Callable parameters currently require user-defined Warp functions, " + f"but parameter '{name}' of '{func.key}' received built-in function '{value.key}'." + ) + + callable_arg_values[name] = value + + if callable_arg_values: + return callable_arg_values + + return None + + +def get_default_arg_value(func, name, value): + if warp._src.types.is_callable_annotation(func.input_types.get(name)) and isinstance( + value, warp._src.context.Function + ): + # Callable defaults need the same specialization path as explicit + # callable arguments. + return value + + return Var(None, type=type(value), constant=value) + + +def bind_call_arg_nodes(func, call_node): + """Bind a call AST to ``func`` and return AST/default arguments by name.""" + + try: + bound_args = func.signature.bind(*call_node.args, **{kw.arg: kw.value for kw in call_node.keywords}) + except TypeError: + return {} + + default_args = {k: v for k, v in func.defaults.items() if k not in bound_args.arguments and v is not None} + apply_defaults(bound_args, default_args) + return bound_args.arguments + + +def resolve_callable_arg_target(adj, arg_node, callable_arg_values=None): + """Resolve a callable argument node or default to a concrete Warp function.""" + + if isinstance(arg_node, warp._src.context.Function): + return arg_node + + if callable_arg_values and isinstance(arg_node, ast.Name): + callable_func = callable_arg_values.get(arg_node.id) + if callable_func is not None: + return callable_func + + callable_func, _ = adj.resolve_static_expression(arg_node, eval_types=False) + return callable_func + + +def iter_call_callable_arg_targets(adj, func, call_node, callable_arg_values=None): + """Yield Warp function targets passed to ``Callable`` parameters.""" + + if not isinstance(func, warp._src.context.Function) or func.is_builtin(): + return + + bound_arg_nodes = bind_call_arg_nodes(func, call_node) + + for arg_name, arg_node in bound_arg_nodes.items(): + if not warp._src.types.is_callable_annotation(func.input_types.get(arg_name)): + continue + + callable_func = resolve_callable_arg_target(adj, arg_node, callable_arg_values) + if isinstance(callable_func, warp._src.context.Function): + yield callable_func + + +def specialize_callable_func(func, callable_arg_values): + """Clone ``func`` for a concrete set of callable parameter targets.""" + + if func.custom_grad_func is not None or func.custom_replay_func is not None: + raise WarpCodegenError( + "Callable parameters are not supported on functions with custom gradients or replay functions: " + f"'{func.key}'" + ) + + specialization_key = tuple( + (name, callable_arg_values[name]) for name in func.input_types if name in callable_arg_values + ) + + specializations = getattr(func, "_callable_specializations", None) + if specializations is None: + specializations = {} + func._callable_specializations = specializations + + specialized_func = specializations.get(specialization_key) + if specialized_func is not None: + return specialized_func + + # The callable targets are inlined by name while being omitted from the C++ + # function parameters, so each target set needs a distinct native name. + suffix_hash = hashlib.sha256() + suffix_hash.update(bytes(func.native_func, "utf-8")) + for name, callable_func in specialization_key: + suffix_hash.update(bytes(name, "utf-8")) + suffix_hash.update(bytes(callable_func.key, "utf-8")) + suffix_hash.update(bytes(callable_func.native_func, "utf-8")) + + specialized_func = shallowcopy(func) + # Specialization clones should not share the parent specialization cache. + specialized_func.__dict__.pop("_callable_specializations", None) + specialized_func.native_func = f"{func.native_func}_callable_{suffix_hash.hexdigest()[:12]}" + specialized_func.value_func = None + specialized_func.adj = Adjoint( + func.func, + overload_annotations=func.adj.arg_types, + is_user_function=func.adj.is_user_function, + skip_forward_codegen=func.adj.skip_forward_codegen, + skip_reverse_codegen=func.adj.skip_reverse_codegen, + custom_reverse_mode=func.adj.custom_reverse_mode, + custom_reverse_num_input_args=func.adj.custom_reverse_num_input_args, + transformers=func.adj.transformers, + source=func.adj.source, + ) + specialized_func.adj.callable_arg_values = dict(callable_arg_values) + specialized_func.adj.used_by_backward_kernel = func.adj.used_by_backward_kernel + specialized_func.adj.force_adjoint_codegen = func.adj.force_adjoint_codegen + + specializations[specialization_key] = specialized_func + return specialized_func + + def get_arg_type(arg: Var | Any) -> type: arg = strip_reference(arg) @@ -1141,7 +1290,7 @@ def extract_function_source(func: Callable) -> tuple[str, int]: # generate function ssa form and adjoint @synchronized - def build(adj, builder, default_builder_options=None): + def build(adj, builder, default_builder_options=None, callable_arg_values=None): # arg Var read/write flags are held during module rebuilds, so we reset here even when skipping a build for arg in adj.args: arg.is_read = False @@ -1150,6 +1299,9 @@ def build(adj, builder, default_builder_options=None): if adj.skip_build: return + if callable_arg_values is None: + callable_arg_values = getattr(adj, "callable_arg_values", None) + adj.builder = builder if default_builder_options is None: @@ -1186,9 +1338,13 @@ def build(adj, builder, default_builder_options=None): # tracks how much additional shared memory is required by any dependent function calls adj.max_required_extra_shared_memory = 0 - # update symbol map for each argument + # Callable-specialized functions replace selected argument Vars with + # Function objects so calls like `op(x)` resolve statically. for a in adj.args: - adj.symbols[a.label] = a + if callable_arg_values is not None and a.label in callable_arg_values: + adj.symbols[a.label] = callable_arg_values[a.label] + else: + adj.symbols[a.label] = a # recursively evaluate function body try: @@ -1588,13 +1744,16 @@ def add_call(adj, func, args, kwargs, type_args, min_outputs=None): if func.defaults: default_vars = { - k: Var(None, type=type(v), constant=v) + k: get_default_arg_value(func, k, v) for k, v in func.defaults.items() if k not in bound_args.arguments and v is not None } apply_defaults(bound_args, default_vars) bound_args = bound_args.arguments + callable_arg_values = get_callable_arg_values(func, bound_args) + if callable_arg_values is not None: + func = specialize_callable_func(func, callable_arg_values) # Constant precision preservation: when calling a 64-bit scalar type # constructor with a single compile-time constant argument, emit @@ -1694,7 +1853,12 @@ def add_call(adj, func, args, kwargs, type_args, min_outputs=None): elif func.dispatch_func is not None: func_args, template_args = func.dispatch_func(func.input_types, return_type, bound_args) else: - func_args = tuple(bound_args.values()) + func_args = tuple( + value + for name, value in bound_args.items() + if func.is_builtin() or not warp._src.types.is_callable_annotation(func.input_types.get(name)) + ) + # Callable parameters are specialization inputs, not C++ arguments. template_args = () func_args = tuple(adj.register_var(x) for x in func_args) @@ -4392,6 +4556,7 @@ def get_references(adj) -> tuple[dict[str, Any], dict[Any, Any], dict[warp._src. constants: dict[str, Any] = {} types: dict[Struct | type, Any] = {} functions: dict[warp._src.context.Function, Any] = {} + callable_arg_values = getattr(adj, "callable_arg_values", None) or {} for node in ast.walk(adj.tree): if isinstance(node, ast.Name) and node.id not in local_variables: @@ -4407,9 +4572,18 @@ def get_references(adj) -> tuple[dict[str, Any], dict[Any, Any], dict[warp._src. elif isinstance(node, ast.Call): func, _ = adj.resolve_static_expression(node.func, eval_types=False) + if func is None and isinstance(node.func, ast.Name): + func = callable_arg_values.get(node.func.id) + if isinstance(func, warp._src.context.Function) and not func.is_builtin(): # calling user-defined function functions[func] = None + + # Callable targets are passed as values, so they must be + # added explicitly to the function reference set. + for callable_func in iter_call_callable_arg_targets(adj, func, node, callable_arg_values): + if not callable_func.is_builtin(): + functions[callable_func] = None elif isinstance(func, Struct): # calling struct constructor types[func] = None @@ -5022,6 +5196,8 @@ def codegen_func(adj, c_func_name: str, device="cpu", options=None, forward_only # forward args for i, arg in enumerate(adj.args): + if warp._src.types.is_callable_annotation(arg.type): + continue if is_tile(arg.type) or is_tile_stack(arg.type): tname = f"tile_{arg.label}" template_params.append(tname) @@ -5038,6 +5214,8 @@ def codegen_func(adj, c_func_name: str, device="cpu", options=None, forward_only # reverse args for i, arg in enumerate(adj.args): + if warp._src.types.is_callable_annotation(arg.type): + continue if adj.custom_reverse_mode and i >= adj.custom_reverse_num_input_args: break # indexed array gradients are regular arrays diff --git a/warp/_src/context.py b/warp/_src/context.py index 8ed20f522c..d9b041aaae 100644 --- a/warp/_src/context.py +++ b/warp/_src/context.py @@ -478,6 +478,11 @@ def get_overload(self, arg_types: list[type], kwarg_types: Mapping[str, type]) - args_matched = True for i in range(len(arg_types)): + # Callable annotations stay type-erased here; specialization + # handles the concrete function target. + if warp._src.types.is_callable_annotation(template_types[i]) and isinstance(arg_types[i], Function): + continue + if not warp._src.types.type_matches_template(arg_types[i], template_types[i]): args_matched = False break @@ -486,11 +491,22 @@ def get_overload(self, arg_types: list[type], kwarg_types: Mapping[str, type]) - # instantiate this function with the specified argument types arg_names = f.input_types.keys() - overload_annotations = dict(zip(arg_names, arg_types, strict=False)) + overload_annotations = {} + for name, arg_type, template_type in zip(arg_names, arg_types, template_types, strict=False): + if warp._src.types.is_callable_annotation(template_type) and isinstance(arg_type, Function): + overload_annotations[name] = template_type + else: + overload_annotations[name] = arg_type + # add defaults for k, d in f.defaults.items(): if k not in overload_annotations: - overload_annotations[k] = warp._src.codegen.strip_reference(warp._src.codegen.get_arg_type(d)) + template_type = f.input_types[k] + default_type = warp._src.codegen.strip_reference(warp._src.codegen.get_arg_type(d)) + if warp._src.types.is_callable_annotation(template_type) and isinstance(default_type, Function): + overload_annotations[k] = template_type + else: + overload_annotations[k] = default_type ovl = shallowcopy(f) ovl.adj = warp._src.codegen.Adjoint(f.func, overload_annotations, source=f.adj.source) @@ -498,7 +514,9 @@ def get_overload(self, arg_types: list[type], kwarg_types: Mapping[str, type]) - ovl.value_func = None ovl.generic_parent = f - sig = warp._src.types.get_signature(arg_types, func_name=self.key) + sig = warp._src.types.get_signature( + list(overload_annotations.values()), func_name=self.key, arg_names=list(overload_annotations.keys()) + ) self.user_overloads[sig] = ovl return ovl @@ -2765,17 +2783,30 @@ def add_ref(ref): self.references.add(ref) ref.dependents.add(self) + callable_arg_values = getattr(adj, "callable_arg_values", None) or {} + # scan for function calls for node in ast.walk(adj.tree): if isinstance(node, ast.Call): try: # try to resolve the function func, _ = adj.resolve_static_expression(node.func, eval_types=False) + if func is None and isinstance(node.func, ast.Name): + func = callable_arg_values.get(node.func.id) # if this is a user-defined function, add a module reference if isinstance(func, warp._src.context.Function) and func.module is not None: add_ref(func.module) + if isinstance(func, warp._src.context.Function) and not func.is_builtin(): + # Callable targets can come from arguments or defaults; + # either way their modules must invalidate this module. + for callable_func in warp._src.codegen.iter_call_callable_arg_targets( + adj, func, node, callable_arg_values + ): + if not callable_func.is_builtin() and callable_func.module is not None: + add_ref(callable_func.module) + except Exception: # Lookups may fail for builtins, but that's ok. # Lookups may also fail for functions in this module that haven't been imported yet, diff --git a/warp/_src/types.py b/warp/_src/types.py index 92812eb1c4..440e789d82 100644 --- a/warp/_src/types.py +++ b/warp/_src/types.py @@ -7132,11 +7132,19 @@ def infer_argument_types(args: list[Any], template_types, arg_names: list[str] | } +def is_callable_annotation(annotation) -> bool: + """Return whether an annotation denotes a type-erased callable.""" + + return annotation is Callable or get_origin(annotation) is Callable + + def get_type_code(arg_type) -> str: if arg_type is Any: # special case for generics # note: since Python 3.11 Any is a type, so we check for it first return "?" + elif is_callable_annotation(arg_type): + return "c" elif ( sys.version_info < (3, 11) and hasattr(types, "GenericAlias") @@ -7216,9 +7224,6 @@ def get_type_code(arg_type) -> str: elif arg_type == Int: # generic int return "i?" - elif isinstance(arg_type, Callable): - # TODO: elaborate on Callable type? - return "c" elif arg_type is Ellipsis: return "?" else: diff --git a/warp/tests/test_func_callable.py b/warp/tests/test_func_callable.py new file mode 100644 index 0000000000..82f9a8df19 --- /dev/null +++ b/warp/tests/test_func_callable.py @@ -0,0 +1,352 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 + +"""Tests for ``Callable`` parameters in user-defined Warp functions. + +These tests live outside ``test_func.py`` because ``Callable`` parameter support +needs a dedicated set of helper functions, kernels, module dependency checks, +specialization checks, and rejection-path coverage. Keep future ``Callable`` +parameter tests in this module so ``test_func.py`` remains focused on general +``@wp.func`` behavior. +""" + +import unittest +from collections.abc import Callable as CollectionsCallable +from typing import Any +from typing import Callable as TypingCallable # noqa: UP035 + +import numpy as np + +import warp as wp +from warp.tests.unittest_utils import * + + +@wp.func +def callable_double_it(x: float): + return x * 2.0 + + +@wp.func +def callable_triple_it(x: float): + return x * 3.0 + + +@wp.func +def callable_apply_typing(g: TypingCallable, x: float): + return g(x) + + +@wp.func +def callable_apply_collections(g: CollectionsCallable, x: float): + return g(x) + + +@wp.func +def callable_apply_typing_parameterized(g: TypingCallable[[float], float], x: float): + return g(x) + + +@wp.func +def callable_apply_collections_parameterized(g: CollectionsCallable[[float], float], x: float): + return g(x) + + +@wp.func +def callable_apply_generic(g: TypingCallable, x: Any): + return g(x) + + +@wp.func +def callable_apply_default(g: TypingCallable = callable_double_it, x: float = 3.0): + return g(x) + + +@wp.func +def callable_apply_nested(x: float): + return callable_apply_typing(callable_double_it, x) + + +@wp.func +def callable_forward_to_apply(g: TypingCallable, x: float): + return callable_apply_typing(g, x) + + +@wp.kernel +def callable_func_parameter_kernel(out: wp.array(dtype=float)): + out[0] = callable_apply_typing(callable_double_it, 3.0) + out[1] = callable_apply_typing(callable_triple_it, 4.0) + out[2] = callable_apply_collections(callable_double_it, 5.0) + out[3] = callable_apply_collections(callable_triple_it, 6.0) + out[4] = callable_apply_nested(7.0) + out[5] = callable_apply_default() + out[6] = callable_apply_typing(g=callable_triple_it, x=8.0) + out[7] = callable_apply_typing_parameterized(callable_double_it, 9.0) + out[8] = callable_apply_collections_parameterized(callable_triple_it, 10.0) + out[9] = callable_apply_generic(callable_double_it, 11.0) + + +def test_callable_func_parameter(test, device): + out = wp.empty(10, dtype=float, device=device) + + wp.launch(callable_func_parameter_kernel, dim=1, outputs=[out], device=device) + + assert_np_equal( + out.numpy(), + np.array([6.0, 12.0, 10.0, 18.0, 14.0, 6.0, 24.0, 18.0, 30.0, 22.0], dtype=np.float32), + ) + + +CALLABLE_TARGET = callable_double_it + + +@wp.kernel +def callable_global_target_kernel(out: wp.array(dtype=float)): + out[0] = callable_apply_typing(CALLABLE_TARGET, 3.0) + + +@wp.kernel +def callable_default_target_kernel(out: wp.array(dtype=float)): + out[0] = callable_apply_default() + + +CALLABLE_DEPENDENCY_EXPLICIT_PROVIDER_MODULE = wp.Module("callable_dependency_explicit_provider") +CALLABLE_DEPENDENCY_EXPLICIT_CONSUMER_MODULE = wp.Module("callable_dependency_explicit_consumer") +CALLABLE_DEPENDENCY_DEFAULT_PROVIDER_MODULE = wp.Module("callable_dependency_default_provider") +CALLABLE_DEPENDENCY_DEFAULT_CONSUMER_MODULE = wp.Module("callable_dependency_default_consumer") + + +def callable_dependency_explicit_target(x: float): + return x + 1.0 + + +callable_dependency_explicit_target = wp.func( + callable_dependency_explicit_target, + module=CALLABLE_DEPENDENCY_EXPLICIT_PROVIDER_MODULE, +) + + +def callable_dependency_apply_explicit(g: TypingCallable, x: float): + return g(x) + + +callable_dependency_apply_explicit = wp.func( + callable_dependency_apply_explicit, + module=CALLABLE_DEPENDENCY_EXPLICIT_CONSUMER_MODULE, +) + + +@wp.kernel(module=CALLABLE_DEPENDENCY_EXPLICIT_CONSUMER_MODULE) +def callable_dependency_explicit_kernel(out: wp.array(dtype=float)): + out[0] = callable_dependency_apply_explicit(callable_dependency_explicit_target, 2.0) + + +def callable_dependency_default_target(x: float): + return x + 1.0 + + +callable_dependency_default_target = wp.func( + callable_dependency_default_target, + module=CALLABLE_DEPENDENCY_DEFAULT_PROVIDER_MODULE, +) + + +def callable_dependency_apply_default(g: TypingCallable = callable_dependency_default_target, x: float = 2.0): + return g(x) + + +callable_dependency_apply_default = wp.func( + callable_dependency_apply_default, + module=CALLABLE_DEPENDENCY_DEFAULT_CONSUMER_MODULE, +) + + +@wp.kernel(module=CALLABLE_DEPENDENCY_DEFAULT_CONSUMER_MODULE) +def callable_dependency_default_kernel(out: wp.array(dtype=float)): + out[0] = callable_dependency_apply_default() + + +@wp.func +def callable_custom_grad_unsupported(g: TypingCallable, x: float): + return x + + +@wp.func_grad(callable_custom_grad_unsupported) +def adj_callable_custom_grad_unsupported(g: TypingCallable, x: float, adj_ret: float): + wp.adjoint[x] += adj_ret + + +@wp.func +def callable_custom_replay_unsupported(g: TypingCallable, x: float): + return x + + +@wp.func_replay(callable_custom_replay_unsupported) +def replay_callable_custom_replay_unsupported(g: TypingCallable, x: float): + return x + + +class TestFuncCallable(unittest.TestCase): + def test_callable_annotation_type_code(self): + from warp._src.types import get_type_code # noqa: PLC0415 + + callable_annotations = ( + TypingCallable, + CollectionsCallable, + TypingCallable[[float], float], + CollectionsCallable[[float], float], + ) + + for annotation in callable_annotations: + with self.subTest(annotation=annotation): + self.assertEqual(get_type_code(annotation), "c") + + def test_callable_argument_target_affects_module_hash(self): + global CALLABLE_TARGET + + original_target = CALLABLE_TARGET + try: + CALLABLE_TARGET = callable_double_it + double_hash = callable_global_target_kernel.module.hash_module() + + CALLABLE_TARGET = callable_triple_it + triple_hash = callable_global_target_kernel.module.hash_module() + finally: + CALLABLE_TARGET = original_target + + self.assertNotEqual(double_hash, triple_hash) + + def test_callable_default_target_affects_module_hash(self): + original_defaults = callable_apply_default.defaults.copy() + try: + callable_apply_default.defaults["g"] = callable_double_it + double_hash = callable_default_target_kernel.module.hash_module() + + callable_apply_default.defaults["g"] = callable_triple_it + triple_hash = callable_default_target_kernel.module.hash_module() + finally: + callable_apply_default.defaults = original_defaults + + self.assertNotEqual(double_hash, triple_hash) + + def test_callable_specialization_cache_not_shared_with_clone(self): + from warp._src.codegen import specialize_callable_func # noqa: PLC0415 + + apply_overload = callable_forward_to_apply.user_overloads["c_f4"] + specialized_apply = specialize_callable_func(apply_overload, {"g": callable_double_it}) + + self.assertIn("_callable_specializations", apply_overload.__dict__) + self.assertNotIn("_callable_specializations", specialized_apply.__dict__) + + def test_callable_specialized_adjoint_references_forwarded_target(self): + from warp._src.codegen import specialize_callable_func # noqa: PLC0415 + + apply_overload = callable_forward_to_apply.user_overloads["c_f4"] + specialized_apply = specialize_callable_func(apply_overload, {"g": callable_double_it}) + _, _, functions = specialized_apply.adj.get_references() + + self.assertIn(callable_apply_typing, functions) + self.assertIn(callable_double_it, functions) + + def test_callable_specialization_preserves_return_annotation(self): + func_module = wp.Module(f"callable_wrong_return_annotation_func_{id(self)}") + kernel_module = wp.Module(f"callable_wrong_return_annotation_kernel_{id(self)}") + + @wp.func(module=func_module) + def callable_wrong_return_annotation(g: TypingCallable, x: float) -> int: + return g(x) + + @wp.kernel(module=kernel_module) + def callable_wrong_return_annotation_kernel(out: wp.array(dtype=float)): + out[0] = float(callable_wrong_return_annotation(callable_double_it, 2.0)) + + out = wp.empty(1, dtype=float, device="cpu") + + with self.assertRaisesRegex( + wp.WarpCodegenError, + r"The function `callable_wrong_return_annotation` has its return type " + r"annotated as `int` but the code returns a value of type `float32`.", + ): + wp.launch(callable_wrong_return_annotation_kernel, dim=1, outputs=[out], device="cpu") + + def test_callable_argument_target_updates_module_dependents(self): + def unload_recursive(module, visited): + module.unload() + visited.add(module) + for dependent in module.dependents: + if dependent not in visited: + unload_recursive(dependent, visited) + + cases = ( + ( + "explicit", + CALLABLE_DEPENDENCY_EXPLICIT_PROVIDER_MODULE, + CALLABLE_DEPENDENCY_EXPLICIT_CONSUMER_MODULE, + callable_dependency_explicit_kernel, + ), + ( + "default", + CALLABLE_DEPENDENCY_DEFAULT_PROVIDER_MODULE, + CALLABLE_DEPENDENCY_DEFAULT_CONSUMER_MODULE, + callable_dependency_default_kernel, + ), + ) + + for name, provider_module, consumer_module, kernel in cases: + with self.subTest(name=name): + out = wp.empty(1, dtype=float, device="cpu") + wp.launch(kernel, dim=1, outputs=[out], device="cpu") + + assert_np_equal(out.numpy(), np.array([3.0], dtype=np.float32)) + self.assertIn(provider_module, consumer_module.references) + self.assertIn(consumer_module, provider_module.dependents) + self.assertTrue(consumer_module.hashers) + + unload_recursive(provider_module, visited=set()) + + self.assertFalse(consumer_module.hashers) + + def test_callable_custom_grad_rejected(self): + @wp.kernel(module="unique") + def custom_grad_rejection_kernel(out: wp.array(dtype=float)): + out[0] = callable_custom_grad_unsupported(callable_double_it, 2.0) + + @wp.kernel(module="unique") + def custom_replay_rejection_kernel(out: wp.array(dtype=float)): + out[0] = callable_custom_replay_unsupported(callable_double_it, 2.0) + + for kernel in (custom_grad_rejection_kernel, custom_replay_rejection_kernel): + with self.subTest(kernel=kernel.key): + out = wp.empty(1, dtype=float, device="cpu") + + with self.assertRaisesRegex( + wp.WarpCodegenError, + "Callable parameters.*custom gradients or replay functions", + ): + wp.launch(kernel, dim=1, outputs=[out], device="cpu") + + def test_callable_builtin_target_rejected(self): + @wp.kernel(module=wp.Module(f"callable_builtin_target_kernel_{id(self)}")) + def callable_builtin_target_kernel(out: wp.array(dtype=float)): + out[0] = callable_apply_typing(wp.sin, 0.5) + + out = wp.empty(1, dtype=float, device="cpu") + + with self.assertRaisesRegex( + wp.WarpCodegenError, + "Callable parameters currently require user-defined Warp functions", + ): + wp.launch(callable_builtin_target_kernel, dim=1, outputs=[out], device="cpu") + + +devices = get_test_devices() + +add_function_test( + TestFuncCallable, + func=test_callable_func_parameter, + name="test_callable_func_parameter", + devices=devices, +) + + +if __name__ == "__main__": + unittest.main(verbosity=2) diff --git a/warp/tests/unittest_suites.py b/warp/tests/unittest_suites.py index 797d06025a..bf8cb60b6a 100644 --- a/warp/tests/unittest_suites.py +++ b/warp/tests/unittest_suites.py @@ -164,6 +164,7 @@ def default_suite(test_loader: unittest.TestLoader = unittest.defaultTestLoader) from warp.tests.test_fastcall import TestFastcall, TestFastcallAvailable from warp.tests.test_fp16 import TestFp16 from warp.tests.test_func import TestFunc + from warp.tests.test_func_callable import TestFuncCallable from warp.tests.test_future_annotations import TestFutureAnnotations from warp.tests.test_generics import TestGenerics from warp.tests.test_grad import TestGrad @@ -295,6 +296,7 @@ def default_suite(test_loader: unittest.TestLoader = unittest.defaultTestLoader) TestFemShape, TestFp16, TestFunc, + TestFuncCallable, TestFutureAnnotations, TestGenerics, TestGrad, @@ -437,6 +439,7 @@ def debug_suite(test_loader: unittest.TestLoader = unittest.defaultTestLoader): from warp.tests.test_fast_math import TestFastMath from warp.tests.test_fp16 import TestFp16 from warp.tests.test_func import TestFunc + from warp.tests.test_func_callable import TestFuncCallable from warp.tests.test_generics import TestGenerics from warp.tests.test_grad import TestGrad from warp.tests.test_grad_customs import TestGradCustoms @@ -472,6 +475,7 @@ def debug_suite(test_loader: unittest.TestLoader = unittest.defaultTestLoader): TestEnum, TestFastMath, TestFunc, + TestFuncCallable, TestGenerics, TestMath, TestModuleHashing,