Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,8 @@
and can be scoped temporarily with `wp.ScopedLogLevel`; all Python-side diagnostic output now routes through
the logger ([GH-1315](https://github.com/NVIDIA/warp/issues/1315),
[GH-1434](https://github.com/NVIDIA/warp/issues/1434)).
- Support passing user-defined Warp functions to `Callable` parameters in `@wp.func` functions
([GH-1424](https://github.com/NVIDIA/warp/issues/1424)).

### Removed

Expand Down
144 changes: 144 additions & 0 deletions design/callable-func-parameters.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,144 @@
# Callable Function Parameters

**Status**: Implemented

**Issue**: [GH-1424](https://github.com/NVIDIA/warp/issues/1424)

## Motivation

Warp users can write higher-order Python helpers that accept a callable and
apply it to values, but the same pattern did not work inside user-defined
`@wp.func` code. A function parameter annotated as `Callable` was not matched
consistently during Warp overload resolution, especially across the two standard
library import paths:

```python
from typing import Callable as TypingCallable
from collections.abc import Callable as AbcCallable
```
Comment thread
coderabbitai[bot] marked this conversation as resolved.

The aliases are illustrative; Warp treats both origins as the same type-erased
`Callable` marker. This design implements the user-defined function portion of
GH-1424 as a first pass: support direct inline calls such as
`apply(double_it, 3.0)` from a kernel or another `@wp.func`, without depending
on local function-variable assignment behavior.

## Requirements

| ID | Requirement | Priority | Notes |
| --- | ----------- | -------- | ----- |
| R1 | Recognize `typing.Callable` and `collections.abc.Callable` as callable annotations. | Must | Includes bare and parameterized forms. |
| R2 | Allow user-defined `@wp.func` objects to match `Callable` parameters. | Must | Callable values are type-erased. |
| R3 | Preserve the existing `"c"` type code for callable annotations. | Must | Used by module hashing and overload keys. |
| R4 | Keep annotation recognition compatible with Python 3.10 through 3.14. | Must | Avoid private `typing` internals. |
| R5 | Include callable argument and default targets in module hashes and dependencies. | Must | Prevent stale compiled modules when callable targets change. |
| R6 | Reject unsupported callable specializations explicitly. | Must | Built-ins are deferred; custom grad/replay functions are first-pass non-goals. |

**Non-goals**:

- Validate parameterized callable signatures such as
`Callable[[float], float]`.
- Implement runtime dispatch for arbitrary Python callables.
- Support kernel-local function variable assignment.
- Support built-in Warp functions such as `wp.sin` or `wp.add` as callable
arguments in this first pass. They are rejected until built-in callable
identity can participate safely in specialization hashing and dependency
tracking.
Comment thread
coderabbitai[bot] marked this conversation as resolved.
- Support callable-specialized functions that have custom gradient or replay
Comment thread
coderabbitai[bot] marked this conversation as resolved.
functions.

Parameterized callable annotations are recognized as callable markers, but
their argument and return types remain unchecked.

## Design

### Approach

The implementation treats `Callable` parameters as compile-time function
references. A callable parameter never becomes a runtime C++ function pointer or
kernel parameter. Instead, code generation specializes the user-defined function
for the concrete user-defined Warp function passed at each call site.

This means:

- `apply(double_it, x)` and `apply(triple_it, x)` produce separate specialized
native functions.
- Callable parameters are bound into the specialized function's codegen symbol
table.
- Callable parameters are omitted from the emitted forward and reverse C++
function signatures.
- Runtime calls pass only the non-callable arguments.

The callable annotation predicate lives in `warp/_src/types.py` as
`is_callable_annotation(annotation)`. It returns true for:

- bare `typing.Callable`,
- bare `collections.abc.Callable`,
- `typing.Callable[...]`,
- `collections.abc.Callable[...]`.

The helper uses `typing.get_origin()` plus the canonical
`collections.abc.Callable` object, which covers the supported Python versions
without relying on private implementation details.

### Alternatives Considered

One option was to normalize annotations globally during `@wp.func` registration.
That would make all `Callable` forms identical up front, but it would also
change the raw annotations stored on every function and could affect unrelated
signature handling. A small predicate keeps the behavior localized.

Another option was to support built-in Warp functions as callable values.
GH-1424 includes built-ins, but this first pass narrows support to user-defined
targets. Built-ins require hashing and dependency behavior for callable
identity, so they are rejected explicitly to avoid partial support that could
reuse stale cached modules.

### Key Implementation Details

`get_type_code()` returns `"c"` for callable annotations before the generic
type branches. This keeps module hash type-code behavior stable for bare and
parameterized `Callable` forms.

`func_match_args()` treats a `warp._src.context.Function` value as compatible
with any callable annotation. Non-function values continue to fail normal
overload resolution.

During `Adjoint.add_call()`, default arguments are applied first. Callable
arguments and callable defaults are then collected. Built-in function values are
rejected, and user-defined function values trigger creation of a specialized
`Function` clone. The clone receives:

- a hash-suffixed native function name,
- a fresh `Adjoint` with the original annotations, including the return
annotation,
- `callable_arg_values` used to bind callable parameter names during codegen.

Module hashing and dependency discovery both inspect callable arguments and
callable defaults. User-defined callable targets are included in the referenced
function set for hashes and in module references/dependents for invalidation.

Callable-specialized functions with custom gradients or custom replay functions
are rejected. Their custom functions are tied to the unspecialized native
function signature, while callable specialization removes callable runtime
parameters from the emitted signature.

## Testing Strategy

`warp/tests/test_func.py` covers:

- bare `typing.Callable` and `collections.abc.Callable` runtime calls,
- parameterized `Callable[[float], float]` runtime calls,
- generic user functions that combine `Callable` with `Any`,
- default callable arguments,
- keyword callable arguments,
- nested user-defined function calls,
- callable targets affecting module hashes,
- callable targets updating cross-module dependents,
- return annotation preservation on specialized functions,
- explicit rejection for built-in callable targets,
- explicit rejection for callable-specialized functions with custom grad or
replay functions.

Local verification should include the focused `TestFunc` suite and pre-commit
over the changed files. CI provides the full Python 3.10 through 3.14 matrix.
52 changes: 52 additions & 0 deletions docs/user_guide/basics.rst
Original file line number Diff line number Diff line change
Expand Up @@ -291,6 +291,58 @@ User functions may also be overloaded by defining multiple function signatures w
def custom(x: wp.vec3):
return x + wp.vec3(1.0, 0.0, 0.0)

.. _callable-parameters:

Callable Parameters
^^^^^^^^^^^^^^^^^^^

User functions can accept another user-defined Warp function by annotating the
parameter as ``Callable`` from ``collections.abc`` or ``typing``.
The callable target is chosen where the user function is called and can be
invoked directly inside the user function body:

.. code-block:: python

from collections.abc import Callable

import warp as wp

@wp.func
def square(x: float):
return x * x


@wp.func
def cube(x: float):
return x * x * x


@wp.func
def apply(f: Callable, x: float):
return f(x)


@wp.kernel
def apply_kernel(
values: wp.array[float],
square_out: wp.array[float],
cube_out: wp.array[float],
):
i = wp.tid()
square_out[i] = apply(square, values[i])
cube_out[i] = apply(cube, values[i])

Callable parameters may also use defaults and keyword arguments:

.. code-block:: python

@wp.func
def apply_default(f: Callable = square, x: float = 0.0):
return f(x)

Pass only user-defined :func:`@wp.func <warp.func>` functions as callable targets.
See :doc:`limitations` for unsupported callable targets and other restrictions.

Tiles may also be passed to user functions. The function signature tile argument should include
dtype and shape parameters to match the tile type intended to be used in the function. For example:

Expand Down
7 changes: 7 additions & 0 deletions docs/user_guide/limitations.rst
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,13 @@ Kernels and User Functions
(e.g., ``wp.float64(wp.PI)`` or ``wp.int64(large_value)``).
* Python ``IntFlag`` values behave like raw integers in Warp kernels: bitwise negation (``~``)
produces the integer negation, not a masked combination of flags as in standard Python ``IntFlag`` behavior.
* :ref:`Callable parameters <callable-parameters>` in user functions only support direct inline calls with
user-defined :func:`@wp.func <warp.func>` targets.
Arbitrary Python callables and built-in Warp functions such as ``wp.sin`` or ``wp.add`` are not supported.
Assigning callable targets to local variables inside kernels or user functions is not supported.
User functions with ``Callable`` parameters also cannot define custom gradient or replay functions.
Type parameters in annotations such as ``Callable[[float], float]`` are accepted but are not validated against
the target function signature.

A limitation of Warp is that each dimension of the grid used to launch a kernel must be representable as a 32-bit
signed integer. Therefore, no single dimension of a grid should exceed :math:`2^{31}-1`.
Expand Down
Loading
Loading