Skip to content

Support Callable parameters in user-defined functions#1454

Open
shi-eric wants to merge 2 commits into
NVIDIA:mainfrom
shi-eric:shi-eric/callable-func-design
Open

Support Callable parameters in user-defined functions#1454
shi-eric wants to merge 2 commits into
NVIDIA:mainfrom
shi-eric:shi-eric/callable-func-design

Conversation

@shi-eric
Copy link
Copy Markdown
Contributor

@shi-eric shi-eric commented May 11, 2026

Description

Refs #1424.

This PR adds first-pass support for Callable-typed parameters in user-defined @wp.func functions. User-defined Warp functions can now be passed from kernels or other functions, including through defaults, keyword arguments, parameterized Callable[[...], ...] annotations, nested calls, and generic helpers that combine Callable with Any.

The implementation keeps callable values as compile-time function references: user functions are specialized for the concrete callable target, callable parameters are omitted from generated runtime signatures, and callable targets participate in module hashing and dependency invalidation. Built-in Warp functions and callable-specialized functions with custom grad/replay are rejected explicitly for this first pass.

Checklist

  • I am familiar with the Contributing Guidelines.
  • New or existing tests cover these changes.
  • The documentation is up to date with these changes.

Test plan

  • uv run --extra dev -m warp.tests -s autodetect -k TestFunc
  • uvx pre-commit run --files warp/_src/types.py warp/_src/codegen.py warp/_src/context.py warp/tests/test_func.py CHANGELOG.md design/callable-func-parameters.md
  • uv run --python <py> --extra dev warp/tests/test_func.py TestFunc.test_callable_func_parameter_cpu -v for Python 3.10, 3.11, 3.12, 3.13, and 3.14

Bug fix

Without this PR, the call to apply(double_it, 3.0) fails during Warp overload resolution/codegen.

from typing import Callable

import warp as wp


@wp.func
def double_it(x: float):
    return x * 2.0


@wp.func
def apply(g: Callable, x: float):
    return g(x)


@wp.kernel
def k(out: wp.array(dtype=float)):
    out[0] = apply(double_it, 3.0)

New feature / enhancement

Callable markers from both typing and collections.abc are accepted. Parameterized forms are recognized as type-erased callable markers; their signatures are not validated in this first pass.

from collections.abc import Callable

import warp as wp


@wp.func
def triple_it(x: float):
    return x * 3.0


@wp.func
def apply(g: Callable[[float], float], x: float):
    return g(x)

Summary by CodeRabbit

  • New Features

    • Support passing user-defined Warp functions as callable parameters (parameterized & unparameterized), with per-call specialization, callable defaults, and callable parameters omitted from emitted native signatures.
  • Documentation

    • Added design doc detailing callable-parameter semantics, overload rules, compatibility goals, hashing/invalidation, and supported/rejected cases.
  • Tests

    • New kernels and unit tests covering callable parameters, defaults, specialization, module-hash effects, return-annotation preservation, dependency invalidation, and rejection cases.
  • Chore

    • Updated changelog to document callable-parameter support.

Review Change Stack

@copy-pr-bot
Copy link
Copy Markdown

copy-pr-bot Bot commented May 11, 2026

Auto-sync is disabled for draft pull requests in this repository. Workflows must be run manually.

Contributors can view more details about this message here.

@coderabbitai
Copy link
Copy Markdown

coderabbitai Bot commented May 11, 2026

Note

Reviews paused

It looks like this branch is under active development. To avoid overwhelming you with review comments due to an influx of new commits, CodeRabbit has automatically paused this review. You can configure this behavior by changing the reviews.auto_review.auto_pause_after_reviewed_commits setting.

Use the following commands to manage reviews:

  • @coderabbitai resume to resume automatic reviews.
  • @coderabbitai review to trigger a single review.

Use the checkboxes below for quick actions:

  • ▶️ Resume reviews
  • 🔍 Trigger review
📝 Walkthrough

Walkthrough

This PR implements support for Callable-typed parameters in user-defined @wp.func functions. Callable parameters are treated as compile-time function references that trigger per-call-site specialization; callable arguments/defaults are extracted and hashed for dependency tracking, bound into adjoint symbols, and omitted from emitted C++ signatures.

