Skip to content

Add JAX FFI Host support#1446

Draft
loney7 wants to merge 11 commits into
NVIDIA:mainfrom
loney7:loney7/ffi-host-support
Draft

Add JAX FFI Host support#1446
loney7 wants to merge 11 commits into
NVIDIA:mainfrom
loney7:loney7/ffi-host-support

Conversation

@loney7
Copy link
Copy Markdown

@loney7 loney7 commented May 8, 2026

Description

This PR adds support for running JAX FFI callbacks on the CPU (Host) in addition to CUDA.

Changes:

  • Registered FFI targets for both "CUDA" and "Host" platforms in register_ffi_callback.
  • Handled the "Host" platform case in ffi_callback by using the CPU device and bypassing CUDA-specific features like streams and graphs.
  • Updated FfiCallable to reconstruct arguments and execute the function on the CPU when running on the Host platform.

Checklist

  • I am familiar with the Contributing Guidelines.
  • New or existing tests cover these changes.
  • The documentation is up to date with these changes.

Test plan

You can verify these changes by running the JAX interop tests which include FFI tests. Ensure they pass on both CPU and GPU (if available).

uv run warp/tests/interop/test_jax.py


<!-- This is an auto-generated comment: release notes by coderabbit.ai -->
## Summary by CodeRabbit

* **New Features**
  * Separate CUDA and CPU execution paths for FFI callbacks with device-aware execution scoping.
  * CUDA graph compatibility enabled only for CUDA execution.
  * CPU/host path can take a direct Python execution route for faster host calls.
  * Separate CUDA and Host callback registrations for JAX interop.

* **Tests**
  * Added CPU/host FFI tests covering kernels and callables (add, sincos, in/out args, scale).
<!-- end of auto-generated comment: release notes by coderabbit.ai -->

This allows JAX FFI callbacks to run on CPU (Host) in addition to CUDA.

Signed-off-by: Ankit Jain <kitsrish@google.com>
@copy-pr-bot
Copy link
Copy Markdown

copy-pr-bot Bot commented May 8, 2026

This pull request requires additional validation before any workflows can run on NVIDIA's runners.

Pull request vetters can view their responsibilities here.

Contributors can view more details about this message here.

@coderabbitai
Copy link
Copy Markdown

coderabbitai Bot commented May 8, 2026

Note

Reviews paused

It looks like this branch is under active development. To avoid overwhelming you with review comments due to an influx of new commits, CodeRabbit has automatically paused this review. You can configure this behavior by changing the reviews.auto_review.auto_pause_after_reviewed_commits setting.

Use the following commands to manage reviews:

  • @coderabbitai resume to resume automatic reviews.
  • @coderabbitai review to trigger a single review.

Use the checkboxes below for quick actions:

  • ▶️ Resume reviews
  • 🔍 Trigger review
📝 Walkthrough

Walkthrough

This pull request extends JAX experimental FFI callbacks to support both CUDA and Host (CPU) execution paths. Callbacks now accept a platform parameter, register separate FFI targets for each platform, and conditionally dispatch to platform-appropriate device and kernel launch logic. CUDA-specific traits and graph capture modes are guarded by platform checks.

Changes

CUDA and Host FFI Execution

Layer / File(s) Summary
Callback Protocol & Platform Parameter
warp/_src/jax_experimental/ffi.py
FfiKernel.ffi_callback(), FfiCallable.ffi_callback(), and register_ffi_callback() callbacks gain platform="CUDA" parameter. CUDA graph compatibility traits are now enabled only when platform=="CUDA".
Dual-Platform FFI Registration
warp/_src/jax_experimental/ffi.py
FfiKernel and FfiCallable register separate FFI capsules for both CUDA and Host platforms, each passing the appropriate platform argument. register_ffi_callback() stores capsules under distinct registry keys (_cuda and _host suffixes).
FfiKernel Platform-Conditional Launch
warp/_src/jax_experimental/ffi.py
FfiKernel execution branches by platform: CUDA selects CUDA device and stream, calls wp_cuda_launch_kernel; Host uses CPU device with no stream, calls wp_cpu_launch_kernel.
FfiCallable Host Execution Short-Circuit
warp/_src/jax_experimental/ffi.py
When platform=="Host", FfiCallable reconstructs Warp arrays on the CPU device from the call frame and directly invokes the wrapped Python function, bypassing all CUDA graph logic.
FfiCallable CUDA Execution & Graph Capture
warp/_src/jax_experimental/ffi.py
CUDA graph modes are guarded by device.is_cuda. Execution scopes conditionally use ScopedStream (when stream present) or ScopedDevice, restricting graph capture and replay to CUDA devices only.
CPU Host Tests
warp/tests/interop/test_jax.py
New CPU-only jax_kernel and jax_callable tests added and registered in TestJax with devices=None.

🎯 4 (Complex) | ⏱️ ~45 minutes

🚥 Pre-merge checks | ✅ 4 | ❌ 1

❌ Failed checks (1 warning)

Check name Status Explanation Resolution
Docstring Coverage ⚠️ Warning Docstring coverage is 5.00% which is insufficient. The required threshold is 80.00%. Write docstrings for the functions missing them to satisfy the coverage threshold.
✅ Passed checks (4 passed)
Check name Status Explanation
Description Check ✅ Passed Check skipped - CodeRabbit’s high-level summary is enabled.
Title check ✅ Passed The title 'Add JAX FFI Host support' directly and clearly summarizes the main change: adding Host/CPU platform support to JAX FFI callbacks alongside existing CUDA support.
Linked Issues check ✅ Passed Check skipped because no linked issues were found for this pull request.
Out of Scope Changes check ✅ Passed Check skipped because no linked issues were found for this pull request.

✏️ Tip: You can configure your own custom pre-merge checks in the settings.

✨ Finishing Touches
🧪 Generate unit tests (beta)
  • Create PR with unit tests

Comment @coderabbitai help to get the list of available commands and usage tips.

@greptile-apps
Copy link
Copy Markdown

greptile-apps Bot commented May 8, 2026

Greptile Summary

This PR extends JAX FFI interop to support CPU (Host) execution in addition to CUDA, by registering separate FFI targets for both platforms in FfiKernel, FfiCallable, and register_ffi_callback, and routing Host-platform callbacks through a CPU-specific code path.

  • FfiKernel: Registers dual CUDA/Host callbacks; Host path selects the CPU device, builds a ctypes ArgsStruct, and calls wp_cpu_launch_kernel instead of wp_cuda_launch_kernel.
  • FfiCallable: Registers dual callbacks; Host path early-returns after calling _reconstruct_args (now a factored-out helper) with the CPU device and wp.ScopedDevice, bypassing all graph-capture logic.
  • register_ffi_callback: Registers dual callbacks and stores them under distinct registry keys ({name}_cuda / {name}_host); adds six new CPU-focused integration tests.

Confidence Score: 3/5

The Host execution path in register_ffi_callback will attempt to call a CUDA-only XLA stream API on every Host invocation, which can crash or silently corrupt state before any user function is reached.

The NameError bugs from the previous review round are correctly fixed. However, register_ffi_callback's inner ffi_callback still constructs an ExecutionContext unconditionally on every call, even on the Host platform, and ExecutionContext.init always invokes get_stream_from_callframe which calls the CUDA-only XLA_FFI_Stream_Get. On a Host-platform invocation there is no CUDA stream in the call frame; this is a present defect on every Host call made through register_ffi_callback.

warp/_src/jax_experimental/ffi.py — specifically the register_ffi_callback inner ffi_callback where ExecutionContext is constructed without a platform guard, and warp/_src/jax_experimental/xla_ffi.py where ExecutionContext.init calls get_stream_from_callframe unconditionally.

