Skip to content
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ else()
FetchContent_Declare(
mlx
GIT_REPOSITORY "https://github.com/ml-explore/mlx.git"
GIT_TAG v0.30.4)
GIT_TAG v0.30.6)
FetchContent_MakeAvailable(mlx)
endif()

Expand All @@ -45,6 +45,7 @@ set(mlxc-src
${CMAKE_CURRENT_LIST_DIR}/mlx/c/array.cpp
${CMAKE_CURRENT_LIST_DIR}/mlx/c/closure.cpp
${CMAKE_CURRENT_LIST_DIR}/mlx/c/compile.cpp
${CMAKE_CURRENT_LIST_DIR}/mlx/c/cuda.cpp
${CMAKE_CURRENT_LIST_DIR}/mlx/c/device.cpp
${CMAKE_CURRENT_LIST_DIR}/mlx/c/distributed.cpp
${CMAKE_CURRENT_LIST_DIR}/mlx/c/distributed_group.cpp
Expand Down
19 changes: 19 additions & 0 deletions mlx/c/cuda.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
/* Copyright © 2023-2024 Apple Inc. */
/* */
/* This file is auto-generated. Do not edit manually. */
/* */

#include "mlx/c/cuda.h"
#include "mlx/backend/cuda/cuda.h"
#include "mlx/c/error.h"
#include "mlx/c/private/mlx.h"

extern "C" int mlx_cuda_is_available(bool* res) {
try {
*res = mlx::core::cu::is_available();
} catch (std::exception& e) {
mlx_error(e.what());
return 1;
}
return 0;
}
39 changes: 39 additions & 0 deletions mlx/c/cuda.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
/* Copyright © 2023-2024 Apple Inc. */
/* */
/* This file is auto-generated. Do not edit manually. */
/* */

#ifndef MLX_CUDA_H
#define MLX_CUDA_H

#include <stdbool.h>
#include <stdint.h>
#include <stdio.h>

#include "mlx/c/array.h"
#include "mlx/c/closure.h"
#include "mlx/c/distributed_group.h"
#include "mlx/c/io_types.h"
#include "mlx/c/map.h"
#include "mlx/c/stream.h"
#include "mlx/c/string.h"
#include "mlx/c/vector.h"

#ifdef __cplusplus
extern "C" {
#endif

/**
* \defgroup cuda Cuda specific operations
*/
/**@{*/

int mlx_cuda_is_available(bool* res);

/**@}*/

#ifdef __cplusplus
}
#endif

#endif
15 changes: 0 additions & 15 deletions mlx/c/metal.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -8,21 +8,6 @@
#include "mlx/c/error.h"
#include "mlx/c/private/mlx.h"

extern "C" mlx_metal_device_info_t mlx_metal_device_info(void) {
auto info = mlx::core::metal::device_info();

mlx_metal_device_info_t c_info;
std::strncpy(
c_info.architecture,
std::get<std::string>(info["architecture"]).c_str(),
256);
c_info.max_buffer_length = std::get<size_t>(info["max_buffer_length"]);
c_info.max_recommended_working_set_size =
std::get<size_t>(info["max_recommended_working_set_size"]);
c_info.memory_size = std::get<size_t>(info["memory_size"]);
return c_info;
}

extern "C" int mlx_metal_is_available(bool* res) {
try {
*res = mlx::core::metal::is_available();
Expand Down
22 changes: 0 additions & 22 deletions mlx/c/metal.h
Original file line number Diff line number Diff line change
Expand Up @@ -28,28 +28,6 @@ extern "C" {
*/
/**@{*/

/**
* @deprecated Use mlx_device_info instead.
* This struct is kept for backwards compatibility.
*/
typedef struct mlx_metal_device_info_t_ {
char architecture[256];
size_t max_buffer_length;
size_t max_recommended_working_set_size;
size_t memory_size;
} mlx_metal_device_info_t;

/**
* @deprecated Use mlx_device_info_get() instead.
* Get Metal device information (deprecated).
*/
#if defined(__GNUC__) || defined(__clang__)
__attribute__((deprecated("Use mlx_device_info_get() instead")))
#elif defined(_MSC_VER)
__declspec(deprecated("Use mlx_device_info_get() instead"))
#endif
mlx_metal_device_info_t mlx_metal_device_info(void);

int mlx_metal_is_available(bool* res);
int mlx_metal_start_capture(const char* path);
int mlx_metal_stop_capture(void);
Expand Down
1 change: 1 addition & 0 deletions mlx/c/mlx.h
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
#include "mlx/c/array.h"
#include "mlx/c/closure.h"
#include "mlx/c/compile.h"
#include "mlx/c/cuda.h"
#include "mlx/c/device.h"
#include "mlx/c/distributed.h"
#include "mlx/c/distributed_group.h"
Expand Down
2 changes: 2 additions & 0 deletions python/c.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@ def c_namespace(namespace):
c_prefix = namespace.split("::")
if c_prefix[0] == "mlx" and c_prefix[1] == "core":
c_prefix.pop(1) # we pop core
if len(c_prefix) == 2 and c_prefix[1] == "cu":
c_prefix[1] = "cuda"
return "_".join(c_prefix)


Expand Down
46 changes: 0 additions & 46 deletions python/mlxhooks.py
Original file line number Diff line number Diff line change
@@ -1,49 +1,3 @@
def mlx_metal_device_info(f, implementation):
if implementation:
print(
"""\
extern "C" mlx_metal_device_info_t mlx_metal_device_info(void) {
auto info = mlx::core::metal::device_info();

mlx_metal_device_info_t c_info;
std::strncpy(
c_info.architecture,
std::get<std::string>(info["architecture"]).c_str(),
256);
c_info.max_buffer_length = std::get<size_t>(info["max_buffer_length"]);
c_info.max_recommended_working_set_size =
std::get<size_t>(info["max_recommended_working_set_size"]);
c_info.memory_size = std::get<size_t>(info["memory_size"]);
return c_info;
}"""
)
else:
print(
"""\
/**
* @deprecated Use mlx_device_info instead.
* This struct is kept for backwards compatibility.
*/
typedef struct mlx_metal_device_info_t_ {
char architecture[256];
size_t max_buffer_length;
size_t max_recommended_working_set_size;
size_t memory_size;
} mlx_metal_device_info_t;

/**
* @deprecated Use mlx_device_info_get() instead.
* Get Metal device information (deprecated).
*/
#if defined(__GNUC__) || defined(__clang__)
__attribute__((deprecated("Use mlx_device_info_get() instead")))
#elif defined(_MSC_VER)
__declspec(deprecated("Use mlx_device_info_get() instead"))
#endif
mlx_metal_device_info_t mlx_metal_device_info(void);"""
)


def __implement_mlx_fast_custom_kernel(backend, backend_specific_code, implementation):
if implementation:
code_config = """
Expand Down