From a9b447a7df38c81ff18cd8768133ddb98df506c1 Mon Sep 17 00:00:00 2001 From: Ankit Jain Date: Fri, 8 May 2026 23:33:02 +0100 Subject: [PATCH 01/10] Add JAX FFI Host support This allows JAX FFI callbacks to run on CPU (Host) in addition to CUDA. Signed-off-by: Ankit Jain --- warp/_src/jax_experimental/ffi.py | 147 +++++++++++++++++++++--------- 1 file changed, 104 insertions(+), 43 deletions(-) diff --git a/warp/_src/jax_experimental/ffi.py b/warp/_src/jax_experimental/ffi.py index 99c845feb0..d68c792dbd 100644 --- a/warp/_src/jax_experimental/ffi.py +++ b/warp/_src/jax_experimental/ffi.py @@ -214,11 +214,19 @@ def __init__( self.input_output_aliases = input_output_aliases # register the callback - FFI_CCALLFUNC = ctypes.CFUNCTYPE(ctypes.c_void_p, ctypes.POINTER(XLA_FFI_CallFrame)) - self.callback_func = FFI_CCALLFUNC(self.ffi_callback) - ffi_ccall_address = ctypes.cast(self.callback_func, ctypes.c_void_p) - ffi_capsule = jax.ffi.pycapsule(ffi_ccall_address.value) - jax.ffi.register_ffi_target(self.name, ffi_capsule, platform="CUDA") + FFI_CCALLFUNC = ctypes.CFUNCTYPE( + ctypes.c_void_p, ctypes.POINTER(XLA_FFI_CallFrame) + ) + + self.callback_func_cuda = FFI_CCALLFUNC(lambda call_frame: self.ffi_callback(call_frame, platform="CUDA")) + ffi_ccall_address_cuda = ctypes.cast(self.callback_func_cuda, ctypes.c_void_p) + ffi_capsule_cuda = jax.ffi.pycapsule(ffi_ccall_address_cuda.value) + jax.ffi.register_ffi_target(self.name, ffi_capsule_cuda, platform="CUDA") + + self.callback_func_host = FFI_CCALLFUNC(lambda call_frame: self.ffi_callback(call_frame, platform="Host")) + ffi_ccall_address_host = ctypes.cast(self.callback_func_host, ctypes.c_void_p) + ffi_capsule_host = jax.ffi.pycapsule(ffi_capsule_host.value) + jax.ffi.register_ffi_target(self.name, ffi_capsule_host, platform="Host") def __call__(self, *args, output_dims=None, launch_dims=None, vmap_method=None): num_inputs = len(args) @@ -332,7 +340,7 @@ def __call__(self, *args, output_dims=None, launch_dims=None, vmap_method=None): return call(*args, launch_id=launch_id) - def ffi_callback(self, call_frame): + def ffi_callback(self, call_frame, platform="CUDA"): try: # On the first call, XLA runtime will query the API version and traits # metadata using the |extension| field. Let us respond to that query @@ -344,10 +352,11 @@ def ffi_callback(self, call_frame): metadata_ext = ctypes.cast(extension, ctypes.POINTER(XLA_FFI_Metadata_Extension)) metadata_ext.contents.metadata.contents.api_version.major_version = 0 metadata_ext.contents.metadata.contents.api_version.minor_version = 1 - # Turn on CUDA graphs for this handler. - metadata_ext.contents.metadata.contents.traits = ( - XLA_FFI_Handler_TraitsBits.COMMAND_BUFFER_COMPATIBLE - ) + # Turn on CUDA graphs for this handler if on CUDA platform. + if platform == "CUDA": + metadata_ext.contents.metadata.contents.traits = ( + XLA_FFI_Handler_TraitsBits.COMMAND_BUFFER_COMPATIBLE + ) return None # Lock is required to prevent race conditions when callback is invoked @@ -427,25 +436,37 @@ def ffi_callback(self, call_frame): kernel_params[0] = ctypes.addressof(launch_bounds) # get device and stream - device = wp.get_cuda_device(get_device_ordinal_from_callframe(call_frame.contents)) - stream = get_stream_from_callframe(call_frame.contents) + if platform == "CUDA": + device = wp.get_cuda_device(get_device_ordinal_from_callframe(call_frame.contents)) + stream = get_stream_from_callframe(call_frame.contents) + else: + device = wp.get_device("cpu") + stream = None # get kernel hooks hooks = self.kernel.module.get_kernel_hooks(self.kernel, device) assert hooks.forward, "Failed to find kernel entry point" # launch the kernel - wp._src.context.runtime.core.wp_cuda_launch_kernel( - device.context, - hooks.forward, - launch_bounds.size, - 0, - 256, - hooks.forward_smem_bytes, - kernel_params, - stream, - None, # apic_info - ) + if device.is_cuda: + wp._src.context.runtime.core.wp_cuda_launch_kernel( + device.context, + hooks.forward, + launch_bounds.size, + 0, + 256, + hooks.forward_smem_bytes, + kernel_params, + stream, + None, # apic_info + ) + else: + wp._src.context.runtime.core.wp_cpu_launch_kernel( + device.context, + hooks.forward, + launch_bounds.size, + kernel_params, + ) except Exception as e: print(traceback.format_exc()) @@ -599,10 +620,16 @@ def __init__( # register the callback FFI_CCALLFUNC = ctypes.CFUNCTYPE(ctypes.c_void_p, ctypes.POINTER(XLA_FFI_CallFrame)) - self.callback_func = FFI_CCALLFUNC(self.ffi_callback) - ffi_ccall_address = ctypes.cast(self.callback_func, ctypes.c_void_p) - ffi_capsule = jax.ffi.pycapsule(ffi_ccall_address.value) - jax.ffi.register_ffi_target(self.name, ffi_capsule, platform="CUDA") + + self.callback_func_cuda = FFI_CCALLFUNC(lambda call_frame: self.ffi_callback(call_frame, platform="CUDA")) + ffi_ccall_address_cuda = ctypes.cast(self.callback_func_cuda, ctypes.c_void_p) + ffi_capsule_cuda = jax.ffi.pycapsule(ffi_ccall_address_cuda.value) + jax.ffi.register_ffi_target(self.name, ffi_capsule_cuda, platform="CUDA") + + self.callback_func_host = FFI_CCALLFUNC(lambda call_frame: self.ffi_callback(call_frame, platform="Host")) + ffi_ccall_address_host = ctypes.cast(self.callback_func_host, ctypes.c_void_p) + ffi_capsule_host = jax.ffi.pycapsule(ffi_ccall_address_host.value) + jax.ffi.register_ffi_target(self.name, ffi_capsule_host, platform="Host") def __call__(self, *args, output_dims=None, vmap_method=None): num_inputs = len(args) @@ -703,7 +730,7 @@ def __call__(self, *args, output_dims=None, vmap_method=None): self.call_id += 1 return call(*args, call_id=call_id) - def ffi_callback(self, call_frame): + def ffi_callback(self, call_frame, platform="CUDA"): try: # On the first call, XLA runtime will query the API version and traits # metadata using the |extension| field. Let us respond to that query @@ -715,8 +742,8 @@ def ffi_callback(self, call_frame): metadata_ext = ctypes.cast(extension, ctypes.POINTER(XLA_FFI_Metadata_Extension)) metadata_ext.contents.metadata.contents.api_version.major_version = 0 metadata_ext.contents.metadata.contents.api_version.minor_version = 1 - # Turn on CUDA graphs for this handler. - if self.graph_mode is GraphMode.JAX: + # Turn on CUDA graphs for this handler if on CUDA platform. + if self.graph_mode is GraphMode.JAX and platform == "CUDA": metadata_ext.contents.metadata.contents.traits = ( XLA_FFI_Handler_TraitsBits.COMMAND_BUFFER_COMPATIBLE ) @@ -743,6 +770,35 @@ def ffi_callback(self, call_frame): assert num_inputs == self.num_inputs assert num_outputs == self.num_outputs + if platform == "Host": + device = wp.get_device("cpu") + # reconstruct the argument list + arg_list = [] + + # input and in-out args + for i, arg in enumerate(self.input_args): + if arg.is_array: + buffer = inputs[i].contents + shape = collapse_batch_dims(buffer.dims[: buffer.rank - arg.dtype_ndim], arg.type.ndim) + arr = wp.array(ptr=buffer.data, dtype=arg.type.dtype, shape=shape, device=device) + arg_list.append(arr) + else: + # scalar argument, get stashed value + value = call_desc.static_inputs[arg.name] + arg_list.append(value) + + # pure output args (skip in-out FFI buffers) + for i, arg in enumerate(self.output_args): + buffer = outputs[i + self.num_in_out].contents + shape = collapse_batch_dims(buffer.dims[: buffer.rank - arg.dtype_ndim], arg.type.ndim) + arr = wp.array(ptr=buffer.data, dtype=arg.type.dtype, shape=shape, device=device) + arg_list.append(arr) + + # call the Python function with reconstructed arguments + with wp.ScopedDevice(device): + self.func(*arg_list) + return + cuda_stream = get_stream_from_callframe(call_frame.contents) device_ordinal = get_device_ordinal_from_callframe(call_frame.contents) @@ -875,8 +931,8 @@ def ffi_callback(self, call_frame): arg_list.append(arr) # call the Python function with reconstructed arguments - with wp.ScopedStream(stream, sync_enter=False): - if stream.is_capturing: + with wp.ScopedStream(stream, sync_enter=False) if stream else wp.ScopedDevice(device): + if stream and stream.is_capturing: # capturing with JAX with wp.ScopedCapture(external=True) as capture: self.func(*arg_list) @@ -884,7 +940,7 @@ def ffi_callback(self, call_frame): # keep a reference to the capture object to prevent required modules getting unloaded call_desc.capture = capture - elif self.graph_mode == GraphMode.WARP: + elif self.graph_mode == GraphMode.WARP and device.is_cuda: # capturing with WARP with wp.ScopedCapture() as capture: self.func(*arg_list) @@ -897,7 +953,7 @@ def ffi_callback(self, call_frame): if self._graph_cache_max is not None and len(self.captures) > self._graph_cache_max: self.captures.popitem(last=False) - elif self.graph_mode == GraphMode.WARP_STAGED_EX: + elif self.graph_mode == GraphMode.WARP_STAGED_EX and device.is_cuda: # capturing with WARP using staging buffers and memcopies done outside of the graph wp_memcpy_batch = wp._src.context.runtime.core.wp_memcpy_batch @@ -940,7 +996,7 @@ def ffi_callback(self, call_frame): # TODO: we should have a way of freeing this call_desc.capture = capture - elif self.graph_mode == GraphMode.WARP_STAGED: + elif self.graph_mode == GraphMode.WARP_STAGED and device.is_cuda: # capturing with WARP using staging buffers and memcopies done inside of the graph wp_cuda_graph_insert_memcpy_batch = ( wp._src.context.runtime.core.wp_cuda_graph_insert_memcpy_batch @@ -1018,7 +1074,7 @@ def ffi_callback(self, call_frame): call_desc.capture = capture else: - # not capturing + # not capturing or on CPU self.func(*arg_list) except Exception as e: @@ -1663,7 +1719,7 @@ def register_ffi_callback(name: str, func: Callable, graph_compatible: bool = Tr # TODO check that the name is not already registered - def ffi_callback(call_frame): + def ffi_callback(call_frame, platform="CUDA"): try: extension = call_frame.contents.extension_start # On the first call, XLA runtime will query the API version and traits @@ -1675,7 +1731,7 @@ def ffi_callback(call_frame): metadata_ext = ctypes.cast(extension, ctypes.POINTER(XLA_FFI_Metadata_Extension)) metadata_ext.contents.metadata.contents.api_version.major_version = 0 metadata_ext.contents.metadata.contents.api_version.minor_version = 1 - if graph_compatible: + if graph_compatible and platform == "CUDA": # Turn on CUDA graphs for this handler. metadata_ext.contents.metadata.contents.traits = ( XLA_FFI_Handler_TraitsBits.COMMAND_BUFFER_COMPATIBLE @@ -1708,12 +1764,17 @@ def ffi_callback(call_frame): return None FFI_CCALLFUNC = ctypes.CFUNCTYPE(ctypes.c_void_p, ctypes.POINTER(XLA_FFI_CallFrame)) - callback_func = FFI_CCALLFUNC(ffi_callback) + callback_func_cuda = FFI_CCALLFUNC(lambda call_frame: ffi_callback(call_frame, platform="CUDA")) + callback_func_host = FFI_CCALLFUNC(lambda call_frame: ffi_callback(call_frame, platform="Host")) with _FFI_REGISTRY_LOCK: - _FFI_CALLBACK_REGISTRY[name] = callback_func - ffi_ccall_address = ctypes.cast(callback_func, ctypes.c_void_p) - ffi_capsule = jax.ffi.pycapsule(ffi_ccall_address.value) - jax.ffi.register_ffi_target(name, ffi_capsule, platform="CUDA") + _FFI_CALLBACK_REGISTRY[f"{name}_cuda"] = callback_func_cuda + _FFI_CALLBACK_REGISTRY[f"{name}_host"] = callback_func_host + ffi_ccall_address_cuda = ctypes.cast(callback_func_cuda, ctypes.c_void_p) + ffi_capsule_cuda = jax.ffi.pycapsule(ffi_ccall_address_cuda.value) + jax.ffi.register_ffi_target(name, ffi_capsule_cuda, platform="CUDA") + ffi_ccall_address_host = ctypes.cast(callback_func_host, ctypes.c_void_p) + ffi_capsule_host = jax.ffi.pycapsule(ffi_capsule_host.value) + jax.ffi.register_ffi_target(name, ffi_capsule_host, platform="Host") ############################################################################### From dea9a19d7e0cfc5be954821a91b1722a2dbf3aa6 Mon Sep 17 00:00:00 2001 From: Ankit Jain Date: Sat, 9 May 2026 00:02:43 +0100 Subject: [PATCH 02/10] Update CHANGELOG for JAX FFI Host support Signed-off-by: Ankit Jain --- CHANGELOG.md | 1 + 1 file changed, 1 insertion(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index 3a9b9b37c1..97001a5249 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -27,6 +27,7 @@ `enable_mathdx_solver` config flag and module option (parity with `enable_mathdx_gemm`) to route these ops through the fallback when libmathdx is available ([GH-1402](https://github.com/NVIDIA/warp/issues/1402)). +- Add JAX FFI Host support, allowing JAX FFI callbacks to run on CPU (Host) in addition to CUDA. ### Removed From 64289139663e1c9ecaed5282043e1d5a2a9bd285 Mon Sep 17 00:00:00 2001 From: Ankit Jain Date: Sat, 9 May 2026 00:34:08 +0100 Subject: [PATCH 03/10] Apply LLVM 21 patch for clang.cpp Applied the patch from the original CL as requested. Signed-off-by: Ankit Jain --- warp/native/clang/clang.cpp | 13 ++++++------- 1 file changed, 6 insertions(+), 7 deletions(-) diff --git a/warp/native/clang/clang.cpp b/warp/native/clang/clang.cpp index aa7c92d724..7a84ab33c9 100644 --- a/warp/native/clang/clang.cpp +++ b/warp/native/clang/clang.cpp @@ -220,13 +220,12 @@ static std::unique_ptr create_compiler( } #if LLVM_VERSION_MAJOR >= 21 - clang::DiagnosticOptions diagnostic_options; - std::unique_ptr text_diagnostic_printer - = std::make_unique(llvm::errs(), diagnostic_options); - clang::IntrusiveRefCntPtr diagnostic_ids; - std::unique_ptr diagnostic_engine = std::make_unique( - diagnostic_ids, diagnostic_options, text_diagnostic_printer.release() + clang::DiagnosticOptions temp_diag_opts; + auto temp_diag_engine = clang::CompilerInstance::createDiagnostics( + *llvm::vfs::getRealFileSystem(), temp_diag_opts, + new clang::IgnoringDiagConsumer() ); + compiler_instance.setDiagnostics(real_diag_engine); #else clang::IntrusiveRefCntPtr diagnostic_options = new clang::DiagnosticOptions(); std::unique_ptr text_diagnostic_printer @@ -676,7 +675,7 @@ static llvm::orc::LLJIT* get_or_create_jit(bool use_legacy_linker) } else { builder.setObjectLinkingLayerCreator( #if LLVM_VERSION_MAJOR >= 21 - [](llvm::orc::ExecutionSession& session) + [&](llvm::orc::ExecutionSession& session, llvm::jitlink::JITLinkMemoryManager&) #else [](llvm::orc::ExecutionSession& session, const llvm::Triple& triple) #endif From e294d47730856bc56c0478aa95d4cd1f712647cb Mon Sep 17 00:00:00 2001 From: Ankit Jain Date: Sat, 9 May 2026 00:47:36 +0100 Subject: [PATCH 04/10] Fix JAX FFI host registration typo Fix a typo where ffi_capsule_host used its own unassigned value instead of ffi_ccall_address_host.value. Signed-off-by: Ankit Jain --- warp/_src/jax_experimental/ffi.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/warp/_src/jax_experimental/ffi.py b/warp/_src/jax_experimental/ffi.py index d68c792dbd..e0dd621e66 100644 --- a/warp/_src/jax_experimental/ffi.py +++ b/warp/_src/jax_experimental/ffi.py @@ -225,7 +225,7 @@ def __init__( self.callback_func_host = FFI_CCALLFUNC(lambda call_frame: self.ffi_callback(call_frame, platform="Host")) ffi_ccall_address_host = ctypes.cast(self.callback_func_host, ctypes.c_void_p) - ffi_capsule_host = jax.ffi.pycapsule(ffi_capsule_host.value) + ffi_capsule_host = jax.ffi.pycapsule(ffi_ccall_address_host.value) jax.ffi.register_ffi_target(self.name, ffi_capsule_host, platform="Host") def __call__(self, *args, output_dims=None, launch_dims=None, vmap_method=None): @@ -1773,7 +1773,7 @@ def ffi_callback(call_frame, platform="CUDA"): ffi_capsule_cuda = jax.ffi.pycapsule(ffi_ccall_address_cuda.value) jax.ffi.register_ffi_target(name, ffi_capsule_cuda, platform="CUDA") ffi_ccall_address_host = ctypes.cast(callback_func_host, ctypes.c_void_p) - ffi_capsule_host = jax.ffi.pycapsule(ffi_capsule_host.value) + ffi_capsule_host = jax.ffi.pycapsule(ffi_ccall_address_host.value) jax.ffi.register_ffi_target(name, ffi_capsule_host, platform="Host") From 12e1cf779fa7437783b4f1bb31c134f31a67dcde Mon Sep 17 00:00:00 2001 From: Ankit Jain Date: Sat, 9 May 2026 00:51:25 +0100 Subject: [PATCH 05/10] Revert changes to clang.cpp and CHANGELOG.md Restore files to their state in main branch to keep this PR focused on JAX FFI changes. Signed-off-by: Ankit Jain --- CHANGELOG.md | 1 - warp/native/clang/clang.cpp | 13 +++++++------ 2 files changed, 7 insertions(+), 7 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 97001a5249..3a9b9b37c1 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -27,7 +27,6 @@ `enable_mathdx_solver` config flag and module option (parity with `enable_mathdx_gemm`) to route these ops through the fallback when libmathdx is available ([GH-1402](https://github.com/NVIDIA/warp/issues/1402)). -- Add JAX FFI Host support, allowing JAX FFI callbacks to run on CPU (Host) in addition to CUDA. ### Removed diff --git a/warp/native/clang/clang.cpp b/warp/native/clang/clang.cpp index 7a84ab33c9..aa7c92d724 100644 --- a/warp/native/clang/clang.cpp +++ b/warp/native/clang/clang.cpp @@ -220,12 +220,13 @@ static std::unique_ptr create_compiler( } #if LLVM_VERSION_MAJOR >= 21 - clang::DiagnosticOptions temp_diag_opts; - auto temp_diag_engine = clang::CompilerInstance::createDiagnostics( - *llvm::vfs::getRealFileSystem(), temp_diag_opts, - new clang::IgnoringDiagConsumer() + clang::DiagnosticOptions diagnostic_options; + std::unique_ptr text_diagnostic_printer + = std::make_unique(llvm::errs(), diagnostic_options); + clang::IntrusiveRefCntPtr diagnostic_ids; + std::unique_ptr diagnostic_engine = std::make_unique( + diagnostic_ids, diagnostic_options, text_diagnostic_printer.release() ); - compiler_instance.setDiagnostics(real_diag_engine); #else clang::IntrusiveRefCntPtr diagnostic_options = new clang::DiagnosticOptions(); std::unique_ptr text_diagnostic_printer @@ -675,7 +676,7 @@ static llvm::orc::LLJIT* get_or_create_jit(bool use_legacy_linker) } else { builder.setObjectLinkingLayerCreator( #if LLVM_VERSION_MAJOR >= 21 - [&](llvm::orc::ExecutionSession& session, llvm::jitlink::JITLinkMemoryManager&) + [](llvm::orc::ExecutionSession& session) #else [](llvm::orc::ExecutionSession& session, const llvm::Triple& triple) #endif From d39169c92ad5bb53b54556d74ef6f62be2effed0 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Fri, 8 May 2026 23:56:51 +0000 Subject: [PATCH 06/10] [pre-commit.ci] auto code formatting --- warp/_src/jax_experimental/ffi.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/warp/_src/jax_experimental/ffi.py b/warp/_src/jax_experimental/ffi.py index e0dd621e66..ae12fd600c 100644 --- a/warp/_src/jax_experimental/ffi.py +++ b/warp/_src/jax_experimental/ffi.py @@ -214,9 +214,7 @@ def __init__( self.input_output_aliases = input_output_aliases # register the callback - FFI_CCALLFUNC = ctypes.CFUNCTYPE( - ctypes.c_void_p, ctypes.POINTER(XLA_FFI_CallFrame) - ) + FFI_CCALLFUNC = ctypes.CFUNCTYPE(ctypes.c_void_p, ctypes.POINTER(XLA_FFI_CallFrame)) self.callback_func_cuda = FFI_CCALLFUNC(lambda call_frame: self.ffi_callback(call_frame, platform="CUDA")) ffi_ccall_address_cuda = ctypes.cast(self.callback_func_cuda, ctypes.c_void_p) From 9e18ede97b35883773ca758fe7aaccd32bd19531 Mon Sep 17 00:00:00 2001 From: Ankit Jain Date: Sat, 9 May 2026 01:00:58 +0100 Subject: [PATCH 07/10] Add tests for JAX FFI Host (CPU) support Add four tests that exercise the Host platform path for FFI: - test_ffi_jax_kernel_add_host: basic two-input one-output kernel - test_ffi_jax_kernel_sincos_host: one-input two-output kernel - test_ffi_jax_kernel_in_out_host: in-out argument handling - test_ffi_jax_callable_scale_constant_host: jax_callable with scalar constant Signed-off-by: Ankit Jain --- warp/tests/interop/test_jax.py | 149 +++++++++++++++++++++++++++++++++ 1 file changed, 149 insertions(+) diff --git a/warp/tests/interop/test_jax.py b/warp/tests/interop/test_jax.py index 12d9757484..f08aca1620 100644 --- a/warp/tests/interop/test_jax.py +++ b/warp/tests/interop/test_jax.py @@ -1122,6 +1122,127 @@ def warp_func(inputs, outputs, attrs, ctx): assert_np_equal(d, 2 * np.arange(ARRAY_SIZE, dtype=np.float32).reshape((ARRAY_SIZE // 2, 2))) +# ========================================================================================================= +# JAX FFI Host (CPU) tests +# ========================================================================================================= + + +@unittest.skipUnless(_jax_version() >= (0, 5, 0), "Jax version too old") +def test_ffi_jax_kernel_add_host(test, device): + # two inputs and one output on CPU + import jax.numpy as jp + + from warp.jax_experimental.ffi import jax_kernel + + jax_add = jax_kernel(add_kernel) + + @jax.jit + def f(): + n = ARRAY_SIZE + a = jp.arange(n, dtype=jp.float32) + b = jp.ones(n, dtype=jp.float32) + return jax_add(a, b) + + with jax.default_device(wp.device_to_jax(device)): + (y,) = f() + + jax.block_until_ready(y) + + result = np.asarray(y) + expected = np.arange(1, ARRAY_SIZE + 1, dtype=np.float32) + + assert_np_equal(result, expected) + + +@unittest.skipUnless(_jax_version() >= (0, 5, 0), "Jax version too old") +def test_ffi_jax_kernel_sincos_host(test, device): + # one input and two outputs on CPU + import jax.numpy as jp + + from warp.jax_experimental.ffi import jax_kernel + + jax_sincos = jax_kernel(sincos_kernel, num_outputs=2) + + n = ARRAY_SIZE + + @jax.jit + def f(): + a = jp.linspace(0, 2 * jp.pi, n, dtype=jp.float32) + return jax_sincos(a) + + with jax.default_device(wp.device_to_jax(device)): + s, c = f() + + jax.block_until_ready([s, c]) + + result_s = np.asarray(s) + result_c = np.asarray(c) + + a = np.linspace(0, 2 * np.pi, n, dtype=np.float32) + expected_s = np.sin(a) + expected_c = np.cos(a) + + assert_np_equal(result_s, expected_s, tol=1e-4) + assert_np_equal(result_c, expected_c, tol=1e-4) + + +@unittest.skipUnless(_jax_version() >= (0, 5, 0), "Jax version too old") +def test_ffi_jax_kernel_in_out_host(test, device): + # in-out args on CPU + import jax.numpy as jp + + from warp.jax_experimental.ffi import jax_kernel + + jax_func = jax_kernel(in_out_kernel, num_outputs=2, in_out_argnames=["b"]) + + f = jax.jit(jax_func) + + with jax.default_device(wp.device_to_jax(device)): + a = jp.ones(ARRAY_SIZE, dtype=jp.float32) + b = jp.arange(ARRAY_SIZE, dtype=jp.float32) + b, c = f(a, b) + + jax.block_until_ready([b, c]) + + assert_np_equal(b, np.arange(1, ARRAY_SIZE + 1, dtype=np.float32)) + assert_np_equal(c, np.full(ARRAY_SIZE, 2, dtype=np.float32)) + + +@unittest.skipUnless(_jax_version() >= (0, 5, 0), "Jax version too old") +def test_ffi_jax_callable_scale_constant_host(test, device): + # scale two arrays using a constant on CPU + import jax.numpy as jp + + from warp.jax_experimental.ffi import jax_callable + + jax_func = jax_callable(scale_func, num_outputs=2) + + @jax.jit + def f(): + # inputs + a = jp.arange(ARRAY_SIZE, dtype=jp.float32) + b = jp.arange(ARRAY_SIZE, dtype=jp.float32).reshape((ARRAY_SIZE // 2, 2)) # wp.vec2 + s = 2.0 + + # output shapes + output_dims = {"c": a.shape, "d": b.shape} + + c, d = jax_func(a, b, s, output_dims=output_dims) + + return c, d + + with jax.default_device(wp.device_to_jax(device)): + result1, result2 = f() + + jax.block_until_ready([result1, result2]) + + expected1 = 2 * np.arange(ARRAY_SIZE, dtype=np.float32) + expected2 = 2 * np.arange(ARRAY_SIZE, dtype=np.float32).reshape((ARRAY_SIZE // 2, 2)) + + assert_np_equal(result1, expected1) + assert_np_equal(result2, expected2) + + @unittest.skipUnless(_jax_version() >= (0, 5, 0), "Jax version too old") def test_ffi_jax_kernel_autodiff_simple(test, device): if device.ordinal > 0: @@ -2082,6 +2203,7 @@ class TestJax(unittest.TestCase): test_devices = get_test_devices() jax_compatible_devices = [] jax_compatible_cuda_devices = [] + jax_compatible_cpu_devices = [] for d in test_devices: try: with jax.default_device(wp.device_to_jax(d)): @@ -2090,6 +2212,8 @@ class TestJax(unittest.TestCase): jax_compatible_devices.append(d) if d.is_cuda: jax_compatible_cuda_devices.append(d) + if d.is_cpu: + jax_compatible_cpu_devices.append(d) except Exception as e: print(f"Skipping Jax DLPack tests on device '{d}' due to exception: {e}") @@ -2256,6 +2380,31 @@ class TestJax(unittest.TestCase): # ffi callback tests add_function_test(TestJax, "test_ffi_callback", test_ffi_callback, devices=jax_compatible_cuda_devices) + # ffi Host (CPU) tests + if jax_compatible_cpu_devices: + add_function_test( + TestJax, "test_ffi_jax_kernel_add_host", test_ffi_jax_kernel_add_host, devices=jax_compatible_cpu_devices + ) + add_function_test( + TestJax, + "test_ffi_jax_kernel_sincos_host", + test_ffi_jax_kernel_sincos_host, + devices=jax_compatible_cpu_devices, + ) + add_function_test( + TestJax, + "test_ffi_jax_kernel_in_out_host", + test_ffi_jax_kernel_in_out_host, + devices=jax_compatible_cpu_devices, + ) + add_function_test( + TestJax, + "test_ffi_jax_callable_scale_constant_host", + test_ffi_jax_callable_scale_constant_host, + devices=jax_compatible_cpu_devices, + ) + + if jax_compatible_cuda_devices: # autodiff tests add_function_test( TestJax, From f6048f47eb4991c09cf66298d7f0ac7452288eb9 Mon Sep 17 00:00:00 2001 From: Ankit Jain Date: Sat, 9 May 2026 01:07:50 +0100 Subject: [PATCH 08/10] Add tests for JAX FFI Host (CPU) support Add test coverage for jax_kernel and jax_callable running on the CPU Host platform. Tests cover basic add, sincos (multi-output), in-out args, scalar constant args, and callable variants. Signed-off-by: Ankit Jain --- warp/tests/interop/test_jax.py | 101 +++++++++++++++++++++++---------- 1 file changed, 72 insertions(+), 29 deletions(-) diff --git a/warp/tests/interop/test_jax.py b/warp/tests/interop/test_jax.py index f08aca1620..279fb1ca6b 100644 --- a/warp/tests/interop/test_jax.py +++ b/warp/tests/interop/test_jax.py @@ -1122,13 +1122,8 @@ def warp_func(inputs, outputs, attrs, ctx): assert_np_equal(d, 2 * np.arange(ARRAY_SIZE, dtype=np.float32).reshape((ARRAY_SIZE // 2, 2))) -# ========================================================================================================= -# JAX FFI Host (CPU) tests -# ========================================================================================================= - - @unittest.skipUnless(_jax_version() >= (0, 5, 0), "Jax version too old") -def test_ffi_jax_kernel_add_host(test, device): +def test_ffi_jax_kernel_host_add(test, device): # two inputs and one output on CPU import jax.numpy as jp @@ -1143,7 +1138,7 @@ def f(): b = jp.ones(n, dtype=jp.float32) return jax_add(a, b) - with jax.default_device(wp.device_to_jax(device)): + with jax.default_device(jax.devices("cpu")[0]): (y,) = f() jax.block_until_ready(y) @@ -1155,7 +1150,7 @@ def f(): @unittest.skipUnless(_jax_version() >= (0, 5, 0), "Jax version too old") -def test_ffi_jax_kernel_sincos_host(test, device): +def test_ffi_jax_kernel_host_sincos(test, device): # one input and two outputs on CPU import jax.numpy as jp @@ -1170,7 +1165,7 @@ def f(): a = jp.linspace(0, 2 * jp.pi, n, dtype=jp.float32) return jax_sincos(a) - with jax.default_device(wp.device_to_jax(device)): + with jax.default_device(jax.devices("cpu")[0]): s, c = f() jax.block_until_ready([s, c]) @@ -1187,7 +1182,7 @@ def f(): @unittest.skipUnless(_jax_version() >= (0, 5, 0), "Jax version too old") -def test_ffi_jax_kernel_in_out_host(test, device): +def test_ffi_jax_kernel_host_in_out(test, device): # in-out args on CPU import jax.numpy as jp @@ -1197,7 +1192,7 @@ def test_ffi_jax_kernel_in_out_host(test, device): f = jax.jit(jax_func) - with jax.default_device(wp.device_to_jax(device)): + with jax.default_device(jax.devices("cpu")[0]): a = jp.ones(ARRAY_SIZE, dtype=jp.float32) b = jp.arange(ARRAY_SIZE, dtype=jp.float32) b, c = f(a, b) @@ -1209,7 +1204,32 @@ def test_ffi_jax_kernel_in_out_host(test, device): @unittest.skipUnless(_jax_version() >= (0, 5, 0), "Jax version too old") -def test_ffi_jax_callable_scale_constant_host(test, device): +def test_ffi_jax_kernel_host_scale_vec_constant(test, device): + # multiply vectors by scalar (constant) on CPU + import jax.numpy as jp + + from warp.jax_experimental.ffi import jax_kernel + + jax_scale_vec = jax_kernel(scale_vec_kernel) + + @jax.jit + def f(): + a = jp.arange(ARRAY_SIZE, dtype=jp.float32).reshape((ARRAY_SIZE // 2, 2)) # array of vec2 + s = 2.0 + return jax_scale_vec(a, s) + + with jax.default_device(jax.devices("cpu")[0]): + (b,) = f() + + jax.block_until_ready(b) + + expected = 2 * np.arange(ARRAY_SIZE, dtype=np.float32).reshape((ARRAY_SIZE // 2, 2)) + + assert_np_equal(b, expected) + + +@unittest.skipUnless(_jax_version() >= (0, 5, 0), "Jax version too old") +def test_ffi_jax_callable_host_scale_constant(test, device): # scale two arrays using a constant on CPU import jax.numpy as jp @@ -1231,7 +1251,7 @@ def f(): return c, d - with jax.default_device(wp.device_to_jax(device)): + with jax.default_device(jax.devices("cpu")[0]): result1, result2 = f() jax.block_until_ready([result1, result2]) @@ -1243,6 +1263,28 @@ def f(): assert_np_equal(result2, expected2) +@unittest.skipUnless(_jax_version() >= (0, 5, 0), "Jax version too old") +def test_ffi_jax_callable_host_in_out(test, device): + # in-out arguments on CPU + import jax.numpy as jp + + from warp.jax_experimental.ffi import jax_callable + + jax_func = jax_callable(in_out_func, num_outputs=2, in_out_argnames=["b"]) + + f = jax.jit(jax_func) + + with jax.default_device(jax.devices("cpu")[0]): + a = jp.ones(ARRAY_SIZE, dtype=jp.float32) + b = jp.arange(ARRAY_SIZE, dtype=jp.float32) + b, c = f(a, b) + + jax.block_until_ready([b, c]) + + assert_np_equal(b, np.arange(1, ARRAY_SIZE + 1, dtype=np.float32)) + assert_np_equal(c, np.full(ARRAY_SIZE, 2, dtype=np.float32)) + + @unittest.skipUnless(_jax_version() >= (0, 5, 0), "Jax version too old") def test_ffi_jax_kernel_autodiff_simple(test, device): if device.ordinal > 0: @@ -2203,7 +2245,6 @@ class TestJax(unittest.TestCase): test_devices = get_test_devices() jax_compatible_devices = [] jax_compatible_cuda_devices = [] - jax_compatible_cpu_devices = [] for d in test_devices: try: with jax.default_device(wp.device_to_jax(d)): @@ -2212,8 +2253,6 @@ class TestJax(unittest.TestCase): jax_compatible_devices.append(d) if d.is_cuda: jax_compatible_cuda_devices.append(d) - if d.is_cpu: - jax_compatible_cpu_devices.append(d) except Exception as e: print(f"Skipping Jax DLPack tests on device '{d}' due to exception: {e}") @@ -2380,31 +2419,35 @@ class TestJax(unittest.TestCase): # ffi callback tests add_function_test(TestJax, "test_ffi_callback", test_ffi_callback, devices=jax_compatible_cuda_devices) - # ffi Host (CPU) tests - if jax_compatible_cpu_devices: + # ffi Host (CPU) tests + add_function_test( + TestJax, "test_ffi_jax_kernel_host_add", test_ffi_jax_kernel_host_add, devices=None + ) + add_function_test( + TestJax, "test_ffi_jax_kernel_host_sincos", test_ffi_jax_kernel_host_sincos, devices=None + ) add_function_test( - TestJax, "test_ffi_jax_kernel_add_host", test_ffi_jax_kernel_add_host, devices=jax_compatible_cpu_devices + TestJax, "test_ffi_jax_kernel_host_in_out", test_ffi_jax_kernel_host_in_out, devices=None ) add_function_test( TestJax, - "test_ffi_jax_kernel_sincos_host", - test_ffi_jax_kernel_sincos_host, - devices=jax_compatible_cpu_devices, + "test_ffi_jax_kernel_host_scale_vec_constant", + test_ffi_jax_kernel_host_scale_vec_constant, + devices=None, ) add_function_test( TestJax, - "test_ffi_jax_kernel_in_out_host", - test_ffi_jax_kernel_in_out_host, - devices=jax_compatible_cpu_devices, + "test_ffi_jax_callable_host_scale_constant", + test_ffi_jax_callable_host_scale_constant, + devices=None, ) add_function_test( TestJax, - "test_ffi_jax_callable_scale_constant_host", - test_ffi_jax_callable_scale_constant_host, - devices=jax_compatible_cpu_devices, + "test_ffi_jax_callable_host_in_out", + test_ffi_jax_callable_host_in_out, + devices=None, ) - if jax_compatible_cuda_devices: # autodiff tests add_function_test( TestJax, From afe74c8b4edf36fe16337744c106a601b16957c0 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Sat, 9 May 2026 00:17:22 +0000 Subject: [PATCH 09/10] [pre-commit.ci] auto code formatting --- warp/tests/interop/test_jax.py | 12 +++--------- 1 file changed, 3 insertions(+), 9 deletions(-) diff --git a/warp/tests/interop/test_jax.py b/warp/tests/interop/test_jax.py index 279fb1ca6b..52ff98f52d 100644 --- a/warp/tests/interop/test_jax.py +++ b/warp/tests/interop/test_jax.py @@ -2420,15 +2420,9 @@ class TestJax(unittest.TestCase): add_function_test(TestJax, "test_ffi_callback", test_ffi_callback, devices=jax_compatible_cuda_devices) # ffi Host (CPU) tests - add_function_test( - TestJax, "test_ffi_jax_kernel_host_add", test_ffi_jax_kernel_host_add, devices=None - ) - add_function_test( - TestJax, "test_ffi_jax_kernel_host_sincos", test_ffi_jax_kernel_host_sincos, devices=None - ) - add_function_test( - TestJax, "test_ffi_jax_kernel_host_in_out", test_ffi_jax_kernel_host_in_out, devices=None - ) + add_function_test(TestJax, "test_ffi_jax_kernel_host_add", test_ffi_jax_kernel_host_add, devices=None) + add_function_test(TestJax, "test_ffi_jax_kernel_host_sincos", test_ffi_jax_kernel_host_sincos, devices=None) + add_function_test(TestJax, "test_ffi_jax_kernel_host_in_out", test_ffi_jax_kernel_host_in_out, devices=None) add_function_test( TestJax, "test_ffi_jax_kernel_host_scale_vec_constant", From 7b3cf542ebd063d642c42f5f11e7c5dc7f7c08a7 Mon Sep 17 00:00:00 2001 From: Ankit Jain Date: Tue, 26 May 2026 14:58:01 +0100 Subject: [PATCH 10/10] Address PR review comments for JAX FFI Host support - Fix wp_cpu_launch_kernel call to use correct 5-arg signature (func, bounds, args, adj_args, apic_info) and build a proper CPU args struct instead of passing GPU-style kernel_params array - Extract _reconstruct_args helper in FfiCallable to deduplicate argument reconstruction between Host and CUDA callback paths - Remove unnecessary stream/device.is_cuda guards in CUDA path since the Host path returns early, making stream always non-None - Move Host CPU test registrations outside jax_compatible_cuda_devices gate so they run on CPU-only CI setups Signed-off-by: Ankit Jain --- warp/_src/jax_experimental/ffi.py | 98 +++++++++++++++---------------- warp/tests/interop/test_jax.py | 53 +++++++++-------- 2 files changed, 77 insertions(+), 74 deletions(-) diff --git a/warp/_src/jax_experimental/ffi.py b/warp/_src/jax_experimental/ffi.py index ae12fd600c..d46aaabb6f 100644 --- a/warp/_src/jax_experimental/ffi.py +++ b/warp/_src/jax_experimental/ffi.py @@ -459,11 +459,24 @@ def ffi_callback(self, call_frame, platform="CUDA"): None, # apic_info ) else: + # Build a CPU args struct from the kernel arguments. + # CPU kernels expect func(bounds*, args_struct*) where args_struct + # is a packed ctypes.Structure, not an array of void pointers. + fields = [] + for i in range(self.num_kernel_args): + arg_name = self.kernel.adj.args[i].label + fields.append((arg_name, type(arg_refs[i]))) + ArgsStruct = type("ArgsStruct", (ctypes.Structure,), {"_fields_": fields}) + args_struct = ArgsStruct() + for i, field in enumerate(fields): + setattr(args_struct, field[0], arg_refs[i]) + wp._src.context.runtime.core.wp_cpu_launch_kernel( - device.context, hooks.forward, - launch_bounds.size, - kernel_params, + ctypes.byref(launch_bounds), + ctypes.byref(args_struct), + None, # adj_args + None, # apic_info ) except Exception as e: @@ -770,27 +783,7 @@ def ffi_callback(self, call_frame, platform="CUDA"): if platform == "Host": device = wp.get_device("cpu") - # reconstruct the argument list - arg_list = [] - - # input and in-out args - for i, arg in enumerate(self.input_args): - if arg.is_array: - buffer = inputs[i].contents - shape = collapse_batch_dims(buffer.dims[: buffer.rank - arg.dtype_ndim], arg.type.ndim) - arr = wp.array(ptr=buffer.data, dtype=arg.type.dtype, shape=shape, device=device) - arg_list.append(arr) - else: - # scalar argument, get stashed value - value = call_desc.static_inputs[arg.name] - arg_list.append(value) - - # pure output args (skip in-out FFI buffers) - for i, arg in enumerate(self.output_args): - buffer = outputs[i + self.num_in_out].contents - shape = collapse_batch_dims(buffer.dims[: buffer.rank - arg.dtype_ndim], arg.type.ndim) - arr = wp.array(ptr=buffer.data, dtype=arg.type.dtype, shape=shape, device=device) - arg_list.append(arr) + arg_list = self._reconstruct_args(inputs, outputs, call_desc, device) # call the Python function with reconstructed arguments with wp.ScopedDevice(device): @@ -906,31 +899,11 @@ def ffi_callback(self, call_frame, platform="CUDA"): device = wp.get_cuda_device(device_ordinal) stream = wp.Stream(device, cuda_stream=cuda_stream) - # reconstruct the argument list - arg_list = [] - - # input and in-out args - for i, arg in enumerate(self.input_args): - if arg.is_array: - buffer = inputs[i].contents - shape = collapse_batch_dims(buffer.dims[: buffer.rank - arg.dtype_ndim], arg.type.ndim) - arr = wp.array(ptr=buffer.data, dtype=arg.type.dtype, shape=shape, device=device) - arg_list.append(arr) - else: - # scalar argument, get stashed value - value = call_desc.static_inputs[arg.name] - arg_list.append(value) - - # pure output args (skip in-out FFI buffers) - for i, arg in enumerate(self.output_args): - buffer = outputs[i + self.num_in_out].contents - shape = collapse_batch_dims(buffer.dims[: buffer.rank - arg.dtype_ndim], arg.type.ndim) - arr = wp.array(ptr=buffer.data, dtype=arg.type.dtype, shape=shape, device=device) - arg_list.append(arr) + arg_list = self._reconstruct_args(inputs, outputs, call_desc, device) # call the Python function with reconstructed arguments - with wp.ScopedStream(stream, sync_enter=False) if stream else wp.ScopedDevice(device): - if stream and stream.is_capturing: + with wp.ScopedStream(stream, sync_enter=False): + if stream.is_capturing: # capturing with JAX with wp.ScopedCapture(external=True) as capture: self.func(*arg_list) @@ -938,7 +911,7 @@ def ffi_callback(self, call_frame, platform="CUDA"): # keep a reference to the capture object to prevent required modules getting unloaded call_desc.capture = capture - elif self.graph_mode == GraphMode.WARP and device.is_cuda: + elif self.graph_mode == GraphMode.WARP: # capturing with WARP with wp.ScopedCapture() as capture: self.func(*arg_list) @@ -951,7 +924,7 @@ def ffi_callback(self, call_frame, platform="CUDA"): if self._graph_cache_max is not None and len(self.captures) > self._graph_cache_max: self.captures.popitem(last=False) - elif self.graph_mode == GraphMode.WARP_STAGED_EX and device.is_cuda: + elif self.graph_mode == GraphMode.WARP_STAGED_EX: # capturing with WARP using staging buffers and memcopies done outside of the graph wp_memcpy_batch = wp._src.context.runtime.core.wp_memcpy_batch @@ -994,7 +967,7 @@ def ffi_callback(self, call_frame, platform="CUDA"): # TODO: we should have a way of freeing this call_desc.capture = capture - elif self.graph_mode == GraphMode.WARP_STAGED and device.is_cuda: + elif self.graph_mode == GraphMode.WARP_STAGED: # capturing with WARP using staging buffers and memcopies done inside of the graph wp_cuda_graph_insert_memcpy_batch = ( wp._src.context.runtime.core.wp_cuda_graph_insert_memcpy_batch @@ -1083,6 +1056,31 @@ def ffi_callback(self, call_frame, platform="CUDA"): return None + def _reconstruct_args(self, inputs, outputs, call_desc, device): + """Reconstruct the argument list from FFI input/output buffers.""" + arg_list = [] + + # input and in-out args + for i, arg in enumerate(self.input_args): + if arg.is_array: + buffer = inputs[i].contents + shape = collapse_batch_dims(buffer.dims[: buffer.rank - arg.dtype_ndim], arg.type.ndim) + arr = wp.array(ptr=buffer.data, dtype=arg.type.dtype, shape=shape, device=device) + arg_list.append(arr) + else: + # scalar argument, get stashed value + value = call_desc.static_inputs[arg.name] + arg_list.append(value) + + # pure output args (skip in-out FFI buffers) + for i, arg in enumerate(self.output_args): + buffer = outputs[i + self.num_in_out].contents + shape = collapse_batch_dims(buffer.dims[: buffer.rank - arg.dtype_ndim], arg.type.ndim) + arr = wp.array(ptr=buffer.data, dtype=arg.type.dtype, shape=shape, device=device) + arg_list.append(arr) + + return arg_list + def _prepare_staging(self, arg_list, call_desc): # create staging arrays input_callback_arrays = [] diff --git a/warp/tests/interop/test_jax.py b/warp/tests/interop/test_jax.py index 52ff98f52d..b247818b26 100644 --- a/warp/tests/interop/test_jax.py +++ b/warp/tests/interop/test_jax.py @@ -2418,30 +2418,6 @@ class TestJax(unittest.TestCase): # ffi callback tests add_function_test(TestJax, "test_ffi_callback", test_ffi_callback, devices=jax_compatible_cuda_devices) - - # ffi Host (CPU) tests - add_function_test(TestJax, "test_ffi_jax_kernel_host_add", test_ffi_jax_kernel_host_add, devices=None) - add_function_test(TestJax, "test_ffi_jax_kernel_host_sincos", test_ffi_jax_kernel_host_sincos, devices=None) - add_function_test(TestJax, "test_ffi_jax_kernel_host_in_out", test_ffi_jax_kernel_host_in_out, devices=None) - add_function_test( - TestJax, - "test_ffi_jax_kernel_host_scale_vec_constant", - test_ffi_jax_kernel_host_scale_vec_constant, - devices=None, - ) - add_function_test( - TestJax, - "test_ffi_jax_callable_host_scale_constant", - test_ffi_jax_callable_host_scale_constant, - devices=None, - ) - add_function_test( - TestJax, - "test_ffi_jax_callable_host_in_out", - test_ffi_jax_callable_host_in_out, - devices=None, - ) - # autodiff tests add_function_test( TestJax, @@ -2565,6 +2541,35 @@ class TestJax(unittest.TestCase): devices=jax_compatible_cuda_devices, ) + # ffi Host (CPU) tests — always registered (not gated on CUDA availability) + add_function_test( + TestJax, "test_ffi_jax_kernel_host_add", test_ffi_jax_kernel_host_add, devices=None + ) + add_function_test( + TestJax, "test_ffi_jax_kernel_host_sincos", test_ffi_jax_kernel_host_sincos, devices=None + ) + add_function_test( + TestJax, "test_ffi_jax_kernel_host_in_out", test_ffi_jax_kernel_host_in_out, devices=None + ) + add_function_test( + TestJax, + "test_ffi_jax_kernel_host_scale_vec_constant", + test_ffi_jax_kernel_host_scale_vec_constant, + devices=None, + ) + add_function_test( + TestJax, + "test_ffi_jax_callable_host_scale_constant", + test_ffi_jax_callable_host_scale_constant, + devices=None, + ) + add_function_test( + TestJax, + "test_ffi_jax_callable_host_in_out", + test_ffi_jax_callable_host_in_out, + devices=None, + ) + # bfloat16 tests require arch >= 80 bf16_jax_devices = [d for d in jax_compatible_devices if d.is_cpu or (d.is_cuda and d.arch >= 80)] if bf16_jax_devices: