Support Callable parameters in user-defined functions#1454
Conversation
|
Auto-sync is disabled for draft pull requests in this repository. Workflows must be run manually. Contributors can view more details about this message here. |
|
Note Reviews pausedIt looks like this branch is under active development. To avoid overwhelming you with review comments due to an influx of new commits, CodeRabbit has automatically paused this review. You can configure this behavior by changing the Use the following commands to manage reviews:
Use the checkboxes below for quick actions:
📝 WalkthroughWalkthroughThis PR implements support for ChangesCallable Parameter Support
Sequence Diagram(s)sequenceDiagram
participant Kernel as Kernel (caller)
participant Codegen as Codegen (add_call / func_match_args)
participant Types as Types (is_callable_annotation)
participant Specialize as Specialization (specialize_callable_func / get_callable_arg_values)
participant Adjoint as Adjoint.build
participant CppGen as C++ Generation
Kernel->>Codegen: call apply(double_it, 3.0)
Codegen->>Types: detect callable parameter via is_callable_annotation()
Codegen->>Codegen: bind args, match overload (callable-aware)
Codegen->>Specialize: extract callable args (get_callable_arg_values)
Specialize->>Specialize: create/cache specialized clone (hash suffix)
Codegen->>Adjoint: Adjoint.build(callable_arg_values)
Adjoint->>Adjoint: populate adj.symbols with concrete Function values
Codegen->>CppGen: emit call (skip callable params in forward/reverse signatures)
Estimated code review effort🎯 4 (Complex) | ⏱️ ~60 minutes 🚥 Pre-merge checks | ✅ 4 | ❌ 1❌ Failed checks (1 warning)
✅ Passed checks (4 passed)
✏️ Tip: You can configure your own custom pre-merge checks in the settings. ✨ Finishing Touches🧪 Generate unit tests (beta)
Comment |
There was a problem hiding this comment.
Actionable comments posted: 2
🧹 Nitpick comments (2)
warp/_src/codegen.py (1)
945-992: 💤 Low valueConsider documenting the specialization key stability for caching.
The specialization cache key at lines 952-954 uses a tuple of
(name, callable_func)pairs, relying onFunctionobject identity for cache hits. If the same logical function is recreated (e.g., module reload), cache misses will occur, generating duplicate specializations.This is likely acceptable for correctness (each specialization is valid), but consider adding a brief comment explaining this design choice for future maintainers.
Also, the check at lines 946-950 correctly rejects functions with custom grad/replay—good defensive validation.
🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the rest with a brief reason, keep changes minimal, and validate. In `@warp/_src/codegen.py` around lines 945 - 992, The specialization cache key in specialize_callable_func currently builds specialization_key from (name, callable_arg_values[name]) and therefore relies on Function object identity for cache hits; add a brief explanatory comment above the specialization_key creation stating that this is an intentional design choice (it may produce cache misses if the same logical function object is recreated, e.g., on module reload), why that is acceptable for correctness, and note that using callable_func.key/native_func could be an alternative for a stable key if desired in the future; reference specialize_callable_func, specialization_key, and func._callable_specializations in the comment so maintainers can find the code easily.CHANGELOG.md (1)
32-33: ⚡ Quick winUse imperative “Add …” phrasing for this Unreleased entry.
Line 32 currently starts with “Support …”; the changelog convention here asks for imperative present tense.
Suggested wording
-- Support passing user-defined Warp functions to `Callable` parameters in `@wp.func` functions +- Add support for passing user-defined Warp functions to `Callable` parameters in `@wp.func` functionsAs per coding guidelines:
CHANGELOG.mdentries inUnreleasedshould use imperative present tense (“Add X”).🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the rest with a brief reason, keep changes minimal, and validate. In `@CHANGELOG.md` around lines 32 - 33, Update the Unreleased changelog entry that currently reads "Support passing user-defined Warp functions to `Callable` parameters in `@wp.func` functions" to use imperative present-tense phrasing; change it to something like "Add support for passing user-defined Warp functions to `Callable` parameters in `@wp.func` functions" so the entry follows the repository's changelog convention.
🤖 Prompt for all review comments with AI agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.
Inline comments:
In `@design/callable-func-parameters.md`:
- Around line 41-43: The document's non-goal wording conflicts with the linked
GH-1424 objective about accepting built-in callables; update the text around the
callable targets section to reconcile them by either (A) removing or rephrasing
the line that excludes built-ins and explicitly stating that built-in Warp
functions like wp.sin and wp.add are supported as callable arguments, or (B) if
built-ins truly remain out of scope, change the GH-1424 reference/closure
language to reflect a narrower objective; ensure references to "wp.sin" and
"wp.add" and the GH-1424 issue ID are corrected so the scope described matches
the linked issue.
- Around line 15-18: The snippet imports Callable twice which shadows and
confuses the reader; update the examples so they don't clobber the same name by
either (a) using only one import (prefer collections.abc.Callable for modern
code) and removing the other, (b) aliasing one import (e.g., import Callable as
TypingCallable) to show the difference, or (c) present two separate, clearly
labeled snippets instead of both lines together—adjust the lines containing
"from typing import Callable" and "from collections.abc import Callable"
accordingly.
---
Nitpick comments:
In `@CHANGELOG.md`:
- Around line 32-33: Update the Unreleased changelog entry that currently reads
"Support passing user-defined Warp functions to `Callable` parameters in
`@wp.func` functions" to use imperative present-tense phrasing; change it to
something like "Add support for passing user-defined Warp functions to
`Callable` parameters in `@wp.func` functions" so the entry follows the
repository's changelog convention.
In `@warp/_src/codegen.py`:
- Around line 945-992: The specialization cache key in specialize_callable_func
currently builds specialization_key from (name, callable_arg_values[name]) and
therefore relies on Function object identity for cache hits; add a brief
explanatory comment above the specialization_key creation stating that this is
an intentional design choice (it may produce cache misses if the same logical
function object is recreated, e.g., on module reload), why that is acceptable
for correctness, and note that using callable_func.key/native_func could be an
alternative for a stable key if desired in the future; reference
specialize_callable_func, specialization_key, and func._callable_specializations
in the comment so maintainers can find the code easily.
🪄 Autofix (Beta)
Fix all unresolved CodeRabbit comments on this PR:
- Push a commit to this branch (recommended)
- Create a new PR with the fixes
ℹ️ Review info
⚙️ Run configuration
Configuration used: Path: .coderabbit.yml
Review profile: CHILL
Plan: Enterprise
Run ID: 9019ac3f-bc6a-4ec7-b04f-ab04997190aa
📒 Files selected for processing (6)
CHANGELOG.mddesign/callable-func-parameters.mdwarp/_src/codegen.pywarp/_src/context.pywarp/_src/types.pywarp/tests/test_func.py
230bf7f to
f6a270d
Compare
af680cc to
99a8d88
Compare
There was a problem hiding this comment.
Actionable comments posted: 3
Caution
Some comments are outside the diff and can’t be posted inline due to platform limitations.
⚠️ Outside diff range comments (1)
warp/_src/codegen.py (1)
4513-4539:⚠️ Potential issue | 🟠 Major | ⚡ Quick winTrack callable-parameter specializations in
get_references()too.This only finds callable targets when the AST argument itself resolves statically. For specialized higher-order functions, forwarded names like
inner(f, x)and direct calls likef(x)are backed byadj.callable_arg_values, not by a global/static lookup, so the concrete callable never gets added tofunctions. That leaves module hashing/dependency invalidation stale when the passed function changes.Suggested direction
def get_references(adj) -> tuple[dict[str, Any], dict[Any, Any], dict[warp._src.context.Function, Any]]: """Traverses ``adj.tree`` and returns referenced constants, types, and user-defined functions.""" local_variables = set() # Track local variables appearing on the LHS so we know when variables are shadowed constants: dict[str, Any] = {} types: dict[Struct | type, Any] = {} functions: dict[warp._src.context.Function, Any] = {} + callable_arg_values = getattr(adj, "callable_arg_values", None) or {} for node in ast.walk(adj.tree): ... elif isinstance(node, ast.Call): func, _ = adj.resolve_static_expression(node.func, eval_types=False) + if func is None and isinstance(node.func, ast.Name): + func = callable_arg_values.get(node.func.id) + + if isinstance(func, warp._src.context.Function) and not func.is_builtin(): + functions[func] = None + try: bound_args = func.signature.bind(*node.args, **{kw.arg: kw.value for kw in node.keywords}) except TypeError: bound_arg_nodes = {} else: ... for arg_name, arg_node in bound_arg_nodes.items(): if not warp._src.types.is_callable_annotation(func.input_types.get(arg_name)): continue if isinstance(arg_node, warp._src.context.Function): callable_func = arg_node + elif isinstance(arg_node, ast.Name): + callable_func = callable_arg_values.get(arg_node.id) else: callable_func, _ = adj.resolve_static_expression(arg_node, eval_types=False) if isinstance(callable_func, warp._src.context.Function) and not callable_func.is_builtin(): functions[callable_func] = None🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the rest with a brief reason, keep changes minimal, and validate. In `@warp/_src/codegen.py` around lines 4513 - 4539, get_references() currently only discovers callable arguments when the AST arg resolves statically via adj.resolve_static_expression, so specialized higher-order calls backed by adj.callable_arg_values (e.g., when func is forwarded into inner/f and invoked) are missed; update the Call-handling branch in get_references() (the block that binds arguments with func.signature.bind, apply_defaults, and iterates bound_arg_nodes) to also check adj.callable_arg_values for the bound argument node (and for raw arg names where resolution failed) and add any concrete warp._src.context.Function entries there into the functions dict (same guard: isinstance(..., Function) and not is_builtin()). Ensure you reference adj.callable_arg_values lookup when bound_arg_nodes contains nodes that didn't resolve statically or that are names/attributes representing forwarded callables so the concrete callable specializations are included.
🤖 Prompt for all review comments with AI agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.
Inline comments:
In `@design/callable-func-parameters.md`:
- Around line 43-46: The PR currently conflicts with the design non-goal that
excludes built-in Warp callables (e.g., wp.sin, wp.add) from being accepted as
callable parameters for user-defined `@wp.func` per callable-func-parameters.md;
either update the PR metadata so it does not close GH-1424 (change "Closes
`#1424`" to "Addresses/Related to/Partial implementation of `#1424`" in the PR
description/commit message) or implement full support for built-in callables
(update the `@wp.func` callable-argument handling, specialization hashing, and
dependency tracking code paths that validate callables so built-ins are
recognized and included) and then remove/update the non-goal text and document
the addition.
In `@warp/_src/codegen.py`:
- Around line 972-975: The specialized callable currently clears its value_func
(specialized_func.value_func = None) which breaks downstream use where
add_call() invokes func.value_func(...); restore the original value_func on the
specialized copy instead of nulling it (e.g., assign specialized_func.value_func
= func.value_func or simply remove the line that sets it to None) so the
specialized_func keeps the callable-return resolver used later by add_call().
In `@warp/_src/context.py`:
- Around line 491-506: The overload_annotations map is being populated with the
template annotation (template_type) for callable parameters, causing all
Callable args to collapse to the same specialization; instead, when
warp._src.types.is_callable_annotation(template_type) and the runtime type
(arg_type or default_type) is a Function, store the concrete Function object
into overload_annotations (use arg_type for positional args and default_type for
defaults) so specializations keyed later (lines ~514-516) use the actual
Function instance rather than the generic annotation; update the branches that
currently assign template_type to assign arg_type or default_type accordingly
and keep using get_arg_type/strip_reference and is_callable_annotation checks as
present.
---
Outside diff comments:
In `@warp/_src/codegen.py`:
- Around line 4513-4539: get_references() currently only discovers callable
arguments when the AST arg resolves statically via
adj.resolve_static_expression, so specialized higher-order calls backed by
adj.callable_arg_values (e.g., when func is forwarded into inner/f and invoked)
are missed; update the Call-handling branch in get_references() (the block that
binds arguments with func.signature.bind, apply_defaults, and iterates
bound_arg_nodes) to also check adj.callable_arg_values for the bound argument
node (and for raw arg names where resolution failed) and add any concrete
warp._src.context.Function entries there into the functions dict (same guard:
isinstance(..., Function) and not is_builtin()). Ensure you reference
adj.callable_arg_values lookup when bound_arg_nodes contains nodes that didn't
resolve statically or that are names/attributes representing forwarded callables
so the concrete callable specializations are included.
🪄 Autofix (Beta)
Fix all unresolved CodeRabbit comments on this PR:
- Push a commit to this branch (recommended)
- Create a new PR with the fixes
ℹ️ Review info
⚙️ Run configuration
Configuration used: Path: .coderabbit.yml
Review profile: CHILL
Plan: Enterprise
Run ID: 075e71fa-9c64-4710-9b9b-afd2eee2f986
📒 Files selected for processing (7)
CHANGELOG.mddesign/callable-func-parameters.mddocs/user_guide/limitations.rstwarp/_src/codegen.pywarp/_src/context.pywarp/_src/types.pywarp/tests/test_func.py
✅ Files skipped from review due to trivial changes (2)
- docs/user_guide/limitations.rst
- CHANGELOG.md
🚧 Files skipped from review as they are similar to previous changes (2)
- warp/_src/types.py
- warp/tests/test_func.py
5414529 to
b574078
Compare
There was a problem hiding this comment.
♻️ Duplicate comments (1)
warp/_src/context.py (1)
493-506:⚠️ Potential issue | 🟠 Major | ⚡ Quick winSpecialize callable overload keys by the concrete
Function.Line 496 and Line 506 still write the annotation back into
overload_annotations. Since Lines 516-518 key the instantiated overload from that map, everyCallableargument collapses to the same specialization and a later call can reuse the wrong overload/body/hash.Proposed fix
overload_annotations = {} for name, arg_type, template_type in zip(arg_names, arg_types, template_types, strict=False): if warp._src.types.is_callable_annotation(template_type) and isinstance(arg_type, Function): - overload_annotations[name] = template_type + overload_annotations[name] = arg_type else: overload_annotations[name] = arg_type @@ template_type = f.input_types[k] default_type = warp._src.codegen.strip_reference(warp._src.codegen.get_arg_type(d)) if warp._src.types.is_callable_annotation(template_type) and isinstance(default_type, Function): - overload_annotations[k] = template_type + overload_annotations[k] = default_type else: overload_annotations[k] = default_typeAlso applies to: 516-518
🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the rest with a brief reason, keep changes minimal, and validate. In `@warp/_src/context.py` around lines 493 - 506, The overload_annotations map currently stores the generic callable annotation (template_type) when a parameter is declared as Callable, causing all callable params to collapse to the same specialization; instead, when warp._src.types.is_callable_annotation(template_type) and the runtime value is a concrete Function, assign the concrete Function type (arg_type for parameters, default_type for defaults) into overload_annotations[name] so each callable parameter is specialized by its actual Function; update both the zip loop that handles arg_types (overload_annotations[name] = arg_type) and the defaults loop that handles f.defaults (overload_annotations[k] = default_type) and ensure the later instantiation that keys overloads uses those concrete Function entries.
🤖 Prompt for all review comments with AI agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.
Duplicate comments:
In `@warp/_src/context.py`:
- Around line 493-506: The overload_annotations map currently stores the generic
callable annotation (template_type) when a parameter is declared as Callable,
causing all callable params to collapse to the same specialization; instead,
when warp._src.types.is_callable_annotation(template_type) and the runtime value
is a concrete Function, assign the concrete Function type (arg_type for
parameters, default_type for defaults) into overload_annotations[name] so each
callable parameter is specialized by its actual Function; update both the zip
loop that handles arg_types (overload_annotations[name] = arg_type) and the
defaults loop that handles f.defaults (overload_annotations[k] = default_type)
and ensure the later instantiation that keys overloads uses those concrete
Function entries.
ℹ️ Review info
⚙️ Run configuration
Configuration used: Path: .coderabbit.yml
Review profile: CHILL
Plan: Enterprise
Run ID: 57311a82-690c-4a18-9a42-229586f4bef2
📒 Files selected for processing (7)
CHANGELOG.mddesign/callable-func-parameters.mddocs/user_guide/limitations.rstwarp/_src/codegen.pywarp/_src/context.pywarp/_src/types.pywarp/tests/test_func.py
✅ Files skipped from review due to trivial changes (3)
- docs/user_guide/limitations.rst
- CHANGELOG.md
- design/callable-func-parameters.md
🚧 Files skipped from review as they are similar to previous changes (3)
- warp/_src/types.py
- warp/_src/codegen.py
- warp/tests/test_func.py
b574078 to
72ab41a
Compare
Greptile SummaryThis PR adds first-pass support for
Confidence Score: 5/5Safe to merge. The callable specialization path is well-isolated, all rejection cases are handled explicitly, and module dependency tracking is correctly extended for callable targets. The implementation is careful and consistently applied across codegen, context, and types. Callable values never escape to runtime C++ parameters, specialization caches are properly isolated, and both module hashing and dependency invalidation include callable targets. The two observations are theoretical edge cases that do not affect correctness under any realistic input today. No files require special attention. The specialize_callable_func hash and bind_call_arg_nodes error handling in warp/_src/codegen.py have minor hardening opportunities but no present defects. Important Files Changed
Sequence DiagramsequenceDiagram
participant Kernel
participant AC as Adjoint.add_call
participant GCV as get_callable_arg_values
participant SCF as specialize_callable_func
participant AB as Adjoint.build
Kernel->>AC: apply(double_it, 3.0)
AC->>GCV: bound_args
GCV-->>AC: callable_arg_values
AC->>SCF: "func=apply, callable_arg_values"
SCF->>SCF: compute SHA-256 suffix
SCF->>SCF: shallowcopy func to specialized_func
SCF->>SCF: pop _callable_specializations from clone
SCF-->>AC: specialized_func cached
AC->>AB: build(builder)
AB->>AB: inject callable symbols
AB-->>Kernel: callable arg omitted from C++ signature
Reviews (4): Last reviewed commit: "Split Callable func tests (GH-1424)" | Re-trigger Greptile |
f5b09ce to
ce230dc
Compare
Signed-off-by: Eric Shi <ershi@nvidia.com>
Move callable parameter coverage into its own test module so the generic function test file remains focused and easier to maintain. Signed-off-by: Eric Shi <ershi@nvidia.com>
ce230dc to
2055cb9
Compare
Description
Refs #1424.
This PR adds first-pass support for
Callable-typed parameters in user-defined@wp.funcfunctions. User-defined Warp functions can now be passed from kernels or other functions, including through defaults, keyword arguments, parameterizedCallable[[...], ...]annotations, nested calls, and generic helpers that combineCallablewithAny.The implementation keeps callable values as compile-time function references: user functions are specialized for the concrete callable target, callable parameters are omitted from generated runtime signatures, and callable targets participate in module hashing and dependency invalidation. Built-in Warp functions and callable-specialized functions with custom grad/replay are rejected explicitly for this first pass.
Checklist
Test plan
uv run --extra dev -m warp.tests -s autodetect -k TestFuncuvx pre-commit run --files warp/_src/types.py warp/_src/codegen.py warp/_src/context.py warp/tests/test_func.py CHANGELOG.md design/callable-func-parameters.mduv run --python <py> --extra dev warp/tests/test_func.py TestFunc.test_callable_func_parameter_cpu -vfor Python 3.10, 3.11, 3.12, 3.13, and 3.14Bug fix
Without this PR, the call to
apply(double_it, 3.0)fails during Warp overload resolution/codegen.New feature / enhancement
Callable markers from both
typingandcollections.abcare accepted. Parameterized forms are recognized as type-erased callable markers; their signatures are not validated in this first pass.Summary by CodeRabbit
New Features
Documentation
Tests
Chore