diff --git a/CMakeLists.txt b/CMakeLists.txt index 1b539e8..e832f40 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -35,7 +35,7 @@ else() FetchContent_Declare( mlx GIT_REPOSITORY "https://github.com/ml-explore/mlx.git" - GIT_TAG v0.30.4) + GIT_TAG v0.30.6) FetchContent_MakeAvailable(mlx) endif() @@ -45,6 +45,7 @@ set(mlxc-src ${CMAKE_CURRENT_LIST_DIR}/mlx/c/array.cpp ${CMAKE_CURRENT_LIST_DIR}/mlx/c/closure.cpp ${CMAKE_CURRENT_LIST_DIR}/mlx/c/compile.cpp + ${CMAKE_CURRENT_LIST_DIR}/mlx/c/cuda.cpp ${CMAKE_CURRENT_LIST_DIR}/mlx/c/device.cpp ${CMAKE_CURRENT_LIST_DIR}/mlx/c/distributed.cpp ${CMAKE_CURRENT_LIST_DIR}/mlx/c/distributed_group.cpp diff --git a/mlx/c/cuda.cpp b/mlx/c/cuda.cpp new file mode 100644 index 0000000..71c07f5 --- /dev/null +++ b/mlx/c/cuda.cpp @@ -0,0 +1,19 @@ +/* Copyright © 2023-2024 Apple Inc. */ +/* */ +/* This file is auto-generated. Do not edit manually. */ +/* */ + +#include "mlx/c/cuda.h" +#include "mlx/backend/cuda/cuda.h" +#include "mlx/c/error.h" +#include "mlx/c/private/mlx.h" + +extern "C" int mlx_cuda_is_available(bool* res) { + try { + *res = mlx::core::cu::is_available(); + } catch (std::exception& e) { + mlx_error(e.what()); + return 1; + } + return 0; +} diff --git a/mlx/c/cuda.h b/mlx/c/cuda.h new file mode 100644 index 0000000..4734f8c --- /dev/null +++ b/mlx/c/cuda.h @@ -0,0 +1,39 @@ +/* Copyright © 2023-2024 Apple Inc. */ +/* */ +/* This file is auto-generated. Do not edit manually. */ +/* */ + +#ifndef MLX_CUDA_H +#define MLX_CUDA_H + +#include +#include +#include + +#include "mlx/c/array.h" +#include "mlx/c/closure.h" +#include "mlx/c/distributed_group.h" +#include "mlx/c/io_types.h" +#include "mlx/c/map.h" +#include "mlx/c/stream.h" +#include "mlx/c/string.h" +#include "mlx/c/vector.h" + +#ifdef __cplusplus +extern "C" { +#endif + +/** + * \defgroup cuda Cuda specific operations + */ +/**@{*/ + +int mlx_cuda_is_available(bool* res); + +/**@}*/ + +#ifdef __cplusplus +} +#endif + +#endif diff --git a/mlx/c/device.cpp b/mlx/c/device.cpp index 2e6d5a4..632ef9a 100644 --- a/mlx/c/device.cpp +++ b/mlx/c/device.cpp @@ -97,6 +97,16 @@ extern "C" int mlx_device_free(mlx_device dev) { return 0; } +extern "C" int mlx_device_is_available(bool* avail, mlx_device dev) { + try { + *avail = mlx::core::is_available(mlx_device_get_(dev)); + return 0; + } catch (std::exception& e) { + mlx_error(e.what()); + return 1; + } +} + extern "C" int mlx_device_count(int* count, mlx_device_type type) { try { auto cpp_type = mlx_device_type_to_cpp(type); diff --git a/mlx/c/device.h b/mlx/c/device.h index bdcbb25..4b74e39 100644 --- a/mlx/c/device.h +++ b/mlx/c/device.h @@ -72,7 +72,10 @@ int mlx_get_default_device(mlx_device* dev); * Set the default MLX device. */ int mlx_set_default_device(mlx_device dev); - +/** + * Check if device is available. + */ +int mlx_device_is_available(bool* avail, mlx_device dev); /** * Get the number of available devices for a device type. */ diff --git a/mlx/c/metal.cpp b/mlx/c/metal.cpp index f409cbe..46d8034 100644 --- a/mlx/c/metal.cpp +++ b/mlx/c/metal.cpp @@ -8,21 +8,6 @@ #include "mlx/c/error.h" #include "mlx/c/private/mlx.h" -extern "C" mlx_metal_device_info_t mlx_metal_device_info(void) { - auto info = mlx::core::metal::device_info(); - - mlx_metal_device_info_t c_info; - std::strncpy( - c_info.architecture, - std::get(info["architecture"]).c_str(), - 256); - c_info.max_buffer_length = std::get(info["max_buffer_length"]); - c_info.max_recommended_working_set_size = - std::get(info["max_recommended_working_set_size"]); - c_info.memory_size = std::get(info["memory_size"]); - return c_info; -} - extern "C" int mlx_metal_is_available(bool* res) { try { *res = mlx::core::metal::is_available(); diff --git a/mlx/c/metal.h b/mlx/c/metal.h index 885eb0c..5877b22 100644 --- a/mlx/c/metal.h +++ b/mlx/c/metal.h @@ -28,28 +28,6 @@ extern "C" { */ /**@{*/ -/** - * @deprecated Use mlx_device_info instead. - * This struct is kept for backwards compatibility. - */ -typedef struct mlx_metal_device_info_t_ { - char architecture[256]; - size_t max_buffer_length; - size_t max_recommended_working_set_size; - size_t memory_size; -} mlx_metal_device_info_t; - -/** - * @deprecated Use mlx_device_info_get() instead. - * Get Metal device information (deprecated). - */ -#if defined(__GNUC__) || defined(__clang__) -__attribute__((deprecated("Use mlx_device_info_get() instead"))) -#elif defined(_MSC_VER) -__declspec(deprecated("Use mlx_device_info_get() instead")) -#endif -mlx_metal_device_info_t mlx_metal_device_info(void); - int mlx_metal_is_available(bool* res); int mlx_metal_start_capture(const char* path); int mlx_metal_stop_capture(void); diff --git a/mlx/c/mlx.h b/mlx/c/mlx.h index b62ea3b..ffadac8 100644 --- a/mlx/c/mlx.h +++ b/mlx/c/mlx.h @@ -6,6 +6,7 @@ #include "mlx/c/array.h" #include "mlx/c/closure.h" #include "mlx/c/compile.h" +#include "mlx/c/cuda.h" #include "mlx/c/device.h" #include "mlx/c/distributed.h" #include "mlx/c/distributed_group.h" diff --git a/python/c.py b/python/c.py index 179c31c..079dfa3 100644 --- a/python/c.py +++ b/python/c.py @@ -17,6 +17,8 @@ def c_namespace(namespace): c_prefix = namespace.split("::") if c_prefix[0] == "mlx" and c_prefix[1] == "core": c_prefix.pop(1) # we pop core + if len(c_prefix) == 2 and c_prefix[1] == "cu": + c_prefix[1] = "cuda" return "_".join(c_prefix) diff --git a/python/mlxhooks.py b/python/mlxhooks.py index 076934e..701ca5b 100644 --- a/python/mlxhooks.py +++ b/python/mlxhooks.py @@ -1,49 +1,3 @@ -def mlx_metal_device_info(f, implementation): - if implementation: - print( - """\ -extern "C" mlx_metal_device_info_t mlx_metal_device_info(void) { - auto info = mlx::core::metal::device_info(); - - mlx_metal_device_info_t c_info; - std::strncpy( - c_info.architecture, - std::get(info["architecture"]).c_str(), - 256); - c_info.max_buffer_length = std::get(info["max_buffer_length"]); - c_info.max_recommended_working_set_size = - std::get(info["max_recommended_working_set_size"]); - c_info.memory_size = std::get(info["memory_size"]); - return c_info; -}""" - ) - else: - print( - """\ -/** - * @deprecated Use mlx_device_info instead. - * This struct is kept for backwards compatibility. - */ -typedef struct mlx_metal_device_info_t_ { - char architecture[256]; - size_t max_buffer_length; - size_t max_recommended_working_set_size; - size_t memory_size; -} mlx_metal_device_info_t; - -/** - * @deprecated Use mlx_device_info_get() instead. - * Get Metal device information (deprecated). - */ -#if defined(__GNUC__) || defined(__clang__) -__attribute__((deprecated("Use mlx_device_info_get() instead"))) -#elif defined(_MSC_VER) -__declspec(deprecated("Use mlx_device_info_get() instead")) -#endif -mlx_metal_device_info_t mlx_metal_device_info(void);""" - ) - - def __implement_mlx_fast_custom_kernel(backend, backend_specific_code, implementation): if implementation: code_config = """