mirror of
https://github.com/ollama/ollama.git
synced 2026-04-18 09:03:35 -04:00
mlx: update as of 3/23 (#14789)
* mlx: update to HEAD on 3/23 Also fixes a few misc vendoring bugs uncovered with this first update. This also renames the version files to make them clearer. * CUDA Fast Gated Delta kernel * mlx: detect eval errors and panic On model errors or missing kernels, don't mask the error, bubble it up.
This commit is contained in:
@@ -157,7 +157,7 @@ COPY CMakeLists.txt CMakePresets.json .
|
|||||||
COPY ml/backend/ggml/ggml ml/backend/ggml/ggml
|
COPY ml/backend/ggml/ggml ml/backend/ggml/ggml
|
||||||
COPY x/imagegen/mlx x/imagegen/mlx
|
COPY x/imagegen/mlx x/imagegen/mlx
|
||||||
COPY go.mod go.sum .
|
COPY go.mod go.sum .
|
||||||
COPY MLX_VERSION MLX_CORE_VERSION .
|
COPY MLX_VERSION MLX_C_VERSION .
|
||||||
RUN curl -fsSL https://golang.org/dl/go$(awk '/^go/ { print $2 }' go.mod).linux-$(case $(uname -m) in x86_64) echo amd64 ;; aarch64) echo arm64 ;; esac).tar.gz | tar xz -C /usr/local
|
RUN curl -fsSL https://golang.org/dl/go$(awk '/^go/ { print $2 }' go.mod).linux-$(case $(uname -m) in x86_64) echo amd64 ;; aarch64) echo arm64 ;; esac).tar.gz | tar xz -C /usr/local
|
||||||
ENV PATH=/usr/local/go/bin:$PATH
|
ENV PATH=/usr/local/go/bin:$PATH
|
||||||
RUN go mod download
|
RUN go mod download
|
||||||
|
|||||||
@@ -1 +0,0 @@
|
|||||||
v0.30.6
|
|
||||||
1
MLX_C_VERSION
Normal file
1
MLX_C_VERSION
Normal file
@@ -0,0 +1 @@
|
|||||||
|
0726ca922fc902c4c61ef9c27d94132be418e945
|
||||||
@@ -1 +1 @@
|
|||||||
v0.5.0
|
38ad257088fb2193ad47e527cf6534a689f30943
|
||||||
|
|||||||
@@ -1,11 +1,11 @@
|
|||||||
include(FetchContent)
|
include(FetchContent)
|
||||||
|
|
||||||
# Read MLX version from top-level file (shared with Dockerfile)
|
# Read MLX-C version from top-level file (shared with Dockerfile)
|
||||||
file(READ "${CMAKE_SOURCE_DIR}/MLX_VERSION" MLX_C_GIT_TAG)
|
file(READ "${CMAKE_SOURCE_DIR}/MLX_C_VERSION" MLX_C_GIT_TAG)
|
||||||
string(STRIP "${MLX_C_GIT_TAG}" MLX_C_GIT_TAG)
|
string(STRIP "${MLX_C_GIT_TAG}" MLX_C_GIT_TAG)
|
||||||
|
|
||||||
# Read MLX core version from top-level file
|
# Read MLX version from top-level file
|
||||||
file(READ "${CMAKE_SOURCE_DIR}/MLX_CORE_VERSION" MLX_GIT_TAG)
|
file(READ "${CMAKE_SOURCE_DIR}/MLX_VERSION" MLX_GIT_TAG)
|
||||||
string(STRIP "${MLX_GIT_TAG}" MLX_GIT_TAG)
|
string(STRIP "${MLX_GIT_TAG}" MLX_GIT_TAG)
|
||||||
|
|
||||||
set(MLX_C_BUILD_EXAMPLES OFF)
|
set(MLX_C_BUILD_EXAMPLES OFF)
|
||||||
@@ -98,6 +98,15 @@ FetchContent_MakeAvailable(mlx-c)
|
|||||||
file(GLOB _mlx_c_hdrs "${mlx-c_SOURCE_DIR}/mlx/c/*.h")
|
file(GLOB _mlx_c_hdrs "${mlx-c_SOURCE_DIR}/mlx/c/*.h")
|
||||||
file(COPY ${_mlx_c_hdrs} DESTINATION "${CMAKE_SOURCE_DIR}/x/mlxrunner/mlx/include/mlx/c/")
|
file(COPY ${_mlx_c_hdrs} DESTINATION "${CMAKE_SOURCE_DIR}/x/mlxrunner/mlx/include/mlx/c/")
|
||||||
|
|
||||||
|
# Regenerate Go/C shim wrappers from the (possibly updated) headers.
|
||||||
|
find_program(GO_EXECUTABLE go REQUIRED)
|
||||||
|
message(STATUS "Regenerating MLX Go wrappers")
|
||||||
|
execute_process(
|
||||||
|
COMMAND ${GO_EXECUTABLE} generate ./x/...
|
||||||
|
WORKING_DIRECTORY ${CMAKE_SOURCE_DIR}
|
||||||
|
COMMAND_ERROR_IS_FATAL ANY
|
||||||
|
)
|
||||||
|
|
||||||
# For local dev builds, override MLX_VERSION with git describe output
|
# For local dev builds, override MLX_VERSION with git describe output
|
||||||
if(TARGET mlx_version AND DEFINED FETCHCONTENT_SOURCE_DIR_MLX)
|
if(TARGET mlx_version AND DEFINED FETCHCONTENT_SOURCE_DIR_MLX)
|
||||||
execute_process(
|
execute_process(
|
||||||
|
|||||||
@@ -165,8 +165,8 @@ int (*mlx_distributed_sum_scatter_ptr)(mlx_array* res, const mlx_array x, const
|
|||||||
int (*mlx_distributed_group_rank_ptr)(mlx_distributed_group group) = NULL;
|
int (*mlx_distributed_group_rank_ptr)(mlx_distributed_group group) = NULL;
|
||||||
int (*mlx_distributed_group_size_ptr)(mlx_distributed_group group) = NULL;
|
int (*mlx_distributed_group_size_ptr)(mlx_distributed_group group) = NULL;
|
||||||
mlx_distributed_group (*mlx_distributed_group_split_ptr)(mlx_distributed_group group, int color, int key) = NULL;
|
mlx_distributed_group (*mlx_distributed_group_split_ptr)(mlx_distributed_group group, int color, int key) = NULL;
|
||||||
bool (*mlx_distributed_is_available_ptr)(void) = NULL;
|
bool (*mlx_distributed_is_available_ptr)(const char* bk) = NULL;
|
||||||
mlx_distributed_group (*mlx_distributed_init_ptr)(bool strict) = NULL;
|
mlx_distributed_group (*mlx_distributed_init_ptr)(bool strict, const char* bk) = NULL;
|
||||||
void (*mlx_set_error_handler_ptr)(mlx_error_handler_func handler, void* data, void (*dtor)(void*)) = NULL;
|
void (*mlx_set_error_handler_ptr)(mlx_error_handler_func handler, void* data, void (*dtor)(void*)) = NULL;
|
||||||
void (*_mlx_error_ptr)(const char* file, const int line, const char* fmt, ...) = NULL;
|
void (*_mlx_error_ptr)(const char* file, const int line, const char* fmt, ...) = NULL;
|
||||||
int (*mlx_export_function_ptr)(const char* file, const mlx_closure fun, const mlx_vector_array args, bool shapeless) = NULL;
|
int (*mlx_export_function_ptr)(const char* file, const mlx_closure fun, const mlx_vector_array args, bool shapeless) = NULL;
|
||||||
@@ -319,10 +319,12 @@ int (*mlx_astype_ptr)(mlx_array* res, const mlx_array a, mlx_dtype dtype, const
|
|||||||
int (*mlx_atleast_1d_ptr)(mlx_array* res, const mlx_array a, const mlx_stream s) = NULL;
|
int (*mlx_atleast_1d_ptr)(mlx_array* res, const mlx_array a, const mlx_stream s) = NULL;
|
||||||
int (*mlx_atleast_2d_ptr)(mlx_array* res, const mlx_array a, const mlx_stream s) = NULL;
|
int (*mlx_atleast_2d_ptr)(mlx_array* res, const mlx_array a, const mlx_stream s) = NULL;
|
||||||
int (*mlx_atleast_3d_ptr)(mlx_array* res, const mlx_array a, const mlx_stream s) = NULL;
|
int (*mlx_atleast_3d_ptr)(mlx_array* res, const mlx_array a, const mlx_stream s) = NULL;
|
||||||
|
int (*mlx_bartlett_ptr)(mlx_array* res, int M, const mlx_stream s) = NULL;
|
||||||
int (*mlx_bitwise_and_ptr)(mlx_array* res, const mlx_array a, const mlx_array b, const mlx_stream s) = NULL;
|
int (*mlx_bitwise_and_ptr)(mlx_array* res, const mlx_array a, const mlx_array b, const mlx_stream s) = NULL;
|
||||||
int (*mlx_bitwise_invert_ptr)(mlx_array* res, const mlx_array a, const mlx_stream s) = NULL;
|
int (*mlx_bitwise_invert_ptr)(mlx_array* res, const mlx_array a, const mlx_stream s) = NULL;
|
||||||
int (*mlx_bitwise_or_ptr)(mlx_array* res, const mlx_array a, const mlx_array b, const mlx_stream s) = NULL;
|
int (*mlx_bitwise_or_ptr)(mlx_array* res, const mlx_array a, const mlx_array b, const mlx_stream s) = NULL;
|
||||||
int (*mlx_bitwise_xor_ptr)(mlx_array* res, const mlx_array a, const mlx_array b, const mlx_stream s) = NULL;
|
int (*mlx_bitwise_xor_ptr)(mlx_array* res, const mlx_array a, const mlx_array b, const mlx_stream s) = NULL;
|
||||||
|
int (*mlx_blackman_ptr)(mlx_array* res, int M, const mlx_stream s) = NULL;
|
||||||
int (*mlx_block_masked_mm_ptr)(mlx_array* res, const mlx_array a, const mlx_array b, int block_size, const mlx_array mask_out , const mlx_array mask_lhs , const mlx_array mask_rhs , const mlx_stream s) = NULL;
|
int (*mlx_block_masked_mm_ptr)(mlx_array* res, const mlx_array a, const mlx_array b, int block_size, const mlx_array mask_out , const mlx_array mask_lhs , const mlx_array mask_rhs , const mlx_stream s) = NULL;
|
||||||
int (*mlx_broadcast_arrays_ptr)(mlx_vector_array* res, const mlx_vector_array inputs, const mlx_stream s) = NULL;
|
int (*mlx_broadcast_arrays_ptr)(mlx_vector_array* res, const mlx_vector_array inputs, const mlx_stream s) = NULL;
|
||||||
int (*mlx_broadcast_to_ptr)(mlx_array* res, const mlx_array a, const int* shape, size_t shape_num, const mlx_stream s) = NULL;
|
int (*mlx_broadcast_to_ptr)(mlx_array* res, const mlx_array a, const int* shape, size_t shape_num, const mlx_stream s) = NULL;
|
||||||
@@ -348,7 +350,7 @@ int (*mlx_cumprod_ptr)(mlx_array* res, const mlx_array a, int axis, bool reverse
|
|||||||
int (*mlx_cumsum_ptr)(mlx_array* res, const mlx_array a, int axis, bool reverse, bool inclusive, const mlx_stream s) = NULL;
|
int (*mlx_cumsum_ptr)(mlx_array* res, const mlx_array a, int axis, bool reverse, bool inclusive, const mlx_stream s) = NULL;
|
||||||
int (*mlx_degrees_ptr)(mlx_array* res, const mlx_array a, const mlx_stream s) = NULL;
|
int (*mlx_degrees_ptr)(mlx_array* res, const mlx_array a, const mlx_stream s) = NULL;
|
||||||
int (*mlx_depends_ptr)(mlx_vector_array* res, const mlx_vector_array inputs, const mlx_vector_array dependencies) = NULL;
|
int (*mlx_depends_ptr)(mlx_vector_array* res, const mlx_vector_array inputs, const mlx_vector_array dependencies) = NULL;
|
||||||
int (*mlx_dequantize_ptr)(mlx_array* res, const mlx_array w, const mlx_array scales, const mlx_array biases , mlx_optional_int group_size, mlx_optional_int bits, const char* mode, mlx_optional_dtype dtype, const mlx_stream s) = NULL;
|
int (*mlx_dequantize_ptr)(mlx_array* res, const mlx_array w, const mlx_array scales, const mlx_array biases , mlx_optional_int group_size, mlx_optional_int bits, const char* mode, const mlx_array global_scale , mlx_optional_dtype dtype, const mlx_stream s) = NULL;
|
||||||
int (*mlx_diag_ptr)(mlx_array* res, const mlx_array a, int k, const mlx_stream s) = NULL;
|
int (*mlx_diag_ptr)(mlx_array* res, const mlx_array a, int k, const mlx_stream s) = NULL;
|
||||||
int (*mlx_diagonal_ptr)(mlx_array* res, const mlx_array a, int offset, int axis1, int axis2, const mlx_stream s) = NULL;
|
int (*mlx_diagonal_ptr)(mlx_array* res, const mlx_array a, int offset, int axis1, int axis2, const mlx_stream s) = NULL;
|
||||||
int (*mlx_divide_ptr)(mlx_array* res, const mlx_array a, const mlx_array b, const mlx_stream s) = NULL;
|
int (*mlx_divide_ptr)(mlx_array* res, const mlx_array a, const mlx_array b, const mlx_stream s) = NULL;
|
||||||
@@ -375,6 +377,8 @@ int (*mlx_gather_qmm_ptr)(mlx_array* res, const mlx_array x, const mlx_array w,
|
|||||||
int (*mlx_greater_ptr)(mlx_array* res, const mlx_array a, const mlx_array b, const mlx_stream s) = NULL;
|
int (*mlx_greater_ptr)(mlx_array* res, const mlx_array a, const mlx_array b, const mlx_stream s) = NULL;
|
||||||
int (*mlx_greater_equal_ptr)(mlx_array* res, const mlx_array a, const mlx_array b, const mlx_stream s) = NULL;
|
int (*mlx_greater_equal_ptr)(mlx_array* res, const mlx_array a, const mlx_array b, const mlx_stream s) = NULL;
|
||||||
int (*mlx_hadamard_transform_ptr)(mlx_array* res, const mlx_array a, mlx_optional_float scale, const mlx_stream s) = NULL;
|
int (*mlx_hadamard_transform_ptr)(mlx_array* res, const mlx_array a, mlx_optional_float scale, const mlx_stream s) = NULL;
|
||||||
|
int (*mlx_hamming_ptr)(mlx_array* res, int M, const mlx_stream s) = NULL;
|
||||||
|
int (*mlx_hanning_ptr)(mlx_array* res, int M, const mlx_stream s) = NULL;
|
||||||
int (*mlx_identity_ptr)(mlx_array* res, int n, mlx_dtype dtype, const mlx_stream s) = NULL;
|
int (*mlx_identity_ptr)(mlx_array* res, int n, mlx_dtype dtype, const mlx_stream s) = NULL;
|
||||||
int (*mlx_imag_ptr)(mlx_array* res, const mlx_array a, const mlx_stream s) = NULL;
|
int (*mlx_imag_ptr)(mlx_array* res, const mlx_array a, const mlx_stream s) = NULL;
|
||||||
int (*mlx_inner_ptr)(mlx_array* res, const mlx_array a, const mlx_array b, const mlx_stream s) = NULL;
|
int (*mlx_inner_ptr)(mlx_array* res, const mlx_array a, const mlx_array b, const mlx_stream s) = NULL;
|
||||||
@@ -434,8 +438,8 @@ int (*mlx_prod_axes_ptr)(mlx_array* res, const mlx_array a, const int* axes, siz
|
|||||||
int (*mlx_prod_axis_ptr)(mlx_array* res, const mlx_array a, int axis, bool keepdims, const mlx_stream s) = NULL;
|
int (*mlx_prod_axis_ptr)(mlx_array* res, const mlx_array a, int axis, bool keepdims, const mlx_stream s) = NULL;
|
||||||
int (*mlx_prod_ptr)(mlx_array* res, const mlx_array a, bool keepdims, const mlx_stream s) = NULL;
|
int (*mlx_prod_ptr)(mlx_array* res, const mlx_array a, bool keepdims, const mlx_stream s) = NULL;
|
||||||
int (*mlx_put_along_axis_ptr)(mlx_array* res, const mlx_array a, const mlx_array indices, const mlx_array values, int axis, const mlx_stream s) = NULL;
|
int (*mlx_put_along_axis_ptr)(mlx_array* res, const mlx_array a, const mlx_array indices, const mlx_array values, int axis, const mlx_stream s) = NULL;
|
||||||
int (*mlx_qqmm_ptr)(mlx_array* res, const mlx_array x, const mlx_array w, const mlx_array w_scales , mlx_optional_int group_size, mlx_optional_int bits, const char* mode, const mlx_stream s) = NULL;
|
int (*mlx_qqmm_ptr)(mlx_array* res, const mlx_array x, const mlx_array w, const mlx_array w_scales , mlx_optional_int group_size, mlx_optional_int bits, const char* mode, const mlx_array global_scale_x , const mlx_array global_scale_w , const mlx_stream s) = NULL;
|
||||||
int (*mlx_quantize_ptr)(mlx_vector_array* res, const mlx_array w, mlx_optional_int group_size, mlx_optional_int bits, const char* mode, const mlx_stream s) = NULL;
|
int (*mlx_quantize_ptr)(mlx_vector_array* res, const mlx_array w, mlx_optional_int group_size, mlx_optional_int bits, const char* mode, const mlx_array global_scale , const mlx_stream s) = NULL;
|
||||||
int (*mlx_quantized_matmul_ptr)(mlx_array* res, const mlx_array x, const mlx_array w, const mlx_array scales, const mlx_array biases , bool transpose, mlx_optional_int group_size, mlx_optional_int bits, const char* mode, const mlx_stream s) = NULL;
|
int (*mlx_quantized_matmul_ptr)(mlx_array* res, const mlx_array x, const mlx_array w, const mlx_array scales, const mlx_array biases , bool transpose, mlx_optional_int group_size, mlx_optional_int bits, const char* mode, const mlx_stream s) = NULL;
|
||||||
int (*mlx_radians_ptr)(mlx_array* res, const mlx_array a, const mlx_stream s) = NULL;
|
int (*mlx_radians_ptr)(mlx_array* res, const mlx_array a, const mlx_stream s) = NULL;
|
||||||
int (*mlx_real_ptr)(mlx_array* res, const mlx_array a, const mlx_stream s) = NULL;
|
int (*mlx_real_ptr)(mlx_array* res, const mlx_array a, const mlx_stream s) = NULL;
|
||||||
@@ -2101,6 +2105,11 @@ int mlx_load_functions(void* handle) {
|
|||||||
fprintf(stderr, "MLX: Failed to load symbol: mlx_atleast_3d\n");
|
fprintf(stderr, "MLX: Failed to load symbol: mlx_atleast_3d\n");
|
||||||
return -1;
|
return -1;
|
||||||
}
|
}
|
||||||
|
mlx_bartlett_ptr = GET_SYM(handle, "mlx_bartlett");
|
||||||
|
if (mlx_bartlett_ptr == NULL) {
|
||||||
|
fprintf(stderr, "MLX: Failed to load symbol: mlx_bartlett\n");
|
||||||
|
return -1;
|
||||||
|
}
|
||||||
mlx_bitwise_and_ptr = GET_SYM(handle, "mlx_bitwise_and");
|
mlx_bitwise_and_ptr = GET_SYM(handle, "mlx_bitwise_and");
|
||||||
if (mlx_bitwise_and_ptr == NULL) {
|
if (mlx_bitwise_and_ptr == NULL) {
|
||||||
fprintf(stderr, "MLX: Failed to load symbol: mlx_bitwise_and\n");
|
fprintf(stderr, "MLX: Failed to load symbol: mlx_bitwise_and\n");
|
||||||
@@ -2121,6 +2130,11 @@ int mlx_load_functions(void* handle) {
|
|||||||
fprintf(stderr, "MLX: Failed to load symbol: mlx_bitwise_xor\n");
|
fprintf(stderr, "MLX: Failed to load symbol: mlx_bitwise_xor\n");
|
||||||
return -1;
|
return -1;
|
||||||
}
|
}
|
||||||
|
mlx_blackman_ptr = GET_SYM(handle, "mlx_blackman");
|
||||||
|
if (mlx_blackman_ptr == NULL) {
|
||||||
|
fprintf(stderr, "MLX: Failed to load symbol: mlx_blackman\n");
|
||||||
|
return -1;
|
||||||
|
}
|
||||||
mlx_block_masked_mm_ptr = GET_SYM(handle, "mlx_block_masked_mm");
|
mlx_block_masked_mm_ptr = GET_SYM(handle, "mlx_block_masked_mm");
|
||||||
if (mlx_block_masked_mm_ptr == NULL) {
|
if (mlx_block_masked_mm_ptr == NULL) {
|
||||||
fprintf(stderr, "MLX: Failed to load symbol: mlx_block_masked_mm\n");
|
fprintf(stderr, "MLX: Failed to load symbol: mlx_block_masked_mm\n");
|
||||||
@@ -2381,6 +2395,16 @@ int mlx_load_functions(void* handle) {
|
|||||||
fprintf(stderr, "MLX: Failed to load symbol: mlx_hadamard_transform\n");
|
fprintf(stderr, "MLX: Failed to load symbol: mlx_hadamard_transform\n");
|
||||||
return -1;
|
return -1;
|
||||||
}
|
}
|
||||||
|
mlx_hamming_ptr = GET_SYM(handle, "mlx_hamming");
|
||||||
|
if (mlx_hamming_ptr == NULL) {
|
||||||
|
fprintf(stderr, "MLX: Failed to load symbol: mlx_hamming\n");
|
||||||
|
return -1;
|
||||||
|
}
|
||||||
|
mlx_hanning_ptr = GET_SYM(handle, "mlx_hanning");
|
||||||
|
if (mlx_hanning_ptr == NULL) {
|
||||||
|
fprintf(stderr, "MLX: Failed to load symbol: mlx_hanning\n");
|
||||||
|
return -1;
|
||||||
|
}
|
||||||
mlx_identity_ptr = GET_SYM(handle, "mlx_identity");
|
mlx_identity_ptr = GET_SYM(handle, "mlx_identity");
|
||||||
if (mlx_identity_ptr == NULL) {
|
if (mlx_identity_ptr == NULL) {
|
||||||
fprintf(stderr, "MLX: Failed to load symbol: mlx_identity\n");
|
fprintf(stderr, "MLX: Failed to load symbol: mlx_identity\n");
|
||||||
@@ -4132,12 +4156,12 @@ mlx_distributed_group mlx_distributed_group_split(mlx_distributed_group group, i
|
|||||||
return mlx_distributed_group_split_ptr(group, color, key);
|
return mlx_distributed_group_split_ptr(group, color, key);
|
||||||
}
|
}
|
||||||
|
|
||||||
bool mlx_distributed_is_available(void) {
|
bool mlx_distributed_is_available(const char* bk) {
|
||||||
return mlx_distributed_is_available_ptr();
|
return mlx_distributed_is_available_ptr(bk);
|
||||||
}
|
}
|
||||||
|
|
||||||
mlx_distributed_group mlx_distributed_init(bool strict) {
|
mlx_distributed_group mlx_distributed_init(bool strict, const char* bk) {
|
||||||
return mlx_distributed_init_ptr(strict);
|
return mlx_distributed_init_ptr(strict, bk);
|
||||||
}
|
}
|
||||||
|
|
||||||
void mlx_set_error_handler(mlx_error_handler_func handler, void* data, void (*dtor)(void*)) {
|
void mlx_set_error_handler(mlx_error_handler_func handler, void* data, void (*dtor)(void*)) {
|
||||||
@@ -4748,6 +4772,10 @@ int mlx_atleast_3d(mlx_array* res, const mlx_array a, const mlx_stream s) {
|
|||||||
return mlx_atleast_3d_ptr(res, a, s);
|
return mlx_atleast_3d_ptr(res, a, s);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
int mlx_bartlett(mlx_array* res, int M, const mlx_stream s) {
|
||||||
|
return mlx_bartlett_ptr(res, M, s);
|
||||||
|
}
|
||||||
|
|
||||||
int mlx_bitwise_and(mlx_array* res, const mlx_array a, const mlx_array b, const mlx_stream s) {
|
int mlx_bitwise_and(mlx_array* res, const mlx_array a, const mlx_array b, const mlx_stream s) {
|
||||||
return mlx_bitwise_and_ptr(res, a, b, s);
|
return mlx_bitwise_and_ptr(res, a, b, s);
|
||||||
}
|
}
|
||||||
@@ -4764,6 +4792,10 @@ int mlx_bitwise_xor(mlx_array* res, const mlx_array a, const mlx_array b, const
|
|||||||
return mlx_bitwise_xor_ptr(res, a, b, s);
|
return mlx_bitwise_xor_ptr(res, a, b, s);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
int mlx_blackman(mlx_array* res, int M, const mlx_stream s) {
|
||||||
|
return mlx_blackman_ptr(res, M, s);
|
||||||
|
}
|
||||||
|
|
||||||
int mlx_block_masked_mm(mlx_array* res, const mlx_array a, const mlx_array b, int block_size, const mlx_array mask_out , const mlx_array mask_lhs , const mlx_array mask_rhs , const mlx_stream s) {
|
int mlx_block_masked_mm(mlx_array* res, const mlx_array a, const mlx_array b, int block_size, const mlx_array mask_out , const mlx_array mask_lhs , const mlx_array mask_rhs , const mlx_stream s) {
|
||||||
return mlx_block_masked_mm_ptr(res, a, b, block_size, mask_out, mask_lhs, mask_rhs, s);
|
return mlx_block_masked_mm_ptr(res, a, b, block_size, mask_out, mask_lhs, mask_rhs, s);
|
||||||
}
|
}
|
||||||
@@ -4864,8 +4896,8 @@ int mlx_depends(mlx_vector_array* res, const mlx_vector_array inputs, const mlx_
|
|||||||
return mlx_depends_ptr(res, inputs, dependencies);
|
return mlx_depends_ptr(res, inputs, dependencies);
|
||||||
}
|
}
|
||||||
|
|
||||||
int mlx_dequantize(mlx_array* res, const mlx_array w, const mlx_array scales, const mlx_array biases , mlx_optional_int group_size, mlx_optional_int bits, const char* mode, mlx_optional_dtype dtype, const mlx_stream s) {
|
int mlx_dequantize(mlx_array* res, const mlx_array w, const mlx_array scales, const mlx_array biases , mlx_optional_int group_size, mlx_optional_int bits, const char* mode, const mlx_array global_scale , mlx_optional_dtype dtype, const mlx_stream s) {
|
||||||
return mlx_dequantize_ptr(res, w, scales, biases, group_size, bits, mode, dtype, s);
|
return mlx_dequantize_ptr(res, w, scales, biases, group_size, bits, mode, global_scale, dtype, s);
|
||||||
}
|
}
|
||||||
|
|
||||||
int mlx_diag(mlx_array* res, const mlx_array a, int k, const mlx_stream s) {
|
int mlx_diag(mlx_array* res, const mlx_array a, int k, const mlx_stream s) {
|
||||||
@@ -4972,6 +5004,14 @@ int mlx_hadamard_transform(mlx_array* res, const mlx_array a, mlx_optional_float
|
|||||||
return mlx_hadamard_transform_ptr(res, a, scale, s);
|
return mlx_hadamard_transform_ptr(res, a, scale, s);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
int mlx_hamming(mlx_array* res, int M, const mlx_stream s) {
|
||||||
|
return mlx_hamming_ptr(res, M, s);
|
||||||
|
}
|
||||||
|
|
||||||
|
int mlx_hanning(mlx_array* res, int M, const mlx_stream s) {
|
||||||
|
return mlx_hanning_ptr(res, M, s);
|
||||||
|
}
|
||||||
|
|
||||||
int mlx_identity(mlx_array* res, int n, mlx_dtype dtype, const mlx_stream s) {
|
int mlx_identity(mlx_array* res, int n, mlx_dtype dtype, const mlx_stream s) {
|
||||||
return mlx_identity_ptr(res, n, dtype, s);
|
return mlx_identity_ptr(res, n, dtype, s);
|
||||||
}
|
}
|
||||||
@@ -5208,12 +5248,12 @@ int mlx_put_along_axis(mlx_array* res, const mlx_array a, const mlx_array indice
|
|||||||
return mlx_put_along_axis_ptr(res, a, indices, values, axis, s);
|
return mlx_put_along_axis_ptr(res, a, indices, values, axis, s);
|
||||||
}
|
}
|
||||||
|
|
||||||
int mlx_qqmm(mlx_array* res, const mlx_array x, const mlx_array w, const mlx_array w_scales , mlx_optional_int group_size, mlx_optional_int bits, const char* mode, const mlx_stream s) {
|
int mlx_qqmm(mlx_array* res, const mlx_array x, const mlx_array w, const mlx_array w_scales , mlx_optional_int group_size, mlx_optional_int bits, const char* mode, const mlx_array global_scale_x , const mlx_array global_scale_w , const mlx_stream s) {
|
||||||
return mlx_qqmm_ptr(res, x, w, w_scales, group_size, bits, mode, s);
|
return mlx_qqmm_ptr(res, x, w, w_scales, group_size, bits, mode, global_scale_x, global_scale_w, s);
|
||||||
}
|
}
|
||||||
|
|
||||||
int mlx_quantize(mlx_vector_array* res, const mlx_array w, mlx_optional_int group_size, mlx_optional_int bits, const char* mode, const mlx_stream s) {
|
int mlx_quantize(mlx_vector_array* res, const mlx_array w, mlx_optional_int group_size, mlx_optional_int bits, const char* mode, const mlx_array global_scale , const mlx_stream s) {
|
||||||
return mlx_quantize_ptr(res, w, group_size, bits, mode, s);
|
return mlx_quantize_ptr(res, w, group_size, bits, mode, global_scale, s);
|
||||||
}
|
}
|
||||||
|
|
||||||
int mlx_quantized_matmul(mlx_array* res, const mlx_array x, const mlx_array w, const mlx_array scales, const mlx_array biases , bool transpose, mlx_optional_int group_size, mlx_optional_int bits, const char* mode, const mlx_stream s) {
|
int mlx_quantized_matmul(mlx_array* res, const mlx_array x, const mlx_array w, const mlx_array scales, const mlx_array biases , bool transpose, mlx_optional_int group_size, mlx_optional_int bits, const char* mode, const mlx_stream s) {
|
||||||
|
|||||||
@@ -2125,7 +2125,8 @@ func Quantize(w *Array, groupSize, bits int, mode string) (weights, scales, bias
|
|||||||
optGroupSize := C.mlx_optional_int{value: C.int(groupSize), has_value: true}
|
optGroupSize := C.mlx_optional_int{value: C.int(groupSize), has_value: true}
|
||||||
optBits := C.mlx_optional_int{value: C.int(bits), has_value: true}
|
optBits := C.mlx_optional_int{value: C.int(bits), has_value: true}
|
||||||
res := C.mlx_vector_array_new()
|
res := C.mlx_vector_array_new()
|
||||||
C.mlx_quantize(&res, w.c, optGroupSize, optBits, cMode, C.default_stream())
|
var globalScale C.mlx_array
|
||||||
|
C.mlx_quantize(&res, w.c, optGroupSize, optBits, cMode, globalScale, C.default_stream())
|
||||||
|
|
||||||
// Result is a vector of arrays: [weights, scales, biases?]
|
// Result is a vector of arrays: [weights, scales, biases?]
|
||||||
// mxfp8 mode returns only 2 elements (no biases)
|
// mxfp8 mode returns only 2 elements (no biases)
|
||||||
@@ -2161,7 +2162,8 @@ func Dequantize(w, scales, biases *Array, groupSize, bits int, mode string) *Arr
|
|||||||
}
|
}
|
||||||
|
|
||||||
res := C.mlx_array_new()
|
res := C.mlx_array_new()
|
||||||
C.mlx_dequantize(&res, w.c, scales.c, b, optGroupSize, optBits, cMode, optDtype, C.default_stream())
|
var globalScale C.mlx_array
|
||||||
|
C.mlx_dequantize(&res, w.c, scales.c, b, optGroupSize, optBits, cMode, globalScale, optDtype, C.default_stream())
|
||||||
return newArray(res)
|
return newArray(res)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -309,10 +309,12 @@
|
|||||||
#undef mlx_atleast_1d
|
#undef mlx_atleast_1d
|
||||||
#undef mlx_atleast_2d
|
#undef mlx_atleast_2d
|
||||||
#undef mlx_atleast_3d
|
#undef mlx_atleast_3d
|
||||||
|
#undef mlx_bartlett
|
||||||
#undef mlx_bitwise_and
|
#undef mlx_bitwise_and
|
||||||
#undef mlx_bitwise_invert
|
#undef mlx_bitwise_invert
|
||||||
#undef mlx_bitwise_or
|
#undef mlx_bitwise_or
|
||||||
#undef mlx_bitwise_xor
|
#undef mlx_bitwise_xor
|
||||||
|
#undef mlx_blackman
|
||||||
#undef mlx_block_masked_mm
|
#undef mlx_block_masked_mm
|
||||||
#undef mlx_broadcast_arrays
|
#undef mlx_broadcast_arrays
|
||||||
#undef mlx_broadcast_to
|
#undef mlx_broadcast_to
|
||||||
@@ -365,6 +367,8 @@
|
|||||||
#undef mlx_greater
|
#undef mlx_greater
|
||||||
#undef mlx_greater_equal
|
#undef mlx_greater_equal
|
||||||
#undef mlx_hadamard_transform
|
#undef mlx_hadamard_transform
|
||||||
|
#undef mlx_hamming
|
||||||
|
#undef mlx_hanning
|
||||||
#undef mlx_identity
|
#undef mlx_identity
|
||||||
#undef mlx_imag
|
#undef mlx_imag
|
||||||
#undef mlx_inner
|
#undef mlx_inner
|
||||||
@@ -751,8 +755,8 @@ extern int (*mlx_distributed_sum_scatter_ptr)(mlx_array* res, const mlx_array x,
|
|||||||
extern int (*mlx_distributed_group_rank_ptr)(mlx_distributed_group group);
|
extern int (*mlx_distributed_group_rank_ptr)(mlx_distributed_group group);
|
||||||
extern int (*mlx_distributed_group_size_ptr)(mlx_distributed_group group);
|
extern int (*mlx_distributed_group_size_ptr)(mlx_distributed_group group);
|
||||||
extern mlx_distributed_group (*mlx_distributed_group_split_ptr)(mlx_distributed_group group, int color, int key);
|
extern mlx_distributed_group (*mlx_distributed_group_split_ptr)(mlx_distributed_group group, int color, int key);
|
||||||
extern bool (*mlx_distributed_is_available_ptr)(void);
|
extern bool (*mlx_distributed_is_available_ptr)(const char* bk);
|
||||||
extern mlx_distributed_group (*mlx_distributed_init_ptr)(bool strict);
|
extern mlx_distributed_group (*mlx_distributed_init_ptr)(bool strict, const char* bk);
|
||||||
extern void (*mlx_set_error_handler_ptr)(mlx_error_handler_func handler, void* data, void (*dtor)(void*));
|
extern void (*mlx_set_error_handler_ptr)(mlx_error_handler_func handler, void* data, void (*dtor)(void*));
|
||||||
extern void (*_mlx_error_ptr)(const char* file, const int line, const char* fmt, ...);
|
extern void (*_mlx_error_ptr)(const char* file, const int line, const char* fmt, ...);
|
||||||
extern int (*mlx_export_function_ptr)(const char* file, const mlx_closure fun, const mlx_vector_array args, bool shapeless);
|
extern int (*mlx_export_function_ptr)(const char* file, const mlx_closure fun, const mlx_vector_array args, bool shapeless);
|
||||||
@@ -905,10 +909,12 @@ extern int (*mlx_astype_ptr)(mlx_array* res, const mlx_array a, mlx_dtype dtype,
|
|||||||
extern int (*mlx_atleast_1d_ptr)(mlx_array* res, const mlx_array a, const mlx_stream s);
|
extern int (*mlx_atleast_1d_ptr)(mlx_array* res, const mlx_array a, const mlx_stream s);
|
||||||
extern int (*mlx_atleast_2d_ptr)(mlx_array* res, const mlx_array a, const mlx_stream s);
|
extern int (*mlx_atleast_2d_ptr)(mlx_array* res, const mlx_array a, const mlx_stream s);
|
||||||
extern int (*mlx_atleast_3d_ptr)(mlx_array* res, const mlx_array a, const mlx_stream s);
|
extern int (*mlx_atleast_3d_ptr)(mlx_array* res, const mlx_array a, const mlx_stream s);
|
||||||
|
extern int (*mlx_bartlett_ptr)(mlx_array* res, int M, const mlx_stream s);
|
||||||
extern int (*mlx_bitwise_and_ptr)(mlx_array* res, const mlx_array a, const mlx_array b, const mlx_stream s);
|
extern int (*mlx_bitwise_and_ptr)(mlx_array* res, const mlx_array a, const mlx_array b, const mlx_stream s);
|
||||||
extern int (*mlx_bitwise_invert_ptr)(mlx_array* res, const mlx_array a, const mlx_stream s);
|
extern int (*mlx_bitwise_invert_ptr)(mlx_array* res, const mlx_array a, const mlx_stream s);
|
||||||
extern int (*mlx_bitwise_or_ptr)(mlx_array* res, const mlx_array a, const mlx_array b, const mlx_stream s);
|
extern int (*mlx_bitwise_or_ptr)(mlx_array* res, const mlx_array a, const mlx_array b, const mlx_stream s);
|
||||||
extern int (*mlx_bitwise_xor_ptr)(mlx_array* res, const mlx_array a, const mlx_array b, const mlx_stream s);
|
extern int (*mlx_bitwise_xor_ptr)(mlx_array* res, const mlx_array a, const mlx_array b, const mlx_stream s);
|
||||||
|
extern int (*mlx_blackman_ptr)(mlx_array* res, int M, const mlx_stream s);
|
||||||
extern int (*mlx_block_masked_mm_ptr)(mlx_array* res, const mlx_array a, const mlx_array b, int block_size, const mlx_array mask_out , const mlx_array mask_lhs , const mlx_array mask_rhs , const mlx_stream s);
|
extern int (*mlx_block_masked_mm_ptr)(mlx_array* res, const mlx_array a, const mlx_array b, int block_size, const mlx_array mask_out , const mlx_array mask_lhs , const mlx_array mask_rhs , const mlx_stream s);
|
||||||
extern int (*mlx_broadcast_arrays_ptr)(mlx_vector_array* res, const mlx_vector_array inputs, const mlx_stream s);
|
extern int (*mlx_broadcast_arrays_ptr)(mlx_vector_array* res, const mlx_vector_array inputs, const mlx_stream s);
|
||||||
extern int (*mlx_broadcast_to_ptr)(mlx_array* res, const mlx_array a, const int* shape, size_t shape_num, const mlx_stream s);
|
extern int (*mlx_broadcast_to_ptr)(mlx_array* res, const mlx_array a, const int* shape, size_t shape_num, const mlx_stream s);
|
||||||
@@ -934,7 +940,7 @@ extern int (*mlx_cumprod_ptr)(mlx_array* res, const mlx_array a, int axis, bool
|
|||||||
extern int (*mlx_cumsum_ptr)(mlx_array* res, const mlx_array a, int axis, bool reverse, bool inclusive, const mlx_stream s);
|
extern int (*mlx_cumsum_ptr)(mlx_array* res, const mlx_array a, int axis, bool reverse, bool inclusive, const mlx_stream s);
|
||||||
extern int (*mlx_degrees_ptr)(mlx_array* res, const mlx_array a, const mlx_stream s);
|
extern int (*mlx_degrees_ptr)(mlx_array* res, const mlx_array a, const mlx_stream s);
|
||||||
extern int (*mlx_depends_ptr)(mlx_vector_array* res, const mlx_vector_array inputs, const mlx_vector_array dependencies);
|
extern int (*mlx_depends_ptr)(mlx_vector_array* res, const mlx_vector_array inputs, const mlx_vector_array dependencies);
|
||||||
extern int (*mlx_dequantize_ptr)(mlx_array* res, const mlx_array w, const mlx_array scales, const mlx_array biases , mlx_optional_int group_size, mlx_optional_int bits, const char* mode, mlx_optional_dtype dtype, const mlx_stream s);
|
extern int (*mlx_dequantize_ptr)(mlx_array* res, const mlx_array w, const mlx_array scales, const mlx_array biases , mlx_optional_int group_size, mlx_optional_int bits, const char* mode, const mlx_array global_scale , mlx_optional_dtype dtype, const mlx_stream s);
|
||||||
extern int (*mlx_diag_ptr)(mlx_array* res, const mlx_array a, int k, const mlx_stream s);
|
extern int (*mlx_diag_ptr)(mlx_array* res, const mlx_array a, int k, const mlx_stream s);
|
||||||
extern int (*mlx_diagonal_ptr)(mlx_array* res, const mlx_array a, int offset, int axis1, int axis2, const mlx_stream s);
|
extern int (*mlx_diagonal_ptr)(mlx_array* res, const mlx_array a, int offset, int axis1, int axis2, const mlx_stream s);
|
||||||
extern int (*mlx_divide_ptr)(mlx_array* res, const mlx_array a, const mlx_array b, const mlx_stream s);
|
extern int (*mlx_divide_ptr)(mlx_array* res, const mlx_array a, const mlx_array b, const mlx_stream s);
|
||||||
@@ -961,6 +967,8 @@ extern int (*mlx_gather_qmm_ptr)(mlx_array* res, const mlx_array x, const mlx_ar
|
|||||||
extern int (*mlx_greater_ptr)(mlx_array* res, const mlx_array a, const mlx_array b, const mlx_stream s);
|
extern int (*mlx_greater_ptr)(mlx_array* res, const mlx_array a, const mlx_array b, const mlx_stream s);
|
||||||
extern int (*mlx_greater_equal_ptr)(mlx_array* res, const mlx_array a, const mlx_array b, const mlx_stream s);
|
extern int (*mlx_greater_equal_ptr)(mlx_array* res, const mlx_array a, const mlx_array b, const mlx_stream s);
|
||||||
extern int (*mlx_hadamard_transform_ptr)(mlx_array* res, const mlx_array a, mlx_optional_float scale, const mlx_stream s);
|
extern int (*mlx_hadamard_transform_ptr)(mlx_array* res, const mlx_array a, mlx_optional_float scale, const mlx_stream s);
|
||||||
|
extern int (*mlx_hamming_ptr)(mlx_array* res, int M, const mlx_stream s);
|
||||||
|
extern int (*mlx_hanning_ptr)(mlx_array* res, int M, const mlx_stream s);
|
||||||
extern int (*mlx_identity_ptr)(mlx_array* res, int n, mlx_dtype dtype, const mlx_stream s);
|
extern int (*mlx_identity_ptr)(mlx_array* res, int n, mlx_dtype dtype, const mlx_stream s);
|
||||||
extern int (*mlx_imag_ptr)(mlx_array* res, const mlx_array a, const mlx_stream s);
|
extern int (*mlx_imag_ptr)(mlx_array* res, const mlx_array a, const mlx_stream s);
|
||||||
extern int (*mlx_inner_ptr)(mlx_array* res, const mlx_array a, const mlx_array b, const mlx_stream s);
|
extern int (*mlx_inner_ptr)(mlx_array* res, const mlx_array a, const mlx_array b, const mlx_stream s);
|
||||||
@@ -1020,8 +1028,8 @@ extern int (*mlx_prod_axes_ptr)(mlx_array* res, const mlx_array a, const int* ax
|
|||||||
extern int (*mlx_prod_axis_ptr)(mlx_array* res, const mlx_array a, int axis, bool keepdims, const mlx_stream s);
|
extern int (*mlx_prod_axis_ptr)(mlx_array* res, const mlx_array a, int axis, bool keepdims, const mlx_stream s);
|
||||||
extern int (*mlx_prod_ptr)(mlx_array* res, const mlx_array a, bool keepdims, const mlx_stream s);
|
extern int (*mlx_prod_ptr)(mlx_array* res, const mlx_array a, bool keepdims, const mlx_stream s);
|
||||||
extern int (*mlx_put_along_axis_ptr)(mlx_array* res, const mlx_array a, const mlx_array indices, const mlx_array values, int axis, const mlx_stream s);
|
extern int (*mlx_put_along_axis_ptr)(mlx_array* res, const mlx_array a, const mlx_array indices, const mlx_array values, int axis, const mlx_stream s);
|
||||||
extern int (*mlx_qqmm_ptr)(mlx_array* res, const mlx_array x, const mlx_array w, const mlx_array w_scales , mlx_optional_int group_size, mlx_optional_int bits, const char* mode, const mlx_stream s);
|
extern int (*mlx_qqmm_ptr)(mlx_array* res, const mlx_array x, const mlx_array w, const mlx_array w_scales , mlx_optional_int group_size, mlx_optional_int bits, const char* mode, const mlx_array global_scale_x , const mlx_array global_scale_w , const mlx_stream s);
|
||||||
extern int (*mlx_quantize_ptr)(mlx_vector_array* res, const mlx_array w, mlx_optional_int group_size, mlx_optional_int bits, const char* mode, const mlx_stream s);
|
extern int (*mlx_quantize_ptr)(mlx_vector_array* res, const mlx_array w, mlx_optional_int group_size, mlx_optional_int bits, const char* mode, const mlx_array global_scale , const mlx_stream s);
|
||||||
extern int (*mlx_quantized_matmul_ptr)(mlx_array* res, const mlx_array x, const mlx_array w, const mlx_array scales, const mlx_array biases , bool transpose, mlx_optional_int group_size, mlx_optional_int bits, const char* mode, const mlx_stream s);
|
extern int (*mlx_quantized_matmul_ptr)(mlx_array* res, const mlx_array x, const mlx_array w, const mlx_array scales, const mlx_array biases , bool transpose, mlx_optional_int group_size, mlx_optional_int bits, const char* mode, const mlx_stream s);
|
||||||
extern int (*mlx_radians_ptr)(mlx_array* res, const mlx_array a, const mlx_stream s);
|
extern int (*mlx_radians_ptr)(mlx_array* res, const mlx_array a, const mlx_stream s);
|
||||||
extern int (*mlx_real_ptr)(mlx_array* res, const mlx_array a, const mlx_stream s);
|
extern int (*mlx_real_ptr)(mlx_array* res, const mlx_array a, const mlx_stream s);
|
||||||
@@ -1492,9 +1500,9 @@ int mlx_distributed_group_size(mlx_distributed_group group);
|
|||||||
|
|
||||||
mlx_distributed_group mlx_distributed_group_split(mlx_distributed_group group, int color, int key);
|
mlx_distributed_group mlx_distributed_group_split(mlx_distributed_group group, int color, int key);
|
||||||
|
|
||||||
bool mlx_distributed_is_available(void);
|
bool mlx_distributed_is_available(const char* bk);
|
||||||
|
|
||||||
mlx_distributed_group mlx_distributed_init(bool strict);
|
mlx_distributed_group mlx_distributed_init(bool strict, const char* bk);
|
||||||
|
|
||||||
void mlx_set_error_handler(mlx_error_handler_func handler, void* data, void (*dtor)(void*));
|
void mlx_set_error_handler(mlx_error_handler_func handler, void* data, void (*dtor)(void*));
|
||||||
|
|
||||||
@@ -1800,6 +1808,8 @@ int mlx_atleast_2d(mlx_array* res, const mlx_array a, const mlx_stream s);
|
|||||||
|
|
||||||
int mlx_atleast_3d(mlx_array* res, const mlx_array a, const mlx_stream s);
|
int mlx_atleast_3d(mlx_array* res, const mlx_array a, const mlx_stream s);
|
||||||
|
|
||||||
|
int mlx_bartlett(mlx_array* res, int M, const mlx_stream s);
|
||||||
|
|
||||||
int mlx_bitwise_and(mlx_array* res, const mlx_array a, const mlx_array b, const mlx_stream s);
|
int mlx_bitwise_and(mlx_array* res, const mlx_array a, const mlx_array b, const mlx_stream s);
|
||||||
|
|
||||||
int mlx_bitwise_invert(mlx_array* res, const mlx_array a, const mlx_stream s);
|
int mlx_bitwise_invert(mlx_array* res, const mlx_array a, const mlx_stream s);
|
||||||
@@ -1808,6 +1818,8 @@ int mlx_bitwise_or(mlx_array* res, const mlx_array a, const mlx_array b, const m
|
|||||||
|
|
||||||
int mlx_bitwise_xor(mlx_array* res, const mlx_array a, const mlx_array b, const mlx_stream s);
|
int mlx_bitwise_xor(mlx_array* res, const mlx_array a, const mlx_array b, const mlx_stream s);
|
||||||
|
|
||||||
|
int mlx_blackman(mlx_array* res, int M, const mlx_stream s);
|
||||||
|
|
||||||
int mlx_block_masked_mm(mlx_array* res, const mlx_array a, const mlx_array b, int block_size, const mlx_array mask_out , const mlx_array mask_lhs , const mlx_array mask_rhs , const mlx_stream s);
|
int mlx_block_masked_mm(mlx_array* res, const mlx_array a, const mlx_array b, int block_size, const mlx_array mask_out , const mlx_array mask_lhs , const mlx_array mask_rhs , const mlx_stream s);
|
||||||
|
|
||||||
int mlx_broadcast_arrays(mlx_vector_array* res, const mlx_vector_array inputs, const mlx_stream s);
|
int mlx_broadcast_arrays(mlx_vector_array* res, const mlx_vector_array inputs, const mlx_stream s);
|
||||||
@@ -1858,7 +1870,7 @@ int mlx_degrees(mlx_array* res, const mlx_array a, const mlx_stream s);
|
|||||||
|
|
||||||
int mlx_depends(mlx_vector_array* res, const mlx_vector_array inputs, const mlx_vector_array dependencies);
|
int mlx_depends(mlx_vector_array* res, const mlx_vector_array inputs, const mlx_vector_array dependencies);
|
||||||
|
|
||||||
int mlx_dequantize(mlx_array* res, const mlx_array w, const mlx_array scales, const mlx_array biases , mlx_optional_int group_size, mlx_optional_int bits, const char* mode, mlx_optional_dtype dtype, const mlx_stream s);
|
int mlx_dequantize(mlx_array* res, const mlx_array w, const mlx_array scales, const mlx_array biases , mlx_optional_int group_size, mlx_optional_int bits, const char* mode, const mlx_array global_scale , mlx_optional_dtype dtype, const mlx_stream s);
|
||||||
|
|
||||||
int mlx_diag(mlx_array* res, const mlx_array a, int k, const mlx_stream s);
|
int mlx_diag(mlx_array* res, const mlx_array a, int k, const mlx_stream s);
|
||||||
|
|
||||||
@@ -1912,6 +1924,10 @@ int mlx_greater_equal(mlx_array* res, const mlx_array a, const mlx_array b, cons
|
|||||||
|
|
||||||
int mlx_hadamard_transform(mlx_array* res, const mlx_array a, mlx_optional_float scale, const mlx_stream s);
|
int mlx_hadamard_transform(mlx_array* res, const mlx_array a, mlx_optional_float scale, const mlx_stream s);
|
||||||
|
|
||||||
|
int mlx_hamming(mlx_array* res, int M, const mlx_stream s);
|
||||||
|
|
||||||
|
int mlx_hanning(mlx_array* res, int M, const mlx_stream s);
|
||||||
|
|
||||||
int mlx_identity(mlx_array* res, int n, mlx_dtype dtype, const mlx_stream s);
|
int mlx_identity(mlx_array* res, int n, mlx_dtype dtype, const mlx_stream s);
|
||||||
|
|
||||||
int mlx_imag(mlx_array* res, const mlx_array a, const mlx_stream s);
|
int mlx_imag(mlx_array* res, const mlx_array a, const mlx_stream s);
|
||||||
@@ -2030,9 +2046,9 @@ int mlx_prod(mlx_array* res, const mlx_array a, bool keepdims, const mlx_stream
|
|||||||
|
|
||||||
int mlx_put_along_axis(mlx_array* res, const mlx_array a, const mlx_array indices, const mlx_array values, int axis, const mlx_stream s);
|
int mlx_put_along_axis(mlx_array* res, const mlx_array a, const mlx_array indices, const mlx_array values, int axis, const mlx_stream s);
|
||||||
|
|
||||||
int mlx_qqmm(mlx_array* res, const mlx_array x, const mlx_array w, const mlx_array w_scales , mlx_optional_int group_size, mlx_optional_int bits, const char* mode, const mlx_stream s);
|
int mlx_qqmm(mlx_array* res, const mlx_array x, const mlx_array w, const mlx_array w_scales , mlx_optional_int group_size, mlx_optional_int bits, const char* mode, const mlx_array global_scale_x , const mlx_array global_scale_w , const mlx_stream s);
|
||||||
|
|
||||||
int mlx_quantize(mlx_vector_array* res, const mlx_array w, mlx_optional_int group_size, mlx_optional_int bits, const char* mode, const mlx_stream s);
|
int mlx_quantize(mlx_vector_array* res, const mlx_array w, mlx_optional_int group_size, mlx_optional_int bits, const char* mode, const mlx_array global_scale , const mlx_stream s);
|
||||||
|
|
||||||
int mlx_quantized_matmul(mlx_array* res, const mlx_array x, const mlx_array w, const mlx_array scales, const mlx_array biases , bool transpose, mlx_optional_int group_size, mlx_optional_int bits, const char* mode, const mlx_stream s);
|
int mlx_quantized_matmul(mlx_array* res, const mlx_array x, const mlx_array w, const mlx_array scales, const mlx_array biases , bool transpose, mlx_optional_int group_size, mlx_optional_int bits, const char* mode, const mlx_stream s);
|
||||||
|
|
||||||
|
|||||||
@@ -230,6 +230,9 @@ func (c *Client) Completion(ctx context.Context, req llm.CompletionRequest, fn f
|
|||||||
|
|
||||||
resp, err := c.client.Do(httpReq)
|
resp, err := c.client.Do(httpReq)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
if errMsg := c.status.getLastErr(); errMsg != "" {
|
||||||
|
return fmt.Errorf("mlx runner failed: %s", errMsg)
|
||||||
|
}
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
defer resp.Body.Close()
|
defer resp.Body.Close()
|
||||||
@@ -267,7 +270,13 @@ func (c *Client) Completion(ctx context.Context, req llm.CompletionRequest, fn f
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
return scanner.Err()
|
if err := scanner.Err(); err != nil {
|
||||||
|
if errMsg := c.status.getLastErr(); errMsg != "" {
|
||||||
|
return fmt.Errorf("mlx runner failed: %s", errMsg)
|
||||||
|
}
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *Client) ContextLength() int {
|
func (c *Client) ContextLength() int {
|
||||||
|
|||||||
@@ -15,7 +15,9 @@ set(CMAKE_INSTALL_RPATH "@loader_path")
|
|||||||
|
|
||||||
include(FetchContent)
|
include(FetchContent)
|
||||||
|
|
||||||
set(MLX_C_GIT_TAG "v0.5.0" CACHE STRING "")
|
# Read MLX-C version from top-level file (shared with imagegen CMakeLists)
|
||||||
|
file(READ "${CMAKE_SOURCE_DIR}/MLX_C_VERSION" MLX_C_GIT_TAG)
|
||||||
|
string(STRIP "${MLX_C_GIT_TAG}" MLX_C_GIT_TAG)
|
||||||
|
|
||||||
FetchContent_Declare(
|
FetchContent_Declare(
|
||||||
mlx-c
|
mlx-c
|
||||||
|
|||||||
@@ -13,6 +13,10 @@ var (
|
|||||||
gatedDeltaMetalKernelOnce sync.Once
|
gatedDeltaMetalKernelOnce sync.Once
|
||||||
gatedDeltaMetalKernel C.mlx_fast_metal_kernel
|
gatedDeltaMetalKernel C.mlx_fast_metal_kernel
|
||||||
gatedDeltaMetalDisabled bool
|
gatedDeltaMetalDisabled bool
|
||||||
|
|
||||||
|
gatedDeltaCUDAKernelOnce sync.Once
|
||||||
|
gatedDeltaCUDAKernel C.mlx_fast_cuda_kernel
|
||||||
|
gatedDeltaCUDADisabled bool
|
||||||
)
|
)
|
||||||
|
|
||||||
const gatedDeltaMetalKernelSource = `
|
const gatedDeltaMetalKernelSource = `
|
||||||
@@ -83,6 +87,86 @@ for (int i = 0; i < n_per_t; ++i) {
|
|||||||
}
|
}
|
||||||
`
|
`
|
||||||
|
|
||||||
|
const gatedDeltaCUDAKernelSource = `
|
||||||
|
auto tid_x = threadIdx.x;
|
||||||
|
auto tid_y = threadIdx.y;
|
||||||
|
auto grid_y = blockIdx.y * blockDim.y + tid_y;
|
||||||
|
auto grid_z = blockIdx.z;
|
||||||
|
|
||||||
|
int T_val = static_cast<int>(*T);
|
||||||
|
|
||||||
|
auto n = grid_z;
|
||||||
|
auto b_idx = n / Hv;
|
||||||
|
auto hv_idx = n % Hv;
|
||||||
|
auto hk_idx = hv_idx / (Hv / Hk);
|
||||||
|
constexpr int n_per_t = Dk / 32;
|
||||||
|
|
||||||
|
// q, k: [B, T, Hk, Dk]
|
||||||
|
auto q_ = q + b_idx * T_val * Hk * Dk + hk_idx * Dk;
|
||||||
|
auto k_ = k + b_idx * T_val * Hk * Dk + hk_idx * Dk;
|
||||||
|
|
||||||
|
// v, y: [B, T, Hv, Dv]
|
||||||
|
auto dv_idx = grid_y;
|
||||||
|
auto v_ = v + b_idx * T_val * Hv * Dv + hv_idx * Dv;
|
||||||
|
y += b_idx * T_val * Hv * Dv + hv_idx * Dv;
|
||||||
|
|
||||||
|
auto dk_idx = tid_x;
|
||||||
|
|
||||||
|
// state_in, state_out: [B, Hv, Dv, Dk]
|
||||||
|
auto i_state = state_in + (n * Dv + dv_idx) * Dk;
|
||||||
|
auto o_state = state_out + (n * Dv + dv_idx) * Dk;
|
||||||
|
|
||||||
|
float state[n_per_t];
|
||||||
|
for (int i = 0; i < n_per_t; ++i) {
|
||||||
|
auto s_idx = n_per_t * dk_idx + i;
|
||||||
|
state[i] = static_cast<float>(i_state[s_idx]);
|
||||||
|
}
|
||||||
|
|
||||||
|
// g: [B, T, Hv]
|
||||||
|
auto g_ = g + b_idx * T_val * Hv;
|
||||||
|
auto beta_ = beta + b_idx * T_val * Hv;
|
||||||
|
|
||||||
|
for (int t = 0; t < T_val; ++t) {
|
||||||
|
float kv_mem = 0.0f;
|
||||||
|
for (int i = 0; i < n_per_t; ++i) {
|
||||||
|
auto s_idx = n_per_t * dk_idx + i;
|
||||||
|
state[i] = state[i] * static_cast<float>(g_[hv_idx]);
|
||||||
|
kv_mem += state[i] * static_cast<float>(k_[s_idx]);
|
||||||
|
}
|
||||||
|
// Warp reduction (full warp, 32 threads in x)
|
||||||
|
for (int offset = 16; offset > 0; offset >>= 1)
|
||||||
|
kv_mem += __shfl_down_sync(0xffffffff, kv_mem, offset);
|
||||||
|
kv_mem = __shfl_sync(0xffffffff, kv_mem, 0);
|
||||||
|
|
||||||
|
auto delta = (static_cast<float>(v_[dv_idx]) - kv_mem) * static_cast<float>(beta_[hv_idx]);
|
||||||
|
|
||||||
|
float out = 0.0f;
|
||||||
|
for (int i = 0; i < n_per_t; ++i) {
|
||||||
|
auto s_idx = n_per_t * dk_idx + i;
|
||||||
|
state[i] = state[i] + static_cast<float>(k_[s_idx]) * delta;
|
||||||
|
out += state[i] * static_cast<float>(q_[s_idx]);
|
||||||
|
}
|
||||||
|
// Warp reduction
|
||||||
|
for (int offset = 16; offset > 0; offset >>= 1)
|
||||||
|
out += __shfl_down_sync(0xffffffff, out, offset);
|
||||||
|
if (tid_x == 0) {
|
||||||
|
y[dv_idx] = static_cast<InT>(out);
|
||||||
|
}
|
||||||
|
|
||||||
|
q_ += Hk * Dk;
|
||||||
|
k_ += Hk * Dk;
|
||||||
|
v_ += Hv * Dv;
|
||||||
|
y += Hv * Dv;
|
||||||
|
g_ += Hv;
|
||||||
|
beta_ += Hv;
|
||||||
|
}
|
||||||
|
|
||||||
|
for (int i = 0; i < n_per_t; ++i) {
|
||||||
|
auto s_idx = n_per_t * dk_idx + i;
|
||||||
|
o_state[s_idx] = static_cast<InT>(state[i]);
|
||||||
|
}
|
||||||
|
`
|
||||||
|
|
||||||
func cStringVector(values []string) (C.mlx_vector_string, func(), bool) {
|
func cStringVector(values []string) (C.mlx_vector_string, func(), bool) {
|
||||||
vec := C.mlx_vector_string_new()
|
vec := C.mlx_vector_string_new()
|
||||||
ok := true
|
ok := true
|
||||||
@@ -352,11 +436,184 @@ func gatedDeltaFallback(q, k, v, g, beta, state *Array) (y, nextState *Array) {
|
|||||||
return Concatenate(outs, 1), nextState
|
return Concatenate(outs, 1), nextState
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func initGatedDeltaCUDAKernel() {
|
||||||
|
var cudaAvail C.bool
|
||||||
|
if C.mlx_cuda_is_available(&cudaAvail) != 0 || !bool(cudaAvail) {
|
||||||
|
gatedDeltaCUDADisabled = true
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
inputs, freeInputs, ok := cStringVector([]string{"q", "k", "v", "g", "beta", "state_in", "T"})
|
||||||
|
if !ok {
|
||||||
|
gatedDeltaCUDADisabled = true
|
||||||
|
freeInputs()
|
||||||
|
return
|
||||||
|
}
|
||||||
|
defer freeInputs()
|
||||||
|
|
||||||
|
outputs, freeOutputs, ok := cStringVector([]string{"y", "state_out"})
|
||||||
|
if !ok {
|
||||||
|
gatedDeltaCUDADisabled = true
|
||||||
|
freeOutputs()
|
||||||
|
return
|
||||||
|
}
|
||||||
|
defer freeOutputs()
|
||||||
|
|
||||||
|
cName := C.CString("gated_delta_step")
|
||||||
|
defer C.free(unsafe.Pointer(cName))
|
||||||
|
cSource := C.CString(gatedDeltaCUDAKernelSource)
|
||||||
|
defer C.free(unsafe.Pointer(cSource))
|
||||||
|
cHeader := C.CString("")
|
||||||
|
defer C.free(unsafe.Pointer(cHeader))
|
||||||
|
|
||||||
|
gatedDeltaCUDAKernel = C.mlx_fast_cuda_kernel_new(
|
||||||
|
cName,
|
||||||
|
inputs,
|
||||||
|
outputs,
|
||||||
|
cSource,
|
||||||
|
cHeader,
|
||||||
|
C.bool(true),
|
||||||
|
C.int(0),
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
func gatedDeltaCUDAKernelApply(q, k, v, g, beta, state *Array) (y, nextState *Array, ok bool) {
|
||||||
|
if gatedDeltaCUDADisabled {
|
||||||
|
return nil, nil, false
|
||||||
|
}
|
||||||
|
if q == nil || k == nil || v == nil || g == nil || beta == nil || state == nil {
|
||||||
|
return nil, nil, false
|
||||||
|
}
|
||||||
|
|
||||||
|
qd := q.Dims()
|
||||||
|
kd := k.Dims()
|
||||||
|
vd := v.Dims()
|
||||||
|
gd := g.Dims()
|
||||||
|
bd := beta.Dims()
|
||||||
|
sd := state.Dims()
|
||||||
|
if len(qd) != 4 || len(kd) != 4 || len(vd) != 4 || len(gd) != 3 || len(bd) != 3 || len(sd) != 4 {
|
||||||
|
return nil, nil, false
|
||||||
|
}
|
||||||
|
|
||||||
|
B, T, Hk, Dk := qd[0], qd[1], qd[2], qd[3]
|
||||||
|
if T <= 0 || Hk <= 0 || Dk <= 0 || Dk%32 != 0 {
|
||||||
|
return nil, nil, false
|
||||||
|
}
|
||||||
|
if kd[0] != B || kd[1] != T || kd[2] != Hk || kd[3] != Dk {
|
||||||
|
return nil, nil, false
|
||||||
|
}
|
||||||
|
Hv, Dv := vd[2], vd[3]
|
||||||
|
if vd[0] != B || vd[1] != T || Hv <= 0 || Dv <= 0 || Hv%Hk != 0 {
|
||||||
|
return nil, nil, false
|
||||||
|
}
|
||||||
|
if gd[0] != B || gd[1] != T || gd[2] != Hv {
|
||||||
|
return nil, nil, false
|
||||||
|
}
|
||||||
|
if bd[0] != B || bd[1] != T || bd[2] != Hv {
|
||||||
|
return nil, nil, false
|
||||||
|
}
|
||||||
|
if sd[0] != B || sd[1] != Hv || sd[2] != Dv || sd[3] != Dk {
|
||||||
|
return nil, nil, false
|
||||||
|
}
|
||||||
|
|
||||||
|
dtype := q.DType()
|
||||||
|
if k.DType() != dtype || v.DType() != dtype || g.DType() != dtype || beta.DType() != dtype || state.DType() != dtype {
|
||||||
|
return nil, nil, false
|
||||||
|
}
|
||||||
|
|
||||||
|
gatedDeltaCUDAKernelOnce.Do(initGatedDeltaCUDAKernel)
|
||||||
|
if gatedDeltaCUDADisabled {
|
||||||
|
return nil, nil, false
|
||||||
|
}
|
||||||
|
|
||||||
|
cfg := C.mlx_fast_cuda_kernel_config_new()
|
||||||
|
defer C.mlx_fast_cuda_kernel_config_free(cfg)
|
||||||
|
|
||||||
|
cInT := C.CString("InT")
|
||||||
|
defer C.free(unsafe.Pointer(cInT))
|
||||||
|
if C.mlx_fast_cuda_kernel_config_add_template_arg_dtype(cfg, cInT, C.mlx_dtype(dtype)) != 0 {
|
||||||
|
gatedDeltaCUDADisabled = true
|
||||||
|
return nil, nil, false
|
||||||
|
}
|
||||||
|
for _, tpl := range []struct {
|
||||||
|
name string
|
||||||
|
value int
|
||||||
|
}{
|
||||||
|
{name: "Dk", value: Dk},
|
||||||
|
{name: "Dv", value: Dv},
|
||||||
|
{name: "Hk", value: Hk},
|
||||||
|
{name: "Hv", value: Hv},
|
||||||
|
} {
|
||||||
|
cn := C.CString(tpl.name)
|
||||||
|
rc := C.mlx_fast_cuda_kernel_config_add_template_arg_int(cfg, cn, C.int(tpl.value))
|
||||||
|
C.free(unsafe.Pointer(cn))
|
||||||
|
if rc != 0 {
|
||||||
|
gatedDeltaCUDADisabled = true
|
||||||
|
return nil, nil, false
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
yShape := []C.int{C.int(B), C.int(T), C.int(Hv), C.int(Dv)}
|
||||||
|
stateShape := []C.int{C.int(B), C.int(Hv), C.int(Dv), C.int(Dk)}
|
||||||
|
if C.mlx_fast_cuda_kernel_config_add_output_arg(cfg, unsafe.SliceData(yShape), C.size_t(len(yShape)), C.mlx_dtype(dtype)) != 0 {
|
||||||
|
gatedDeltaCUDADisabled = true
|
||||||
|
return nil, nil, false
|
||||||
|
}
|
||||||
|
if C.mlx_fast_cuda_kernel_config_add_output_arg(cfg, unsafe.SliceData(stateShape), C.size_t(len(stateShape)), C.mlx_dtype(dtype)) != 0 {
|
||||||
|
gatedDeltaCUDADisabled = true
|
||||||
|
return nil, nil, false
|
||||||
|
}
|
||||||
|
if C.mlx_fast_cuda_kernel_config_set_grid(cfg, 32, C.int(Dv), C.int(B*Hv)) != 0 {
|
||||||
|
gatedDeltaCUDADisabled = true
|
||||||
|
return nil, nil, false
|
||||||
|
}
|
||||||
|
threadY := Dv
|
||||||
|
if threadY > 4 {
|
||||||
|
threadY = 4
|
||||||
|
}
|
||||||
|
if C.mlx_fast_cuda_kernel_config_set_thread_group(cfg, 32, C.int(threadY), 1) != 0 {
|
||||||
|
gatedDeltaCUDADisabled = true
|
||||||
|
return nil, nil, false
|
||||||
|
}
|
||||||
|
|
||||||
|
tScalar := FromValue(T)
|
||||||
|
inputs := []C.mlx_array{
|
||||||
|
q.ctx,
|
||||||
|
k.ctx,
|
||||||
|
v.ctx,
|
||||||
|
g.ctx,
|
||||||
|
beta.ctx,
|
||||||
|
state.ctx,
|
||||||
|
tScalar.ctx,
|
||||||
|
}
|
||||||
|
inVec := C.mlx_vector_array_new_data(unsafe.SliceData(inputs), C.size_t(len(inputs)))
|
||||||
|
defer C.mlx_vector_array_free(inVec)
|
||||||
|
|
||||||
|
outVec := C.mlx_vector_array_new()
|
||||||
|
defer C.mlx_vector_array_free(outVec)
|
||||||
|
if C.mlx_fast_cuda_kernel_apply(&outVec, gatedDeltaCUDAKernel, inVec, cfg, DefaultStream().ctx) != 0 {
|
||||||
|
gatedDeltaCUDADisabled = true
|
||||||
|
return nil, nil, false
|
||||||
|
}
|
||||||
|
if int(C.mlx_vector_array_size(outVec)) < 2 {
|
||||||
|
return nil, nil, false
|
||||||
|
}
|
||||||
|
|
||||||
|
y = New("GATED_DELTA_CUDA_Y")
|
||||||
|
nextState = New("GATED_DELTA_CUDA_STATE")
|
||||||
|
C.mlx_vector_array_get(&y.ctx, outVec, 0)
|
||||||
|
C.mlx_vector_array_get(&nextState.ctx, outVec, 1)
|
||||||
|
return y, nextState, true
|
||||||
|
}
|
||||||
|
|
||||||
// GatedDelta runs the recurrent update operation.
|
// GatedDelta runs the recurrent update operation.
|
||||||
//
|
//
|
||||||
// It uses the fused Metal kernel when available and otherwise falls back to a
|
// It tries the fused CUDA kernel first, then Metal, then falls back to a
|
||||||
// backend-agnostic MLX implementation with identical inputs/outputs.
|
// backend-agnostic MLX implementation with identical inputs/outputs.
|
||||||
func GatedDelta(q, k, v, g, beta, state *Array) (y, nextState *Array) {
|
func GatedDelta(q, k, v, g, beta, state *Array) (y, nextState *Array) {
|
||||||
|
if y, nextState, ok := gatedDeltaCUDAKernelApply(q, k, v, g, beta, state); ok {
|
||||||
|
return y, nextState
|
||||||
|
}
|
||||||
if y, nextState, ok := gatedDeltaKernel(q, k, v, g, beta, state); ok {
|
if y, nextState, ok := gatedDeltaKernel(q, k, v, g, beta, state); ok {
|
||||||
return y, nextState
|
return y, nextState
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -326,8 +326,10 @@ int (*mlx_distributed_sum_scatter_)(
|
|||||||
int (*mlx_distributed_group_rank_)(mlx_distributed_group group) = NULL;
|
int (*mlx_distributed_group_rank_)(mlx_distributed_group group) = NULL;
|
||||||
int (*mlx_distributed_group_size_)(mlx_distributed_group group) = NULL;
|
int (*mlx_distributed_group_size_)(mlx_distributed_group group) = NULL;
|
||||||
mlx_distributed_group (*mlx_distributed_group_split_)(mlx_distributed_group group, int color, int key) = NULL;
|
mlx_distributed_group (*mlx_distributed_group_split_)(mlx_distributed_group group, int color, int key) = NULL;
|
||||||
bool (*mlx_distributed_is_available_)(void) = NULL;
|
bool (*mlx_distributed_is_available_)(const char* bk /* may be null */) = NULL;
|
||||||
mlx_distributed_group (*mlx_distributed_init_)(bool strict) = NULL;
|
mlx_distributed_group (*mlx_distributed_init_)(
|
||||||
|
bool strict,
|
||||||
|
const char* bk /* may be null */) = NULL;
|
||||||
void (*mlx_set_error_handler_)(
|
void (*mlx_set_error_handler_)(
|
||||||
mlx_error_handler_func handler,
|
mlx_error_handler_func handler,
|
||||||
void* data,
|
void* data,
|
||||||
@@ -924,6 +926,7 @@ int (*mlx_astype_)(
|
|||||||
int (*mlx_atleast_1d_)(mlx_array* res, const mlx_array a, const mlx_stream s) = NULL;
|
int (*mlx_atleast_1d_)(mlx_array* res, const mlx_array a, const mlx_stream s) = NULL;
|
||||||
int (*mlx_atleast_2d_)(mlx_array* res, const mlx_array a, const mlx_stream s) = NULL;
|
int (*mlx_atleast_2d_)(mlx_array* res, const mlx_array a, const mlx_stream s) = NULL;
|
||||||
int (*mlx_atleast_3d_)(mlx_array* res, const mlx_array a, const mlx_stream s) = NULL;
|
int (*mlx_atleast_3d_)(mlx_array* res, const mlx_array a, const mlx_stream s) = NULL;
|
||||||
|
int (*mlx_bartlett_)(mlx_array* res, int M, const mlx_stream s) = NULL;
|
||||||
int (*mlx_bitwise_and_)(
|
int (*mlx_bitwise_and_)(
|
||||||
mlx_array* res,
|
mlx_array* res,
|
||||||
const mlx_array a,
|
const mlx_array a,
|
||||||
@@ -940,6 +943,7 @@ int (*mlx_bitwise_xor_)(
|
|||||||
const mlx_array a,
|
const mlx_array a,
|
||||||
const mlx_array b,
|
const mlx_array b,
|
||||||
const mlx_stream s) = NULL;
|
const mlx_stream s) = NULL;
|
||||||
|
int (*mlx_blackman_)(mlx_array* res, int M, const mlx_stream s) = NULL;
|
||||||
int (*mlx_block_masked_mm_)(
|
int (*mlx_block_masked_mm_)(
|
||||||
mlx_array* res,
|
mlx_array* res,
|
||||||
const mlx_array a,
|
const mlx_array a,
|
||||||
@@ -1120,6 +1124,7 @@ int (*mlx_dequantize_)(
|
|||||||
mlx_optional_int group_size,
|
mlx_optional_int group_size,
|
||||||
mlx_optional_int bits,
|
mlx_optional_int bits,
|
||||||
const char* mode,
|
const char* mode,
|
||||||
|
const mlx_array global_scale /* may be null */,
|
||||||
mlx_optional_dtype dtype,
|
mlx_optional_dtype dtype,
|
||||||
const mlx_stream s) = NULL;
|
const mlx_stream s) = NULL;
|
||||||
int (*mlx_diag_)(mlx_array* res, const mlx_array a, int k, const mlx_stream s) = NULL;
|
int (*mlx_diag_)(mlx_array* res, const mlx_array a, int k, const mlx_stream s) = NULL;
|
||||||
@@ -1256,6 +1261,8 @@ int (*mlx_hadamard_transform_)(
|
|||||||
const mlx_array a,
|
const mlx_array a,
|
||||||
mlx_optional_float scale,
|
mlx_optional_float scale,
|
||||||
const mlx_stream s) = NULL;
|
const mlx_stream s) = NULL;
|
||||||
|
int (*mlx_hamming_)(mlx_array* res, int M, const mlx_stream s) = NULL;
|
||||||
|
int (*mlx_hanning_)(mlx_array* res, int M, const mlx_stream s) = NULL;
|
||||||
int (*mlx_identity_)(mlx_array* res, int n, mlx_dtype dtype, const mlx_stream s) = NULL;
|
int (*mlx_identity_)(mlx_array* res, int n, mlx_dtype dtype, const mlx_stream s) = NULL;
|
||||||
int (*mlx_imag_)(mlx_array* res, const mlx_array a, const mlx_stream s) = NULL;
|
int (*mlx_imag_)(mlx_array* res, const mlx_array a, const mlx_stream s) = NULL;
|
||||||
int (*mlx_inner_)(
|
int (*mlx_inner_)(
|
||||||
@@ -1548,6 +1555,8 @@ int (*mlx_qqmm_)(
|
|||||||
mlx_optional_int group_size,
|
mlx_optional_int group_size,
|
||||||
mlx_optional_int bits,
|
mlx_optional_int bits,
|
||||||
const char* mode,
|
const char* mode,
|
||||||
|
const mlx_array global_scale_x /* may be null */,
|
||||||
|
const mlx_array global_scale_w /* may be null */,
|
||||||
const mlx_stream s) = NULL;
|
const mlx_stream s) = NULL;
|
||||||
int (*mlx_quantize_)(
|
int (*mlx_quantize_)(
|
||||||
mlx_vector_array* res,
|
mlx_vector_array* res,
|
||||||
@@ -1555,6 +1564,7 @@ int (*mlx_quantize_)(
|
|||||||
mlx_optional_int group_size,
|
mlx_optional_int group_size,
|
||||||
mlx_optional_int bits,
|
mlx_optional_int bits,
|
||||||
const char* mode,
|
const char* mode,
|
||||||
|
const mlx_array global_scale /* may be null */,
|
||||||
const mlx_stream s) = NULL;
|
const mlx_stream s) = NULL;
|
||||||
int (*mlx_quantized_matmul_)(
|
int (*mlx_quantized_matmul_)(
|
||||||
mlx_array* res,
|
mlx_array* res,
|
||||||
@@ -2550,10 +2560,12 @@ int mlx_dynamic_load_symbols(mlx_dynamic_handle handle) {
|
|||||||
CHECK_LOAD(handle, mlx_atleast_1d);
|
CHECK_LOAD(handle, mlx_atleast_1d);
|
||||||
CHECK_LOAD(handle, mlx_atleast_2d);
|
CHECK_LOAD(handle, mlx_atleast_2d);
|
||||||
CHECK_LOAD(handle, mlx_atleast_3d);
|
CHECK_LOAD(handle, mlx_atleast_3d);
|
||||||
|
CHECK_LOAD(handle, mlx_bartlett);
|
||||||
CHECK_LOAD(handle, mlx_bitwise_and);
|
CHECK_LOAD(handle, mlx_bitwise_and);
|
||||||
CHECK_LOAD(handle, mlx_bitwise_invert);
|
CHECK_LOAD(handle, mlx_bitwise_invert);
|
||||||
CHECK_LOAD(handle, mlx_bitwise_or);
|
CHECK_LOAD(handle, mlx_bitwise_or);
|
||||||
CHECK_LOAD(handle, mlx_bitwise_xor);
|
CHECK_LOAD(handle, mlx_bitwise_xor);
|
||||||
|
CHECK_LOAD(handle, mlx_blackman);
|
||||||
CHECK_LOAD(handle, mlx_block_masked_mm);
|
CHECK_LOAD(handle, mlx_block_masked_mm);
|
||||||
CHECK_LOAD(handle, mlx_broadcast_arrays);
|
CHECK_LOAD(handle, mlx_broadcast_arrays);
|
||||||
CHECK_LOAD(handle, mlx_broadcast_to);
|
CHECK_LOAD(handle, mlx_broadcast_to);
|
||||||
@@ -2606,6 +2618,8 @@ int mlx_dynamic_load_symbols(mlx_dynamic_handle handle) {
|
|||||||
CHECK_LOAD(handle, mlx_greater);
|
CHECK_LOAD(handle, mlx_greater);
|
||||||
CHECK_LOAD(handle, mlx_greater_equal);
|
CHECK_LOAD(handle, mlx_greater_equal);
|
||||||
CHECK_LOAD(handle, mlx_hadamard_transform);
|
CHECK_LOAD(handle, mlx_hadamard_transform);
|
||||||
|
CHECK_LOAD(handle, mlx_hamming);
|
||||||
|
CHECK_LOAD(handle, mlx_hanning);
|
||||||
CHECK_LOAD(handle, mlx_identity);
|
CHECK_LOAD(handle, mlx_identity);
|
||||||
CHECK_LOAD(handle, mlx_imag);
|
CHECK_LOAD(handle, mlx_imag);
|
||||||
CHECK_LOAD(handle, mlx_inner);
|
CHECK_LOAD(handle, mlx_inner);
|
||||||
|
|||||||
@@ -300,10 +300,12 @@
|
|||||||
#define mlx_atleast_1d mlx_atleast_1d_mlx_gen_orig_
|
#define mlx_atleast_1d mlx_atleast_1d_mlx_gen_orig_
|
||||||
#define mlx_atleast_2d mlx_atleast_2d_mlx_gen_orig_
|
#define mlx_atleast_2d mlx_atleast_2d_mlx_gen_orig_
|
||||||
#define mlx_atleast_3d mlx_atleast_3d_mlx_gen_orig_
|
#define mlx_atleast_3d mlx_atleast_3d_mlx_gen_orig_
|
||||||
|
#define mlx_bartlett mlx_bartlett_mlx_gen_orig_
|
||||||
#define mlx_bitwise_and mlx_bitwise_and_mlx_gen_orig_
|
#define mlx_bitwise_and mlx_bitwise_and_mlx_gen_orig_
|
||||||
#define mlx_bitwise_invert mlx_bitwise_invert_mlx_gen_orig_
|
#define mlx_bitwise_invert mlx_bitwise_invert_mlx_gen_orig_
|
||||||
#define mlx_bitwise_or mlx_bitwise_or_mlx_gen_orig_
|
#define mlx_bitwise_or mlx_bitwise_or_mlx_gen_orig_
|
||||||
#define mlx_bitwise_xor mlx_bitwise_xor_mlx_gen_orig_
|
#define mlx_bitwise_xor mlx_bitwise_xor_mlx_gen_orig_
|
||||||
|
#define mlx_blackman mlx_blackman_mlx_gen_orig_
|
||||||
#define mlx_block_masked_mm mlx_block_masked_mm_mlx_gen_orig_
|
#define mlx_block_masked_mm mlx_block_masked_mm_mlx_gen_orig_
|
||||||
#define mlx_broadcast_arrays mlx_broadcast_arrays_mlx_gen_orig_
|
#define mlx_broadcast_arrays mlx_broadcast_arrays_mlx_gen_orig_
|
||||||
#define mlx_broadcast_to mlx_broadcast_to_mlx_gen_orig_
|
#define mlx_broadcast_to mlx_broadcast_to_mlx_gen_orig_
|
||||||
@@ -356,6 +358,8 @@
|
|||||||
#define mlx_greater mlx_greater_mlx_gen_orig_
|
#define mlx_greater mlx_greater_mlx_gen_orig_
|
||||||
#define mlx_greater_equal mlx_greater_equal_mlx_gen_orig_
|
#define mlx_greater_equal mlx_greater_equal_mlx_gen_orig_
|
||||||
#define mlx_hadamard_transform mlx_hadamard_transform_mlx_gen_orig_
|
#define mlx_hadamard_transform mlx_hadamard_transform_mlx_gen_orig_
|
||||||
|
#define mlx_hamming mlx_hamming_mlx_gen_orig_
|
||||||
|
#define mlx_hanning mlx_hanning_mlx_gen_orig_
|
||||||
#define mlx_identity mlx_identity_mlx_gen_orig_
|
#define mlx_identity mlx_identity_mlx_gen_orig_
|
||||||
#define mlx_imag mlx_imag_mlx_gen_orig_
|
#define mlx_imag mlx_imag_mlx_gen_orig_
|
||||||
#define mlx_inner mlx_inner_mlx_gen_orig_
|
#define mlx_inner mlx_inner_mlx_gen_orig_
|
||||||
@@ -889,10 +893,12 @@
|
|||||||
#undef mlx_atleast_1d
|
#undef mlx_atleast_1d
|
||||||
#undef mlx_atleast_2d
|
#undef mlx_atleast_2d
|
||||||
#undef mlx_atleast_3d
|
#undef mlx_atleast_3d
|
||||||
|
#undef mlx_bartlett
|
||||||
#undef mlx_bitwise_and
|
#undef mlx_bitwise_and
|
||||||
#undef mlx_bitwise_invert
|
#undef mlx_bitwise_invert
|
||||||
#undef mlx_bitwise_or
|
#undef mlx_bitwise_or
|
||||||
#undef mlx_bitwise_xor
|
#undef mlx_bitwise_xor
|
||||||
|
#undef mlx_blackman
|
||||||
#undef mlx_block_masked_mm
|
#undef mlx_block_masked_mm
|
||||||
#undef mlx_broadcast_arrays
|
#undef mlx_broadcast_arrays
|
||||||
#undef mlx_broadcast_to
|
#undef mlx_broadcast_to
|
||||||
@@ -945,6 +951,8 @@
|
|||||||
#undef mlx_greater
|
#undef mlx_greater
|
||||||
#undef mlx_greater_equal
|
#undef mlx_greater_equal
|
||||||
#undef mlx_hadamard_transform
|
#undef mlx_hadamard_transform
|
||||||
|
#undef mlx_hamming
|
||||||
|
#undef mlx_hanning
|
||||||
#undef mlx_identity
|
#undef mlx_identity
|
||||||
#undef mlx_imag
|
#undef mlx_imag
|
||||||
#undef mlx_inner
|
#undef mlx_inner
|
||||||
@@ -1501,8 +1509,10 @@ extern int (*mlx_distributed_sum_scatter_)(
|
|||||||
extern int (*mlx_distributed_group_rank_)(mlx_distributed_group group);
|
extern int (*mlx_distributed_group_rank_)(mlx_distributed_group group);
|
||||||
extern int (*mlx_distributed_group_size_)(mlx_distributed_group group);
|
extern int (*mlx_distributed_group_size_)(mlx_distributed_group group);
|
||||||
extern mlx_distributed_group (*mlx_distributed_group_split_)(mlx_distributed_group group, int color, int key);
|
extern mlx_distributed_group (*mlx_distributed_group_split_)(mlx_distributed_group group, int color, int key);
|
||||||
extern bool (*mlx_distributed_is_available_)(void);
|
extern bool (*mlx_distributed_is_available_)(const char* bk /* may be null */);
|
||||||
extern mlx_distributed_group (*mlx_distributed_init_)(bool strict);
|
extern mlx_distributed_group (*mlx_distributed_init_)(
|
||||||
|
bool strict,
|
||||||
|
const char* bk /* may be null */);
|
||||||
extern void (*mlx_set_error_handler_)(
|
extern void (*mlx_set_error_handler_)(
|
||||||
mlx_error_handler_func handler,
|
mlx_error_handler_func handler,
|
||||||
void* data,
|
void* data,
|
||||||
@@ -2099,6 +2109,7 @@ extern int (*mlx_astype_)(
|
|||||||
extern int (*mlx_atleast_1d_)(mlx_array* res, const mlx_array a, const mlx_stream s);
|
extern int (*mlx_atleast_1d_)(mlx_array* res, const mlx_array a, const mlx_stream s);
|
||||||
extern int (*mlx_atleast_2d_)(mlx_array* res, const mlx_array a, const mlx_stream s);
|
extern int (*mlx_atleast_2d_)(mlx_array* res, const mlx_array a, const mlx_stream s);
|
||||||
extern int (*mlx_atleast_3d_)(mlx_array* res, const mlx_array a, const mlx_stream s);
|
extern int (*mlx_atleast_3d_)(mlx_array* res, const mlx_array a, const mlx_stream s);
|
||||||
|
extern int (*mlx_bartlett_)(mlx_array* res, int M, const mlx_stream s);
|
||||||
extern int (*mlx_bitwise_and_)(
|
extern int (*mlx_bitwise_and_)(
|
||||||
mlx_array* res,
|
mlx_array* res,
|
||||||
const mlx_array a,
|
const mlx_array a,
|
||||||
@@ -2115,6 +2126,7 @@ extern int (*mlx_bitwise_xor_)(
|
|||||||
const mlx_array a,
|
const mlx_array a,
|
||||||
const mlx_array b,
|
const mlx_array b,
|
||||||
const mlx_stream s);
|
const mlx_stream s);
|
||||||
|
extern int (*mlx_blackman_)(mlx_array* res, int M, const mlx_stream s);
|
||||||
extern int (*mlx_block_masked_mm_)(
|
extern int (*mlx_block_masked_mm_)(
|
||||||
mlx_array* res,
|
mlx_array* res,
|
||||||
const mlx_array a,
|
const mlx_array a,
|
||||||
@@ -2295,6 +2307,7 @@ extern int (*mlx_dequantize_)(
|
|||||||
mlx_optional_int group_size,
|
mlx_optional_int group_size,
|
||||||
mlx_optional_int bits,
|
mlx_optional_int bits,
|
||||||
const char* mode,
|
const char* mode,
|
||||||
|
const mlx_array global_scale /* may be null */,
|
||||||
mlx_optional_dtype dtype,
|
mlx_optional_dtype dtype,
|
||||||
const mlx_stream s);
|
const mlx_stream s);
|
||||||
extern int (*mlx_diag_)(mlx_array* res, const mlx_array a, int k, const mlx_stream s);
|
extern int (*mlx_diag_)(mlx_array* res, const mlx_array a, int k, const mlx_stream s);
|
||||||
@@ -2431,6 +2444,8 @@ extern int (*mlx_hadamard_transform_)(
|
|||||||
const mlx_array a,
|
const mlx_array a,
|
||||||
mlx_optional_float scale,
|
mlx_optional_float scale,
|
||||||
const mlx_stream s);
|
const mlx_stream s);
|
||||||
|
extern int (*mlx_hamming_)(mlx_array* res, int M, const mlx_stream s);
|
||||||
|
extern int (*mlx_hanning_)(mlx_array* res, int M, const mlx_stream s);
|
||||||
extern int (*mlx_identity_)(mlx_array* res, int n, mlx_dtype dtype, const mlx_stream s);
|
extern int (*mlx_identity_)(mlx_array* res, int n, mlx_dtype dtype, const mlx_stream s);
|
||||||
extern int (*mlx_imag_)(mlx_array* res, const mlx_array a, const mlx_stream s);
|
extern int (*mlx_imag_)(mlx_array* res, const mlx_array a, const mlx_stream s);
|
||||||
extern int (*mlx_inner_)(
|
extern int (*mlx_inner_)(
|
||||||
@@ -2723,6 +2738,8 @@ extern int (*mlx_qqmm_)(
|
|||||||
mlx_optional_int group_size,
|
mlx_optional_int group_size,
|
||||||
mlx_optional_int bits,
|
mlx_optional_int bits,
|
||||||
const char* mode,
|
const char* mode,
|
||||||
|
const mlx_array global_scale_x /* may be null */,
|
||||||
|
const mlx_array global_scale_w /* may be null */,
|
||||||
const mlx_stream s);
|
const mlx_stream s);
|
||||||
extern int (*mlx_quantize_)(
|
extern int (*mlx_quantize_)(
|
||||||
mlx_vector_array* res,
|
mlx_vector_array* res,
|
||||||
@@ -2730,6 +2747,7 @@ extern int (*mlx_quantize_)(
|
|||||||
mlx_optional_int group_size,
|
mlx_optional_int group_size,
|
||||||
mlx_optional_int bits,
|
mlx_optional_int bits,
|
||||||
const char* mode,
|
const char* mode,
|
||||||
|
const mlx_array global_scale /* may be null */,
|
||||||
const mlx_stream s);
|
const mlx_stream s);
|
||||||
extern int (*mlx_quantized_matmul_)(
|
extern int (*mlx_quantized_matmul_)(
|
||||||
mlx_array* res,
|
mlx_array* res,
|
||||||
@@ -4033,11 +4051,13 @@ static inline int mlx_distributed_group_size(mlx_distributed_group group) {
|
|||||||
static inline mlx_distributed_group mlx_distributed_group_split(mlx_distributed_group group, int color, int key) {
|
static inline mlx_distributed_group mlx_distributed_group_split(mlx_distributed_group group, int color, int key) {
|
||||||
return mlx_distributed_group_split_(group, color, key);
|
return mlx_distributed_group_split_(group, color, key);
|
||||||
}
|
}
|
||||||
static inline bool mlx_distributed_is_available(void) {
|
static inline bool mlx_distributed_is_available(const char* bk /* may be null */) {
|
||||||
return mlx_distributed_is_available_();
|
return mlx_distributed_is_available_(bk);
|
||||||
}
|
}
|
||||||
static inline mlx_distributed_group mlx_distributed_init(bool strict) {
|
static inline mlx_distributed_group mlx_distributed_init(
|
||||||
return mlx_distributed_init_(strict);
|
bool strict,
|
||||||
|
const char* bk /* may be null */) {
|
||||||
|
return mlx_distributed_init_(strict, bk);
|
||||||
}
|
}
|
||||||
static inline void mlx_set_error_handler(
|
static inline void mlx_set_error_handler(
|
||||||
mlx_error_handler_func handler,
|
mlx_error_handler_func handler,
|
||||||
@@ -4939,6 +4959,9 @@ static inline int mlx_atleast_2d(mlx_array* res, const mlx_array a, const mlx_st
|
|||||||
static inline int mlx_atleast_3d(mlx_array* res, const mlx_array a, const mlx_stream s) {
|
static inline int mlx_atleast_3d(mlx_array* res, const mlx_array a, const mlx_stream s) {
|
||||||
return mlx_atleast_3d_(res, a, s);
|
return mlx_atleast_3d_(res, a, s);
|
||||||
}
|
}
|
||||||
|
static inline int mlx_bartlett(mlx_array* res, int M, const mlx_stream s) {
|
||||||
|
return mlx_bartlett_(res, M, s);
|
||||||
|
}
|
||||||
static inline int mlx_bitwise_and(
|
static inline int mlx_bitwise_and(
|
||||||
mlx_array* res,
|
mlx_array* res,
|
||||||
const mlx_array a,
|
const mlx_array a,
|
||||||
@@ -4963,6 +4986,9 @@ static inline int mlx_bitwise_xor(
|
|||||||
const mlx_stream s) {
|
const mlx_stream s) {
|
||||||
return mlx_bitwise_xor_(res, a, b, s);
|
return mlx_bitwise_xor_(res, a, b, s);
|
||||||
}
|
}
|
||||||
|
static inline int mlx_blackman(mlx_array* res, int M, const mlx_stream s) {
|
||||||
|
return mlx_blackman_(res, M, s);
|
||||||
|
}
|
||||||
static inline int mlx_block_masked_mm(
|
static inline int mlx_block_masked_mm(
|
||||||
mlx_array* res,
|
mlx_array* res,
|
||||||
const mlx_array a,
|
const mlx_array a,
|
||||||
@@ -5193,9 +5219,10 @@ static inline int mlx_dequantize(
|
|||||||
mlx_optional_int group_size,
|
mlx_optional_int group_size,
|
||||||
mlx_optional_int bits,
|
mlx_optional_int bits,
|
||||||
const char* mode,
|
const char* mode,
|
||||||
|
const mlx_array global_scale /* may be null */,
|
||||||
mlx_optional_dtype dtype,
|
mlx_optional_dtype dtype,
|
||||||
const mlx_stream s) {
|
const mlx_stream s) {
|
||||||
return mlx_dequantize_(res, w, scales, biases, group_size, bits, mode, dtype, s);
|
return mlx_dequantize_(res, w, scales, biases, group_size, bits, mode, global_scale, dtype, s);
|
||||||
}
|
}
|
||||||
static inline int mlx_diag(mlx_array* res, const mlx_array a, int k, const mlx_stream s) {
|
static inline int mlx_diag(mlx_array* res, const mlx_array a, int k, const mlx_stream s) {
|
||||||
return mlx_diag_(res, a, k, s);
|
return mlx_diag_(res, a, k, s);
|
||||||
@@ -5383,6 +5410,12 @@ static inline int mlx_hadamard_transform(
|
|||||||
const mlx_stream s) {
|
const mlx_stream s) {
|
||||||
return mlx_hadamard_transform_(res, a, scale, s);
|
return mlx_hadamard_transform_(res, a, scale, s);
|
||||||
}
|
}
|
||||||
|
static inline int mlx_hamming(mlx_array* res, int M, const mlx_stream s) {
|
||||||
|
return mlx_hamming_(res, M, s);
|
||||||
|
}
|
||||||
|
static inline int mlx_hanning(mlx_array* res, int M, const mlx_stream s) {
|
||||||
|
return mlx_hanning_(res, M, s);
|
||||||
|
}
|
||||||
static inline int mlx_identity(mlx_array* res, int n, mlx_dtype dtype, const mlx_stream s) {
|
static inline int mlx_identity(mlx_array* res, int n, mlx_dtype dtype, const mlx_stream s) {
|
||||||
return mlx_identity_(res, n, dtype, s);
|
return mlx_identity_(res, n, dtype, s);
|
||||||
}
|
}
|
||||||
@@ -5793,8 +5826,10 @@ static inline int mlx_qqmm(
|
|||||||
mlx_optional_int group_size,
|
mlx_optional_int group_size,
|
||||||
mlx_optional_int bits,
|
mlx_optional_int bits,
|
||||||
const char* mode,
|
const char* mode,
|
||||||
|
const mlx_array global_scale_x /* may be null */,
|
||||||
|
const mlx_array global_scale_w /* may be null */,
|
||||||
const mlx_stream s) {
|
const mlx_stream s) {
|
||||||
return mlx_qqmm_(res, x, w, w_scales, group_size, bits, mode, s);
|
return mlx_qqmm_(res, x, w, w_scales, group_size, bits, mode, global_scale_x, global_scale_w, s);
|
||||||
}
|
}
|
||||||
static inline int mlx_quantize(
|
static inline int mlx_quantize(
|
||||||
mlx_vector_array* res,
|
mlx_vector_array* res,
|
||||||
@@ -5802,8 +5837,9 @@ static inline int mlx_quantize(
|
|||||||
mlx_optional_int group_size,
|
mlx_optional_int group_size,
|
||||||
mlx_optional_int bits,
|
mlx_optional_int bits,
|
||||||
const char* mode,
|
const char* mode,
|
||||||
|
const mlx_array global_scale /* may be null */,
|
||||||
const mlx_stream s) {
|
const mlx_stream s) {
|
||||||
return mlx_quantize_(res, w, group_size, bits, mode, s);
|
return mlx_quantize_(res, w, group_size, bits, mode, global_scale, s);
|
||||||
}
|
}
|
||||||
static inline int mlx_quantized_matmul(
|
static inline int mlx_quantized_matmul(
|
||||||
mlx_array* res,
|
mlx_array* res,
|
||||||
|
|||||||
@@ -1,7 +1,7 @@
|
|||||||
# Vendored MLX-C Headers
|
# Vendored MLX-C Headers
|
||||||
|
|
||||||
These header files are vendored from [mlx-c](https://github.com/ml-explore/mlx-c).
|
These header files are vendored from [mlx-c](https://github.com/ml-explore/mlx-c).
|
||||||
The pinned version is in `MLX_VERSION` at the repo root.
|
The pinned version is in `MLX_C_VERSION` at the repo root.
|
||||||
|
|
||||||
Headers are automatically refreshed when you run a CMake build:
|
Headers are automatically refreshed when you run a CMake build:
|
||||||
|
|
||||||
|
|||||||
@@ -42,12 +42,14 @@ mlx_distributed_group_split(mlx_distributed_group group, int color, int key);
|
|||||||
/**
|
/**
|
||||||
* Check if distributed is available.
|
* Check if distributed is available.
|
||||||
*/
|
*/
|
||||||
bool mlx_distributed_is_available(void);
|
bool mlx_distributed_is_available(const char* bk /* may be null */);
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Initialize distributed.
|
* Initialize distributed.
|
||||||
*/
|
*/
|
||||||
mlx_distributed_group mlx_distributed_init(bool strict);
|
mlx_distributed_group mlx_distributed_init(
|
||||||
|
bool strict,
|
||||||
|
const char* bk /* may be null */);
|
||||||
|
|
||||||
/**@}*/
|
/**@}*/
|
||||||
|
|
||||||
|
|||||||
@@ -166,6 +166,7 @@ int mlx_astype(
|
|||||||
int mlx_atleast_1d(mlx_array* res, const mlx_array a, const mlx_stream s);
|
int mlx_atleast_1d(mlx_array* res, const mlx_array a, const mlx_stream s);
|
||||||
int mlx_atleast_2d(mlx_array* res, const mlx_array a, const mlx_stream s);
|
int mlx_atleast_2d(mlx_array* res, const mlx_array a, const mlx_stream s);
|
||||||
int mlx_atleast_3d(mlx_array* res, const mlx_array a, const mlx_stream s);
|
int mlx_atleast_3d(mlx_array* res, const mlx_array a, const mlx_stream s);
|
||||||
|
int mlx_bartlett(mlx_array* res, int M, const mlx_stream s);
|
||||||
int mlx_bitwise_and(
|
int mlx_bitwise_and(
|
||||||
mlx_array* res,
|
mlx_array* res,
|
||||||
const mlx_array a,
|
const mlx_array a,
|
||||||
@@ -182,6 +183,7 @@ int mlx_bitwise_xor(
|
|||||||
const mlx_array a,
|
const mlx_array a,
|
||||||
const mlx_array b,
|
const mlx_array b,
|
||||||
const mlx_stream s);
|
const mlx_stream s);
|
||||||
|
int mlx_blackman(mlx_array* res, int M, const mlx_stream s);
|
||||||
int mlx_block_masked_mm(
|
int mlx_block_masked_mm(
|
||||||
mlx_array* res,
|
mlx_array* res,
|
||||||
const mlx_array a,
|
const mlx_array a,
|
||||||
@@ -362,6 +364,7 @@ int mlx_dequantize(
|
|||||||
mlx_optional_int group_size,
|
mlx_optional_int group_size,
|
||||||
mlx_optional_int bits,
|
mlx_optional_int bits,
|
||||||
const char* mode,
|
const char* mode,
|
||||||
|
const mlx_array global_scale /* may be null */,
|
||||||
mlx_optional_dtype dtype,
|
mlx_optional_dtype dtype,
|
||||||
const mlx_stream s);
|
const mlx_stream s);
|
||||||
int mlx_diag(mlx_array* res, const mlx_array a, int k, const mlx_stream s);
|
int mlx_diag(mlx_array* res, const mlx_array a, int k, const mlx_stream s);
|
||||||
@@ -498,6 +501,8 @@ int mlx_hadamard_transform(
|
|||||||
const mlx_array a,
|
const mlx_array a,
|
||||||
mlx_optional_float scale,
|
mlx_optional_float scale,
|
||||||
const mlx_stream s);
|
const mlx_stream s);
|
||||||
|
int mlx_hamming(mlx_array* res, int M, const mlx_stream s);
|
||||||
|
int mlx_hanning(mlx_array* res, int M, const mlx_stream s);
|
||||||
int mlx_identity(mlx_array* res, int n, mlx_dtype dtype, const mlx_stream s);
|
int mlx_identity(mlx_array* res, int n, mlx_dtype dtype, const mlx_stream s);
|
||||||
int mlx_imag(mlx_array* res, const mlx_array a, const mlx_stream s);
|
int mlx_imag(mlx_array* res, const mlx_array a, const mlx_stream s);
|
||||||
int mlx_inner(
|
int mlx_inner(
|
||||||
@@ -790,6 +795,8 @@ int mlx_qqmm(
|
|||||||
mlx_optional_int group_size,
|
mlx_optional_int group_size,
|
||||||
mlx_optional_int bits,
|
mlx_optional_int bits,
|
||||||
const char* mode,
|
const char* mode,
|
||||||
|
const mlx_array global_scale_x /* may be null */,
|
||||||
|
const mlx_array global_scale_w /* may be null */,
|
||||||
const mlx_stream s);
|
const mlx_stream s);
|
||||||
int mlx_quantize(
|
int mlx_quantize(
|
||||||
mlx_vector_array* res,
|
mlx_vector_array* res,
|
||||||
@@ -797,6 +804,7 @@ int mlx_quantize(
|
|||||||
mlx_optional_int group_size,
|
mlx_optional_int group_size,
|
||||||
mlx_optional_int bits,
|
mlx_optional_int bits,
|
||||||
const char* mode,
|
const char* mode,
|
||||||
|
const mlx_array global_scale /* may be null */,
|
||||||
const mlx_stream s);
|
const mlx_stream s);
|
||||||
int mlx_quantized_matmul(
|
int mlx_quantized_matmul(
|
||||||
mlx_array* res,
|
mlx_array* res,
|
||||||
|
|||||||
@@ -7,8 +7,44 @@ package mlx
|
|||||||
// #cgo LDFLAGS: -lstdc++
|
// #cgo LDFLAGS: -lstdc++
|
||||||
// #cgo darwin LDFLAGS: -framework Foundation -framework Metal -framework Accelerate
|
// #cgo darwin LDFLAGS: -framework Foundation -framework Metal -framework Accelerate
|
||||||
// #include "generated.h"
|
// #include "generated.h"
|
||||||
|
// #include <string.h>
|
||||||
|
//
|
||||||
|
// static char _mlx_last_error_msg[1024] = {0};
|
||||||
|
// static int _mlx_last_error_flag = 0;
|
||||||
|
//
|
||||||
|
// static void _mlx_capture_error_handler(const char* msg, void* data) {
|
||||||
|
// (void)data;
|
||||||
|
// strncpy(_mlx_last_error_msg, msg, sizeof(_mlx_last_error_msg) - 1);
|
||||||
|
// _mlx_last_error_msg[sizeof(_mlx_last_error_msg) - 1] = '\0';
|
||||||
|
// _mlx_last_error_flag = 1;
|
||||||
|
// }
|
||||||
|
//
|
||||||
|
// static void mlx_install_capture_handler(void) {
|
||||||
|
// if (mlx_set_error_handler_) {
|
||||||
|
// mlx_set_error_handler_(_mlx_capture_error_handler, NULL, NULL);
|
||||||
|
// }
|
||||||
|
// }
|
||||||
|
//
|
||||||
|
// static void mlx_clear_last_error(void) {
|
||||||
|
// _mlx_last_error_flag = 0;
|
||||||
|
// _mlx_last_error_msg[0] = '\0';
|
||||||
|
// }
|
||||||
|
//
|
||||||
|
// static int mlx_had_last_error(void) {
|
||||||
|
// return _mlx_last_error_flag;
|
||||||
|
// }
|
||||||
|
//
|
||||||
|
// static const char* mlx_get_last_error(void) {
|
||||||
|
// return _mlx_last_error_flag ? _mlx_last_error_msg : NULL;
|
||||||
|
// }
|
||||||
import "C"
|
import "C"
|
||||||
|
|
||||||
|
func init() {
|
||||||
|
// Replace the default exit(-1) error handler with one that captures
|
||||||
|
// the error message so we can surface it in Go.
|
||||||
|
C.mlx_install_capture_handler()
|
||||||
|
}
|
||||||
|
|
||||||
// Version returns the MLX core library version string.
|
// Version returns the MLX core library version string.
|
||||||
func Version() string {
|
func Version() string {
|
||||||
str := C.mlx_string_new()
|
str := C.mlx_string_new()
|
||||||
@@ -31,10 +67,19 @@ func doEval(outputs []*Array, async bool) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
C.mlx_clear_last_error()
|
||||||
|
var rc C.int
|
||||||
if async {
|
if async {
|
||||||
C.mlx_async_eval(vector)
|
rc = C.mlx_async_eval(vector)
|
||||||
} else {
|
} else {
|
||||||
C.mlx_eval(vector)
|
rc = C.mlx_eval(vector)
|
||||||
|
}
|
||||||
|
if rc != 0 {
|
||||||
|
msg := "mlx eval failed"
|
||||||
|
if C.mlx_had_last_error() != 0 {
|
||||||
|
msg = C.GoString(C.mlx_get_last_error())
|
||||||
|
}
|
||||||
|
panic("mlx: " + msg)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -17,7 +17,8 @@ func Quantize(w *Array, groupSize, bits int, mode string) (weights, scales, bias
|
|||||||
optBits := C.mlx_optional_int{value: C.int(bits), has_value: true}
|
optBits := C.mlx_optional_int{value: C.int(bits), has_value: true}
|
||||||
res := C.mlx_vector_array_new()
|
res := C.mlx_vector_array_new()
|
||||||
defer C.mlx_vector_array_free(res)
|
defer C.mlx_vector_array_free(res)
|
||||||
C.mlx_quantize(&res, w.ctx, optGroupSize, optBits, cMode, DefaultStream().ctx)
|
var globalScale C.mlx_array
|
||||||
|
C.mlx_quantize(&res, w.ctx, optGroupSize, optBits, cMode, globalScale, DefaultStream().ctx)
|
||||||
|
|
||||||
vecSize := int(C.mlx_vector_array_size(res))
|
vecSize := int(C.mlx_vector_array_size(res))
|
||||||
w0 := New("QUANTIZE_W")
|
w0 := New("QUANTIZE_W")
|
||||||
@@ -45,7 +46,8 @@ func Dequantize(w, scales, biases *Array, groupSize, bits int, mode string) *Arr
|
|||||||
}
|
}
|
||||||
|
|
||||||
out := New("DEQUANTIZE")
|
out := New("DEQUANTIZE")
|
||||||
C.mlx_dequantize(&out.ctx, w.ctx, scales.ctx, b, optGroupSize, optBits, cMode, optDtype, DefaultStream().ctx)
|
var globalScale C.mlx_array
|
||||||
|
C.mlx_dequantize(&out.ctx, w.ctx, scales.ctx, b, optGroupSize, optBits, cMode, globalScale, optDtype, DefaultStream().ctx)
|
||||||
return out
|
return out
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user