Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ else()
FetchContent_Declare(
mlx
GIT_REPOSITORY "https://github.com/ml-explore/mlx.git"
GIT_TAG v0.30.4)
GIT_TAG v0.30.6)
FetchContent_MakeAvailable(mlx)
endif()

Expand All @@ -45,6 +45,7 @@ set(mlxc-src
${CMAKE_CURRENT_LIST_DIR}/mlx/c/array.cpp
${CMAKE_CURRENT_LIST_DIR}/mlx/c/closure.cpp
${CMAKE_CURRENT_LIST_DIR}/mlx/c/compile.cpp
${CMAKE_CURRENT_LIST_DIR}/mlx/c/cuda.cpp
${CMAKE_CURRENT_LIST_DIR}/mlx/c/device.cpp
${CMAKE_CURRENT_LIST_DIR}/mlx/c/distributed.cpp
${CMAKE_CURRENT_LIST_DIR}/mlx/c/distributed_group.cpp
Expand Down
19 changes: 19 additions & 0 deletions mlx/c/cuda.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
/* Copyright © 2023-2024 Apple Inc. */
/* */
/* This file is auto-generated. Do not edit manually. */
/* */

#include "mlx/c/cuda.h"
#include "mlx/backend/cuda/cuda.h"
#include "mlx/c/error.h"
#include "mlx/c/private/mlx.h"

extern "C" int mlx_cuda_is_available(bool* res) {
try {
*res = mlx::core::cu::is_available();
} catch (std::exception& e) {
mlx_error(e.what());
return 1;
}
return 0;
}
39 changes: 39 additions & 0 deletions mlx/c/cuda.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
/* Copyright © 2023-2024 Apple Inc. */
/* */
/* This file is auto-generated. Do not edit manually. */
/* */

#ifndef MLX_CUDA_H
#define MLX_CUDA_H

#include <stdbool.h>
#include <stdint.h>
#include <stdio.h>

#include "mlx/c/array.h"
#include "mlx/c/closure.h"
#include "mlx/c/distributed_group.h"
#include "mlx/c/io_types.h"
#include "mlx/c/map.h"
#include "mlx/c/stream.h"
#include "mlx/c/string.h"
#include "mlx/c/vector.h"

#ifdef __cplusplus
extern "C" {
#endif

/**
* \defgroup cuda Cuda specific operations
*/
/**@{*/

int mlx_cuda_is_available(bool* res);

/**@}*/

#ifdef __cplusplus
}
#endif

#endif
10 changes: 10 additions & 0 deletions mlx/c/device.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,16 @@ extern "C" int mlx_device_free(mlx_device dev) {
return 0;
}

extern "C" int mlx_device_is_available(bool* avail, mlx_device dev) {
try {
*avail = mlx::core::is_available(mlx_device_get_(dev));
return 0;
} catch (std::exception& e) {
mlx_error(e.what());
return 1;
}
}

extern "C" int mlx_device_count(int* count, mlx_device_type type) {
try {
auto cpp_type = mlx_device_type_to_cpp(type);
Expand Down
5 changes: 4 additions & 1 deletion mlx/c/device.h
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,10 @@ int mlx_get_default_device(mlx_device* dev);
* Set the default MLX device.
*/
int mlx_set_default_device(mlx_device dev);

/**
* Check if device is available.
*/
int mlx_device_is_available(bool* avail, mlx_device dev);
/**
* Get the number of available devices for a device type.
*/
Expand Down
15 changes: 0 additions & 15 deletions mlx/c/metal.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -8,21 +8,6 @@
#include "mlx/c/error.h"
#include "mlx/c/private/mlx.h"

extern "C" mlx_metal_device_info_t mlx_metal_device_info(void) {
auto info = mlx::core::metal::device_info();

mlx_metal_device_info_t c_info;
std::strncpy(
c_info.architecture,
std::get<std::string>(info["architecture"]).c_str(),
256);
c_info.max_buffer_length = std::get<size_t>(info["max_buffer_length"]);
c_info.max_recommended_working_set_size =
std::get<size_t>(info["max_recommended_working_set_size"]);
c_info.memory_size = std::get<size_t>(info["memory_size"]);
return c_info;
}

extern "C" int mlx_metal_is_available(bool* res) {
try {
*res = mlx::core::metal::is_available();
Expand Down
22 changes: 0 additions & 22 deletions mlx/c/metal.h
Original file line number Diff line number Diff line change
Expand Up @@ -28,28 +28,6 @@ extern "C" {
*/
/**@{*/

/**
* @deprecated Use mlx_device_info instead.
* This struct is kept for backwards compatibility.
*/
typedef struct mlx_metal_device_info_t_ {
char architecture[256];
size_t max_buffer_length;
size_t max_recommended_working_set_size;
size_t memory_size;
} mlx_metal_device_info_t;

/**
* @deprecated Use mlx_device_info_get() instead.
* Get Metal device information (deprecated).
*/
#if defined(__GNUC__) || defined(__clang__)
__attribute__((deprecated("Use mlx_device_info_get() instead")))
#elif defined(_MSC_VER)
__declspec(deprecated("Use mlx_device_info_get() instead"))
#endif
mlx_metal_device_info_t mlx_metal_device_info(void);

int mlx_metal_is_available(bool* res);
int mlx_metal_start_capture(const char* path);
int mlx_metal_stop_capture(void);
Expand Down
1 change: 1 addition & 0 deletions mlx/c/mlx.h
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
#include "mlx/c/array.h"
#include "mlx/c/closure.h"
#include "mlx/c/compile.h"
#include "mlx/c/cuda.h"
#include "mlx/c/device.h"
#include "mlx/c/distributed.h"
#include "mlx/c/distributed_group.h"
Expand Down
2 changes: 2 additions & 0 deletions python/c.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@ def c_namespace(namespace):
c_prefix = namespace.split("::")
if c_prefix[0] == "mlx" and c_prefix[1] == "core":
c_prefix.pop(1) # we pop core
if len(c_prefix) == 2 and c_prefix[1] == "cu":
c_prefix[1] = "cuda"
return "_".join(c_prefix)


Expand Down
46 changes: 0 additions & 46 deletions python/mlxhooks.py
Original file line number Diff line number Diff line change
@@ -1,49 +1,3 @@
def mlx_metal_device_info(f, implementation):
if implementation:
print(
"""\
extern "C" mlx_metal_device_info_t mlx_metal_device_info(void) {
auto info = mlx::core::metal::device_info();

mlx_metal_device_info_t c_info;
std::strncpy(
c_info.architecture,
std::get<std::string>(info["architecture"]).c_str(),
256);
c_info.max_buffer_length = std::get<size_t>(info["max_buffer_length"]);
c_info.max_recommended_working_set_size =
std::get<size_t>(info["max_recommended_working_set_size"]);
c_info.memory_size = std::get<size_t>(info["memory_size"]);
return c_info;
}"""
)
else:
print(
"""\
/**
* @deprecated Use mlx_device_info instead.
* This struct is kept for backwards compatibility.
*/
typedef struct mlx_metal_device_info_t_ {
char architecture[256];
size_t max_buffer_length;
size_t max_recommended_working_set_size;
size_t memory_size;
} mlx_metal_device_info_t;

/**
* @deprecated Use mlx_device_info_get() instead.
* Get Metal device information (deprecated).
*/
#if defined(__GNUC__) || defined(__clang__)
__attribute__((deprecated("Use mlx_device_info_get() instead")))
#elif defined(_MSC_VER)
__declspec(deprecated("Use mlx_device_info_get() instead"))
#endif
mlx_metal_device_info_t mlx_metal_device_info(void);"""
)


def __implement_mlx_fast_custom_kernel(backend, backend_specific_code, implementation):
if implementation:
code_config = """
Expand Down