Important Files Changed

Filename Overview
warp/_src/jax_experimental/ffi.py Adds Host platform support to FfiKernel, FfiCallable, and register_ffi_callback; fixes self-referential NameError bugs from the prior round; ExecutionContext still unconditionally calls the CUDA-only get_stream_from_callframe for register_ffi_callback Host calls
warp/tests/interop/test_jax.py Adds six Host-platform FFI tests (kernel add/sincos/in-out/scale, callable scale/in-out); always registered regardless of CUDA availability; test structure and assertions are correct

Flowchart

%%{init: {'theme': 'neutral'}}%%
flowchart TD
    A[JAX dispatches FFI call] --> B{Platform?}
    B -->|CUDA| C[ffi_callback platform=CUDA]
    B -->|Host| D[ffi_callback platform=Host]

    C --> E[get_device_ordinal_from_callframe]
    C --> F[get_stream_from_callframe]
    E & F --> G{FfiKernel or FfiCallable?}

    D --> H{FfiKernel or FfiCallable?}

    G -->|FfiKernel| I[wp_cuda_launch_kernel]
    G -->|FfiCallable| J[Graph mode logic → func]

    H -->|FfiKernel| K[Build ArgsStruct → wp_cpu_launch_kernel]
    H -->|FfiCallable| L[_reconstruct_args → ScopedDevice → func]

    subgraph register_ffi_callback
        D2[Host platform] --> M[ExecutionContext ⚠️ calls get_stream_from_callframe]
        M --> N[func inputs outputs attrs ctx]
    end
Loading

Comments Outside Diff (1)

  1. warp/_src/jax_experimental/ffi.py, line 1750 (link)

    P1 ExecutionContext unconditionally calls get_stream_from_callframe on the Host platform

    ExecutionContext.__init__ (in xla_ffi.py) always calls get_stream_from_callframe, which invokes api.contents.XLA_FFI_Stream_Get — a CUDA-only XLA FFI API. When ffi_callback is dispatched from the "Host" platform, there is no CUDA stream in the call frame; the Host XLA runtime may not implement XLA_FFI_Stream_Get at all, leading to a null-function-pointer dereference or a garbage ctx.stream value. Now that platform is a parameter of ffi_callback, ExecutionContext construction (or at least the ctx.stream assignment inside it) should be guarded on platform == "CUDA" before being passed to func.

Reviews (6): Last reviewed commit: "Address PR review comments for JAX FFI H..." | Re-trigger Greptile

Comment on lines +226 to +229
self.callback_func_host = FFI_CCALLFUNC(lambda call_frame: self.ffi_callback(call_frame, platform="Host"))
ffi_ccall_address_host = ctypes.cast(self.callback_func_host, ctypes.c_void_p)
ffi_capsule_host = jax.ffi.pycapsule(ffi_capsule_host.value)
jax.ffi.register_ffi_target(self.name, ffi_capsule_host, platform="Host")
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P0 Self-referential NameError in FfiKernel.__init__: ffi_capsule_host is passed as its own argument to pycapsule before the name is ever assigned. This will raise NameError: name 'ffi_capsule_host' is not defined every time an FfiKernel is instantiated, making the entire Host platform registration dead code. The address variable ffi_ccall_address_host should be used here instead — consistent with how the CUDA path is written above.

Suggested change
self.callback_func_host = FFI_CCALLFUNC(lambda call_frame: self.ffi_callback(call_frame, platform="Host"))
ffi_ccall_address_host = ctypes.cast(self.callback_func_host, ctypes.c_void_p)
ffi_capsule_host = jax.ffi.pycapsule(ffi_capsule_host.value)
jax.ffi.register_ffi_target(self.name, ffi_capsule_host, platform="Host")
self.callback_func_host = FFI_CCALLFUNC(lambda call_frame: self.ffi_callback(call_frame, platform="Host"))
ffi_ccall_address_host = ctypes.cast(self.callback_func_host, ctypes.c_void_p)
ffi_capsule_host = jax.ffi.pycapsule(ffi_ccall_address_host.value)
jax.ffi.register_ffi_target(self.name, ffi_capsule_host, platform="Host")