Changes

Callable Parameter Support

Layer / File(s) Summary
Design Documentation
design/callable-func-parameters.md
Design doc added: semantics, recognition rules, specialization, hashing, rejection cases, and test plan.
Callable Type Detection
warp/_src/types.py
is_callable_annotation() added; get_type_code() uses it to emit "c" for typing/collections.abc callables.
Overload Instantiation
warp/_src/context.py
Function.get_overload() preserved callable-template params and computes instantiated overload signature keys from callable-aware annotations (handles defaults).
Overload Matching
warp/_src/codegen.py
func_match_args() updated to match callable-annotated params when bound argument is a Warp Function.
Specialization Helpers
warp/_src/codegen.py
Added get_callable_arg_values(), get_default_arg_value(), specialize_callable_func(); shallow-copy/cache specialized function variants with hashed native_func suffixes.
Adjoint & Call Integration
warp/_src/codegen.py
Adjoint.build() accepts callable_arg_values; adj.symbols bind callable params to concrete Function values. add_call() applies callable-aware defaults, triggers specialization, and omits callable params from runtime dispatch.
Codegen Signatures / Dispatch
warp/_src/codegen.py
Forward and reverse signature generation skip callable-annotated parameters so callables are not emitted in the C++ ABI; dispatch filtering removes callable args from C++ calls.
Dependency Tracking
warp/_src/context.py, warp/_src/codegen.py
Reference discovery resolves callable-typed arguments to concrete Function targets and records their modules for hashing and invalidation.
Tests
warp/tests/test_func.py
Adds helpers, kernels, and unit tests covering typing/collections.abc callables, parameterized/generic/default/nested usage, module hash invalidation, return-annotation checks, dependency propagation, and rejection cases (builtins/custom grad/replay).
Docs / Changelog
docs/user_guide/limitations.rst, CHANGELOG.md
Adds limitation note and changelog entry (GH-1424).

Sequence Diagram(s)

sequenceDiagram
    participant Kernel as Kernel (caller)
    participant Codegen as Codegen (add_call / func_match_args)
    participant Types as Types (is_callable_annotation)
    participant Specialize as Specialization (specialize_callable_func / get_callable_arg_values)
    participant Adjoint as Adjoint.build
    participant CppGen as C++ Generation

    Kernel->>Codegen: call apply(double_it, 3.0)
    Codegen->>Types: detect callable parameter via is_callable_annotation()
    Codegen->>Codegen: bind args, match overload (callable-aware)
    Codegen->>Specialize: extract callable args (get_callable_arg_values)
    Specialize->>Specialize: create/cache specialized clone (hash suffix)
    Codegen->>Adjoint: Adjoint.build(callable_arg_values)
    Adjoint->>Adjoint: populate adj.symbols with concrete Function values
    Codegen->>CppGen: emit call (skip callable params in forward/reverse signatures)
Loading

Estimated code review effort

🎯 4 (Complex) | ⏱️ ~60 minutes

🚥 Pre-merge checks | ✅ 4 | ❌ 1

❌ Failed checks (1 warning)

Check name Status Explanation Resolution
Docstring Coverage ⚠️ Warning Docstring coverage is 11.54% which is insufficient. The required threshold is 80.00%. Write docstrings for the functions missing them to satisfy the coverage threshold.
✅ Passed checks (4 passed)
Check name Status Explanation
Title check ✅ Passed The title 'Support Callable parameters in user-defined functions' directly and clearly summarizes the main change in the PR, matching the primary objective of enabling Callable-typed parameters in @wp.func functions.
Linked Issues check ✅ Passed The PR comprehensively addresses all coding requirements from issue #1424: unified Callable annotation handling, overload resolution fixes, type code generation ('c' code), callable-aware specialization, module hashing/dependency tracking, and comprehensive test coverage with documented limitations.
Out of Scope Changes check ✅ Passed All changes are directly related to implementing Callable parameter support per GH-1424: type detection, overload matching, code generation, callable specialization, dependency tracking, tests, and documentation of limitations. No unrelated changes detected.
Description Check ✅ Passed Check skipped - CodeRabbit’s high-level summary is enabled.

