mlxrunner: tokenize prompts in request handler goroutines

Move tokenization out of the single GPU processing goroutine and
into each request's HTTP handler goroutine. This allows the next
request's prompt to be tokenized on the CPU while the current
request is executing on the GPU.
This commit is contained in:
Jesse Gross
2026-04-03 16:25:33 -07:00
parent 845b7b29c4
commit 22e8abe666
3 changed files with 40 additions and 28 deletions

View File

@@ -6,10 +6,8 @@ import (
"errors"
"fmt"
"log/slog"
"net/http"
"time"
"github.com/ollama/ollama/api"
"github.com/ollama/ollama/logutil"
"github.com/ollama/ollama/x/mlxrunner/mlx"
)
@@ -18,13 +16,37 @@ func prefillChunkSize() int {
return 2 << 10
}
func (r *Runner) TextGenerationPipeline(request Request) error {
// Prepare tokenizes the prompt and validates it against the model's
// context length. It is safe to call from any goroutine. On success it
// populates request.Tokens and adjusts request.Options.MaxTokens.
func (r *Runner) Prepare(request *Request) error {
if r.Model == nil {
return errors.New("model not loaded")
}
tokens := r.Tokenizer.Encode(request.Prompt, r.Tokenizer.AddBOS())
if len(tokens) == 0 {
return errors.New("empty prompt")
}
if len(tokens) >= r.contextLength {
return fmt.Errorf("input length (%d tokens) exceeds the model's maximum context length (%d tokens)", len(tokens), r.contextLength)
}
// Cap generation to stay within the model's context length
maxGenerate := r.contextLength - len(tokens)
if request.Options.MaxTokens <= 0 {
request.Options.MaxTokens = maxGenerate
} else {
request.Options.MaxTokens = min(request.Options.MaxTokens, maxGenerate)
}
request.Tokens = tokens
return nil
}
func (r *Runner) TextGenerationPipeline(ctx context.Context, request Request) error {
mlx.ResetPeakMemory()
ctx := request.Ctx
var (
sample, logprobs *mlx.Array
nextSample, nextLogprobs *mlx.Array
@@ -46,26 +68,7 @@ func (r *Runner) TextGenerationPipeline(request Request) error {
slog.Info("peak memory", "size", mlx.PrettyBytes(mlx.PeakMemory()))
}()
inputs := r.Tokenizer.Encode(request.Prompt, r.Tokenizer.AddBOS())
if len(inputs) == 0 {
return errors.New("empty prompt")
}
if len(inputs) >= r.contextLength {
return api.StatusError{
StatusCode: http.StatusBadRequest,
ErrorMessage: fmt.Sprintf("input length (%d tokens) exceeds the model's maximum context length (%d tokens)", len(inputs), r.contextLength),
}
}
// Cap generation to stay within the model's context length
maxGenerate := r.contextLength - len(inputs)
if request.Options.MaxTokens <= 0 {
request.Options.MaxTokens = maxGenerate
} else {
request.Options.MaxTokens = min(request.Options.MaxTokens, maxGenerate)
}
inputs := request.Tokens
request.Sampler.ResetHistory(inputs)
session := r.cache.begin(r.Model, inputs)

View File

@@ -18,13 +18,17 @@ import (
"github.com/ollama/ollama/x/tokenizer"
)
// Request is a short-lived struct that carries a completion request through
// a channel from the HTTP handler to the runner goroutine. The ctx field
// must travel with the request so that cancellation propagates across the
// channel boundary.
type Request struct {
TextCompletionsRequest
Responses chan CompletionResponse
Pipeline func(Request) error
Ctx context.Context
Pipeline func(context.Context, Request) error
Ctx context.Context //nolint:containedctx
Tokens []int32
Sampler *sample.Sampler
}
@@ -147,7 +151,7 @@ func (r *Runner) Run(host, port string, mux http.Handler) error {
case <-ctx.Done():
return nil
case request := <-r.Requests:
if err := request.Pipeline(request); err != nil {
if err := request.Pipeline(request.Ctx, request); err != nil {
slog.Info("Request terminated", "error", err)
var statusErr api.StatusError
if !errors.As(err, &statusErr) {

View File

@@ -105,6 +105,11 @@ func Execute(args []string) error {
request.Options.PresencePenalty,
)
if err := runner.Prepare(&request); err != nil {
http.Error(w, err.Error(), http.StatusBadRequest)
return
}
var cancel context.CancelFunc
request.Ctx, cancel = context.WithCancel(r.Context())
defer cancel()