Comment on lines +1775 to +1777
ffi_ccall_address_host = ctypes.cast(callback_func_host, ctypes.c_void_p)
ffi_capsule_host = jax.ffi.pycapsule(ffi_capsule_host.value)
jax.ffi.register_ffi_target(name, ffi_capsule_host, platform="Host")
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P0 Same self-referential NameError in register_ffi_callback: ffi_capsule_host is referenced before it is assigned. This causes every call to register_ffi_callback to raise NameError: name 'ffi_capsule_host' is not defined, so no Host target is ever registered. The fix mirrors the working CUDA block two lines above.

Suggested change
ffi_ccall_address_host = ctypes.cast(callback_func_host, ctypes.c_void_p)
ffi_capsule_host = jax.ffi.pycapsule(ffi_capsule_host.value)
jax.ffi.register_ffi_target(name, ffi_capsule_host, platform="Host")
ffi_ccall_address_host = ctypes.cast(callback_func_host, ctypes.c_void_p)
ffi_capsule_host = jax.ffi.pycapsule(ffi_ccall_address_host.value)
jax.ffi.register_ffi_target(name, ffi_capsule_host, platform="Host")

Copy link
Copy Markdown

@coderabbitai coderabbitai Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 3

🤖 Prompt for all review comments with AI agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

Inline comments:
In `@warp/_src/jax_experimental/ffi.py`:
- Around line 1772-1777: In register_ffi_callback, the Host ffi capsule is
constructed from the wrong variable (ffi_capsule_host) causing a NameError;
change the construction to use the host ccall address value by calling
jax.ffi.pycapsule(ffi_ccall_address_host.value) and then register that capsule
with jax.ffi.register_ffi_target(name, ffi_capsule_host, platform="Host") so the
Host path mirrors the CUDA path (refer to ffi_ccall_address_host,
ffi_capsule_host, register_ffi_target).
- Around line 629-632: The code assigns ffi_ccall_address_host then creates
ffi_capsule_host but mistakenly uses ffi_capsule_host.value (self-referential
NameError); change the capsule creation to use the previously computed address
value (ffi_ccall_address_host.value) so the lines around callback_func_host,
ffi_ccall_address_host, ffi_capsule_host and the
jax.ffi.register_ffi_target(self.name, ffi_capsule_host, platform="Host") call
use ffi_ccall_address_host.value when constructing the pycapsule for the Host
callback that wraps FFI_CCALLFUNC and calls self.ffi_callback.
- Around line 226-229: The Host FFI registration references an undefined
variable: replace the erroneous creation of ffi_capsule_host (currently using
ffi_capsule_host.value) with a capsule built from the c_void_p address you just
made; specifically, after creating callback_func_host and
ffi_ccall_address_host, set ffi_capsule_host =
jax.ffi.pycapsule(ffi_ccall_address_host.value) and then call
jax.ffi.register_ffi_target(self.name, ffi_capsule_host, platform="Host") so the
capsule is created from the correct address (symbols: callback_func_host,
FFI_CCALLFUNC, ffi_ccall_address_host, ffi_capsule_host,
jax.ffi.register_ffi_target, self.name, self.ffi_callback).
🪄 Autofix (Beta)

Fix all unresolved CodeRabbit comments on this PR:

  • Push a commit to this branch (recommended)
  • Create a new PR with the fixes

ℹ️ Review info
⚙️ Run configuration

Configuration used: Path: .coderabbit.yml

Review profile: CHILL

Plan: Enterprise

Run ID: d2c93d0f-cbef-4891-b361-773ad8140f6c

📥 Commits

Reviewing files that changed from the base of the PR and between 54327e3 and 09569ba.

📒 Files selected for processing (1)
  • warp/_src/jax_experimental/ffi.py

Comment thread warp/_src/jax_experimental/ffi.py
Comment thread warp/_src/jax_experimental/ffi.py
Comment thread warp/_src/jax_experimental/ffi.py
Ankit Jain added 4 commits May 9, 2026 00:50
Signed-off-by: Ankit Jain <kitsrish@google.com>
Applied the patch from the original CL as requested.

Signed-off-by: Ankit Jain <kitsrish@google.com>
Fix a typo where ffi_capsule_host used its own unassigned value instead of ffi_ccall_address_host.value.

Signed-off-by: Ankit Jain <kitsrish@google.com>
Restore files to their state in main branch to keep this PR focused on JAX FFI changes.

Signed-off-by: Ankit Jain <kitsrish@google.com>
Comment on lines +464 to +469
wp._src.context.runtime.core.wp_cpu_launch_kernel(
device.context,
hooks.forward,
launch_bounds.size,
kernel_params,
)
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P0 Wrong arguments passed to wp_cpu_launch_kernel

The call is missing one argument and the arguments are in the wrong positions. The registered ctypes signature (in context.py) is (func, bounds, args, adj_args, apic_info), but the call here passes device.context as func, hooks.forward as bounds, launch_bounds.size (an integer, not a pointer) as args, and kernel_params as adj_args, with apic_info omitted entirely. The correct call should place the casted hooks.forward function pointer as the first argument, a reference to launch_bounds as the second, and kernel_params as the third args pointer, with None for adj_args and apic_info. As written, this will pass garbage pointers to the native kernel launcher, causing a crash or silent memory corruption on every Host-platform FfiKernel call.

@loney7
Copy link
Copy Markdown
Author

loney7 commented May 8, 2026

pre-commit.ci autofix

pre-commit-ci Bot and others added 3 commits May 8, 2026 23:56
Add four tests that exercise the Host platform path for FFI:
- test_ffi_jax_kernel_add_host: basic two-input one-output kernel
- test_ffi_jax_kernel_sincos_host: one-input two-output kernel
- test_ffi_jax_kernel_in_out_host: in-out argument handling
- test_ffi_jax_callable_scale_constant_host: jax_callable with scalar constant

Signed-off-by: Ankit Jain <kitsrish@google.com>
Add test coverage for jax_kernel and jax_callable running on the CPU
Host platform. Tests cover basic add, sincos (multi-output), in-out
args, scalar constant args, and callable variants.

Signed-off-by: Ankit Jain <kitsrish@google.com>
Copy link
Copy Markdown

@coderabbitai coderabbitai Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 1

🤖 Prompt for all review comments with AI agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

Inline comments:
In `@warp/tests/interop/test_jax.py`:
- Around line 2422-2450: The host-only test registrations for TestJax (the
add_function_test calls registering test_ffi_jax_kernel_host_add,
test_ffi_jax_kernel_host_sincos, test_ffi_jax_kernel_host_in_out,
test_ffi_jax_kernel_host_scale_vec_constant,
test_ffi_jax_callable_host_scale_constant, and
test_ffi_jax_callable_host_in_out) are currently inside the
jax_compatible_cuda_devices conditional; move these specific add_function_test
calls out of that CUDA-only if block so they are always registered on CPU-only
setups, keeping the existing device=None argument and leaving CUDA/GPU-specific
registrations inside the original conditional.
🪄 Autofix (Beta)

Fix all unresolved CodeRabbit comments on this PR:

  • Push a commit to this branch (recommended)
  • Create a new PR with the fixes

ℹ️ Review info
⚙️ Run configuration

Configuration used: Path: .coderabbit.yml

Review profile: CHILL

Plan: Enterprise

Run ID: fcfb2e10-e509-41b9-b675-c8960056b6be