✏️ Tip: You can configure your own custom pre-merge checks in the settings.

✨ Finishing Touches
🧪 Generate unit tests (beta)
  • Create PR with unit tests

Comment @coderabbitai help to get the list of available commands and usage tips.

Copy link
Copy Markdown

@coderabbitai coderabbitai Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 2

🧹 Nitpick comments (2)
warp/_src/codegen.py (1)

945-992: 💤 Low value

Consider documenting the specialization key stability for caching.

The specialization cache key at lines 952-954 uses a tuple of (name, callable_func) pairs, relying on Function object identity for cache hits. If the same logical function is recreated (e.g., module reload), cache misses will occur, generating duplicate specializations.

This is likely acceptable for correctness (each specialization is valid), but consider adding a brief comment explaining this design choice for future maintainers.

Also, the check at lines 946-950 correctly rejects functions with custom grad/replay—good defensive validation.

🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In `@warp/_src/codegen.py` around lines 945 - 992, The specialization cache key in
specialize_callable_func currently builds specialization_key from (name,
callable_arg_values[name]) and therefore relies on Function object identity for
cache hits; add a brief explanatory comment above the specialization_key
creation stating that this is an intentional design choice (it may produce cache
misses if the same logical function object is recreated, e.g., on module
reload), why that is acceptable for correctness, and note that using
callable_func.key/native_func could be an alternative for a stable key if
desired in the future; reference specialize_callable_func, specialization_key,
and func._callable_specializations in the comment so maintainers can find the
code easily.
CHANGELOG.md (1)

32-33: ⚡ Quick win

Use imperative “Add …” phrasing for this Unreleased entry.

Line 32 currently starts with “Support …”; the changelog convention here asks for imperative present tense.

Suggested wording
-- Support passing user-defined Warp functions to `Callable` parameters in `@wp.func` functions
+- Add support for passing user-defined Warp functions to `Callable` parameters in `@wp.func` functions

As per coding guidelines: CHANGELOG.md entries in Unreleased should use imperative present tense (“Add X”).

🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In `@CHANGELOG.md` around lines 32 - 33, Update the Unreleased changelog entry
that currently reads "Support passing user-defined Warp functions to `Callable`
parameters in `@wp.func` functions" to use imperative present-tense phrasing;
change it to something like "Add support for passing user-defined Warp functions
to `Callable` parameters in `@wp.func` functions" so the entry follows the
repository's changelog convention.
🤖 Prompt for all review comments with AI agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

Inline comments:
In `@design/callable-func-parameters.md`:
- Around line 41-43: The document's non-goal wording conflicts with the linked
GH-1424 objective about accepting built-in callables; update the text around the
callable targets section to reconcile them by either (A) removing or rephrasing
the line that excludes built-ins and explicitly stating that built-in Warp
functions like wp.sin and wp.add are supported as callable arguments, or (B) if
built-ins truly remain out of scope, change the GH-1424 reference/closure
language to reflect a narrower objective; ensure references to "wp.sin" and
"wp.add" and the GH-1424 issue ID are corrected so the scope described matches
the linked issue.
- Around line 15-18: The snippet imports Callable twice which shadows and
confuses the reader; update the examples so they don't clobber the same name by
either (a) using only one import (prefer collections.abc.Callable for modern
code) and removing the other, (b) aliasing one import (e.g., import Callable as
TypingCallable) to show the difference, or (c) present two separate, clearly
labeled snippets instead of both lines together—adjust the lines containing
"from typing import Callable" and "from collections.abc import Callable"
accordingly.

---

Nitpick comments:
In `@CHANGELOG.md`:
- Around line 32-33: Update the Unreleased changelog entry that currently reads
"Support passing user-defined Warp functions to `Callable` parameters in
`@wp.func` functions" to use imperative present-tense phrasing; change it to
something like "Add support for passing user-defined Warp functions to
`Callable` parameters in `@wp.func` functions" so the entry follows the
repository's changelog convention.

