-
Notifications
You must be signed in to change notification settings - Fork 521
Support Callable parameters in user-defined functions #1454
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Open
shi-eric
wants to merge
2
commits into
NVIDIA:main
Choose a base branch
from
shi-eric:shi-eric/callable-func-design
base: main
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
+788
−13
Open
Changes from all commits
Commits
Show all changes
2 commits
Select commit
Hold shift + click to select a range
File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,144 @@ | ||
| # Callable Function Parameters | ||
|
|
||
| **Status**: Implemented | ||
|
|
||
| **Issue**: [GH-1424](https://github.com/NVIDIA/warp/issues/1424) | ||
|
|
||
| ## Motivation | ||
|
|
||
| Warp users can write higher-order Python helpers that accept a callable and | ||
| apply it to values, but the same pattern did not work inside user-defined | ||
| `@wp.func` code. A function parameter annotated as `Callable` was not matched | ||
| consistently during Warp overload resolution, especially across the two standard | ||
| library import paths: | ||
|
|
||
| ```python | ||
| from typing import Callable as TypingCallable | ||
| from collections.abc import Callable as AbcCallable | ||
| ``` | ||
|
|
||
| The aliases are illustrative; Warp treats both origins as the same type-erased | ||
| `Callable` marker. This design implements the user-defined function portion of | ||
| GH-1424 as a first pass: support direct inline calls such as | ||
| `apply(double_it, 3.0)` from a kernel or another `@wp.func`, without depending | ||
| on local function-variable assignment behavior. | ||
|
|
||
| ## Requirements | ||
|
|
||
| | ID | Requirement | Priority | Notes | | ||
| | --- | ----------- | -------- | ----- | | ||
| | R1 | Recognize `typing.Callable` and `collections.abc.Callable` as callable annotations. | Must | Includes bare and parameterized forms. | | ||
| | R2 | Allow user-defined `@wp.func` objects to match `Callable` parameters. | Must | Callable values are type-erased. | | ||
| | R3 | Preserve the existing `"c"` type code for callable annotations. | Must | Used by module hashing and overload keys. | | ||
| | R4 | Keep annotation recognition compatible with Python 3.10 through 3.14. | Must | Avoid private `typing` internals. | | ||
| | R5 | Include callable argument and default targets in module hashes and dependencies. | Must | Prevent stale compiled modules when callable targets change. | | ||
| | R6 | Reject unsupported callable specializations explicitly. | Must | Built-ins are deferred; custom grad/replay functions are first-pass non-goals. | | ||
|
|
||
| **Non-goals**: | ||
|
|
||
| - Validate parameterized callable signatures such as | ||
| `Callable[[float], float]`. | ||
| - Implement runtime dispatch for arbitrary Python callables. | ||
| - Support kernel-local function variable assignment. | ||
| - Support built-in Warp functions such as `wp.sin` or `wp.add` as callable | ||
| arguments in this first pass. They are rejected until built-in callable | ||
| identity can participate safely in specialization hashing and dependency | ||
| tracking. | ||
|
coderabbitai[bot] marked this conversation as resolved.
|
||
| - Support callable-specialized functions that have custom gradient or replay | ||
|
coderabbitai[bot] marked this conversation as resolved.
|
||
| functions. | ||
|
|
||
| Parameterized callable annotations are recognized as callable markers, but | ||
| their argument and return types remain unchecked. | ||
|
|
||
| ## Design | ||
|
|
||
| ### Approach | ||
|
|
||
| The implementation treats `Callable` parameters as compile-time function | ||
| references. A callable parameter never becomes a runtime C++ function pointer or | ||
| kernel parameter. Instead, code generation specializes the user-defined function | ||
| for the concrete user-defined Warp function passed at each call site. | ||
|
|
||
| This means: | ||
|
|
||
| - `apply(double_it, x)` and `apply(triple_it, x)` produce separate specialized | ||
| native functions. | ||
| - Callable parameters are bound into the specialized function's codegen symbol | ||
| table. | ||
| - Callable parameters are omitted from the emitted forward and reverse C++ | ||
| function signatures. | ||
| - Runtime calls pass only the non-callable arguments. | ||
|
|
||
| The callable annotation predicate lives in `warp/_src/types.py` as | ||
| `is_callable_annotation(annotation)`. It returns true for: | ||
|
|
||
| - bare `typing.Callable`, | ||
| - bare `collections.abc.Callable`, | ||
| - `typing.Callable[...]`, | ||
| - `collections.abc.Callable[...]`. | ||
|
|
||
| The helper uses `typing.get_origin()` plus the canonical | ||
| `collections.abc.Callable` object, which covers the supported Python versions | ||
| without relying on private implementation details. | ||
|
|
||
| ### Alternatives Considered | ||
|
|
||
| One option was to normalize annotations globally during `@wp.func` registration. | ||
| That would make all `Callable` forms identical up front, but it would also | ||
| change the raw annotations stored on every function and could affect unrelated | ||
| signature handling. A small predicate keeps the behavior localized. | ||
|
|
||
| Another option was to support built-in Warp functions as callable values. | ||
| GH-1424 includes built-ins, but this first pass narrows support to user-defined | ||
| targets. Built-ins require hashing and dependency behavior for callable | ||
| identity, so they are rejected explicitly to avoid partial support that could | ||
| reuse stale cached modules. | ||
|
|
||
| ### Key Implementation Details | ||
|
|
||
| `get_type_code()` returns `"c"` for callable annotations before the generic | ||
| type branches. This keeps module hash type-code behavior stable for bare and | ||
| parameterized `Callable` forms. | ||
|
|
||
| `func_match_args()` treats a `warp._src.context.Function` value as compatible | ||
| with any callable annotation. Non-function values continue to fail normal | ||
| overload resolution. | ||
|
|
||
| During `Adjoint.add_call()`, default arguments are applied first. Callable | ||
| arguments and callable defaults are then collected. Built-in function values are | ||
| rejected, and user-defined function values trigger creation of a specialized | ||
| `Function` clone. The clone receives: | ||
|
|
||
| - a hash-suffixed native function name, | ||
| - a fresh `Adjoint` with the original annotations, including the return | ||
| annotation, | ||
| - `callable_arg_values` used to bind callable parameter names during codegen. | ||
|
|
||
| Module hashing and dependency discovery both inspect callable arguments and | ||
| callable defaults. User-defined callable targets are included in the referenced | ||
| function set for hashes and in module references/dependents for invalidation. | ||
|
|
||
| Callable-specialized functions with custom gradients or custom replay functions | ||
| are rejected. Their custom functions are tied to the unspecialized native | ||
| function signature, while callable specialization removes callable runtime | ||
| parameters from the emitted signature. | ||
|
|
||
| ## Testing Strategy | ||
|
|
||
| `warp/tests/test_func.py` covers: | ||
|
|
||
| - bare `typing.Callable` and `collections.abc.Callable` runtime calls, | ||
| - parameterized `Callable[[float], float]` runtime calls, | ||
| - generic user functions that combine `Callable` with `Any`, | ||
| - default callable arguments, | ||
| - keyword callable arguments, | ||
| - nested user-defined function calls, | ||
| - callable targets affecting module hashes, | ||
| - callable targets updating cross-module dependents, | ||
| - return annotation preservation on specialized functions, | ||
| - explicit rejection for built-in callable targets, | ||
| - explicit rejection for callable-specialized functions with custom grad or | ||
| replay functions. | ||
|
|
||
| Local verification should include the focused `TestFunc` suite and pre-commit | ||
| over the changed files. CI provides the full Python 3.10 through 3.14 matrix. | ||
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.