mirror of
https://github.com/ollama/ollama.git
synced 2026-04-18 09:03:35 -04:00
mlx: add compiled closure support
Wraps MLX's mlx_compile API so Go functions can be traced into fused kernels. Contiguous elementwise chains collapse into a single Metal/CUDA kernel instead of launching one per op. Exposes Compile plus arity helpers (Compile1/2/3) that mirror Python's @mx.compile decorator shape, lazily building the closure on first call so package-level declarations work before the MLX dylib loads.
This commit is contained in:
@@ -27,7 +27,11 @@ var arrays []*Array
|
|||||||
|
|
||||||
func New(name string) *Array {
|
func New(name string) *Array {
|
||||||
t := &Array{name: name}
|
t := &Array{name: name}
|
||||||
arrays = append(arrays, t)
|
if tracing {
|
||||||
|
traceScratch = append(traceScratch, t)
|
||||||
|
} else {
|
||||||
|
arrays = append(arrays, t)
|
||||||
|
}
|
||||||
return t
|
return t
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
192
x/mlxrunner/mlx/compile.go
Normal file
192
x/mlxrunner/mlx/compile.go
Normal file
@@ -0,0 +1,192 @@
|
|||||||
|
package mlx
|
||||||
|
|
||||||
|
// #include <stdlib.h>
|
||||||
|
// #include "generated.h"
|
||||||
|
//
|
||||||
|
// extern int closureCallback(mlx_vector_array* res, mlx_vector_array input, void* payload);
|
||||||
|
// extern void closureDestructor(void* payload);
|
||||||
|
import "C"
|
||||||
|
|
||||||
|
import (
|
||||||
|
"log/slog"
|
||||||
|
"runtime/cgo"
|
||||||
|
"sync"
|
||||||
|
"unsafe"
|
||||||
|
)
|
||||||
|
|
||||||
|
// CompileFunc is the signature of a function that can be compiled.
|
||||||
|
type CompileFunc func(inputs ...*Array) []*Array
|
||||||
|
|
||||||
|
// CompileOption configures Compile behavior.
|
||||||
|
type CompileOption func(*compileConfig)
|
||||||
|
|
||||||
|
type compileConfig struct {
|
||||||
|
shapeless bool
|
||||||
|
}
|
||||||
|
|
||||||
|
// Shapeless traces the function once against symbolic shapes so the compiled
|
||||||
|
// graph accepts any input shape afterwards. Without this option, MLX re-traces
|
||||||
|
// on each new (shape, dtype) combination and caches each specialization.
|
||||||
|
func Shapeless() CompileOption {
|
||||||
|
return func(c *compileConfig) { c.shapeless = true }
|
||||||
|
}
|
||||||
|
|
||||||
|
// Compile returns a compiled version of fn. When called during another
|
||||||
|
// compile's trace, fn is inlined directly so outer compiles can fuse through
|
||||||
|
// inner ones.
|
||||||
|
//
|
||||||
|
// Compiled functions must not have side effects outside of the function. Do
|
||||||
|
// not access data other than the arguments passed in (either Go data or MLX
|
||||||
|
// arrays) unless it is a constant.
|
||||||
|
func Compile(name string, fn CompileFunc, opts ...CompileOption) CompileFunc {
|
||||||
|
var cfg compileConfig
|
||||||
|
for _, o := range opts {
|
||||||
|
o(&cfg)
|
||||||
|
}
|
||||||
|
|
||||||
|
var closure C.mlx_closure
|
||||||
|
var once sync.Once
|
||||||
|
|
||||||
|
return func(inputs ...*Array) []*Array {
|
||||||
|
if tracing {
|
||||||
|
return fn(inputs...)
|
||||||
|
}
|
||||||
|
|
||||||
|
once.Do(func() {
|
||||||
|
payload := (*cgo.Handle)(C.malloc(C.size_t(unsafe.Sizeof(cgo.Handle(0)))))
|
||||||
|
*payload = cgo.NewHandle(fn)
|
||||||
|
src := C.mlx_closure_new_func_payload(
|
||||||
|
(*[0]byte)(C.closureCallback),
|
||||||
|
unsafe.Pointer(payload),
|
||||||
|
(*[0]byte)(C.closureDestructor),
|
||||||
|
)
|
||||||
|
defer C.mlx_closure_free(src)
|
||||||
|
|
||||||
|
closure = C.mlx_closure_new()
|
||||||
|
mlxCheck(name+": compile failed", func() C.int {
|
||||||
|
return C.mlx_compile(&closure, src, C.bool(cfg.shapeless))
|
||||||
|
})
|
||||||
|
})
|
||||||
|
|
||||||
|
inVec := C.mlx_vector_array_new()
|
||||||
|
defer C.mlx_vector_array_free(inVec)
|
||||||
|
for _, in := range inputs {
|
||||||
|
C.mlx_vector_array_append_value(inVec, in.ctx)
|
||||||
|
}
|
||||||
|
|
||||||
|
outVec := C.mlx_vector_array_new()
|
||||||
|
defer C.mlx_vector_array_free(outVec)
|
||||||
|
mlxCheck(name+": closure apply failed", func() C.int {
|
||||||
|
return C.mlx_closure_apply(&outVec, closure, inVec)
|
||||||
|
})
|
||||||
|
|
||||||
|
n := int(C.mlx_vector_array_size(outVec))
|
||||||
|
outputs := make([]*Array, n)
|
||||||
|
for i := range n {
|
||||||
|
outputs[i] = New(name)
|
||||||
|
C.mlx_vector_array_get(&outputs[i].ctx, outVec, C.size_t(i))
|
||||||
|
}
|
||||||
|
return outputs
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Compile1 compiles a unary function. See Compile.
|
||||||
|
func Compile1(name string, fn func(*Array) *Array, opts ...CompileOption) func(*Array) *Array {
|
||||||
|
cf := Compile(name, func(in ...*Array) []*Array {
|
||||||
|
return []*Array{fn(in[0])}
|
||||||
|
}, opts...)
|
||||||
|
return func(a *Array) *Array {
|
||||||
|
return cf(a)[0]
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Compile2 compiles a binary function. See Compile.
|
||||||
|
func Compile2(name string, fn func(*Array, *Array) *Array, opts ...CompileOption) func(*Array, *Array) *Array {
|
||||||
|
cf := Compile(name, func(in ...*Array) []*Array {
|
||||||
|
return []*Array{fn(in[0], in[1])}
|
||||||
|
}, opts...)
|
||||||
|
return func(a, b *Array) *Array {
|
||||||
|
return cf(a, b)[0]
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Compile3 compiles a ternary function. See Compile.
|
||||||
|
func Compile3(name string, fn func(*Array, *Array, *Array) *Array, opts ...CompileOption) func(*Array, *Array, *Array) *Array {
|
||||||
|
cf := Compile(name, func(in ...*Array) []*Array {
|
||||||
|
return []*Array{fn(in[0], in[1], in[2])}
|
||||||
|
}, opts...)
|
||||||
|
return func(a, b, c *Array) *Array {
|
||||||
|
return cf(a, b, c)[0]
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// tracing is true while a compile callback is running. Since MLX is
|
||||||
|
// single-threaded at this level a plain Go bool suffices.
|
||||||
|
var tracing bool
|
||||||
|
|
||||||
|
// traceScratch collects arrays created during a compile trace so they can be
|
||||||
|
// freed as a group when the callback returns.
|
||||||
|
var traceScratch []*Array
|
||||||
|
|
||||||
|
//export closureCallback
|
||||||
|
func closureCallback(res *C.mlx_vector_array, input C.mlx_vector_array, payload unsafe.Pointer) (rc C.int) {
|
||||||
|
defer func() {
|
||||||
|
if r := recover(); r != nil {
|
||||||
|
slog.Error("mlx closure callback panicked", "panic", r)
|
||||||
|
rc = 1
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
|
handle := *(*cgo.Handle)(payload)
|
||||||
|
fn := handle.Value().(CompileFunc)
|
||||||
|
|
||||||
|
// When tracing, we track all of the intermediates that are created and free them separately at the end of
|
||||||
|
// the process. This will give the effect of a single op - inputs are owned by the original caller (via
|
||||||
|
// the MLX layer) and outputs are transferred back to MLX to create a new Go side tensor.
|
||||||
|
if tracing {
|
||||||
|
panic("mlx: nested compile trace")
|
||||||
|
}
|
||||||
|
tracing = true
|
||||||
|
traceScratch = nil
|
||||||
|
defer func() {
|
||||||
|
for _, a := range traceScratch {
|
||||||
|
if a.pinned > 0 {
|
||||||
|
panic("mlx: traced array was pinned during compilation")
|
||||||
|
}
|
||||||
|
if a.Valid() {
|
||||||
|
C.mlx_array_free(a.ctx)
|
||||||
|
a.ctx.ctx = nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
tracing = false
|
||||||
|
traceScratch = nil
|
||||||
|
}()
|
||||||
|
|
||||||
|
n := int(C.mlx_vector_array_size(input))
|
||||||
|
inputs := make([]*Array, n)
|
||||||
|
for i := range n {
|
||||||
|
a := New("")
|
||||||
|
C.mlx_vector_array_get(&a.ctx, input, C.size_t(i))
|
||||||
|
inputs[i] = a
|
||||||
|
}
|
||||||
|
|
||||||
|
outputs := fn(inputs...)
|
||||||
|
|
||||||
|
var arrPtr *C.mlx_array
|
||||||
|
if len(outputs) > 0 {
|
||||||
|
handles := make([]C.mlx_array, len(outputs))
|
||||||
|
for i, out := range outputs {
|
||||||
|
handles[i] = out.ctx
|
||||||
|
}
|
||||||
|
arrPtr = &handles[0]
|
||||||
|
}
|
||||||
|
C.mlx_vector_array_set_data(res, arrPtr, C.size_t(len(outputs)))
|
||||||
|
return 0
|
||||||
|
}
|
||||||
|
|
||||||
|
//export closureDestructor
|
||||||
|
func closureDestructor(payload unsafe.Pointer) {
|
||||||
|
handle := *(*cgo.Handle)(payload)
|
||||||
|
handle.Delete()
|
||||||
|
C.free(payload)
|
||||||
|
}
|
||||||
147
x/mlxrunner/mlx/compile_test.go
Normal file
147
x/mlxrunner/mlx/compile_test.go
Normal file
@@ -0,0 +1,147 @@
|
|||||||
|
package mlx
|
||||||
|
|
||||||
|
import (
|
||||||
|
"testing"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestCompileFusion(t *testing.T) {
|
||||||
|
skipIfNoMLX(t)
|
||||||
|
|
||||||
|
// Compile fuses the ops inside a function body into a single kernel,
|
||||||
|
// eliminating intermediate buffers. Use a diamond-shaped graph where
|
||||||
|
// two branches must be materialized simultaneously without fusion,
|
||||||
|
// then compare peak memory against the compiled version which fuses
|
||||||
|
// everything into one kernel with no intermediates.
|
||||||
|
const n = 1024 * 1024 // 4MB per float32 array
|
||||||
|
data := make([]float32, n)
|
||||||
|
for i := range data {
|
||||||
|
data[i] = float32(i + 1)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Diamond: both a*b and a+b must be live for the final multiply.
|
||||||
|
// Without fusion: peak includes both intermediates (~8MB extra).
|
||||||
|
// With fusion: single kernel, no intermediates.
|
||||||
|
body := func(a, b *Array) *Array {
|
||||||
|
return a.Multiply(b).Multiply(a.Add(b))
|
||||||
|
}
|
||||||
|
|
||||||
|
a := FromValues(data, n)
|
||||||
|
b := FromValues(data, n)
|
||||||
|
Pin(a, b)
|
||||||
|
defer Unpin(a, b)
|
||||||
|
|
||||||
|
// Compiled: ops fused into a single kernel.
|
||||||
|
EnableCompile()
|
||||||
|
fn := Compile2("diamond", body, Shapeless())
|
||||||
|
warm := fn(a, b)
|
||||||
|
Eval(warm)
|
||||||
|
Sweep()
|
||||||
|
ClearCache()
|
||||||
|
ResetPeakMemory()
|
||||||
|
y := fn(a, b)
|
||||||
|
Eval(y)
|
||||||
|
compiledPeak := PeakMemory()
|
||||||
|
Sweep()
|
||||||
|
|
||||||
|
// Uncompiled: ops evaluated individually, intermediates materialized.
|
||||||
|
ClearCache()
|
||||||
|
ResetPeakMemory()
|
||||||
|
z := body(a, b)
|
||||||
|
Eval(z)
|
||||||
|
uncompiledPeak := PeakMemory()
|
||||||
|
Sweep()
|
||||||
|
|
||||||
|
if compiledPeak == 0 && uncompiledPeak == 0 {
|
||||||
|
t.Skip("peak memory tracking not available")
|
||||||
|
}
|
||||||
|
|
||||||
|
t.Logf("peak memory: compiled=%d uncompiled=%d", compiledPeak, uncompiledPeak)
|
||||||
|
|
||||||
|
if compiledPeak >= uncompiledPeak {
|
||||||
|
t.Fatalf("compilation did not reduce peak memory: compiled=%d uncompiled=%d", compiledPeak, uncompiledPeak)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestCompileNested(t *testing.T) {
|
||||||
|
skipIfNoMLX(t)
|
||||||
|
|
||||||
|
// A compiled function that calls another compiled function should
|
||||||
|
// produce correct results. The inner function inlines via isTracing()
|
||||||
|
// during the outer's trace.
|
||||||
|
inner := Compile1("silu", func(a *Array) *Array {
|
||||||
|
return a.Multiply(a.Sigmoid())
|
||||||
|
}, Shapeless())
|
||||||
|
|
||||||
|
outer := Compile2("swiglu", func(gate, up *Array) *Array {
|
||||||
|
return inner(gate).Multiply(up)
|
||||||
|
}, Shapeless())
|
||||||
|
|
||||||
|
gate := FromValues([]float32{0, 1, 2}, 3)
|
||||||
|
up := FromValues([]float32{1, 1, 1}, 3)
|
||||||
|
Pin(gate, up)
|
||||||
|
defer Unpin(gate, up)
|
||||||
|
|
||||||
|
y := outer(gate, up)
|
||||||
|
Eval(y)
|
||||||
|
|
||||||
|
// silu(x) = x * sigmoid(x); for x=0 → 0, x=1 → ~0.7311, x=2 → ~1.7616
|
||||||
|
got := y.Floats()
|
||||||
|
want := []float32{0, 0.7310586, 1.7615942}
|
||||||
|
for i, v := range got {
|
||||||
|
if v-want[i] > 1e-4 || want[i]-v > 1e-4 {
|
||||||
|
t.Fatalf("got[%d]=%v want %v", i, v, want[i])
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestCompileCallbackPanicRecovers(t *testing.T) {
|
||||||
|
skipIfNoMLX(t)
|
||||||
|
|
||||||
|
boom := Compile1("boom", func(a *Array) *Array {
|
||||||
|
panic("intentional test panic")
|
||||||
|
})
|
||||||
|
|
||||||
|
x := FromValues([]float32{1}, 1)
|
||||||
|
Pin(x)
|
||||||
|
defer Unpin(x)
|
||||||
|
|
||||||
|
defer func() {
|
||||||
|
r := recover()
|
||||||
|
if r == nil {
|
||||||
|
t.Fatal("expected panic from Call, got none")
|
||||||
|
}
|
||||||
|
if _, ok := r.(string); !ok {
|
||||||
|
t.Fatalf("expected string panic, got %T: %v", r, r)
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
boom(x)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestCompileNoTrackingGrowth(t *testing.T) {
|
||||||
|
skipIfNoMLX(t)
|
||||||
|
|
||||||
|
// Repeated invocations of a compiled kernel should not grow the
|
||||||
|
// tracked-arrays list — the callback's traceScratch collects
|
||||||
|
// intermediates during tracing and frees them when the callback returns.
|
||||||
|
fn := Compile2("mul_add", func(a, b *Array) *Array {
|
||||||
|
return a.Multiply(b).Add(b)
|
||||||
|
})
|
||||||
|
|
||||||
|
a := FromValues([]float32{1, 2}, 2)
|
||||||
|
b := FromValues([]float32{3, 4}, 2)
|
||||||
|
Pin(a, b)
|
||||||
|
defer Unpin(a, b)
|
||||||
|
|
||||||
|
Sweep()
|
||||||
|
before := len(arrays)
|
||||||
|
|
||||||
|
for range 100 {
|
||||||
|
_ = fn(a, b)
|
||||||
|
Sweep()
|
||||||
|
}
|
||||||
|
|
||||||
|
after := len(arrays)
|
||||||
|
if after > before+2 {
|
||||||
|
t.Fatalf("tracked arrays grew from %d to %d across 100 calls (includes initial trace)", before, after)
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -9,8 +9,8 @@ package mlx
|
|||||||
// #include "generated.h"
|
// #include "generated.h"
|
||||||
// #include <string.h>
|
// #include <string.h>
|
||||||
//
|
//
|
||||||
// static char _mlx_last_error_msg[1024] = {0};
|
// static __thread char _mlx_last_error_msg[1024] = {0};
|
||||||
// static int _mlx_last_error_flag = 0;
|
// static __thread int _mlx_last_error_flag = 0;
|
||||||
//
|
//
|
||||||
// static void _mlx_capture_error_handler(const char* msg, void* data) {
|
// static void _mlx_capture_error_handler(const char* msg, void* data) {
|
||||||
// (void)data;
|
// (void)data;
|
||||||
@@ -30,15 +30,13 @@ package mlx
|
|||||||
// _mlx_last_error_msg[0] = '\0';
|
// _mlx_last_error_msg[0] = '\0';
|
||||||
// }
|
// }
|
||||||
//
|
//
|
||||||
// static int mlx_had_last_error(void) {
|
|
||||||
// return _mlx_last_error_flag;
|
|
||||||
// }
|
|
||||||
//
|
|
||||||
// static const char* mlx_get_last_error(void) {
|
// static const char* mlx_get_last_error(void) {
|
||||||
// return _mlx_last_error_flag ? _mlx_last_error_msg : NULL;
|
// return _mlx_last_error_flag ? _mlx_last_error_msg : "";
|
||||||
// }
|
// }
|
||||||
import "C"
|
import "C"
|
||||||
|
|
||||||
|
import "runtime"
|
||||||
|
|
||||||
func init() {
|
func init() {
|
||||||
// Replace the default exit(-1) error handler with one that captures
|
// Replace the default exit(-1) error handler with one that captures
|
||||||
// the error message so we can surface it in Go.
|
// the error message so we can surface it in Go.
|
||||||
@@ -53,6 +51,24 @@ func Version() string {
|
|||||||
return C.GoString(C.mlx_string_data(str))
|
return C.GoString(C.mlx_string_data(str))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// mlxCheck locks the goroutine to its OS thread, clears the captured error
|
||||||
|
// state, calls fn, and panics with the captured message if fn returns non-zero.
|
||||||
|
// The thread lock ensures the thread-local error state is read from the same
|
||||||
|
// thread that executed the call.
|
||||||
|
func mlxCheck(fallback string, fn func() C.int) {
|
||||||
|
runtime.LockOSThread()
|
||||||
|
defer runtime.UnlockOSThread()
|
||||||
|
|
||||||
|
C.mlx_clear_last_error()
|
||||||
|
if fn() != 0 {
|
||||||
|
msg := C.GoString(C.mlx_get_last_error())
|
||||||
|
if msg == "" {
|
||||||
|
msg = fallback
|
||||||
|
}
|
||||||
|
panic("mlx: " + msg)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func doEval(outputs []*Array, async bool) {
|
func doEval(outputs []*Array, async bool) {
|
||||||
if len(outputs) == 0 {
|
if len(outputs) == 0 {
|
||||||
return
|
return
|
||||||
@@ -67,20 +83,12 @@ func doEval(outputs []*Array, async bool) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
C.mlx_clear_last_error()
|
mlxCheck("eval failed", func() C.int {
|
||||||
var rc C.int
|
if async {
|
||||||
if async {
|
return C.mlx_async_eval(vector)
|
||||||
rc = C.mlx_async_eval(vector)
|
|
||||||
} else {
|
|
||||||
rc = C.mlx_eval(vector)
|
|
||||||
}
|
|
||||||
if rc != 0 {
|
|
||||||
msg := "mlx eval failed"
|
|
||||||
if C.mlx_had_last_error() != 0 {
|
|
||||||
msg = C.GoString(C.mlx_get_last_error())
|
|
||||||
}
|
}
|
||||||
panic("mlx: " + msg)
|
return C.mlx_eval(vector)
|
||||||
}
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
func AsyncEval(outputs ...*Array) {
|
func AsyncEval(outputs ...*Array) {
|
||||||
|
|||||||
@@ -23,15 +23,6 @@ func (r *Runner) TextGenerationPipeline(request Request) error {
|
|||||||
return errors.New("model not loaded")
|
return errors.New("model not loaded")
|
||||||
}
|
}
|
||||||
|
|
||||||
enableCompile := true
|
|
||||||
if modelCompile, ok := r.Model.(interface{ EnableCompile() bool }); ok {
|
|
||||||
enableCompile = modelCompile.EnableCompile()
|
|
||||||
}
|
|
||||||
if enableCompile {
|
|
||||||
mlx.EnableCompile()
|
|
||||||
} else {
|
|
||||||
mlx.DisableCompile()
|
|
||||||
}
|
|
||||||
mlx.ResetPeakMemory()
|
mlx.ResetPeakMemory()
|
||||||
ctx := request.Ctx
|
ctx := request.Ctx
|
||||||
var (
|
var (
|
||||||
|
|||||||
@@ -79,6 +79,8 @@ func (r *Runner) Load(modelName string) error {
|
|||||||
r.Model = m
|
r.Model = m
|
||||||
r.Tokenizer = m.Tokenizer()
|
r.Tokenizer = m.Tokenizer()
|
||||||
r.contextLength = m.MaxContextLength()
|
r.contextLength = m.MaxContextLength()
|
||||||
|
|
||||||
|
mlx.EnableCompile()
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user