In `@warp/_src/codegen.py`:
- Around line 945-992: The specialization cache key in specialize_callable_func
currently builds specialization_key from (name, callable_arg_values[name]) and
therefore relies on Function object identity for cache hits; add a brief
explanatory comment above the specialization_key creation stating that this is
an intentional design choice (it may produce cache misses if the same logical
function object is recreated, e.g., on module reload), why that is acceptable
for correctness, and note that using callable_func.key/native_func could be an
alternative for a stable key if desired in the future; reference
specialize_callable_func, specialization_key, and func._callable_specializations
in the comment so maintainers can find the code easily.
🪄 Autofix (Beta)

Fix all unresolved CodeRabbit comments on this PR:

  • Push a commit to this branch (recommended)
  • Create a new PR with the fixes

ℹ️ Review info
⚙️ Run configuration

Configuration used: Path: .coderabbit.yml

Review profile: CHILL

Plan: Enterprise

Run ID: 9019ac3f-bc6a-4ec7-b04f-ab04997190aa

📥 Commits

Reviewing files that changed from the base of the PR and between bd91b99 and 230bf7f.

📒 Files selected for processing (6)
  • CHANGELOG.md
  • design/callable-func-parameters.md
  • warp/_src/codegen.py
  • warp/_src/context.py
  • warp/_src/types.py
  • warp/tests/test_func.py

Comment thread design/callable-func-parameters.md
Comment thread design/callable-func-parameters.md
@shi-eric shi-eric force-pushed the shi-eric/callable-func-design branch from 230bf7f to f6a270d Compare May 11, 2026 03:26
@shi-eric shi-eric marked this pull request as ready for review May 11, 2026 03:29
@shi-eric shi-eric force-pushed the shi-eric/callable-func-design branch 2 times, most recently from af680cc to 99a8d88 Compare May 11, 2026 03:40
Copy link
Copy Markdown

@coderabbitai coderabbitai Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 3

Caution

Some comments are outside the diff and can’t be posted inline due to platform limitations.

⚠️ Outside diff range comments (1)
warp/_src/codegen.py (1)

4513-4539: ⚠️ Potential issue | 🟠 Major | ⚡ Quick win

Track callable-parameter specializations in get_references() too.

This only finds callable targets when the AST argument itself resolves statically. For specialized higher-order functions, forwarded names like inner(f, x) and direct calls like f(x) are backed by adj.callable_arg_values, not by a global/static lookup, so the concrete callable never gets added to functions. That leaves module hashing/dependency invalidation stale when the passed function changes.

Suggested direction
 def get_references(adj) -> tuple[dict[str, Any], dict[Any, Any], dict[warp._src.context.Function, Any]]:
     """Traverses ``adj.tree`` and returns referenced constants, types, and user-defined functions."""

     local_variables = set()  # Track local variables appearing on the LHS so we know when variables are shadowed

     constants: dict[str, Any] = {}
     types: dict[Struct | type, Any] = {}
     functions: dict[warp._src.context.Function, Any] = {}
+    callable_arg_values = getattr(adj, "callable_arg_values", None) or {}

     for node in ast.walk(adj.tree):
         ...
         elif isinstance(node, ast.Call):
             func, _ = adj.resolve_static_expression(node.func, eval_types=False)
+            if func is None and isinstance(node.func, ast.Name):
+                func = callable_arg_values.get(node.func.id)
+
+            if isinstance(func, warp._src.context.Function) and not func.is_builtin():
+                functions[func] = None
+
                 try:
                     bound_args = func.signature.bind(*node.args, **{kw.arg: kw.value for kw in node.keywords})
                 except TypeError:
                     bound_arg_nodes = {}
                 else:
                     ...
                 for arg_name, arg_node in bound_arg_nodes.items():
                     if not warp._src.types.is_callable_annotation(func.input_types.get(arg_name)):
                         continue

                     if isinstance(arg_node, warp._src.context.Function):
                         callable_func = arg_node
+                    elif isinstance(arg_node, ast.Name):
+                        callable_func = callable_arg_values.get(arg_node.id)
                     else:
                         callable_func, _ = adj.resolve_static_expression(arg_node, eval_types=False)

                     if isinstance(callable_func, warp._src.context.Function) and not callable_func.is_builtin():
                         functions[callable_func] = None
🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In `@warp/_src/codegen.py` around lines 4513 - 4539, get_references() currently
only discovers callable arguments when the AST arg resolves statically via
adj.resolve_static_expression, so specialized higher-order calls backed by
adj.callable_arg_values (e.g., when func is forwarded into inner/f and invoked)
are missed; update the Call-handling branch in get_references() (the block that
binds arguments with func.signature.bind, apply_defaults, and iterates
bound_arg_nodes) to also check adj.callable_arg_values for the bound argument
node (and for raw arg names where resolution failed) and add any concrete
warp._src.context.Function entries there into the functions dict (same guard:
isinstance(..., Function) and not is_builtin()). Ensure you reference
adj.callable_arg_values lookup when bound_arg_nodes contains nodes that didn't
resolve statically or that are names/attributes representing forwarded callables
so the concrete callable specializations are included.
🤖 Prompt for all review comments with AI agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

Inline comments:
In `@design/callable-func-parameters.md`:
- Around line 43-46: The PR currently conflicts with the design non-goal that
excludes built-in Warp callables (e.g., wp.sin, wp.add) from being accepted as
callable parameters for user-defined `@wp.func` per callable-func-parameters.md;
either update the PR metadata so it does not close GH-1424 (change "Closes
`#1424`" to "Addresses/Related to/Partial implementation of `#1424`" in the PR
description/commit message) or implement full support for built-in callables
(update the `@wp.func` callable-argument handling, specialization hashing, and
dependency tracking code paths that validate callables so built-ins are
recognized and included) and then remove/update the non-goal text and document
the addition.

In `@warp/_src/codegen.py`:
- Around line 972-975: The specialized callable currently clears its value_func
(specialized_func.value_func = None) which breaks downstream use where
add_call() invokes func.value_func(...); restore the original value_func on the
specialized copy instead of nulling it (e.g., assign specialized_func.value_func
= func.value_func or simply remove the line that sets it to None) so the
specialized_func keeps the callable-return resolver used later by add_call().

In `@warp/_src/context.py`:
- Around line 491-506: The overload_annotations map is being populated with the
template annotation (template_type) for callable parameters, causing all
Callable args to collapse to the same specialization; instead, when
warp._src.types.is_callable_annotation(template_type) and the runtime type
(arg_type or default_type) is a Function, store the concrete Function object
into overload_annotations (use arg_type for positional args and default_type for
defaults) so specializations keyed later (lines ~514-516) use the actual
Function instance rather than the generic annotation; update the branches that
currently assign template_type to assign arg_type or default_type accordingly
and keep using get_arg_type/strip_reference and is_callable_annotation checks as
present.

---

Outside diff comments:
In `@warp/_src/codegen.py`:
- Around line 4513-4539: get_references() currently only discovers callable
arguments when the AST arg resolves statically via
adj.resolve_static_expression, so specialized higher-order calls backed by
adj.callable_arg_values (e.g., when func is forwarded into inner/f and invoked)
are missed; update the Call-handling branch in get_references() (the block that
binds arguments with func.signature.bind, apply_defaults, and iterates
bound_arg_nodes) to also check adj.callable_arg_values for the bound argument
node (and for raw arg names where resolution failed) and add any concrete
warp._src.context.Function entries there into the functions dict (same guard:
isinstance(..., Function) and not is_builtin()). Ensure you reference
adj.callable_arg_values lookup when bound_arg_nodes contains nodes that didn't
resolve statically or that are names/attributes representing forwarded callables
so the concrete callable specializations are included.
🪄 Autofix (Beta)

Fix all unresolved CodeRabbit comments on this PR:

  • Push a commit to this branch (recommended)
  • Create a new PR with the fixes

ℹ️ Review info
⚙️ Run configuration

Configuration used: Path: .coderabbit.yml

Review profile: CHILL

Plan: Enterprise

Run ID: 075e71fa-9c64-4710-9b9b-afd2eee2f986

📥 Commits

Reviewing files that changed from the base of the PR and between af680cc and 99a8d88.

📒 Files selected for processing (7)
  • CHANGELOG.md
  • design/callable-func-parameters.md
  • docs/user_guide/limitations.rst
  • warp/_src/codegen.py
  • warp/_src/context.py
  • warp/_src/types.py
  • warp/tests/test_func.py
✅ Files skipped from review due to trivial changes (2)
  • docs/user_guide/limitations.rst
  • CHANGELOG.md
🚧 Files skipped from review as they are similar to previous changes (2)
  • warp/_src/types.py
  • warp/tests/test_func.py

Comment thread design/callable-func-parameters.md
Comment thread warp/_src/codegen.py
Comment thread warp/_src/context.py
@shi-eric shi-eric force-pushed the shi-eric/callable-func-design branch 2 times, most recently from 5414529 to b574078 Compare May 11, 2026 03:50
Copy link
Copy Markdown

@coderabbitai coderabbitai Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

♻️ Duplicate comments (1)
warp/_src/context.py (1)

493-506: ⚠️ Potential issue | 🟠 Major | ⚡ Quick win

Specialize callable overload keys by the concrete Function.

Line 496 and Line 506 still write the annotation back into overload_annotations. Since Lines 516-518 key the instantiated overload from that map, every Callable argument collapses to the same specialization and a later call can reuse the wrong overload/body/hash.

Proposed fix
                 overload_annotations = {}
                 for name, arg_type, template_type in zip(arg_names, arg_types, template_types, strict=False):
                     if warp._src.types.is_callable_annotation(template_type) and isinstance(arg_type, Function):
-                        overload_annotations[name] = template_type
+                        overload_annotations[name] = arg_type
                     else:
                         overload_annotations[name] = arg_type
@@
                         template_type = f.input_types[k]
                         default_type = warp._src.codegen.strip_reference(warp._src.codegen.get_arg_type(d))
                         if warp._src.types.is_callable_annotation(template_type) and isinstance(default_type, Function):
-                            overload_annotations[k] = template_type
+                            overload_annotations[k] = default_type
                         else:
                             overload_annotations[k] = default_type

Also applies to: 516-518

🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In `@warp/_src/context.py` around lines 493 - 506, The overload_annotations map
currently stores the generic callable annotation (template_type) when a
parameter is declared as Callable, causing all callable params to collapse to
the same specialization; instead, when
warp._src.types.is_callable_annotation(template_type) and the runtime value is a
concrete Function, assign the concrete Function type (arg_type for parameters,
default_type for defaults) into overload_annotations[name] so each callable
parameter is specialized by its actual Function; update both the zip loop that
handles arg_types (overload_annotations[name] = arg_type) and the defaults loop
that handles f.defaults (overload_annotations[k] = default_type) and ensure the
later instantiation that keys overloads uses those concrete Function entries.
🤖 Prompt for all review comments with AI agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

Duplicate comments:
In `@warp/_src/context.py`:
- Around line 493-506: The overload_annotations map currently stores the generic
callable annotation (template_type) when a parameter is declared as Callable,
causing all callable params to collapse to the same specialization; instead,
when warp._src.types.is_callable_annotation(template_type) and the runtime value
is a concrete Function, assign the concrete Function type (arg_type for
parameters, default_type for defaults) into overload_annotations[name] so each
callable parameter is specialized by its actual Function; update both the zip
loop that handles arg_types (overload_annotations[name] = arg_type) and the
defaults loop that handles f.defaults (overload_annotations[k] = default_type)
and ensure the later instantiation that keys overloads uses those concrete
Function entries.

ℹ️ Review info
⚙️ Run configuration

Configuration used: Path: .coderabbit.yml

Review profile: CHILL

Plan: Enterprise

Run ID: 57311a82-690c-4a18-9a42-229586f4bef2

📥 Commits

Reviewing files that changed from the base of the PR and between 99a8d88 and 5414529.

📒 Files selected for processing (7)
  • CHANGELOG.md
  • design/callable-func-parameters.md
  • docs/user_guide/limitations.rst
  • warp/_src/codegen.py
  • warp/_src/context.py
  • warp/_src/types.py
  • warp/tests/test_func.py
✅ Files skipped from review due to trivial changes (3)
  • docs/user_guide/limitations.rst
  • CHANGELOG.md
  • design/callable-func-parameters.md
🚧 Files skipped from review as they are similar to previous changes (3)
  • warp/_src/types.py
  • warp/_src/codegen.py
  • warp/tests/test_func.py

@shi-eric shi-eric force-pushed the shi-eric/callable-func-design branch from b574078 to 72ab41a Compare May 11, 2026 03:55
@greptile-apps
Copy link
Copy Markdown

greptile-apps Bot commented May 11, 2026

Greptile Summary

This PR adds first-pass support for Callable-typed parameters in user-defined @wp.func functions, enabling higher-order patterns like apply(double_it, x) from kernels or other functions. The implementation treats callable values as compile-time function references: functions are specialized per concrete callable target, callable params are omitted from emitted C++ signatures, and callable targets participate in module hashing and dependency invalidation.

  • warp/_src/types.py introduces is_callable_annotation() using typing.get_origin + collections.abc.Callable, replacing the incorrect isinstance(arg_type, Callable) catch-all in get_type_code.
  • warp/_src/codegen.py adds specialization helpers (specialize_callable_func, bind_call_arg_nodes, iter_call_callable_arg_targets) and updates Adjoint.build to inject callable symbols and skip callable args from C++ signature generation.
  • warp/_src/context.py updates generic template instantiation and _find_references to correctly track callable-target module dependencies.

Confidence Score: 5/5

Safe to merge. The callable specialization path is well-isolated, all rejection cases are handled explicitly, and module dependency tracking is correctly extended for callable targets.

The implementation is careful and consistently applied across codegen, context, and types. Callable values never escape to runtime C++ parameters, specialization caches are properly isolated, and both module hashing and dependency invalidation include callable targets. The two observations are theoretical edge cases that do not affect correctness under any realistic input today.

No files require special attention. The specialize_callable_func hash and bind_call_arg_nodes error handling in warp/_src/codegen.py have minor hardening opportunities but no present defects.

Important Files Changed

Filename Overview
warp/_src/codegen.py Core codegen changes: adds callable specialization helpers, updates Adjoint.build to inject callable symbols, and skips callable args in C++ signature emission. Minor issues with hash field separators and broad TypeError catch.
warp/_src/context.py Updates generic template instantiation to preserve Callable annotations in overload_annotations and adds callable-target module dependency tracking in _find_references.
warp/_src/types.py Adds is_callable_annotation() predicate and moves callable type-code handling earlier in get_type_code, removing the incorrect isinstance catch-all.
warp/tests/test_func_callable.py New dedicated test module with thorough coverage of typing/abc variants, parameterized forms, generics, defaults, keyword args, nested calls, module hash/dependency effects, and rejection paths.

Sequence Diagram

sequenceDiagram
    participant Kernel
    participant AC as Adjoint.add_call
    participant GCV as get_callable_arg_values
    participant SCF as specialize_callable_func
    participant AB as Adjoint.build

    Kernel->>AC: apply(double_it, 3.0)
    AC->>GCV: bound_args
    GCV-->>AC: callable_arg_values
    AC->>SCF: "func=apply, callable_arg_values"
    SCF->>SCF: compute SHA-256 suffix
    SCF->>SCF: shallowcopy func to specialized_func
    SCF->>SCF: pop _callable_specializations from clone
    SCF-->>AC: specialized_func cached
    AC->>AB: build(builder)
    AB->>AB: inject callable symbols
    AB-->>Kernel: callable arg omitted from C++ signature
Loading

Reviews (4): Last reviewed commit: "Split Callable func tests (GH-1424)" | Re-trigger Greptile

Comment thread warp/_src/codegen.py
Comment thread warp/_src/codegen.py
@shi-eric shi-eric force-pushed the shi-eric/callable-func-design branch 2 times, most recently from f5b09ce to ce230dc Compare May 11, 2026 05:07
shi-eric added 2 commits May 14, 2026 07:11
Signed-off-by: Eric Shi <ershi@nvidia.com>
Move callable parameter coverage into its own test module so the generic function test file remains focused and easier to maintain.

Signed-off-by: Eric Shi <ershi@nvidia.com>
@shi-eric shi-eric force-pushed the shi-eric/callable-func-design branch from ce230dc to 2055cb9 Compare May 14, 2026 07:25
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant