-
Notifications
You must be signed in to change notification settings - Fork 519
Add JAX FFI Host support #1446
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Add JAX FFI Host support #1446
Changes from 9 commits
a9b447a
09569ba
dea9a19
6428913
e294d47
12e1cf7
d39169c
9e18ede
f6048f4
afe74c8
7b3cf54
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change | ||||||||||||
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
|
|
@@ -215,10 +215,16 @@ def __init__( | |||||||||||||
|
|
||||||||||||||
| # register the callback | ||||||||||||||
| FFI_CCALLFUNC = ctypes.CFUNCTYPE(ctypes.c_void_p, ctypes.POINTER(XLA_FFI_CallFrame)) | ||||||||||||||
| self.callback_func = FFI_CCALLFUNC(self.ffi_callback) | ||||||||||||||
| ffi_ccall_address = ctypes.cast(self.callback_func, ctypes.c_void_p) | ||||||||||||||
| ffi_capsule = jax.ffi.pycapsule(ffi_ccall_address.value) | ||||||||||||||
| jax.ffi.register_ffi_target(self.name, ffi_capsule, platform="CUDA") | ||||||||||||||
|
|
||||||||||||||
| self.callback_func_cuda = FFI_CCALLFUNC(lambda call_frame: self.ffi_callback(call_frame, platform="CUDA")) | ||||||||||||||
| ffi_ccall_address_cuda = ctypes.cast(self.callback_func_cuda, ctypes.c_void_p) | ||||||||||||||
| ffi_capsule_cuda = jax.ffi.pycapsule(ffi_ccall_address_cuda.value) | ||||||||||||||
| jax.ffi.register_ffi_target(self.name, ffi_capsule_cuda, platform="CUDA") | ||||||||||||||
|
|
||||||||||||||
| self.callback_func_host = FFI_CCALLFUNC(lambda call_frame: self.ffi_callback(call_frame, platform="Host")) | ||||||||||||||
| ffi_ccall_address_host = ctypes.cast(self.callback_func_host, ctypes.c_void_p) | ||||||||||||||
| ffi_capsule_host = jax.ffi.pycapsule(ffi_ccall_address_host.value) | ||||||||||||||
| jax.ffi.register_ffi_target(self.name, ffi_capsule_host, platform="Host") | ||||||||||||||
|
coderabbitai[bot] marked this conversation as resolved.
|
||||||||||||||
|
|
||||||||||||||
| def __call__(self, *args, output_dims=None, launch_dims=None, vmap_method=None): | ||||||||||||||
| num_inputs = len(args) | ||||||||||||||
|
|
@@ -332,7 +338,7 @@ def __call__(self, *args, output_dims=None, launch_dims=None, vmap_method=None): | |||||||||||||
|
|
||||||||||||||
| return call(*args, launch_id=launch_id) | ||||||||||||||
|
|
||||||||||||||
| def ffi_callback(self, call_frame): | ||||||||||||||
| def ffi_callback(self, call_frame, platform="CUDA"): | ||||||||||||||
| try: | ||||||||||||||
| # On the first call, XLA runtime will query the API version and traits | ||||||||||||||
| # metadata using the |extension| field. Let us respond to that query | ||||||||||||||
|
|
@@ -344,10 +350,11 @@ def ffi_callback(self, call_frame): | |||||||||||||
| metadata_ext = ctypes.cast(extension, ctypes.POINTER(XLA_FFI_Metadata_Extension)) | ||||||||||||||
| metadata_ext.contents.metadata.contents.api_version.major_version = 0 | ||||||||||||||
| metadata_ext.contents.metadata.contents.api_version.minor_version = 1 | ||||||||||||||
| # Turn on CUDA graphs for this handler. | ||||||||||||||
| metadata_ext.contents.metadata.contents.traits = ( | ||||||||||||||
| XLA_FFI_Handler_TraitsBits.COMMAND_BUFFER_COMPATIBLE | ||||||||||||||
| ) | ||||||||||||||
| # Turn on CUDA graphs for this handler if on CUDA platform. | ||||||||||||||
| if platform == "CUDA": | ||||||||||||||
| metadata_ext.contents.metadata.contents.traits = ( | ||||||||||||||
| XLA_FFI_Handler_TraitsBits.COMMAND_BUFFER_COMPATIBLE | ||||||||||||||
| ) | ||||||||||||||
| return None | ||||||||||||||
|
|
||||||||||||||
| # Lock is required to prevent race conditions when callback is invoked | ||||||||||||||
|
|
@@ -427,25 +434,37 @@ def ffi_callback(self, call_frame): | |||||||||||||
| kernel_params[0] = ctypes.addressof(launch_bounds) | ||||||||||||||
|
|
||||||||||||||
| # get device and stream | ||||||||||||||
| device = wp.get_cuda_device(get_device_ordinal_from_callframe(call_frame.contents)) | ||||||||||||||
| stream = get_stream_from_callframe(call_frame.contents) | ||||||||||||||
| if platform == "CUDA": | ||||||||||||||
| device = wp.get_cuda_device(get_device_ordinal_from_callframe(call_frame.contents)) | ||||||||||||||
| stream = get_stream_from_callframe(call_frame.contents) | ||||||||||||||
| else: | ||||||||||||||
| device = wp.get_device("cpu") | ||||||||||||||
| stream = None | ||||||||||||||
|
|
||||||||||||||
| # get kernel hooks | ||||||||||||||
| hooks = self.kernel.module.get_kernel_hooks(self.kernel, device) | ||||||||||||||
| assert hooks.forward, "Failed to find kernel entry point" | ||||||||||||||
|
|
||||||||||||||
| # launch the kernel | ||||||||||||||
| wp._src.context.runtime.core.wp_cuda_launch_kernel( | ||||||||||||||
| device.context, | ||||||||||||||
| hooks.forward, | ||||||||||||||
| launch_bounds.size, | ||||||||||||||
| 0, | ||||||||||||||
| 256, | ||||||||||||||
| hooks.forward_smem_bytes, | ||||||||||||||
| kernel_params, | ||||||||||||||
| stream, | ||||||||||||||
| None, # apic_info | ||||||||||||||
| ) | ||||||||||||||
| if device.is_cuda: | ||||||||||||||
| wp._src.context.runtime.core.wp_cuda_launch_kernel( | ||||||||||||||
| device.context, | ||||||||||||||
| hooks.forward, | ||||||||||||||
| launch_bounds.size, | ||||||||||||||
| 0, | ||||||||||||||
| 256, | ||||||||||||||
| hooks.forward_smem_bytes, | ||||||||||||||
| kernel_params, | ||||||||||||||
| stream, | ||||||||||||||
| None, # apic_info | ||||||||||||||
| ) | ||||||||||||||
| else: | ||||||||||||||
| wp._src.context.runtime.core.wp_cpu_launch_kernel( | ||||||||||||||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This Host path is using the CUDA launch ABI for |
||||||||||||||
| device.context, | ||||||||||||||
| hooks.forward, | ||||||||||||||
| launch_bounds.size, | ||||||||||||||
| kernel_params, | ||||||||||||||
| ) | ||||||||||||||
|
Comment on lines
+474
to
+480
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
The call is missing one argument and the arguments are in the wrong positions. The registered ctypes signature (in |
||||||||||||||
|
|
||||||||||||||
| except Exception as e: | ||||||||||||||
| print(traceback.format_exc()) | ||||||||||||||
|
|
@@ -599,10 +618,16 @@ def __init__( | |||||||||||||
|
|
||||||||||||||
| # register the callback | ||||||||||||||
| FFI_CCALLFUNC = ctypes.CFUNCTYPE(ctypes.c_void_p, ctypes.POINTER(XLA_FFI_CallFrame)) | ||||||||||||||
| self.callback_func = FFI_CCALLFUNC(self.ffi_callback) | ||||||||||||||
| ffi_ccall_address = ctypes.cast(self.callback_func, ctypes.c_void_p) | ||||||||||||||
| ffi_capsule = jax.ffi.pycapsule(ffi_ccall_address.value) | ||||||||||||||
| jax.ffi.register_ffi_target(self.name, ffi_capsule, platform="CUDA") | ||||||||||||||
|
|
||||||||||||||
| self.callback_func_cuda = FFI_CCALLFUNC(lambda call_frame: self.ffi_callback(call_frame, platform="CUDA")) | ||||||||||||||
| ffi_ccall_address_cuda = ctypes.cast(self.callback_func_cuda, ctypes.c_void_p) | ||||||||||||||
| ffi_capsule_cuda = jax.ffi.pycapsule(ffi_ccall_address_cuda.value) | ||||||||||||||
| jax.ffi.register_ffi_target(self.name, ffi_capsule_cuda, platform="CUDA") | ||||||||||||||
|
|
||||||||||||||
| self.callback_func_host = FFI_CCALLFUNC(lambda call_frame: self.ffi_callback(call_frame, platform="Host")) | ||||||||||||||
| ffi_ccall_address_host = ctypes.cast(self.callback_func_host, ctypes.c_void_p) | ||||||||||||||
| ffi_capsule_host = jax.ffi.pycapsule(ffi_ccall_address_host.value) | ||||||||||||||
| jax.ffi.register_ffi_target(self.name, ffi_capsule_host, platform="Host") | ||||||||||||||
|
coderabbitai[bot] marked this conversation as resolved.
|
||||||||||||||
|
|
||||||||||||||
| def __call__(self, *args, output_dims=None, vmap_method=None): | ||||||||||||||
| num_inputs = len(args) | ||||||||||||||
|
|
@@ -703,7 +728,7 @@ def __call__(self, *args, output_dims=None, vmap_method=None): | |||||||||||||
| self.call_id += 1 | ||||||||||||||
| return call(*args, call_id=call_id) | ||||||||||||||
|
|
||||||||||||||
| def ffi_callback(self, call_frame): | ||||||||||||||
| def ffi_callback(self, call_frame, platform="CUDA"): | ||||||||||||||
| try: | ||||||||||||||
| # On the first call, XLA runtime will query the API version and traits | ||||||||||||||
| # metadata using the |extension| field. Let us respond to that query | ||||||||||||||
|
|
@@ -715,8 +740,8 @@ def ffi_callback(self, call_frame): | |||||||||||||
| metadata_ext = ctypes.cast(extension, ctypes.POINTER(XLA_FFI_Metadata_Extension)) | ||||||||||||||
| metadata_ext.contents.metadata.contents.api_version.major_version = 0 | ||||||||||||||
| metadata_ext.contents.metadata.contents.api_version.minor_version = 1 | ||||||||||||||
| # Turn on CUDA graphs for this handler. | ||||||||||||||
| if self.graph_mode is GraphMode.JAX: | ||||||||||||||
| # Turn on CUDA graphs for this handler if on CUDA platform. | ||||||||||||||
| if self.graph_mode is GraphMode.JAX and platform == "CUDA": | ||||||||||||||
| metadata_ext.contents.metadata.contents.traits = ( | ||||||||||||||
| XLA_FFI_Handler_TraitsBits.COMMAND_BUFFER_COMPATIBLE | ||||||||||||||
| ) | ||||||||||||||
|
|
@@ -743,6 +768,35 @@ def ffi_callback(self, call_frame): | |||||||||||||
| assert num_inputs == self.num_inputs | ||||||||||||||
| assert num_outputs == self.num_outputs | ||||||||||||||
|
|
||||||||||||||
| if platform == "Host": | ||||||||||||||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Can we deduplicate this argument reconstruction code? Perhaps extract to a helper function. |
||||||||||||||
| device = wp.get_device("cpu") | ||||||||||||||
| # reconstruct the argument list | ||||||||||||||
| arg_list = [] | ||||||||||||||
|
|
||||||||||||||
| # input and in-out args | ||||||||||||||
| for i, arg in enumerate(self.input_args): | ||||||||||||||
| if arg.is_array: | ||||||||||||||
| buffer = inputs[i].contents | ||||||||||||||
| shape = collapse_batch_dims(buffer.dims[: buffer.rank - arg.dtype_ndim], arg.type.ndim) | ||||||||||||||
| arr = wp.array(ptr=buffer.data, dtype=arg.type.dtype, shape=shape, device=device) | ||||||||||||||
| arg_list.append(arr) | ||||||||||||||
| else: | ||||||||||||||
| # scalar argument, get stashed value | ||||||||||||||
| value = call_desc.static_inputs[arg.name] | ||||||||||||||
| arg_list.append(value) | ||||||||||||||
|
|
||||||||||||||
| # pure output args (skip in-out FFI buffers) | ||||||||||||||
| for i, arg in enumerate(self.output_args): | ||||||||||||||
| buffer = outputs[i + self.num_in_out].contents | ||||||||||||||
| shape = collapse_batch_dims(buffer.dims[: buffer.rank - arg.dtype_ndim], arg.type.ndim) | ||||||||||||||
| arr = wp.array(ptr=buffer.data, dtype=arg.type.dtype, shape=shape, device=device) | ||||||||||||||
| arg_list.append(arr) | ||||||||||||||
|
|
||||||||||||||
| # call the Python function with reconstructed arguments | ||||||||||||||
| with wp.ScopedDevice(device): | ||||||||||||||
| self.func(*arg_list) | ||||||||||||||
| return | ||||||||||||||
|
|
||||||||||||||
| cuda_stream = get_stream_from_callframe(call_frame.contents) | ||||||||||||||
| device_ordinal = get_device_ordinal_from_callframe(call_frame.contents) | ||||||||||||||
|
|
||||||||||||||
|
|
@@ -875,16 +929,16 @@ def ffi_callback(self, call_frame): | |||||||||||||
| arg_list.append(arr) | ||||||||||||||
|
|
||||||||||||||
| # call the Python function with reconstructed arguments | ||||||||||||||
| with wp.ScopedStream(stream, sync_enter=False): | ||||||||||||||
| if stream.is_capturing: | ||||||||||||||
| with wp.ScopedStream(stream, sync_enter=False) if stream else wp.ScopedDevice(device): | ||||||||||||||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. No need to check |
||||||||||||||
| if stream and stream.is_capturing: | ||||||||||||||
| # capturing with JAX | ||||||||||||||
| with wp.ScopedCapture(external=True) as capture: | ||||||||||||||
| self.func(*arg_list) | ||||||||||||||
|
|
||||||||||||||
| # keep a reference to the capture object to prevent required modules getting unloaded | ||||||||||||||
| call_desc.capture = capture | ||||||||||||||
|
|
||||||||||||||
| elif self.graph_mode == GraphMode.WARP: | ||||||||||||||
| elif self.graph_mode == GraphMode.WARP and device.is_cuda: | ||||||||||||||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. No need to check |
||||||||||||||
| # capturing with WARP | ||||||||||||||
| with wp.ScopedCapture() as capture: | ||||||||||||||
| self.func(*arg_list) | ||||||||||||||
|
|
@@ -897,7 +951,7 @@ def ffi_callback(self, call_frame): | |||||||||||||
| if self._graph_cache_max is not None and len(self.captures) > self._graph_cache_max: | ||||||||||||||
| self.captures.popitem(last=False) | ||||||||||||||
|
|
||||||||||||||
| elif self.graph_mode == GraphMode.WARP_STAGED_EX: | ||||||||||||||
| elif self.graph_mode == GraphMode.WARP_STAGED_EX and device.is_cuda: | ||||||||||||||
| # capturing with WARP using staging buffers and memcopies done outside of the graph | ||||||||||||||
| wp_memcpy_batch = wp._src.context.runtime.core.wp_memcpy_batch | ||||||||||||||
|
|
||||||||||||||
|
|
@@ -940,7 +994,7 @@ def ffi_callback(self, call_frame): | |||||||||||||
| # TODO: we should have a way of freeing this | ||||||||||||||
| call_desc.capture = capture | ||||||||||||||
|
|
||||||||||||||
| elif self.graph_mode == GraphMode.WARP_STAGED: | ||||||||||||||
| elif self.graph_mode == GraphMode.WARP_STAGED and device.is_cuda: | ||||||||||||||
| # capturing with WARP using staging buffers and memcopies done inside of the graph | ||||||||||||||
| wp_cuda_graph_insert_memcpy_batch = ( | ||||||||||||||
| wp._src.context.runtime.core.wp_cuda_graph_insert_memcpy_batch | ||||||||||||||
|
|
@@ -1018,7 +1072,7 @@ def ffi_callback(self, call_frame): | |||||||||||||
| call_desc.capture = capture | ||||||||||||||
|
|
||||||||||||||
| else: | ||||||||||||||
| # not capturing | ||||||||||||||
| # not capturing or on CPU | ||||||||||||||
| self.func(*arg_list) | ||||||||||||||
|
|
||||||||||||||
| except Exception as e: | ||||||||||||||
|
|
@@ -1663,7 +1717,7 @@ def register_ffi_callback(name: str, func: Callable, graph_compatible: bool = Tr | |||||||||||||
|
|
||||||||||||||
| # TODO check that the name is not already registered | ||||||||||||||
|
|
||||||||||||||
| def ffi_callback(call_frame): | ||||||||||||||
| def ffi_callback(call_frame, platform="CUDA"): | ||||||||||||||
| try: | ||||||||||||||
| extension = call_frame.contents.extension_start | ||||||||||||||
| # On the first call, XLA runtime will query the API version and traits | ||||||||||||||
|
|
@@ -1675,7 +1729,7 @@ def ffi_callback(call_frame): | |||||||||||||
| metadata_ext = ctypes.cast(extension, ctypes.POINTER(XLA_FFI_Metadata_Extension)) | ||||||||||||||
| metadata_ext.contents.metadata.contents.api_version.major_version = 0 | ||||||||||||||
| metadata_ext.contents.metadata.contents.api_version.minor_version = 1 | ||||||||||||||
| if graph_compatible: | ||||||||||||||
| if graph_compatible and platform == "CUDA": | ||||||||||||||
| # Turn on CUDA graphs for this handler. | ||||||||||||||
| metadata_ext.contents.metadata.contents.traits = ( | ||||||||||||||
| XLA_FFI_Handler_TraitsBits.COMMAND_BUFFER_COMPATIBLE | ||||||||||||||
|
|
@@ -1708,12 +1762,17 @@ def ffi_callback(call_frame): | |||||||||||||
| return None | ||||||||||||||
|
|
||||||||||||||
| FFI_CCALLFUNC = ctypes.CFUNCTYPE(ctypes.c_void_p, ctypes.POINTER(XLA_FFI_CallFrame)) | ||||||||||||||
| callback_func = FFI_CCALLFUNC(ffi_callback) | ||||||||||||||
| callback_func_cuda = FFI_CCALLFUNC(lambda call_frame: ffi_callback(call_frame, platform="CUDA")) | ||||||||||||||
| callback_func_host = FFI_CCALLFUNC(lambda call_frame: ffi_callback(call_frame, platform="Host")) | ||||||||||||||
| with _FFI_REGISTRY_LOCK: | ||||||||||||||
| _FFI_CALLBACK_REGISTRY[name] = callback_func | ||||||||||||||
| ffi_ccall_address = ctypes.cast(callback_func, ctypes.c_void_p) | ||||||||||||||
| ffi_capsule = jax.ffi.pycapsule(ffi_ccall_address.value) | ||||||||||||||
| jax.ffi.register_ffi_target(name, ffi_capsule, platform="CUDA") | ||||||||||||||
| _FFI_CALLBACK_REGISTRY[f"{name}_cuda"] = callback_func_cuda | ||||||||||||||
| _FFI_CALLBACK_REGISTRY[f"{name}_host"] = callback_func_host | ||||||||||||||
| ffi_ccall_address_cuda = ctypes.cast(callback_func_cuda, ctypes.c_void_p) | ||||||||||||||
| ffi_capsule_cuda = jax.ffi.pycapsule(ffi_ccall_address_cuda.value) | ||||||||||||||
| jax.ffi.register_ffi_target(name, ffi_capsule_cuda, platform="CUDA") | ||||||||||||||
| ffi_ccall_address_host = ctypes.cast(callback_func_host, ctypes.c_void_p) | ||||||||||||||
| ffi_capsule_host = jax.ffi.pycapsule(ffi_ccall_address_host.value) | ||||||||||||||
| jax.ffi.register_ffi_target(name, ffi_capsule_host, platform="Host") | ||||||||||||||
|
Comment on lines
+1771
to
+1773
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
coderabbitai[bot] marked this conversation as resolved.
|
||||||||||||||
|
|
||||||||||||||
|
|
||||||||||||||
| ############################################################################### | ||||||||||||||
|
|
||||||||||||||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
NameErrorinFfiKernel.__init__:ffi_capsule_hostis passed as its own argument topycapsulebefore the name is ever assigned. This will raiseNameError: name 'ffi_capsule_host' is not definedevery time anFfiKernelis instantiated, making the entire Host platform registration dead code. The address variableffi_ccall_address_hostshould be used here instead — consistent with how the CUDA path is written above.