Skip to content
143 changes: 101 additions & 42 deletions warp/_src/jax_experimental/ffi.py
Original file line number Diff line number Diff line change
Expand Up @@ -215,10 +215,16 @@ def __init__(

# register the callback
FFI_CCALLFUNC = ctypes.CFUNCTYPE(ctypes.c_void_p, ctypes.POINTER(XLA_FFI_CallFrame))
self.callback_func = FFI_CCALLFUNC(self.ffi_callback)
ffi_ccall_address = ctypes.cast(self.callback_func, ctypes.c_void_p)
ffi_capsule = jax.ffi.pycapsule(ffi_ccall_address.value)
jax.ffi.register_ffi_target(self.name, ffi_capsule, platform="CUDA")

self.callback_func_cuda = FFI_CCALLFUNC(lambda call_frame: self.ffi_callback(call_frame, platform="CUDA"))
ffi_ccall_address_cuda = ctypes.cast(self.callback_func_cuda, ctypes.c_void_p)
ffi_capsule_cuda = jax.ffi.pycapsule(ffi_ccall_address_cuda.value)
jax.ffi.register_ffi_target(self.name, ffi_capsule_cuda, platform="CUDA")

self.callback_func_host = FFI_CCALLFUNC(lambda call_frame: self.ffi_callback(call_frame, platform="Host"))
ffi_ccall_address_host = ctypes.cast(self.callback_func_host, ctypes.c_void_p)
ffi_capsule_host = jax.ffi.pycapsule(ffi_ccall_address_host.value)
jax.ffi.register_ffi_target(self.name, ffi_capsule_host, platform="Host")
Comment on lines +224 to +227
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P0 Self-referential NameError in FfiKernel.__init__: ffi_capsule_host is passed as its own argument to pycapsule before the name is ever assigned. This will raise NameError: name 'ffi_capsule_host' is not defined every time an FfiKernel is instantiated, making the entire Host platform registration dead code. The address variable ffi_ccall_address_host should be used here instead — consistent with how the CUDA path is written above.

Suggested change
self.callback_func_host = FFI_CCALLFUNC(lambda call_frame: self.ffi_callback(call_frame, platform="Host"))
ffi_ccall_address_host = ctypes.cast(self.callback_func_host, ctypes.c_void_p)
ffi_capsule_host = jax.ffi.pycapsule(ffi_capsule_host.value)
jax.ffi.register_ffi_target(self.name, ffi_capsule_host, platform="Host")
self.callback_func_host = FFI_CCALLFUNC(lambda call_frame: self.ffi_callback(call_frame, platform="Host"))
ffi_ccall_address_host = ctypes.cast(self.callback_func_host, ctypes.c_void_p)
ffi_capsule_host = jax.ffi.pycapsule(ffi_ccall_address_host.value)
jax.ffi.register_ffi_target(self.name, ffi_capsule_host, platform="Host")

Comment thread
coderabbitai[bot] marked this conversation as resolved.

def __call__(self, *args, output_dims=None, launch_dims=None, vmap_method=None):
num_inputs = len(args)
Expand Down Expand Up @@ -332,7 +338,7 @@ def __call__(self, *args, output_dims=None, launch_dims=None, vmap_method=None):

return call(*args, launch_id=launch_id)

def ffi_callback(self, call_frame):
def ffi_callback(self, call_frame, platform="CUDA"):
try:
# On the first call, XLA runtime will query the API version and traits
# metadata using the |extension| field. Let us respond to that query
Expand All @@ -344,10 +350,11 @@ def ffi_callback(self, call_frame):
metadata_ext = ctypes.cast(extension, ctypes.POINTER(XLA_FFI_Metadata_Extension))
metadata_ext.contents.metadata.contents.api_version.major_version = 0
metadata_ext.contents.metadata.contents.api_version.minor_version = 1
# Turn on CUDA graphs for this handler.
metadata_ext.contents.metadata.contents.traits = (
XLA_FFI_Handler_TraitsBits.COMMAND_BUFFER_COMPATIBLE
)
# Turn on CUDA graphs for this handler if on CUDA platform.
if platform == "CUDA":
metadata_ext.contents.metadata.contents.traits = (
XLA_FFI_Handler_TraitsBits.COMMAND_BUFFER_COMPATIBLE
)
return None

# Lock is required to prevent race conditions when callback is invoked
Expand Down Expand Up @@ -427,25 +434,37 @@ def ffi_callback(self, call_frame):
kernel_params[0] = ctypes.addressof(launch_bounds)

# get device and stream
device = wp.get_cuda_device(get_device_ordinal_from_callframe(call_frame.contents))
stream = get_stream_from_callframe(call_frame.contents)
if platform == "CUDA":
device = wp.get_cuda_device(get_device_ordinal_from_callframe(call_frame.contents))
stream = get_stream_from_callframe(call_frame.contents)
else:
device = wp.get_device("cpu")
stream = None

# get kernel hooks
hooks = self.kernel.module.get_kernel_hooks(self.kernel, device)
assert hooks.forward, "Failed to find kernel entry point"

# launch the kernel
wp._src.context.runtime.core.wp_cuda_launch_kernel(
device.context,
hooks.forward,
launch_bounds.size,
0,
256,
hooks.forward_smem_bytes,
kernel_params,
stream,
None, # apic_info
)
if device.is_cuda:
wp._src.context.runtime.core.wp_cuda_launch_kernel(
device.context,
hooks.forward,
launch_bounds.size,
0,
256,
hooks.forward_smem_bytes,
kernel_params,
stream,
None, # apic_info
)
else:
wp._src.context.runtime.core.wp_cpu_launch_kernel(
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This Host path is using the CUDA launch ABI for wp_cpu_launch_kernel. The CPU binding expects (func, bounds, args, adj_args, apic_info), with kernel args packed into the CPU args struct. As written, Host jax_kernel raises TypeError: this function takes at least 5 arguments (4 given) instead of launching.

device.context,
hooks.forward,
launch_bounds.size,
kernel_params,
)
Comment on lines +474 to +480
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P0 Wrong arguments passed to wp_cpu_launch_kernel

The call is missing one argument and the arguments are in the wrong positions. The registered ctypes signature (in context.py) is (func, bounds, args, adj_args, apic_info), but the call here passes device.context as func, hooks.forward as bounds, launch_bounds.size (an integer, not a pointer) as args, and kernel_params as adj_args, with apic_info omitted entirely. The correct call should place the casted hooks.forward function pointer as the first argument, a reference to launch_bounds as the second, and kernel_params as the third args pointer, with None for adj_args and apic_info. As written, this will pass garbage pointers to the native kernel launcher, causing a crash or silent memory corruption on every Host-platform FfiKernel call.


except Exception as e:
print(traceback.format_exc())
Expand Down Expand Up @@ -599,10 +618,16 @@ def __init__(

# register the callback
FFI_CCALLFUNC = ctypes.CFUNCTYPE(ctypes.c_void_p, ctypes.POINTER(XLA_FFI_CallFrame))
self.callback_func = FFI_CCALLFUNC(self.ffi_callback)
ffi_ccall_address = ctypes.cast(self.callback_func, ctypes.c_void_p)
ffi_capsule = jax.ffi.pycapsule(ffi_ccall_address.value)
jax.ffi.register_ffi_target(self.name, ffi_capsule, platform="CUDA")

self.callback_func_cuda = FFI_CCALLFUNC(lambda call_frame: self.ffi_callback(call_frame, platform="CUDA"))
ffi_ccall_address_cuda = ctypes.cast(self.callback_func_cuda, ctypes.c_void_p)
ffi_capsule_cuda = jax.ffi.pycapsule(ffi_ccall_address_cuda.value)
jax.ffi.register_ffi_target(self.name, ffi_capsule_cuda, platform="CUDA")

self.callback_func_host = FFI_CCALLFUNC(lambda call_frame: self.ffi_callback(call_frame, platform="Host"))
ffi_ccall_address_host = ctypes.cast(self.callback_func_host, ctypes.c_void_p)
ffi_capsule_host = jax.ffi.pycapsule(ffi_ccall_address_host.value)
jax.ffi.register_ffi_target(self.name, ffi_capsule_host, platform="Host")
Comment thread
coderabbitai[bot] marked this conversation as resolved.

def __call__(self, *args, output_dims=None, vmap_method=None):
num_inputs = len(args)
Expand Down Expand Up @@ -703,7 +728,7 @@ def __call__(self, *args, output_dims=None, vmap_method=None):
self.call_id += 1
return call(*args, call_id=call_id)

def ffi_callback(self, call_frame):
def ffi_callback(self, call_frame, platform="CUDA"):
try:
# On the first call, XLA runtime will query the API version and traits
# metadata using the |extension| field. Let us respond to that query
Expand All @@ -715,8 +740,8 @@ def ffi_callback(self, call_frame):
metadata_ext = ctypes.cast(extension, ctypes.POINTER(XLA_FFI_Metadata_Extension))
metadata_ext.contents.metadata.contents.api_version.major_version = 0
metadata_ext.contents.metadata.contents.api_version.minor_version = 1
# Turn on CUDA graphs for this handler.
if self.graph_mode is GraphMode.JAX:
# Turn on CUDA graphs for this handler if on CUDA platform.
if self.graph_mode is GraphMode.JAX and platform == "CUDA":
metadata_ext.contents.metadata.contents.traits = (
XLA_FFI_Handler_TraitsBits.COMMAND_BUFFER_COMPATIBLE
)
Expand All @@ -743,6 +768,35 @@ def ffi_callback(self, call_frame):
assert num_inputs == self.num_inputs
assert num_outputs == self.num_outputs

if platform == "Host":
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we deduplicate this argument reconstruction code? Perhaps extract to a helper function.

device = wp.get_device("cpu")
# reconstruct the argument list
arg_list = []

# input and in-out args
for i, arg in enumerate(self.input_args):
if arg.is_array:
buffer = inputs[i].contents
shape = collapse_batch_dims(buffer.dims[: buffer.rank - arg.dtype_ndim], arg.type.ndim)
arr = wp.array(ptr=buffer.data, dtype=arg.type.dtype, shape=shape, device=device)
arg_list.append(arr)
else:
# scalar argument, get stashed value
value = call_desc.static_inputs[arg.name]
arg_list.append(value)

# pure output args (skip in-out FFI buffers)
for i, arg in enumerate(self.output_args):
buffer = outputs[i + self.num_in_out].contents
shape = collapse_batch_dims(buffer.dims[: buffer.rank - arg.dtype_ndim], arg.type.ndim)
arr = wp.array(ptr=buffer.data, dtype=arg.type.dtype, shape=shape, device=device)
arg_list.append(arr)

# call the Python function with reconstructed arguments
with wp.ScopedDevice(device):
self.func(*arg_list)
return

cuda_stream = get_stream_from_callframe(call_frame.contents)
device_ordinal = get_device_ordinal_from_callframe(call_frame.contents)

Expand Down Expand Up @@ -875,16 +929,16 @@ def ffi_callback(self, call_frame):
arg_list.append(arr)

# call the Python function with reconstructed arguments
with wp.ScopedStream(stream, sync_enter=False):
if stream.is_capturing:
with wp.ScopedStream(stream, sync_enter=False) if stream else wp.ScopedDevice(device):
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No need to check stream here, since the host path returns early above. If we get here, the stream is not None.

if stream and stream.is_capturing:
# capturing with JAX
with wp.ScopedCapture(external=True) as capture:
self.func(*arg_list)

# keep a reference to the capture object to prevent required modules getting unloaded
call_desc.capture = capture

elif self.graph_mode == GraphMode.WARP:
elif self.graph_mode == GraphMode.WARP and device.is_cuda:
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No need to check device.is_cuda here, since the host part returns early above. Same comment on a few more lines below.

# capturing with WARP
with wp.ScopedCapture() as capture:
self.func(*arg_list)
Expand All @@ -897,7 +951,7 @@ def ffi_callback(self, call_frame):
if self._graph_cache_max is not None and len(self.captures) > self._graph_cache_max:
self.captures.popitem(last=False)

elif self.graph_mode == GraphMode.WARP_STAGED_EX:
elif self.graph_mode == GraphMode.WARP_STAGED_EX and device.is_cuda:
# capturing with WARP using staging buffers and memcopies done outside of the graph
wp_memcpy_batch = wp._src.context.runtime.core.wp_memcpy_batch

Expand Down Expand Up @@ -940,7 +994,7 @@ def ffi_callback(self, call_frame):
# TODO: we should have a way of freeing this
call_desc.capture = capture

elif self.graph_mode == GraphMode.WARP_STAGED:
elif self.graph_mode == GraphMode.WARP_STAGED and device.is_cuda:
# capturing with WARP using staging buffers and memcopies done inside of the graph
wp_cuda_graph_insert_memcpy_batch = (
wp._src.context.runtime.core.wp_cuda_graph_insert_memcpy_batch
Expand Down Expand Up @@ -1018,7 +1072,7 @@ def ffi_callback(self, call_frame):
call_desc.capture = capture

else:
# not capturing
# not capturing or on CPU
self.func(*arg_list)

except Exception as e:
Expand Down Expand Up @@ -1663,7 +1717,7 @@ def register_ffi_callback(name: str, func: Callable, graph_compatible: bool = Tr

# TODO check that the name is not already registered

def ffi_callback(call_frame):
def ffi_callback(call_frame, platform="CUDA"):
try:
extension = call_frame.contents.extension_start
# On the first call, XLA runtime will query the API version and traits
Expand All @@ -1675,7 +1729,7 @@ def ffi_callback(call_frame):
metadata_ext = ctypes.cast(extension, ctypes.POINTER(XLA_FFI_Metadata_Extension))
metadata_ext.contents.metadata.contents.api_version.major_version = 0
metadata_ext.contents.metadata.contents.api_version.minor_version = 1
if graph_compatible:
if graph_compatible and platform == "CUDA":
# Turn on CUDA graphs for this handler.
metadata_ext.contents.metadata.contents.traits = (
XLA_FFI_Handler_TraitsBits.COMMAND_BUFFER_COMPATIBLE
Expand Down Expand Up @@ -1708,12 +1762,17 @@ def ffi_callback(call_frame):
return None

FFI_CCALLFUNC = ctypes.CFUNCTYPE(ctypes.c_void_p, ctypes.POINTER(XLA_FFI_CallFrame))
callback_func = FFI_CCALLFUNC(ffi_callback)
callback_func_cuda = FFI_CCALLFUNC(lambda call_frame: ffi_callback(call_frame, platform="CUDA"))
callback_func_host = FFI_CCALLFUNC(lambda call_frame: ffi_callback(call_frame, platform="Host"))
with _FFI_REGISTRY_LOCK:
_FFI_CALLBACK_REGISTRY[name] = callback_func
ffi_ccall_address = ctypes.cast(callback_func, ctypes.c_void_p)
ffi_capsule = jax.ffi.pycapsule(ffi_ccall_address.value)
jax.ffi.register_ffi_target(name, ffi_capsule, platform="CUDA")
_FFI_CALLBACK_REGISTRY[f"{name}_cuda"] = callback_func_cuda
_FFI_CALLBACK_REGISTRY[f"{name}_host"] = callback_func_host
ffi_ccall_address_cuda = ctypes.cast(callback_func_cuda, ctypes.c_void_p)
ffi_capsule_cuda = jax.ffi.pycapsule(ffi_ccall_address_cuda.value)
jax.ffi.register_ffi_target(name, ffi_capsule_cuda, platform="CUDA")
ffi_ccall_address_host = ctypes.cast(callback_func_host, ctypes.c_void_p)
ffi_capsule_host = jax.ffi.pycapsule(ffi_ccall_address_host.value)
jax.ffi.register_ffi_target(name, ffi_capsule_host, platform="Host")
Comment on lines +1771 to +1773
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P0 Same self-referential NameError in register_ffi_callback: ffi_capsule_host is referenced before it is assigned. This causes every call to register_ffi_callback to raise NameError: name 'ffi_capsule_host' is not defined, so no Host target is ever registered. The fix mirrors the working CUDA block two lines above.

Suggested change
ffi_ccall_address_host = ctypes.cast(callback_func_host, ctypes.c_void_p)
ffi_capsule_host = jax.ffi.pycapsule(ffi_capsule_host.value)
jax.ffi.register_ffi_target(name, ffi_capsule_host, platform="Host")
ffi_ccall_address_host = ctypes.cast(callback_func_host, ctypes.c_void_p)
ffi_capsule_host = jax.ffi.pycapsule(ffi_ccall_address_host.value)
jax.ffi.register_ffi_target(name, ffi_capsule_host, platform="Host")

Comment thread
coderabbitai[bot] marked this conversation as resolved.


###############################################################################
Expand Down
Loading