📥 Commits

Reviewing files that changed from the base of the PR and between d39169c and f6048f4.

📒 Files selected for processing (1)
  • warp/tests/interop/test_jax.py

Comment thread warp/tests/interop/test_jax.py Outdated
@loney7
Copy link
Copy Markdown
Author

loney7 commented May 9, 2026

pre-commit.ci autofix

@shi-eric shi-eric requested a review from nvlukasz May 12, 2026 16:45
Comment thread warp/tests/interop/test_jax.py Outdated
add_function_test(TestJax, "test_ffi_callback", test_ffi_callback, devices=jax_compatible_cuda_devices)

# ffi Host (CPU) tests
add_function_test(TestJax, "test_ffi_jax_kernel_host_add", test_ffi_jax_kernel_host_add, devices=None)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These Host tests only use jax.devices("cpu"), but they are registered inside the CUDA JAX availability gate. Please move them under CPU JAX availability so CPU-only CI covers the Host backend.

None, # apic_info
)
else:
wp._src.context.runtime.core.wp_cpu_launch_kernel(
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This Host path is using the CUDA launch ABI for wp_cpu_launch_kernel. The CPU binding expects (func, bounds, args, adj_args, apic_info), with kernel args packed into the CPU args struct. As written, Host jax_kernel raises TypeError: this function takes at least 5 arguments (4 given) instead of launching.

@shi-eric
Copy link
Copy Markdown
Contributor

Please add a CHANGELOG.md entry under Unreleased for this JAX Host FFI behavior change.

@shi-eric
Copy link
Copy Markdown
Contributor

Please squash this PR down to a single coherent commit before merge.

@shi-eric
Copy link
Copy Markdown
Contributor

Please rebase onto current main and move the fix/tests to the promoted JAX code paths. Commit 604a8961df6d40ea64ff1e740b23581e4c72c96f promoted the JAX code from jax_experimental to jax after this PR was opened, so the final diff should target the current locations.

Copy link
Copy Markdown
Contributor

@nvlukasz nvlukasz left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you for the contribution. Please address outstanding comments or let us know if you are unable to do so.

assert num_inputs == self.num_inputs
assert num_outputs == self.num_outputs

if platform == "Host":
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we deduplicate this argument reconstruction code? Perhaps extract to a helper function.

Comment thread warp/_src/jax_experimental/ffi.py Outdated
# call the Python function with reconstructed arguments
with wp.ScopedStream(stream, sync_enter=False):
if stream.is_capturing:
with wp.ScopedStream(stream, sync_enter=False) if stream else wp.ScopedDevice(device):
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No need to check stream here, since the host path returns early above. If we get here, the stream is not None.

Comment thread warp/_src/jax_experimental/ffi.py Outdated
call_desc.capture = capture

elif self.graph_mode == GraphMode.WARP:
elif self.graph_mode == GraphMode.WARP and device.is_cuda:
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No need to check device.is_cuda here, since the host part returns early above. Same comment on a few more lines below.

@loney7
Copy link
Copy Markdown
Author

loney7 commented May 19, 2026

Thank you for the contribution. Please address outstanding comments or let us know if you are unable to do so.

@nvlukasz , please allow till the end of this week to address the outstanding comments. I apologise for the delay.

- Fix wp_cpu_launch_kernel call to use correct 5-arg signature
  (func, bounds, args, adj_args, apic_info) and build a proper CPU
  args struct instead of passing GPU-style kernel_params array
- Extract _reconstruct_args helper in FfiCallable to deduplicate
  argument reconstruction between Host and CUDA callback paths
- Remove unnecessary stream/device.is_cuda guards in CUDA path since
  the Host path returns early, making stream always non-None
- Move Host CPU test registrations outside jax_compatible_cuda_devices
  gate so they run on CPU-only CI setups

Signed-off-by: Ankit Jain <kitsrish@google.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants