Compare commits

...

35 Commits

Author SHA1 Message Date
Parth Sareen
123b300af6 docs: update hermes (#15655) 2026-04-17 14:20:59 -07:00
Parth Sareen
57653b8e42 cmd/launch: show WSL guidance on Windows instead of handing off (#15637) 2026-04-16 17:18:04 -07:00
Parth Sareen
a50ce61c54 launch: skip unchanged managed-single rewrite (#15633) 2026-04-16 16:20:42 -07:00
Daniel Hiltgen
2bb7ea00d2 create: avoid gc race with create (#15628)
If you have a long running create, and start another ollama server with the
same model dir, the GC algorithm deletes the pending blobs and breaks the
create.  This adds a 1h grace period to avoid deleting in-flight creation
operations.
2026-04-16 13:29:16 -07:00
Daniel Hiltgen
55fa80d07a mlx: additional gemma4 cache fixes (#15607)
Harden additional corner cases
2026-04-16 13:07:19 -07:00
Daniel Hiltgen
b9cb535407 mlx: fix gemma4 cache to use logical view (#15617) 2026-04-16 11:54:30 -07:00
Daniel Hiltgen
031baef094 mlx: fix imagegen lookup (#15588)
* mlx: fix imagegen lookup

Fixes #15533 - imagegen had fallen out of sync with the new layout
for multiple mlx libraries on Metal.

* review comments
2026-04-16 10:39:00 -07:00
Mike Wallio
7d271e6dc9 cmd/launch: add Copilot CLI integration (#15583)
---------

Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>
Co-authored-by: ParthSareen <parth.sareen@ollama.com>
2026-04-15 17:22:53 -07:00
Devon Rifkin
c88dae2d6b Merge pull request #15612 from ollama/drifkin/gemma4-split-templates
gemma4: render differently based on model size
2026-04-15 17:15:35 -07:00
Devon Rifkin
9e3618d663 make empty block conditional 2026-04-15 15:35:25 -07:00
Daniel Hiltgen
5d920cc6bc Keep Gemma4 router projection in source precision (#15613) 2026-04-15 15:04:23 -07:00
Devon Rifkin
e585ecd11f gemma4: render differently based on model size
Following up on #15560, this change now has e2b/e4b render differently
from 26b/31b.

For backwards compatibility, we take the existing renderer name `gemma4`
and make it do dynamic resolution based on the model name/size, but the
intended use is for the models to be republished with the renderer
variant specified explicitly: `gemma4-small` or `gemma4-large`.
2026-04-15 14:37:16 -07:00
Eva H
cdddea0592 launch: always list cloud recommendations first (#15593) 2026-04-15 13:17:35 -07:00
Parth Sareen
43f90def04 launch: add hermes (#15569) 2026-04-15 12:00:23 -07:00
Daniel Hiltgen
06ae6367bd mlx: fix RotatingKVCache.concat() dropping context on mid-rotation (#15591)
After the rotating buffer has wrapped (c.offset > c.maxSize) a subsequent
L>1 Update() went through a slice-to-[0, c.idx) path that discarded all
slots in [c.idx, Dim), losing the older-but-still-in-window tokens the
first Q of the new batch needs for its sliding-window attention.

Linearize the circular buffer to logical order in that wrapped case so
the existing trim + concat preserves the last (maxSize - 1) old tokens.
When the buffer has not yet wrapped (c.offset <= c.maxSize), slots
[c.idx, Dim) are grow padding or stale post-rewind data, so keep
dropping them.
2026-04-14 18:29:06 -07:00
Daniel Hiltgen
48ad7085c4 mlx: Improve gemma4 performance with fused operations (#15587)
* mlx: Improve gemma4 performance with fused operations

* review comments
2026-04-14 18:04:04 -07:00
Jesse Gross
e1e3cec8d0 models: fuse MLP activation functions via mlx_compile
Converts SiLU/GELUApprox to compiled kernels and adds SwiGLU,
matching upstream mlx/mlx_lm's activations pattern. Routes llama,
qwen3, qwen3_5 (dense + MoE), and glm4_moe_lite MLP paths through
mlx.SwiGLU so each MLP invocation runs as one fused Metal/CUDA
kernel rather than a chain of per-op launches.
2026-04-14 16:38:32 -07:00
Jesse Gross
d3e67e305c mlx: add compiled closure support
Wraps MLX's mlx_compile API so Go functions can be traced into fused
kernels. Contiguous elementwise chains collapse into a single
Metal/CUDA kernel instead of launching one per op.

Exposes Compile plus arity helpers (Compile1/2/3) that mirror Python's
@mx.compile decorator shape, lazily building the closure on first call
so package-level declarations work before the MLX dylib loads.
2026-04-14 16:38:32 -07:00
Eva H
698e04a14b launch: OpenCode inline config (#15586) 2026-04-14 15:08:42 -07:00
Eva H
1d9537bc33 launch/openclaw: fix --yes flag behaviour to skip channels configuration (#15589) 2026-04-14 13:57:35 -07:00
Eva H
120424d832 Revert "launch/opencode: use inline config (#15462)" (#15568) 2026-04-13 18:40:17 -07:00
Eva H
5818001610 launch: skip unchanged integration rewrite configration (#15491) 2026-04-13 17:18:56 -07:00
Daniel Hiltgen
2cba7756c5 Gemma4 on MLX (#15244)
* gemma4: implement Gemma 4 model for MLX (text-only runtime)

* gemma4: two MoE + SWA prefill perf fixes

Two performance optimizations in the gemma4 forward pass

1. Memoize the sliding-window prefill mask across layers.
2. Softmax only over the selected experts in Router.Forward.

* review comments
2026-04-13 16:36:51 -07:00
Devon Rifkin
bf2a421727 gemma4: restore e2b-style nothink prompt (#15560)
Gemma 4 prompts differ when thinking is disabled for different sized
models: 26b/31b emit an empty thought block, while e2b/e4b do not.

Before #15490, our shared Gemma 4 renderer effectively matched the
e2b behavior. #15490 changed it to always emit the empty thought block,
which regressed e2b/e4b nothink behavior and led to #15536 (and possibly

This change restores the previous shared behavior by removing the empty
trailing thought block. It also renames the checked-in upstream chat
templates so the e2b and 31b fixtures are tracked separately.

A follow-up will split Gemma 4 rendering by model size.

Fixes: #15536
2026-04-13 14:26:15 -07:00
Eva H
f3cf6b75fb launch/opencode: use inline config (#15462) 2026-04-13 13:41:31 -07:00
Devon Rifkin
5dfac387a6 Revert "gemma4: fix nothink case renderer (#15553)" (#15556)
This reverts commit 4d75f5da03.
2026-04-13 13:12:18 -07:00
Daniel Hiltgen
a99e5d9c22 mac: prevent generate on cross-compiles (#15120)
For some versions of Xcode, cmake builds are failing due to header problems in
cross-compiling during the generate phase.  Since generate is producing arch
independent generated output, we can skip this during cross-compiling.
2026-04-13 13:04:58 -07:00
Daniel Hiltgen
0abf3aca36 cgo: suppress deprecated warning to quiet down go build (#15438) 2026-04-13 13:04:11 -07:00
Devon Rifkin
ee0266462a Revert "gemma4: add nothink renderer tests (#15554)" (#15555)
This reverts commit 1b70bb8a10.
2026-04-13 13:00:59 -07:00
Daniel Hiltgen
c88fb286ec mlx: add op wrappers for Conv2d, Pad, activations, trig, and masked SDPA (#14913)
* mlx: add op wrappers for Conv2d, Pad, activations, trig, and masked SDPA

Add Conv2d, flexible Pad (with axes/mode), PadConstant, Maximum,
Minimum, Softplus, ReLU, GLU, Clamp, Sin, Cos, Clip,
ScaledDotProductAttentionMasked, and RoPEWithFreqs. Refactor
RoPEWithBase to delegate to RoPEWithFreqs.

* review comments

* mlx: fix ScaledDotProductAttentionMasked to consult the mask argument
2026-04-13 11:43:24 -07:00
Daniel Hiltgen
d3da29cbfc mlx: mixed-precision quant and capability detection improvements (#15409)
Improve the MLX model creation pipeline with several model-agnostic changes:

- Rewrite supportsVision to use vision_config instead of architecture name
- Add supportsAudio for audio encoder detection
- Add alignment checking (isAligned) for quantization group sizes
- Support per-projection mixed quantization in MoE expert packing
- Record per-tensor quant metadata in safetensors blobs
- Parse per-tensor quant metadata at model load time
- Validate quantize output is non-empty before storing
- Fix pin/unpin cleanup in expert group quantization
- Promote v_proj/k_proj/down_proj to INT8 for INT4 base quant
- Add MetalIsAvailable() utility
- Skip audio encoder tensors from quantization
2026-04-13 11:43:07 -07:00
Devon Rifkin
1b70bb8a10 gemma4: add nothink renderer tests (#15554)
Meant to include in #15553
2026-04-13 11:38:19 -07:00
Daniel Hiltgen
ec29ce4ce3 gemma4: fix compiler error on metal (#15550)
On some systems, the metal runtime compiler is failing due to an
uninitialized variable from #15378.

Fixes #15548
2026-04-13 11:32:00 -07:00
Devon Rifkin
4d75f5da03 gemma4: fix nothink case renderer (#15553)
Regressed in #15490

Fixes: #15536
2026-04-13 11:23:19 -07:00
saman-amd
798fd09bfe Update to ROCm 7.2.1 (#15483)
Co-authored-by: Samiii777 <58442200+Samiii777@users.noreply.github.com>
2026-04-12 12:11:58 -07:00
76 changed files with 8922 additions and 1340 deletions

View File

@@ -51,7 +51,7 @@ jobs:
container: nvidia/cuda:13.0.0-devel-ubuntu22.04
flags: '-DCMAKE_CUDA_ARCHITECTURES=87'
- preset: ROCm
container: rocm/dev-ubuntu-22.04:7.2
container: rocm/dev-ubuntu-22.04:7.2.1
extra-packages: rocm-libs
flags: '-DAMDGPU_TARGETS=gfx1010 -DCMAKE_PREFIX_PATH=/opt/rocm'
- preset: Vulkan

View File

@@ -2,7 +2,7 @@
ARG FLAVOR=${TARGETARCH}
ARG ROCMVERSION=7.2
ARG ROCMVERSION=7.2.1
ARG JETPACK5VERSION=r35.4.1
ARG JETPACK6VERSION=r36.4.0
ARG CMAKEVERSION=3.31.2

View File

@@ -55,7 +55,7 @@ The official [Ollama Docker image](https://hub.docker.com/r/ollama/ollama) `olla
ollama
```
You'll be prompted to run a model or connect Ollama to your existing agents or applications such as `claude`, `codex`, `openclaw` and more.
You'll be prompted to run a model or connect Ollama to your existing agents or applications such as `Claude Code`, `OpenClaw`, `OpenCode` , `Codex`, `Copilot`, and more.
### Coding
@@ -65,7 +65,7 @@ To launch a specific integration:
ollama launch claude
```
Supported integrations include [Claude Code](https://docs.ollama.com/integrations/claude-code), [Codex](https://docs.ollama.com/integrations/codex), [Droid](https://docs.ollama.com/integrations/droid), and [OpenCode](https://docs.ollama.com/integrations/opencode).
Supported integrations include [Claude Code](https://docs.ollama.com/integrations/claude-code), [Codex](https://docs.ollama.com/integrations/codex), [Copilot CLI](https://docs.ollama.com/integrations/copilot-cli), [Droid](https://docs.ollama.com/integrations/droid), and [OpenCode](https://docs.ollama.com/integrations/opencode).
### AI assistant

View File

@@ -58,6 +58,9 @@ func TestLaunchCmd(t *testing.T) {
if cmd.Long == "" {
t.Error("Long description should not be empty")
}
if !strings.Contains(cmd.Long, "hermes") {
t.Error("Long description should mention hermes")
}
})
t.Run("flags exist", func(t *testing.T) {

76
cmd/launch/copilot.go Normal file
View File

@@ -0,0 +1,76 @@
package launch
import (
"fmt"
"os"
"os/exec"
"path/filepath"
"runtime"
"github.com/ollama/ollama/envconfig"
)
// Copilot implements Runner for GitHub Copilot CLI integration.
type Copilot struct{}
func (c *Copilot) String() string { return "Copilot CLI" }
func (c *Copilot) args(model string, extra []string) []string {
var args []string
if model != "" {
args = append(args, "--model", model)
}
args = append(args, extra...)
return args
}
func (c *Copilot) findPath() (string, error) {
if p, err := exec.LookPath("copilot"); err == nil {
return p, nil
}
home, err := os.UserHomeDir()
if err != nil {
return "", err
}
name := "copilot"
if runtime.GOOS == "windows" {
name = "copilot.exe"
}
fallback := filepath.Join(home, ".local", "bin", name)
if _, err := os.Stat(fallback); err != nil {
return "", err
}
return fallback, nil
}
func (c *Copilot) Run(model string, args []string) error {
copilotPath, err := c.findPath()
if err != nil {
return fmt.Errorf("copilot is not installed, install from https://docs.github.com/en/copilot/how-tos/set-up/install-copilot-cli")
}
cmd := exec.Command(copilotPath, c.args(model, args)...)
cmd.Stdin = os.Stdin
cmd.Stdout = os.Stdout
cmd.Stderr = os.Stderr
cmd.Env = append(os.Environ(), c.envVars(model)...)
return cmd.Run()
}
// envVars returns the environment variables that configure Copilot CLI
// to use Ollama as its model provider.
func (c *Copilot) envVars(model string) []string {
env := []string{
"COPILOT_PROVIDER_BASE_URL=" + envconfig.Host().String() + "/v1",
"COPILOT_PROVIDER_API_KEY=",
"COPILOT_PROVIDER_WIRE_API=responses",
}
if model != "" {
env = append(env, "COPILOT_MODEL="+model)
}
return env
}

161
cmd/launch/copilot_test.go Normal file
View File

@@ -0,0 +1,161 @@
package launch
import (
"os"
"path/filepath"
"runtime"
"slices"
"strings"
"testing"
)
func TestCopilotIntegration(t *testing.T) {
c := &Copilot{}
t.Run("String", func(t *testing.T) {
if got := c.String(); got != "Copilot CLI" {
t.Errorf("String() = %q, want %q", got, "Copilot CLI")
}
})
t.Run("implements Runner", func(t *testing.T) {
var _ Runner = c
})
}
func TestCopilotFindPath(t *testing.T) {
c := &Copilot{}
t.Run("finds copilot in PATH", func(t *testing.T) {
tmpDir := t.TempDir()
name := "copilot"
if runtime.GOOS == "windows" {
name = "copilot.exe"
}
fakeBin := filepath.Join(tmpDir, name)
os.WriteFile(fakeBin, []byte("#!/bin/sh\n"), 0o755)
t.Setenv("PATH", tmpDir)
got, err := c.findPath()
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if got != fakeBin {
t.Errorf("findPath() = %q, want %q", got, fakeBin)
}
})
t.Run("returns error when not in PATH", func(t *testing.T) {
t.Setenv("PATH", t.TempDir()) // empty dir, no copilot binary
_, err := c.findPath()
if err == nil {
t.Fatal("expected error, got nil")
}
})
t.Run("falls back to ~/.local/bin/copilot", func(t *testing.T) {
tmpDir := t.TempDir()
setTestHome(t, tmpDir)
t.Setenv("PATH", t.TempDir()) // empty dir, no copilot binary
name := "copilot"
if runtime.GOOS == "windows" {
name = "copilot.exe"
}
fallback := filepath.Join(tmpDir, ".local", "bin", name)
os.MkdirAll(filepath.Dir(fallback), 0o755)
os.WriteFile(fallback, []byte("#!/bin/sh\n"), 0o755)
got, err := c.findPath()
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if got != fallback {
t.Errorf("findPath() = %q, want %q", got, fallback)
}
})
t.Run("returns error when neither PATH nor fallback exists", func(t *testing.T) {
tmpDir := t.TempDir()
setTestHome(t, tmpDir)
t.Setenv("PATH", t.TempDir()) // empty dir, no copilot binary
_, err := c.findPath()
if err == nil {
t.Fatal("expected error, got nil")
}
})
}
func TestCopilotArgs(t *testing.T) {
c := &Copilot{}
tests := []struct {
name string
model string
args []string
want []string
}{
{"with model", "llama3.2", nil, []string{"--model", "llama3.2"}},
{"empty model", "", nil, nil},
{"with model and extra", "llama3.2", []string{"--verbose"}, []string{"--model", "llama3.2", "--verbose"}},
{"empty model with help", "", []string{"--help"}, []string{"--help"}},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got := c.args(tt.model, tt.args)
if !slices.Equal(got, tt.want) {
t.Errorf("args(%q, %v) = %v, want %v", tt.model, tt.args, got, tt.want)
}
})
}
}
func TestCopilotEnvVars(t *testing.T) {
c := &Copilot{}
envMap := func(envs []string) map[string]string {
m := make(map[string]string)
for _, e := range envs {
k, v, _ := strings.Cut(e, "=")
m[k] = v
}
return m
}
t.Run("sets required provider env vars with model", func(t *testing.T) {
got := envMap(c.envVars("llama3.2"))
if got["COPILOT_PROVIDER_BASE_URL"] == "" {
t.Error("COPILOT_PROVIDER_BASE_URL should be set")
}
if !strings.HasSuffix(got["COPILOT_PROVIDER_BASE_URL"], "/v1") {
t.Errorf("COPILOT_PROVIDER_BASE_URL = %q, want /v1 suffix", got["COPILOT_PROVIDER_BASE_URL"])
}
if _, ok := got["COPILOT_PROVIDER_API_KEY"]; !ok {
t.Error("COPILOT_PROVIDER_API_KEY should be set (empty)")
}
if got["COPILOT_PROVIDER_WIRE_API"] != "responses" {
t.Errorf("COPILOT_PROVIDER_WIRE_API = %q, want %q", got["COPILOT_PROVIDER_WIRE_API"], "responses")
}
if got["COPILOT_MODEL"] != "llama3.2" {
t.Errorf("COPILOT_MODEL = %q, want %q", got["COPILOT_MODEL"], "llama3.2")
}
})
t.Run("omits COPILOT_MODEL when model is empty", func(t *testing.T) {
got := envMap(c.envVars(""))
if _, ok := got["COPILOT_MODEL"]; ok {
t.Errorf("COPILOT_MODEL should not be set for empty model, got %q", got["COPILOT_MODEL"])
}
})
t.Run("uses custom OLLAMA_HOST", func(t *testing.T) {
t.Setenv("OLLAMA_HOST", "http://myhost:9999")
got := envMap(c.envVars("test"))
if !strings.Contains(got["COPILOT_PROVIDER_BASE_URL"], "myhost:9999") {
t.Errorf("COPILOT_PROVIDER_BASE_URL = %q, want custom host", got["COPILOT_PROVIDER_BASE_URL"])
}
})
}

679
cmd/launch/hermes.go Normal file
View File

@@ -0,0 +1,679 @@
package launch
import (
"bufio"
"bytes"
"context"
"fmt"
"net/http"
"os"
"os/exec"
"path/filepath"
"runtime"
"slices"
"strconv"
"strings"
"gopkg.in/yaml.v3"
"github.com/ollama/ollama/api"
"github.com/ollama/ollama/cmd/config"
"github.com/ollama/ollama/cmd/internal/fileutil"
"github.com/ollama/ollama/envconfig"
)
const (
hermesInstallScript = "curl -fsSL https://raw.githubusercontent.com/NousResearch/hermes-agent/main/scripts/install.sh | bash -s -- --skip-setup"
hermesProviderName = "Ollama"
hermesProviderKey = "ollama-launch"
hermesLegacyKey = "ollama"
hermesPlaceholderKey = "ollama"
hermesGatewaySetupHint = "hermes gateway setup"
hermesGatewaySetupTitle = "Connect a messaging app now?"
)
var (
hermesGOOS = runtime.GOOS
hermesLookPath = exec.LookPath
hermesCommand = exec.Command
hermesUserHome = os.UserHomeDir
hermesOllamaURL = envconfig.ConnectableHost
)
var hermesMessagingEnvGroups = [][]string{
{"TELEGRAM_BOT_TOKEN"},
{"DISCORD_BOT_TOKEN"},
{"SLACK_BOT_TOKEN"},
{"SIGNAL_ACCOUNT"},
{"EMAIL_ADDRESS"},
{"TWILIO_ACCOUNT_SID"},
{"MATRIX_ACCESS_TOKEN", "MATRIX_PASSWORD"},
{"MATTERMOST_TOKEN"},
{"WHATSAPP_PHONE_NUMBER_ID"},
{"DINGTALK_CLIENT_ID"},
{"FEISHU_APP_ID"},
{"WECOM_BOT_ID"},
{"WEIXIN_ACCOUNT_ID"},
{"BLUEBUBBLES_SERVER_URL"},
{"WEBHOOK_ENABLED"},
}
// Hermes is intentionally not an Editor integration: launch owns one primary
// model and the local Ollama endpoint, while Hermes keeps its own discovery and
// switching UX after startup.
type Hermes struct{}
func (h *Hermes) String() string { return "Hermes Agent" }
func (h *Hermes) Run(_ string, args []string) error {
// Hermes reads its primary model from config.yaml. launch configures that
// default model ahead of time so we can keep runtime invocation simple and
// still let Hermes discover additional models later via its own UX.
bin, err := h.binary()
if err != nil {
return err
}
if err := h.runGatewaySetupPreflight(args, func() error {
return hermesAttachedCommand(bin, "gateway", "setup").Run()
}); err != nil {
return err
}
return hermesAttachedCommand(bin, args...).Run()
}
func (h *Hermes) Paths() []string {
configPath, err := hermesConfigPath()
if err != nil {
return nil
}
return []string{configPath}
}
func (h *Hermes) Configure(model string) error {
configPath, err := hermesConfigPath()
if err != nil {
return err
}
cfg := map[string]any{}
if data, err := os.ReadFile(configPath); err == nil {
if err := yaml.Unmarshal(data, &cfg); err != nil {
return fmt.Errorf("parse hermes config: %w", err)
}
} else if !os.IsNotExist(err) {
return err
}
modelSection, _ := cfg["model"].(map[string]any)
if modelSection == nil {
modelSection = make(map[string]any)
}
models := h.listModels(model)
applyHermesManagedProviders(cfg, hermesBaseURL(), model, models)
// launch writes the minimum provider/default-model settings needed to
// bootstrap Hermes against Ollama. The active provider stays on a
// launch-owned key so /model stays aligned with the launcher-managed entry,
// and the Ollama endpoint lives in providers: so the picker shows one row.
modelSection["provider"] = hermesProviderKey
modelSection["default"] = model
modelSection["base_url"] = hermesBaseURL()
modelSection["api_key"] = hermesPlaceholderKey
cfg["model"] = modelSection
// use Hermes' built-in web toolset for now.
// TODO(parthsareen): move this to using Ollama web search
cfg["toolsets"] = mergeHermesToolsets(cfg["toolsets"])
data, err := yaml.Marshal(cfg)
if err != nil {
return err
}
if err := os.MkdirAll(filepath.Dir(configPath), 0o755); err != nil {
return err
}
return fileutil.WriteWithBackup(configPath, data)
}
func (h *Hermes) CurrentModel() string {
configPath, err := hermesConfigPath()
if err != nil {
return ""
}
data, err := os.ReadFile(configPath)
if err != nil {
return ""
}
cfg := map[string]any{}
if yaml.Unmarshal(data, &cfg) != nil {
return ""
}
return hermesManagedCurrentModel(cfg, hermesBaseURL())
}
func (h *Hermes) Onboard() error {
return config.MarkIntegrationOnboarded("hermes")
}
func (h *Hermes) RequiresInteractiveOnboarding() bool {
return false
}
func (h *Hermes) RefreshRuntimeAfterConfigure() error {
running, err := h.gatewayRunning()
if err != nil {
return fmt.Errorf("check Hermes gateway status: %w", err)
}
if !running {
return nil
}
fmt.Fprintf(os.Stderr, "%sRefreshing Hermes messaging gateway...%s\n", ansiGray, ansiReset)
if err := h.restartGateway(); err != nil {
return fmt.Errorf("restart Hermes gateway: %w", err)
}
fmt.Fprintln(os.Stderr)
return nil
}
func (h *Hermes) installed() bool {
_, err := h.binary()
return err == nil
}
func (h *Hermes) ensureInstalled() error {
if h.installed() {
return nil
}
if hermesGOOS == "windows" {
return hermesWindowsHint()
}
var missing []string
for _, dep := range []string{"bash", "curl", "git"} {
if _, err := hermesLookPath(dep); err != nil {
missing = append(missing, dep)
}
}
if len(missing) > 0 {
return fmt.Errorf("Hermes is not installed and required dependencies are missing\n\nInstall the following first:\n %s\n\nThen re-run:\n ollama launch hermes", strings.Join(missing, "\n "))
}
ok, err := ConfirmPrompt("Hermes is not installed. Install now?")
if err != nil {
return err
}
if !ok {
return fmt.Errorf("hermes installation cancelled")
}
fmt.Fprintf(os.Stderr, "\nInstalling Hermes...\n")
if err := hermesAttachedCommand("bash", "-lc", hermesInstallScript).Run(); err != nil {
return fmt.Errorf("failed to install hermes: %w", err)
}
if !h.installed() {
return fmt.Errorf("hermes was installed but the binary was not found on PATH\n\nYou may need to restart your shell")
}
fmt.Fprintf(os.Stderr, "%sHermes installed successfully%s\n\n", ansiGreen, ansiReset)
return nil
}
func (h *Hermes) listModels(defaultModel string) []string {
client := hermesOllamaClient()
resp, err := client.List(context.Background())
if err != nil {
return []string{defaultModel}
}
models := make([]string, 0, len(resp.Models)+1)
seen := make(map[string]struct{}, len(resp.Models)+1)
add := func(name string) {
name = strings.TrimSpace(name)
if name == "" {
return
}
if _, ok := seen[name]; ok {
return
}
seen[name] = struct{}{}
models = append(models, name)
}
add(defaultModel)
for _, entry := range resp.Models {
add(entry.Name)
}
if len(models) == 0 {
return []string{defaultModel}
}
return models
}
func (h *Hermes) binary() (string, error) {
if path, err := hermesLookPath("hermes"); err == nil {
return path, nil
}
if hermesGOOS == "windows" {
return "", hermesWindowsHint()
}
home, err := hermesUserHome()
if err != nil {
return "", err
}
fallback := filepath.Join(home, ".local", "bin", "hermes")
if _, err := os.Stat(fallback); err == nil {
return fallback, nil
}
return "", fmt.Errorf("hermes is not installed")
}
func hermesConfigPath() (string, error) {
home, err := hermesUserHome()
if err != nil {
return "", err
}
return filepath.Join(home, ".hermes", "config.yaml"), nil
}
func hermesBaseURL() string {
return strings.TrimRight(hermesOllamaURL().String(), "/") + "/v1"
}
func hermesEnvPath() (string, error) {
home, err := hermesUserHome()
if err != nil {
return "", err
}
return filepath.Join(home, ".hermes", ".env"), nil
}
func (h *Hermes) runGatewaySetupPreflight(args []string, runSetup func() error) error {
if len(args) > 0 || !isInteractiveSession() || currentLaunchConfirmPolicy.yes || currentLaunchConfirmPolicy.requireYesMessage {
return nil
}
if h.messagingConfigured() {
return nil
}
fmt.Fprintf(os.Stderr, "\nHermes can message you on Telegram, Discord, Slack, and more.\n\n")
ok, err := ConfirmPromptWithOptions(hermesGatewaySetupTitle, ConfirmOptions{
YesLabel: "Yes",
NoLabel: "Set up later",
})
if err != nil {
return err
}
if !ok {
return nil
}
if err := runSetup(); err != nil {
return fmt.Errorf("hermes messaging setup failed: %w\n\nTry running: %s", err, hermesGatewaySetupHint)
}
return nil
}
func (h *Hermes) messagingConfigured() bool {
envVars, err := h.gatewayEnvVars()
if err != nil {
return false
}
for _, group := range hermesMessagingEnvGroups {
for _, key := range group {
if strings.TrimSpace(envVars[key]) != "" {
return true
}
}
}
return false
}
func (h *Hermes) gatewayEnvVars() (map[string]string, error) {
envVars := make(map[string]string)
envFilePath, err := hermesEnvPath()
if err != nil {
return nil, err
}
switch data, err := os.ReadFile(envFilePath); {
case err == nil:
for key, value := range hermesParseEnvFile(data) {
envVars[key] = value
}
case os.IsNotExist(err):
// nothing persisted yet
default:
return nil, err
}
for _, group := range hermesMessagingEnvGroups {
for _, key := range group {
if value, ok := os.LookupEnv(key); ok {
envVars[key] = value
}
}
}
return envVars, nil
}
func (h *Hermes) gatewayRunning() (bool, error) {
status, err := h.gatewayStatusOutput()
if err != nil {
return false, err
}
return hermesGatewayStatusRunning(status), nil
}
func (h *Hermes) gatewayStatusOutput() (string, error) {
bin, err := h.binary()
if err != nil {
return "", err
}
out, err := hermesCommand(bin, "gateway", "status").CombinedOutput()
return string(out), err
}
func (h *Hermes) restartGateway() error {
bin, err := h.binary()
if err != nil {
return err
}
return hermesAttachedCommand(bin, "gateway", "restart").Run()
}
func hermesGatewayStatusRunning(output string) bool {
status := strings.ToLower(output)
switch {
case strings.Contains(status, "gateway is not running"):
return false
case strings.Contains(status, "gateway service is stopped"):
return false
case strings.Contains(status, "gateway service is not loaded"):
return false
case strings.Contains(status, "gateway is running"):
return true
case strings.Contains(status, "gateway service is running"):
return true
case strings.Contains(status, "gateway service is loaded"):
return true
default:
return false
}
}
func hermesParseEnvFile(data []byte) map[string]string {
out := make(map[string]string)
scanner := bufio.NewScanner(bytes.NewReader(data))
for scanner.Scan() {
line := strings.TrimSpace(strings.TrimPrefix(scanner.Text(), "\ufeff"))
if line == "" || strings.HasPrefix(line, "#") {
continue
}
if strings.HasPrefix(line, "export ") {
line = strings.TrimSpace(strings.TrimPrefix(line, "export "))
}
key, value, ok := strings.Cut(line, "=")
if !ok {
continue
}
key = strings.TrimSpace(key)
if key == "" {
continue
}
value = strings.TrimSpace(value)
if len(value) >= 2 {
switch {
case value[0] == '"' && value[len(value)-1] == '"':
if unquoted, err := strconv.Unquote(value); err == nil {
value = unquoted
}
case value[0] == '\'' && value[len(value)-1] == '\'':
value = value[1 : len(value)-1]
}
}
out[key] = value
}
return out
}
func hermesOllamaClient() *api.Client {
// Hermes queries the same launch-resolved Ollama host that launch writes
// into config, so model discovery follows the configured endpoint.
return api.NewClient(hermesOllamaURL(), http.DefaultClient)
}
func applyHermesManagedProviders(cfg map[string]any, baseURL string, model string, models []string) {
providers := hermesUserProviders(cfg["providers"])
entry := hermesManagedProviderEntry(providers)
if entry == nil {
entry = make(map[string]any)
}
entry["name"] = hermesProviderName
entry["api"] = baseURL
entry["default_model"] = model
entry["models"] = hermesStringListAny(models)
providers[hermesProviderKey] = entry
delete(providers, hermesLegacyKey)
cfg["providers"] = providers
customProviders := hermesWithoutManagedCustomProviders(cfg["custom_providers"])
if len(customProviders) == 0 {
delete(cfg, "custom_providers")
return
}
cfg["custom_providers"] = customProviders
}
func hermesManagedCurrentModel(cfg map[string]any, baseURL string) string {
modelCfg, _ := cfg["model"].(map[string]any)
if modelCfg == nil {
return ""
}
provider, _ := modelCfg["provider"].(string)
if strings.TrimSpace(strings.ToLower(provider)) != hermesProviderKey {
return ""
}
configBaseURL, _ := modelCfg["base_url"].(string)
if hermesNormalizeURL(configBaseURL) != hermesNormalizeURL(baseURL) {
return ""
}
current, _ := modelCfg["default"].(string)
current = strings.TrimSpace(current)
if current == "" {
return ""
}
providers := hermesUserProviders(cfg["providers"])
entry, _ := providers[hermesProviderKey].(map[string]any)
if entry == nil {
return ""
}
if hermesHasManagedCustomProvider(cfg["custom_providers"]) {
return ""
}
apiURL, _ := entry["api"].(string)
if hermesNormalizeURL(apiURL) != hermesNormalizeURL(baseURL) {
return ""
}
defaultModel, _ := entry["default_model"].(string)
if strings.TrimSpace(defaultModel) != current {
return ""
}
return current
}
func hermesUserProviders(current any) map[string]any {
switch existing := current.(type) {
case map[string]any:
out := make(map[string]any, len(existing))
for key, value := range existing {
out[key] = value
}
return out
case map[any]any:
out := make(map[string]any, len(existing))
for key, value := range existing {
if s, ok := key.(string); ok {
out[s] = value
}
}
return out
default:
return make(map[string]any)
}
}
func hermesCustomProviders(current any) []any {
switch existing := current.(type) {
case []any:
return append([]any(nil), existing...)
case []map[string]any:
out := make([]any, 0, len(existing))
for _, entry := range existing {
out = append(out, entry)
}
return out
default:
return nil
}
}
func hermesManagedProviderEntry(providers map[string]any) map[string]any {
for _, key := range []string{hermesProviderKey, hermesLegacyKey} {
if entry, _ := providers[key].(map[string]any); entry != nil {
return entry
}
}
return nil
}
func hermesWithoutManagedCustomProviders(current any) []any {
customProviders := hermesCustomProviders(current)
preserved := make([]any, 0, len(customProviders))
for _, item := range customProviders {
entry, _ := item.(map[string]any)
if entry == nil {
preserved = append(preserved, item)
continue
}
if hermesManagedCustomProvider(entry) {
continue
}
preserved = append(preserved, entry)
}
return preserved
}
func hermesHasManagedCustomProvider(current any) bool {
for _, item := range hermesCustomProviders(current) {
entry, _ := item.(map[string]any)
if entry != nil && hermesManagedCustomProvider(entry) {
return true
}
}
return false
}
func hermesManagedCustomProvider(entry map[string]any) bool {
name, _ := entry["name"].(string)
return strings.EqualFold(strings.TrimSpace(name), hermesProviderName)
}
func hermesNormalizeURL(raw string) string {
return strings.TrimRight(strings.TrimSpace(raw), "/")
}
func hermesStringListAny(models []string) []any {
out := make([]any, 0, len(models))
for _, model := range dedupeModelList(models) {
model = strings.TrimSpace(model)
if model == "" {
continue
}
out = append(out, model)
}
return out
}
func mergeHermesToolsets(current any) any {
added := false
switch existing := current.(type) {
case []any:
out := make([]any, 0, len(existing)+1)
for _, item := range existing {
out = append(out, item)
if s, _ := item.(string); s == "web" {
added = true
}
}
if !added {
out = append(out, "web")
}
return out
case []string:
out := append([]string(nil), existing...)
if !slices.Contains(out, "web") {
out = append(out, "web")
}
asAny := make([]any, 0, len(out))
for _, item := range out {
asAny = append(asAny, item)
}
return asAny
case string:
if strings.TrimSpace(existing) == "" {
return []any{"hermes-cli", "web"}
}
parts := strings.Split(existing, ",")
out := make([]any, 0, len(parts)+1)
for _, part := range parts {
part = strings.TrimSpace(part)
if part == "" {
continue
}
if part == "web" {
added = true
}
out = append(out, part)
}
if !added {
out = append(out, "web")
}
return out
default:
return []any{"hermes-cli", "web"}
}
}
func hermesAttachedCommand(name string, args ...string) *exec.Cmd {
cmd := hermesCommand(name, args...)
cmd.Stdin = os.Stdin
cmd.Stdout = os.Stdout
cmd.Stderr = os.Stderr
return cmd
}
func hermesWindowsHint() error {
return fmt.Errorf("Hermes on Windows requires WSL2. Install WSL with: wsl --install\n" +
"Then run 'ollama launch hermes' from inside your WSL shell.\n" +
"Docs: https://hermes-agent.nousresearch.com/docs/getting-started/installation/")
}

1110
cmd/launch/hermes_test.go Normal file

File diff suppressed because it is too large Load Diff

View File

@@ -14,6 +14,7 @@ import (
"github.com/google/go-cmp/cmp"
"github.com/ollama/ollama/api"
"github.com/ollama/ollama/cmd/config"
)
type stubEditorRunner struct {
@@ -73,7 +74,7 @@ func TestIntegrationLookup(t *testing.T) {
}
func TestIntegrationRegistry(t *testing.T) {
expectedIntegrations := []string{"claude", "codex", "droid", "opencode"}
expectedIntegrations := []string{"claude", "codex", "droid", "opencode", "hermes"}
for _, name := range expectedIntegrations {
t.Run(name, func(t *testing.T) {
@@ -328,7 +329,7 @@ func TestBuildModelList_NoExistingModels(t *testing.T) {
}
}
func TestBuildModelList_OnlyLocalModels_CloudRecsAtBottom(t *testing.T) {
func TestBuildModelList_OnlyLocalModels_CloudRecsStillFirst(t *testing.T) {
existing := []modelInfo{
{Name: "llama3.2:latest", Remote: false},
{Name: "qwen2.5:latest", Remote: false},
@@ -337,10 +338,11 @@ func TestBuildModelList_OnlyLocalModels_CloudRecsAtBottom(t *testing.T) {
items, _, _, _ := buildModelList(existing, nil, "")
got := names(items)
// Recommended pinned at top (local recs first, then cloud recs when only-local), then installed non-recs
want := []string{"gemma4", "qwen3.5", "kimi-k2.5:cloud", "qwen3.5:cloud", "glm-5.1:cloud", "minimax-m2.7:cloud", "llama3.2", "qwen2.5"}
// Cloud recs always come first among recommended, regardless of installed inventory.
// Cloud disablement is handled upstream in loadSelectableModels via filterCloudItems.
want := []string{"kimi-k2.5:cloud", "qwen3.5:cloud", "glm-5.1:cloud", "minimax-m2.7:cloud", "gemma4", "qwen3.5", "llama3.2", "qwen2.5"}
if diff := cmp.Diff(want, got); diff != "" {
t.Errorf("recs pinned at top, local recs before cloud recs (-want +got):\n%s", diff)
t.Errorf("cloud recs pinned first even when no cloud models installed (-want +got):\n%s", diff)
}
}
@@ -587,7 +589,7 @@ func TestBuildModelList_MixedCase_CloudRecsFirst(t *testing.T) {
}
}
func TestBuildModelList_OnlyLocal_LocalRecsFirst(t *testing.T) {
func TestBuildModelList_OnlyLocal_CloudRecsStillFirst(t *testing.T) {
existing := []modelInfo{
{Name: "llama3.2:latest", Remote: false},
}
@@ -595,11 +597,11 @@ func TestBuildModelList_OnlyLocal_LocalRecsFirst(t *testing.T) {
items, _, _, _ := buildModelList(existing, nil, "")
got := names(items)
// Local recs should sort before cloud recs in only-local case
// Cloud recs sort before local recs regardless of installed inventory.
localIdx := slices.Index(got, "gemma4")
cloudIdx := slices.Index(got, "glm-5.1:cloud")
if localIdx > cloudIdx {
t.Errorf("local recs should be before cloud recs in only-local case, got %v", got)
if cloudIdx > localIdx {
t.Errorf("cloud recs should be before local recs even when only local models installed, got %v", got)
}
}
@@ -722,6 +724,59 @@ func TestLauncherClientFilterDisabledCloudModels_ChecksStatusOncePerInvocation(t
}
}
func TestSavedMatchesModels(t *testing.T) {
tests := []struct {
name string
saved *config.IntegrationConfig
models []string
want bool
}{
{
name: "nil saved",
saved: nil,
models: []string{"llama3.2"},
want: false,
},
{
name: "identical order",
saved: &config.IntegrationConfig{Models: []string{"llama3.2", "qwen3:8b"}},
models: []string{"llama3.2", "qwen3:8b"},
want: true,
},
{
name: "different order",
saved: &config.IntegrationConfig{Models: []string{"llama3.2", "qwen3:8b"}},
models: []string{"qwen3:8b", "llama3.2"},
want: false,
},
{
name: "subset",
saved: &config.IntegrationConfig{Models: []string{"llama3.2", "qwen3:8b"}},
models: []string{"llama3.2"},
want: false,
},
{
name: "nil models in saved with non-nil models",
saved: &config.IntegrationConfig{Models: nil},
models: []string{"llama3.2"},
want: false,
},
{
name: "empty both",
saved: &config.IntegrationConfig{Models: nil},
models: nil,
want: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
if got := savedMatchesModels(tt.saved, tt.models); got != tt.want {
t.Fatalf("savedMatchesModels = %v, want %v", got, tt.want)
}
})
}
}
func TestPrepareEditorIntegration_SavesOnlyAfterSuccessfulEdit(t *testing.T) {
tmpDir := t.TempDir()
setTestHome(t, tmpDir)
@@ -1455,27 +1510,13 @@ func TestListIntegrationInfos(t *testing.T) {
}
})
t.Run("sorted with custom order at end", func(t *testing.T) {
// integrationOrder entries (cline, opencode) should appear last, in that order.
// All other entries should be sorted alphabetically before them.
orderRank := make(map[string]int)
for i, name := range integrationOrder {
orderRank[name] = i + 1
t.Run("follows launcher order", func(t *testing.T) {
got := make([]string, 0, len(infos))
for _, info := range infos {
got = append(got, info.Name)
}
for i := 1; i < len(infos); i++ {
aRank, bRank := orderRank[infos[i-1].Name], orderRank[infos[i].Name]
switch {
case aRank == 0 && bRank == 0:
if infos[i-1].Name >= infos[i].Name {
t.Errorf("non-ordered items not sorted: %q >= %q", infos[i-1].Name, infos[i].Name)
}
case aRank > 0 && bRank == 0:
t.Errorf("ordered item %q should come after non-ordered %q", infos[i-1].Name, infos[i].Name)
case aRank > 0 && bRank > 0:
if aRank >= bRank {
t.Errorf("ordered items wrong: %q (rank %d) before %q (rank %d)", infos[i-1].Name, aRank, infos[i].Name, bRank)
}
}
if diff := compareStrings(got, integrationOrder); diff != "" {
t.Fatalf("launcher integration order mismatch: %s", diff)
}
})
@@ -1503,6 +1544,28 @@ func TestListIntegrationInfos(t *testing.T) {
}
}
})
t.Run("includes hermes", func(t *testing.T) {
for _, info := range infos {
if info.Name == "hermes" {
return
}
}
t.Fatal("expected hermes to be included in ListIntegrationInfos")
})
t.Run("hermes still resolves explicitly", func(t *testing.T) {
name, runner, err := LookupIntegration("hermes")
if err != nil {
t.Fatalf("expected explicit hermes integration lookup to work, got %v", err)
}
if name != "hermes" {
t.Fatalf("expected canonical name hermes, got %q", name)
}
if runner.String() == "" {
t.Fatal("expected hermes integration runner to be present")
}
})
}
func TestBuildModelList_Descriptions(t *testing.T) {
@@ -1591,6 +1654,7 @@ func TestIntegration_AutoInstallable(t *testing.T) {
}{
{"openclaw", true},
{"pi", true},
{"hermes", true},
{"claude", false},
{"codex", false},
{"opencode", false},

View File

@@ -6,6 +6,7 @@ import (
"fmt"
"net/http"
"os"
"slices"
"strings"
"github.com/ollama/ollama/api"
@@ -140,6 +141,36 @@ type Editor interface {
Models() []string
}
// ManagedSingleModel is the narrow launch-owned config path for integrations
// like Hermes that have one primary model selected by launcher, need launcher
// to persist minimal config, and still keep their own model discovery and
// onboarding UX. This stays separate from Runner-only integrations and the
// multi-model Editor flow so Hermes-specific behavior stays scoped to one path.
type ManagedSingleModel interface {
Paths() []string
Configure(model string) error
CurrentModel() string
Onboard() error
}
// ManagedRuntimeRefresher lets managed integrations refresh any long-lived
// background runtime after launch rewrites their config.
type ManagedRuntimeRefresher interface {
RefreshRuntimeAfterConfigure() error
}
// ManagedOnboardingValidator lets managed integrations re-check saved
// onboarding state when launcher needs a stronger live readiness signal.
type ManagedOnboardingValidator interface {
OnboardingComplete() bool
}
// ManagedInteractiveOnboarding lets a managed integration declare whether its
// onboarding step really requires an interactive terminal. Hermes does not.
type ManagedInteractiveOnboarding interface {
RequiresInteractiveOnboarding() bool
}
type modelInfo struct {
Name string
Remote bool
@@ -175,7 +206,9 @@ Supported integrations:
claude Claude Code
cline Cline
codex Codex
copilot Copilot CLI (aliases: copilot-cli)
droid Droid
hermes Hermes Agent
opencode OpenCode
openclaw OpenClaw (aliases: clawdbot, moltbot)
pi Pi
@@ -185,6 +218,7 @@ Examples:
ollama launch
ollama launch claude
ollama launch claude --model <model>
ollama launch hermes
ollama launch droid --config (does not auto-launch)
ollama launch codex -- -p myprofile (pass extra args to integration)
ollama launch codex -- --sandbox workspace-write`,
@@ -307,36 +341,54 @@ func LaunchIntegration(ctx context.Context, req IntegrationLaunchRequest) error
if err != nil {
return err
}
policy := launchIntegrationPolicy(req)
if policy.Confirm == LaunchConfirmAutoApprove && !isInteractiveSession() && req.ModelOverride == "" {
return fmt.Errorf("headless --yes launch for %s requires --model <model>", name)
}
launchClient, saved, err := prepareIntegrationLaunch(name, policy)
if err != nil {
return err
}
if managed, ok := runner.(ManagedSingleModel); ok {
if err := EnsureIntegrationInstalled(name, runner); err != nil {
return err
}
return launchClient.launchManagedSingleIntegration(ctx, name, runner, managed, saved, req)
}
if !req.ConfigureOnly {
if err := EnsureIntegrationInstalled(name, runner); err != nil {
return err
}
}
var policy LaunchPolicy
// TUI does not set a policy, whereas ollama launch <app> does as it can have flags which change the behavior
if req.Policy == nil {
policy = defaultLaunchPolicy(isInteractiveSession(), false)
} else {
policy = *req.Policy
}
launchClient, err := newLauncherClient(policy)
if err != nil {
return err
}
saved, _ := loadStoredIntegrationConfig(name)
// In headless --yes mode we cannot prompt, so require an explicit --model.
if policy.Confirm == LaunchConfirmAutoApprove && !isInteractiveSession() && req.ModelOverride == "" {
return fmt.Errorf("headless --yes launch for %s requires --model <model>", name)
}
if editor, ok := runner.(Editor); ok {
return launchClient.launchEditorIntegration(ctx, name, runner, editor, saved, req)
}
return launchClient.launchSingleIntegration(ctx, name, runner, saved, req)
}
func launchIntegrationPolicy(req IntegrationLaunchRequest) LaunchPolicy {
// TUI does not set a policy, whereas ollama launch <app> does as it can
// have flags which change the behavior.
if req.Policy != nil {
return *req.Policy
}
return defaultLaunchPolicy(isInteractiveSession(), false)
}
func prepareIntegrationLaunch(name string, policy LaunchPolicy) (*launcherClient, *config.IntegrationConfig, error) {
launchClient, err := newLauncherClient(policy)
if err != nil {
return nil, nil, err
}
saved, _ := loadStoredIntegrationConfig(name)
return launchClient, saved, nil
}
func (c *launcherClient) buildLauncherState(ctx context.Context) (*LauncherState, error) {
_ = c.loadModelInventoryOnce(ctx)
@@ -367,9 +419,18 @@ func (c *launcherClient) buildLauncherIntegrationState(ctx context.Context, info
if err != nil {
return LauncherIntegrationState{}, err
}
currentModel, usable, err := c.launcherModelState(ctx, info.Name, integration.editor)
if err != nil {
return LauncherIntegrationState{}, err
var currentModel string
var usable bool
if managed, ok := integration.spec.Runner.(ManagedSingleModel); ok {
currentModel, usable, err = c.launcherManagedModelState(ctx, info.Name, managed)
if err != nil {
return LauncherIntegrationState{}, err
}
} else {
currentModel, usable, err = c.launcherModelState(ctx, info.Name, integration.editor)
if err != nil {
return LauncherIntegrationState{}, err
}
}
return LauncherIntegrationState{
@@ -407,6 +468,28 @@ func (c *launcherClient) launcherModelState(ctx context.Context, name string, is
return model, usableErr == nil && usable, nil
}
func (c *launcherClient) launcherManagedModelState(ctx context.Context, name string, managed ManagedSingleModel) (string, bool, error) {
current := managed.CurrentModel()
if current == "" {
cfg, loadErr := loadStoredIntegrationConfig(name)
if loadErr == nil {
current = primaryModelFromConfig(cfg)
}
if current != "" {
return current, false, nil
}
}
if current == "" {
return "", false, nil
}
usable, err := c.savedModelUsable(ctx, current)
if err != nil {
return current, false, err
}
return current, usable, nil
}
func (c *launcherClient) resolveRunModel(ctx context.Context, req RunModelRequest) (string, error) {
current := config.LastModel()
if !req.ForcePicker && current != "" && c.policy.Confirm == LaunchConfirmAutoApprove && !isInteractiveSession() {
@@ -443,35 +526,15 @@ func (c *launcherClient) resolveRunModel(ctx context.Context, req RunModelReques
}
func (c *launcherClient) launchSingleIntegration(ctx context.Context, name string, runner Runner, saved *config.IntegrationConfig, req IntegrationLaunchRequest) error {
current := primaryModelFromConfig(saved)
target := req.ModelOverride
needsConfigure := req.ForceConfigure
if target == "" {
target = current
usable, err := c.savedModelUsable(ctx, target)
if err != nil {
return err
}
if !usable {
needsConfigure = true
}
}
if needsConfigure {
selected, err := c.selectSingleModelWithSelector(ctx, fmt.Sprintf("Select model for %s:", runner), target, DefaultSingleSelector)
if err != nil {
return err
}
target = selected
} else if err := c.ensureModelsReady(ctx, []string{target}); err != nil {
target, _, err := c.resolveSingleIntegrationTarget(ctx, runner, primaryModelFromConfig(saved), req)
if err != nil {
return err
}
if target == "" {
return nil
}
current := primaryModelFromConfig(saved)
if target != current {
if err := config.SaveIntegration(name, []string{target}); err != nil {
return fmt.Errorf("failed to save: %w", err)
@@ -500,7 +563,7 @@ func (c *launcherClient) launchEditorIntegration(ctx context.Context, name strin
return nil
}
if needsConfigure || req.ModelOverride != "" {
if (needsConfigure || req.ModelOverride != "") && !savedMatchesModels(saved, models) {
if err := prepareEditorIntegration(name, runner, editor, models); err != nil {
return err
}
@@ -509,6 +572,102 @@ func (c *launcherClient) launchEditorIntegration(ctx context.Context, name strin
return launchAfterConfiguration(name, runner, models[0], req)
}
func (c *launcherClient) launchManagedSingleIntegration(ctx context.Context, name string, runner Runner, managed ManagedSingleModel, saved *config.IntegrationConfig, req IntegrationLaunchRequest) error {
current := managed.CurrentModel()
selectionCurrent := current
if selectionCurrent == "" {
selectionCurrent = primaryModelFromConfig(saved)
}
target, needsConfigure, err := c.resolveSingleIntegrationTarget(ctx, runner, selectionCurrent, req)
if err != nil {
return err
}
if target == "" {
return nil
}
if (current == "" || needsConfigure || req.ModelOverride != "" || target != current) && !savedMatchesModels(saved, []string{target}) {
if err := prepareManagedSingleIntegration(name, runner, managed, target); err != nil {
return err
}
if refresher, ok := managed.(ManagedRuntimeRefresher); ok {
if err := refresher.RefreshRuntimeAfterConfigure(); err != nil {
return err
}
}
}
if !managedIntegrationOnboarded(saved, managed) {
if !isInteractiveSession() && managedRequiresInteractiveOnboarding(managed) {
return fmt.Errorf("%s still needs interactive gateway setup; run 'ollama launch %s' in a terminal to finish onboarding", runner, name)
}
if err := managed.Onboard(); err != nil {
return err
}
}
if req.ConfigureOnly {
return nil
}
return runIntegration(runner, target, req.ExtraArgs)
}
func (c *launcherClient) resolveSingleIntegrationTarget(ctx context.Context, runner Runner, current string, req IntegrationLaunchRequest) (string, bool, error) {
target := req.ModelOverride
needsConfigure := req.ForceConfigure
if target == "" {
target = current
usable, err := c.savedModelUsable(ctx, target)
if err != nil {
return "", false, err
}
if !usable {
needsConfigure = true
}
}
if needsConfigure {
selected, err := c.selectSingleModelWithSelector(ctx, fmt.Sprintf("Select model for %s:", runner), target, DefaultSingleSelector)
if err != nil {
return "", false, err
}
target = selected
} else if err := c.ensureModelsReady(ctx, []string{target}); err != nil {
return "", false, err
}
return target, needsConfigure, nil
}
func savedIntegrationOnboarded(saved *config.IntegrationConfig) bool {
return saved != nil && saved.Onboarded
}
func managedIntegrationOnboarded(saved *config.IntegrationConfig, managed ManagedSingleModel) bool {
if !savedIntegrationOnboarded(saved) {
return false
}
validator, ok := managed.(ManagedOnboardingValidator)
if !ok {
return true
}
return validator.OnboardingComplete()
}
// Most managed integrations treat onboarding as an interactive terminal step.
// Hermes opts out because its launch-owned onboarding is just bookkeeping, so
// headless launches should not be blocked once config is already prepared.
func managedRequiresInteractiveOnboarding(managed ManagedSingleModel) bool {
onboarding, ok := managed.(ManagedInteractiveOnboarding)
if !ok {
return true
}
return onboarding.RequiresInteractiveOnboarding()
}
func (c *launcherClient) selectSingleModelWithSelector(ctx context.Context, title, current string, selector SingleSelector) (string, error) {
if selector == nil {
return "", fmt.Errorf("no selector configured")
@@ -846,6 +1005,13 @@ func firstModel(models []string) string {
return models[0]
}
func savedMatchesModels(saved *config.IntegrationConfig, models []string) bool {
if saved == nil {
return false
}
return slices.Equal(saved.Models, models)
}
func editorPreCheckedModels(saved *config.IntegrationConfig, override string) []string {
if override == "" {
if saved == nil {

View File

@@ -49,6 +49,55 @@ func (r *launcherSingleRunner) Run(model string, args []string) error {
func (r *launcherSingleRunner) String() string { return "StubSingle" }
type launcherManagedRunner struct {
paths []string
currentModel string
configured []string
ranModel string
onboarded bool
onboardCalls int
onboardingComplete bool
refreshCalls int
refreshErr error
}
func (r *launcherManagedRunner) Run(model string, args []string) error {
r.ranModel = model
return nil
}
func (r *launcherManagedRunner) String() string { return "StubManaged" }
func (r *launcherManagedRunner) Paths() []string { return r.paths }
func (r *launcherManagedRunner) Configure(model string) error {
r.configured = append(r.configured, model)
r.currentModel = model
return nil
}
func (r *launcherManagedRunner) CurrentModel() string { return r.currentModel }
func (r *launcherManagedRunner) Onboard() error {
r.onboardCalls++
r.onboarded = true
r.onboardingComplete = true
return nil
}
func (r *launcherManagedRunner) OnboardingComplete() bool { return r.onboardingComplete }
func (r *launcherManagedRunner) RefreshRuntimeAfterConfigure() error {
r.refreshCalls++
return r.refreshErr
}
type launcherHeadlessManagedRunner struct {
launcherManagedRunner
}
func (r *launcherHeadlessManagedRunner) RequiresInteractiveOnboarding() bool { return false }
func setLaunchTestHome(t *testing.T, dir string) {
t.Helper()
t.Setenv("HOME", dir)
@@ -141,6 +190,451 @@ func TestDefaultLaunchPolicy(t *testing.T) {
}
}
func TestBuildLauncherState_ManagedSingleIntegrationUsesCurrentModel(t *testing.T) {
tmpDir := t.TempDir()
setLaunchTestHome(t, tmpDir)
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
switch r.URL.Path {
case "/api/tags":
fmt.Fprint(w, `{"models":[{"name":"gemma4"}]}`)
case "/api/show":
fmt.Fprint(w, `{"model_info":{"general.context_length":131072}}`)
default:
http.NotFound(w, r)
}
}))
defer srv.Close()
t.Setenv("OLLAMA_HOST", srv.URL)
runner := &launcherManagedRunner{currentModel: "gemma4"}
withIntegrationOverride(t, "pi", runner)
state, err := BuildLauncherState(context.Background())
if err != nil {
t.Fatalf("BuildLauncherState returned error: %v", err)
}
if state.Integrations["pi"].CurrentModel != "gemma4" {
t.Fatalf("expected managed current model from integration config, got %q", state.Integrations["pi"].CurrentModel)
}
if !state.Integrations["pi"].ModelUsable {
t.Fatal("expected managed current model to be usable")
}
}
func TestBuildLauncherState_ManagedSingleIntegrationShowsSavedModelWhenLiveConfigMissing(t *testing.T) {
tmpDir := t.TempDir()
setLaunchTestHome(t, tmpDir)
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
switch r.URL.Path {
case "/api/tags":
fmt.Fprint(w, `{"models":[{"name":"gemma4"}]}`)
case "/api/show":
fmt.Fprint(w, `{"model_info":{"general.context_length":131072}}`)
default:
http.NotFound(w, r)
}
}))
defer srv.Close()
t.Setenv("OLLAMA_HOST", srv.URL)
if err := config.SaveIntegration("pi", []string{"gemma4"}); err != nil {
t.Fatalf("failed to save managed integration config: %v", err)
}
runner := &launcherManagedRunner{}
withIntegrationOverride(t, "pi", runner)
state, err := BuildLauncherState(context.Background())
if err != nil {
t.Fatalf("BuildLauncherState returned error: %v", err)
}
if state.Integrations["pi"].CurrentModel != "gemma4" {
t.Fatalf("expected saved model to remain visible, got %q", state.Integrations["pi"].CurrentModel)
}
if state.Integrations["pi"].ModelUsable {
t.Fatal("expected missing live config to mark managed model unusable")
}
}
func TestLaunchIntegration_ManagedSingleIntegrationConfiguresOnboardsAndRuns(t *testing.T) {
tmpDir := t.TempDir()
setLaunchTestHome(t, tmpDir)
withInteractiveSession(t, true)
withLauncherHooks(t)
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
switch r.URL.Path {
case "/api/tags":
fmt.Fprint(w, `{"models":[{"name":"gemma4"}]}`)
case "/api/show":
fmt.Fprint(w, `{"model_info":{"general.context_length":131072}}`)
default:
http.NotFound(w, r)
}
}))
defer srv.Close()
t.Setenv("OLLAMA_HOST", srv.URL)
runner := &launcherManagedRunner{
paths: nil,
}
withIntegrationOverride(t, "stubmanaged", runner)
DefaultSingleSelector = func(title string, items []ModelItem, current string) (string, error) {
return "gemma4", nil
}
DefaultConfirmPrompt = func(prompt string, options ConfirmOptions) (bool, error) {
return true, nil
}
if err := LaunchIntegration(context.Background(), IntegrationLaunchRequest{Name: "stubmanaged"}); err != nil {
t.Fatalf("LaunchIntegration returned error: %v", err)
}
if diff := compareStrings(runner.configured, []string{"gemma4"}); diff != "" {
t.Fatalf("configured models mismatch: %s", diff)
}
if runner.refreshCalls != 1 {
t.Fatalf("expected runtime refresh once after configure, got %d", runner.refreshCalls)
}
if runner.onboardCalls != 1 {
t.Fatalf("expected onboarding to run once, got %d", runner.onboardCalls)
}
if runner.ranModel != "gemma4" {
t.Fatalf("expected launch to run configured model, got %q", runner.ranModel)
}
saved, err := config.LoadIntegration("stubmanaged")
if err != nil {
t.Fatalf("failed to reload managed integration config: %v", err)
}
if diff := compareStrings(saved.Models, []string{"gemma4"}); diff != "" {
t.Fatalf("saved models mismatch: %s", diff)
}
}
func TestLaunchIntegration_ManagedSingleIntegrationReOnboardsWhenSavedFlagIsStale(t *testing.T) {
tmpDir := t.TempDir()
setLaunchTestHome(t, tmpDir)
withInteractiveSession(t, true)
withLauncherHooks(t)
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
switch r.URL.Path {
case "/api/tags":
fmt.Fprint(w, `{"models":[{"name":"gemma4"}]}`)
case "/api/show":
fmt.Fprint(w, `{"model_info":{"general.context_length":131072}}`)
default:
http.NotFound(w, r)
}
}))
defer srv.Close()
t.Setenv("OLLAMA_HOST", srv.URL)
runner := &launcherManagedRunner{
currentModel: "gemma4",
onboardingComplete: false,
}
withIntegrationOverride(t, "stubmanaged", runner)
if err := config.SaveIntegration("stubmanaged", []string{"gemma4"}); err != nil {
t.Fatalf("failed to save managed integration config: %v", err)
}
if err := config.MarkIntegrationOnboarded("stubmanaged"); err != nil {
t.Fatalf("failed to mark managed integration onboarded: %v", err)
}
if err := LaunchIntegration(context.Background(), IntegrationLaunchRequest{Name: "stubmanaged"}); err != nil {
t.Fatalf("LaunchIntegration returned error: %v", err)
}
if runner.onboardCalls != 1 {
t.Fatalf("expected stale onboarded flag to trigger onboarding, got %d calls", runner.onboardCalls)
}
if runner.refreshCalls != 0 {
t.Fatalf("expected no runtime refresh when config is unchanged, got %d", runner.refreshCalls)
}
if runner.ranModel != "gemma4" {
t.Fatalf("expected launch to run saved model after onboarding repair, got %q", runner.ranModel)
}
}
func TestLaunchIntegration_ManagedSingleIntegrationConfigOnlySkipsFinalRun(t *testing.T) {
tmpDir := t.TempDir()
setLaunchTestHome(t, tmpDir)
withInteractiveSession(t, true)
withLauncherHooks(t)
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
switch r.URL.Path {
case "/api/show":
fmt.Fprint(w, `{"model_info":{"general.context_length":131072}}`)
default:
http.NotFound(w, r)
}
}))
defer srv.Close()
t.Setenv("OLLAMA_HOST", srv.URL)
runner := &launcherManagedRunner{
paths: nil,
}
withIntegrationOverride(t, "stubmanaged", runner)
DefaultConfirmPrompt = func(prompt string, options ConfirmOptions) (bool, error) {
return true, nil
}
if err := LaunchIntegration(context.Background(), IntegrationLaunchRequest{
Name: "stubmanaged",
ModelOverride: "gemma4",
ConfigureOnly: true,
}); err != nil {
t.Fatalf("LaunchIntegration returned error: %v", err)
}
if runner.ranModel != "" {
t.Fatalf("expected configure-only flow to skip final launch, got %q", runner.ranModel)
}
if runner.refreshCalls != 1 {
t.Fatalf("expected configure-only flow to refresh runtime once, got %d", runner.refreshCalls)
}
if runner.onboardCalls != 1 {
t.Fatalf("expected configure-only flow to onboard once, got %d", runner.onboardCalls)
}
}
func TestLaunchIntegration_ManagedSingleIntegrationSkipsRewriteWhenSavedMatches(t *testing.T) {
tmpDir := t.TempDir()
setLaunchTestHome(t, tmpDir)
withInteractiveSession(t, true)
withLauncherHooks(t)
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
switch r.URL.Path {
case "/api/tags":
fmt.Fprint(w, `{"models":[{"name":"gemma4"}]}`)
case "/api/show":
fmt.Fprint(w, `{"model_info":{"general.context_length":131072}}`)
default:
http.NotFound(w, r)
}
}))
defer srv.Close()
t.Setenv("OLLAMA_HOST", srv.URL)
if err := config.SaveIntegration("stubmanaged", []string{"gemma4"}); err != nil {
t.Fatalf("failed to save managed integration config: %v", err)
}
runner := &launcherManagedRunner{}
withIntegrationOverride(t, "stubmanaged", runner)
DefaultSingleSelector = func(title string, items []ModelItem, current string) (string, error) {
t.Fatal("selector should not be called when saved model matches target")
return "", nil
}
DefaultConfirmPrompt = func(prompt string, options ConfirmOptions) (bool, error) {
t.Fatal("confirm prompt should not run when saved model matches target")
return false, nil
}
if err := LaunchIntegration(context.Background(), IntegrationLaunchRequest{Name: "stubmanaged"}); err != nil {
t.Fatalf("LaunchIntegration returned error: %v", err)
}
if len(runner.configured) != 0 {
t.Fatalf("expected Configure to be skipped when saved matches, got %v", runner.configured)
}
if runner.refreshCalls != 0 {
t.Fatalf("expected no runtime refresh when config is unchanged, got %d", runner.refreshCalls)
}
if runner.ranModel != "gemma4" {
t.Fatalf("expected launch to run saved model, got %q", runner.ranModel)
}
}
func TestLaunchIntegration_ManagedSingleIntegrationRewritesWhenSavedDiffers(t *testing.T) {
tmpDir := t.TempDir()
setLaunchTestHome(t, tmpDir)
withInteractiveSession(t, true)
withLauncherHooks(t)
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
switch r.URL.Path {
case "/api/tags":
fmt.Fprint(w, `{"models":[{"name":"gemma4"}]}`)
case "/api/show":
fmt.Fprint(w, `{"model_info":{"general.context_length":131072}}`)
default:
http.NotFound(w, r)
}
}))
defer srv.Close()
t.Setenv("OLLAMA_HOST", srv.URL)
if err := config.SaveIntegration("stubmanaged", []string{"old-model"}); err != nil {
t.Fatalf("failed to save managed integration config: %v", err)
}
runner := &launcherManagedRunner{}
withIntegrationOverride(t, "stubmanaged", runner)
DefaultSingleSelector = func(title string, items []ModelItem, current string) (string, error) {
t.Fatal("selector should not be called when model override is provided")
return "", nil
}
DefaultConfirmPrompt = func(prompt string, options ConfirmOptions) (bool, error) {
return true, nil
}
if err := LaunchIntegration(context.Background(), IntegrationLaunchRequest{
Name: "stubmanaged",
ModelOverride: "gemma4",
}); err != nil {
t.Fatalf("LaunchIntegration returned error: %v", err)
}
if diff := compareStrings(runner.configured, []string{"gemma4"}); diff != "" {
t.Fatalf("expected Configure to run when saved differs from target: %s", diff)
}
if runner.refreshCalls != 1 {
t.Fatalf("expected runtime refresh once after configure, got %d", runner.refreshCalls)
}
if runner.ranModel != "gemma4" {
t.Fatalf("expected launch to run configured model, got %q", runner.ranModel)
}
}
func TestLaunchIntegration_ManagedSingleIntegrationStopsWhenRuntimeRefreshFails(t *testing.T) {
tmpDir := t.TempDir()
setLaunchTestHome(t, tmpDir)
withInteractiveSession(t, true)
withLauncherHooks(t)
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
switch r.URL.Path {
case "/api/show":
fmt.Fprint(w, `{"model_info":{"general.context_length":131072}}`)
default:
http.NotFound(w, r)
}
}))
defer srv.Close()
t.Setenv("OLLAMA_HOST", srv.URL)
runner := &launcherManagedRunner{
refreshErr: fmt.Errorf("boom"),
}
withIntegrationOverride(t, "stubmanaged", runner)
DefaultConfirmPrompt = func(prompt string, options ConfirmOptions) (bool, error) {
return true, nil
}
err := LaunchIntegration(context.Background(), IntegrationLaunchRequest{
Name: "stubmanaged",
ModelOverride: "gemma4",
})
if err == nil || !strings.Contains(err.Error(), "boom") {
t.Fatalf("expected runtime refresh error, got %v", err)
}
if runner.ranModel != "" {
t.Fatalf("expected final launch to stop on runtime refresh failure, got %q", runner.ranModel)
}
if runner.refreshCalls != 1 {
t.Fatalf("expected one runtime refresh attempt, got %d", runner.refreshCalls)
}
if runner.onboardCalls != 0 {
t.Fatalf("expected onboarding to stop after refresh failure, got %d", runner.onboardCalls)
}
}
func TestLaunchIntegration_ManagedSingleIntegrationHeadlessNeedsInteractiveOnboarding(t *testing.T) {
tmpDir := t.TempDir()
setLaunchTestHome(t, tmpDir)
withInteractiveSession(t, false)
withLauncherHooks(t)
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
switch r.URL.Path {
case "/api/show":
fmt.Fprint(w, `{"model_info":{"general.context_length":131072}}`)
default:
http.NotFound(w, r)
}
}))
defer srv.Close()
t.Setenv("OLLAMA_HOST", srv.URL)
runner := &launcherManagedRunner{
paths: nil,
}
withIntegrationOverride(t, "stubmanaged", runner)
err := LaunchIntegration(context.Background(), IntegrationLaunchRequest{
Name: "stubmanaged",
ModelOverride: "gemma4",
Policy: &LaunchPolicy{Confirm: LaunchConfirmAutoApprove, MissingModel: LaunchMissingModelAutoPull},
})
if err == nil {
t.Fatal("expected headless onboarding requirement to fail")
}
if !strings.Contains(err.Error(), "interactive gateway setup") {
t.Fatalf("expected interactive onboarding guidance, got %v", err)
}
if runner.ranModel != "" {
t.Fatalf("expected no final launch when onboarding is still required, got %q", runner.ranModel)
}
if runner.onboardCalls != 0 {
t.Fatalf("expected no onboarding attempts in headless mode, got %d", runner.onboardCalls)
}
}
func TestLaunchIntegration_ManagedSingleIntegrationHeadlessAllowsNonInteractiveOnboarding(t *testing.T) {
tmpDir := t.TempDir()
setLaunchTestHome(t, tmpDir)
withInteractiveSession(t, false)
withLauncherHooks(t)
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
switch r.URL.Path {
case "/api/show":
fmt.Fprint(w, `{"model_info":{"general.context_length":131072}}`)
default:
http.NotFound(w, r)
}
}))
defer srv.Close()
t.Setenv("OLLAMA_HOST", srv.URL)
runner := &launcherHeadlessManagedRunner{}
withIntegrationOverride(t, "stubmanaged", runner)
err := LaunchIntegration(context.Background(), IntegrationLaunchRequest{
Name: "stubmanaged",
ModelOverride: "gemma4",
Policy: &LaunchPolicy{Confirm: LaunchConfirmAutoApprove, MissingModel: LaunchMissingModelAutoPull},
})
if err != nil {
t.Fatalf("expected non-interactive onboarding to succeed headlessly, got %v", err)
}
if diff := compareStrings(runner.configured, []string{"gemma4"}); diff != "" {
t.Fatalf("configured models mismatch: %s", diff)
}
if runner.onboardCalls != 1 {
t.Fatalf("expected onboarding to run once, got %d", runner.onboardCalls)
}
if runner.ranModel != "gemma4" {
t.Fatalf("expected launch to run configured model, got %q", runner.ranModel)
}
}
func TestBuildLauncherState_InstalledAndCloudDisabled(t *testing.T) {
tmpDir := t.TempDir()
setLaunchTestHome(t, tmpDir)

View File

@@ -230,7 +230,7 @@ func pullMissingModel(ctx context.Context, client *api.Client, model string) err
// prepareEditorIntegration persists models and applies editor-managed config files.
func prepareEditorIntegration(name string, runner Runner, editor Editor, models []string) error {
if ok, err := confirmEditorEdit(runner, editor); err != nil {
if ok, err := confirmConfigEdit(runner, editor.Paths()); err != nil {
return err
} else if !ok {
return errCancelled
@@ -244,8 +244,22 @@ func prepareEditorIntegration(name string, runner Runner, editor Editor, models
return nil
}
func confirmEditorEdit(runner Runner, editor Editor) (bool, error) {
paths := editor.Paths()
func prepareManagedSingleIntegration(name string, runner Runner, managed ManagedSingleModel, model string) error {
if ok, err := confirmConfigEdit(runner, managed.Paths()); err != nil {
return err
} else if !ok {
return errCancelled
}
if err := managed.Configure(model); err != nil {
return fmt.Errorf("setup failed: %w", err)
}
if err := config.SaveIntegration(name, []string{model}); err != nil {
return fmt.Errorf("failed to save: %w", err)
}
return nil
}
func confirmConfigEdit(runner Runner, paths []string) (bool, error) {
if len(paths) == 0 {
return true, nil
}
@@ -345,8 +359,6 @@ func buildModelList(existing []modelInfo, preChecked []string, current string) (
recRank[rec.Name] = i + 1
}
onlyLocal := hasLocalModel && !hasCloudModel
if hasLocalModel || hasCloudModel {
slices.SortStableFunc(items, func(a, b ModelItem) int {
ac, bc := checked[a.Name], checked[b.Name]
@@ -368,12 +380,6 @@ func buildModelList(existing []modelInfo, preChecked []string, current string) (
}
if aRec && bRec {
if aCloud != bCloud {
if onlyLocal {
if aCloud {
return 1
}
return -1
}
if aCloud {
return -1
}

View File

@@ -186,6 +186,11 @@ func (c *Openclaw) runChannelSetupPreflight(bin string) error {
if !isInteractiveSession() {
return nil
}
// --yes is headless; channel setup spawns an interactive picker we can't
// auto-answer, so skip it. Users can run `openclaw channels add` later.
if currentLaunchConfirmPolicy.yes {
return nil
}
for {
if c.channelsConfigured() {

View File

@@ -1304,6 +1304,46 @@ func TestOpenclawChannelSetupPreflight(t *testing.T) {
}
})
t.Run("--yes skips preflight without channels configured", func(t *testing.T) {
tmpDir := t.TempDir()
setTestHome(t, tmpDir)
t.Setenv("PATH", tmpDir)
configDir := filepath.Join(tmpDir, ".openclaw")
if err := os.MkdirAll(configDir, 0o755); err != nil {
t.Fatal(err)
}
// Empty config = no channels configured. Without the --yes skip, the
// preflight would prompt and (on confirm) spawn `openclaw channels add`.
if err := os.WriteFile(filepath.Join(configDir, "openclaw.json"), []byte(`{}`), 0o644); err != nil {
t.Fatal(err)
}
bin := filepath.Join(tmpDir, "openclaw")
if err := os.WriteFile(bin, []byte("#!/bin/sh\nprintf '%s\\n' \"$*\" >> \"$HOME/invocations.log\"\n"), 0o755); err != nil {
t.Fatal(err)
}
oldInteractive := isInteractiveSession
isInteractiveSession = func() bool { return true }
defer func() { isInteractiveSession = oldInteractive }()
restore := withLaunchConfirmPolicy(launchConfirmPolicy{yes: true})
defer restore()
oldConfirmPrompt := DefaultConfirmPrompt
DefaultConfirmPrompt = func(prompt string, options ConfirmOptions) (bool, error) {
t.Fatalf("did not expect prompt in --yes mode: %s", prompt)
return false, nil
}
defer func() { DefaultConfirmPrompt = oldConfirmPrompt }()
if err := c.runChannelSetupPreflight("openclaw"); err != nil {
t.Fatalf("runChannelSetupPreflight() error = %v", err)
}
if _, err := os.Stat(filepath.Join(tmpDir, "invocations.log")); !os.IsNotExist(err) {
t.Fatalf("expected no channels add invocation in --yes mode, got err=%v", err)
}
})
t.Run("set up later prompts once and exits", func(t *testing.T) {
tmpDir := t.TempDir()
setTestHome(t, tmpDir)

View File

@@ -3,20 +3,22 @@ package launch
import (
"encoding/json"
"fmt"
"maps"
"os"
"os/exec"
"path/filepath"
"runtime"
"slices"
"strings"
"github.com/ollama/ollama/cmd/internal/fileutil"
"github.com/ollama/ollama/envconfig"
)
// OpenCode implements Runner and Editor for OpenCode integration
type OpenCode struct{}
// OpenCode implements Runner and Editor for OpenCode integration.
// Config is passed via OPENCODE_CONFIG_CONTENT env var at launch time
// instead of writing to opencode's config files.
type OpenCode struct {
configContent string // JSON config built by Edit, passed to Run via env var
}
func (o *OpenCode) String() string { return "OpenCode" }
@@ -51,25 +53,51 @@ func (o *OpenCode) Run(model string, args []string) error {
cmd.Stdin = os.Stdin
cmd.Stdout = os.Stdout
cmd.Stderr = os.Stderr
cmd.Env = os.Environ()
if content := o.resolveContent(model); content != "" {
cmd.Env = append(cmd.Env, "OPENCODE_CONFIG_CONTENT="+content)
}
return cmd.Run()
}
// resolveContent returns the inline config to send via OPENCODE_CONFIG_CONTENT.
// Returns content built by Edit if available, otherwise builds from model.json
// with the requested model as primary (e.g. re-launch with saved config).
func (o *OpenCode) resolveContent(model string) string {
if o.configContent != "" {
return o.configContent
}
models := readModelJSONModels()
if !slices.Contains(models, model) {
models = append([]string{model}, models...)
}
content, err := buildInlineConfig(model, models)
if err != nil {
return ""
}
return content
}
func (o *OpenCode) Paths() []string {
home, err := os.UserHomeDir()
sp, err := openCodeStatePath()
if err != nil {
return nil
}
var paths []string
p := filepath.Join(home, ".config", "opencode", "opencode.json")
if _, err := os.Stat(p); err == nil {
paths = append(paths, p)
}
sp := filepath.Join(home, ".local", "state", "opencode", "model.json")
if _, err := os.Stat(sp); err == nil {
paths = append(paths, sp)
return []string{sp}
}
return paths
return nil
}
// openCodeStatePath returns the path to opencode's model state file.
// TODO: this hardcodes the Linux/macOS XDG path. On Windows, opencode stores
// state under %LOCALAPPDATA% (or similar) — verify and branch on runtime.GOOS.
func openCodeStatePath() (string, error) {
home, err := os.UserHomeDir()
if err != nil {
return "", err
}
return filepath.Join(home, ".local", "state", "opencode", "model.json"), nil
}
func (o *OpenCode) Edit(modelList []string) error {
@@ -77,110 +105,17 @@ func (o *OpenCode) Edit(modelList []string) error {
return nil
}
home, err := os.UserHomeDir()
content, err := buildInlineConfig(modelList[0], modelList)
if err != nil {
return err
}
o.configContent = content
configPath := filepath.Join(home, ".config", "opencode", "opencode.json")
if err := os.MkdirAll(filepath.Dir(configPath), 0o755); err != nil {
return err
}
config := make(map[string]any)
if data, err := os.ReadFile(configPath); err == nil {
_ = json.Unmarshal(data, &config) // Ignore parse errors; treat missing/corrupt files as empty
}
config["$schema"] = "https://opencode.ai/config.json"
provider, ok := config["provider"].(map[string]any)
if !ok {
provider = make(map[string]any)
}
ollama, ok := provider["ollama"].(map[string]any)
if !ok {
ollama = map[string]any{
"npm": "@ai-sdk/openai-compatible",
"name": "Ollama",
"options": map[string]any{
"baseURL": envconfig.Host().String() + "/v1",
},
}
}
// Migrate legacy provider name
if name, _ := ollama["name"].(string); name == "Ollama (local)" {
ollama["name"] = "Ollama"
}
models, ok := ollama["models"].(map[string]any)
if !ok {
models = make(map[string]any)
}
selectedSet := make(map[string]bool)
for _, m := range modelList {
selectedSet[m] = true
}
for name, cfg := range models {
if cfgMap, ok := cfg.(map[string]any); ok {
if isOllamaModel(cfgMap) && !selectedSet[name] {
delete(models, name)
}
}
}
for _, model := range modelList {
if existing, ok := models[model].(map[string]any); ok {
// migrate existing models without _launch marker
if isOllamaModel(existing) {
existing["_launch"] = true
if name, ok := existing["name"].(string); ok {
existing["name"] = strings.TrimSuffix(name, " [Ollama]")
}
}
if isCloudModelName(model) {
if l, ok := lookupCloudModelLimit(model); ok {
existing["limit"] = map[string]any{
"context": l.Context,
"output": l.Output,
}
}
}
continue
}
entry := map[string]any{
"name": model,
"_launch": true,
}
if isCloudModelName(model) {
if l, ok := lookupCloudModelLimit(model); ok {
entry["limit"] = map[string]any{
"context": l.Context,
"output": l.Output,
}
}
}
models[model] = entry
}
ollama["models"] = models
provider["ollama"] = ollama
config["provider"] = provider
config["model"] = "ollama/" + modelList[0]
configData, err := json.MarshalIndent(config, "", " ")
// Write model state file so models appear in OpenCode's model picker
statePath, err := openCodeStatePath()
if err != nil {
return err
}
if err := fileutil.WriteWithBackup(configPath, configData); err != nil {
return err
}
statePath := filepath.Join(home, ".local", "state", "opencode", "model.json")
if err := os.MkdirAll(filepath.Dir(statePath), 0o755); err != nil {
return err
}
@@ -232,33 +167,82 @@ func (o *OpenCode) Edit(modelList []string) error {
}
func (o *OpenCode) Models() []string {
home, err := os.UserHomeDir()
if err != nil {
return nil
}
config, err := fileutil.ReadJSON(filepath.Join(home, ".config", "opencode", "opencode.json"))
if err != nil {
return nil
}
provider, _ := config["provider"].(map[string]any)
ollama, _ := provider["ollama"].(map[string]any)
models, _ := ollama["models"].(map[string]any)
if len(models) == 0 {
return nil
}
keys := slices.Collect(maps.Keys(models))
slices.Sort(keys)
return keys
return nil
}
// isOllamaModel reports whether a model config entry is managed by us
func isOllamaModel(cfg map[string]any) bool {
if v, ok := cfg["_launch"].(bool); ok && v {
return true
// buildInlineConfig produces the JSON string for OPENCODE_CONFIG_CONTENT.
// primary is the model to launch with, models is the full list of available models.
func buildInlineConfig(primary string, models []string) (string, error) {
if primary == "" || len(models) == 0 {
return "", fmt.Errorf("buildInlineConfig: primary and models are required")
}
// previously used [Ollama] as a suffix for the model managed by ollama launch
if name, ok := cfg["name"].(string); ok {
return strings.HasSuffix(name, "[Ollama]")
config := map[string]any{
"$schema": "https://opencode.ai/config.json",
"provider": map[string]any{
"ollama": map[string]any{
"npm": "@ai-sdk/openai-compatible",
"name": "Ollama",
"options": map[string]any{
"baseURL": envconfig.Host().String() + "/v1",
},
"models": buildModelEntries(models),
},
},
"model": "ollama/" + primary,
}
return false
data, err := json.Marshal(config)
if err != nil {
return "", err
}
return string(data), nil
}
// readModelJSONModels reads ollama model IDs from the opencode model.json state file
func readModelJSONModels() []string {
statePath, err := openCodeStatePath()
if err != nil {
return nil
}
data, err := os.ReadFile(statePath)
if err != nil {
return nil
}
var state map[string]any
if err := json.Unmarshal(data, &state); err != nil {
return nil
}
recent, _ := state["recent"].([]any)
var models []string
for _, entry := range recent {
e, ok := entry.(map[string]any)
if !ok {
continue
}
if e["providerID"] != "ollama" {
continue
}
if id, ok := e["modelID"].(string); ok && id != "" {
models = append(models, id)
}
}
return models
}
func buildModelEntries(modelList []string) map[string]any {
models := make(map[string]any)
for _, model := range modelList {
entry := map[string]any{
"name": model,
}
if isCloudModelName(model) {
if l, ok := lookupCloudModelLimit(model); ok {
entry["limit"] = map[string]any{
"context": l.Context,
"output": l.Output,
}
}
}
models[model] = entry
}
return models
}

File diff suppressed because it is too large Load Diff

View File

@@ -33,7 +33,7 @@ type IntegrationInfo struct {
Description string
}
var launcherIntegrationOrder = []string{"opencode", "droid", "pi"}
var launcherIntegrationOrder = []string{"openclaw", "claude", "opencode", "hermes", "codex", "copilot", "droid", "pi"}
var integrationSpecs = []*IntegrationSpec{
{
@@ -74,6 +74,19 @@ var integrationSpecs = []*IntegrationSpec{
Command: []string{"npm", "install", "-g", "@openai/codex"},
},
},
{
Name: "copilot",
Runner: &Copilot{},
Aliases: []string{"copilot-cli"},
Description: "GitHub's AI coding agent for the terminal",
Install: IntegrationInstallSpec{
CheckInstalled: func() bool {
_, err := (&Copilot{}).findPath()
return err == nil
},
URL: "https://github.com/features/copilot/cli/",
},
},
{
Name: "droid",
Runner: &Droid{},
@@ -136,6 +149,20 @@ var integrationSpecs = []*IntegrationSpec{
Command: []string{"npm", "install", "-g", "@mariozechner/pi-coding-agent@latest"},
},
},
{
Name: "hermes",
Runner: &Hermes{},
Description: "Self-improving AI agent built by Nous Research",
Install: IntegrationInstallSpec{
CheckInstalled: func() bool {
return (&Hermes{}).installed()
},
EnsureInstalled: func() error {
return (&Hermes{}).ensureInstalled()
},
URL: "https://hermes-agent.nousresearch.com/docs/getting-started/installation/",
},
},
{
Name: "vscode",
Runner: &VSCode{},
@@ -255,10 +282,10 @@ func ListVisibleIntegrationSpecs() []IntegrationSpec {
return aRank - bRank
}
if aRank > 0 {
return 1
return -1
}
if bRank > 0 {
return -1
return 1
}
return strings.Compare(a.Name, b.Name)
})

View File

@@ -26,7 +26,7 @@ func TestEditorRunsDoNotRewriteConfig(t *testing.T) {
binary: "opencode",
runner: &OpenCode{},
checkPath: func(home string) string {
return filepath.Join(home, ".config", "opencode", "opencode.json")
return filepath.Join(home, ".local", "state", "opencode", "model.json")
},
},
{

View File

@@ -45,21 +45,12 @@ type menuItem struct {
isOthers bool
}
var mainMenuItems = []menuItem{
{
title: "Chat with a model",
description: "Start an interactive chat with a model",
isRunModel: true,
},
{
integration: "openclaw",
},
{
integration: "claude",
},
{
integration: "opencode",
},
const pinnedIntegrationCount = 3
var runModelMenuItem = menuItem{
title: "Chat with a model",
description: "Start an interactive chat with a model",
isRunModel: true,
}
var othersMenuItem = menuItem{
@@ -102,20 +93,14 @@ func shouldExpandOthers(state *launch.LauncherState) bool {
}
func buildMenuItems(state *launch.LauncherState, showOthers bool) []menuItem {
items := make([]menuItem, 0, len(mainMenuItems)+1)
for _, item := range mainMenuItems {
if item.integration == "" {
items = append(items, item)
continue
}
if integrationState, ok := state.Integrations[item.integration]; ok {
items = append(items, integrationMenuItem(integrationState))
}
}
items := []menuItem{runModelMenuItem}
items = append(items, pinnedIntegrationItems(state)...)
if showOthers {
items = append(items, otherIntegrationItems(state)...)
} else {
otherItems := otherIntegrationItems(state)
switch {
case showOthers:
items = append(items, otherItems...)
case len(otherItems) > 0:
items = append(items, othersMenuItem)
}
@@ -135,17 +120,28 @@ func integrationMenuItem(state launch.LauncherIntegrationState) menuItem {
}
func otherIntegrationItems(state *launch.LauncherState) []menuItem {
pinned := map[string]bool{
"openclaw": true,
"claude": true,
"opencode": true,
ordered := orderedIntegrationItems(state)
if len(ordered) <= pinnedIntegrationCount {
return nil
}
return ordered[pinnedIntegrationCount:]
}
func pinnedIntegrationItems(state *launch.LauncherState) []menuItem {
ordered := orderedIntegrationItems(state)
if len(ordered) <= pinnedIntegrationCount {
return ordered
}
return ordered[:pinnedIntegrationCount]
}
func orderedIntegrationItems(state *launch.LauncherState) []menuItem {
if state == nil {
return nil
}
var items []menuItem
items := make([]menuItem, 0, len(state.Integrations))
for _, info := range launch.ListIntegrationInfos() {
if pinned[info.Name] {
continue
}
integrationState, ok := state.Integrations[info.Name]
if !ok {
continue
@@ -155,6 +151,10 @@ func otherIntegrationItems(state *launch.LauncherState) []menuItem {
return items
}
func primaryMenuItemCount(state *launch.LauncherState) int {
return 1 + len(pinnedIntegrationItems(state))
}
func initialCursor(state *launch.LauncherState, items []menuItem) int {
if state == nil || state.LastSelection == "" {
return 0
@@ -190,7 +190,7 @@ func (m model) Update(msg tea.Msg) (tea.Model, tea.Cmd) {
if m.cursor > 0 {
m.cursor--
}
if m.showOthers && m.cursor < len(mainMenuItems) {
if m.showOthers && m.cursor < primaryMenuItemCount(m.state) {
m.showOthers = false
m.items = buildMenuItems(m.state, false)
m.cursor = min(m.cursor, len(m.items)-1)

View File

@@ -5,6 +5,7 @@ import (
"testing"
tea "github.com/charmbracelet/bubbletea"
"github.com/google/go-cmp/cmp"
"github.com/ollama/ollama/cmd/launch"
)
@@ -43,6 +44,13 @@ func launcherTestState() *launch.LauncherState {
Selectable: true,
Changeable: true,
},
"hermes": {
Name: "hermes",
DisplayName: "Hermes Agent",
Description: "Self-improving AI agent built by Nous Research",
Selectable: true,
Changeable: true,
},
"droid": {
Name: "droid",
DisplayName: "Droid",
@@ -70,8 +78,28 @@ func findMenuCursorByIntegration(items []menuItem, name string) int {
return -1
}
func integrationSequence(items []menuItem) []string {
sequence := make([]string, 0, len(items))
for _, item := range items {
switch {
case item.isRunModel:
sequence = append(sequence, "run")
case item.isOthers:
sequence = append(sequence, "more")
case item.integration != "":
sequence = append(sequence, item.integration)
}
}
return sequence
}
func compareStrings(got, want []string) string {
return cmp.Diff(want, got)
}
func TestMenuRendersPinnedItemsAndMore(t *testing.T) {
view := newModel(launcherTestState()).View()
menu := newModel(launcherTestState())
view := menu.View()
for _, want := range []string{"Chat with a model", "Launch OpenClaw", "Launch Claude Code", "Launch OpenCode", "More..."} {
if !strings.Contains(view, want) {
t.Fatalf("expected menu view to contain %q\n%s", want, view)
@@ -80,23 +108,31 @@ func TestMenuRendersPinnedItemsAndMore(t *testing.T) {
if strings.Contains(view, "Launch Codex") {
t.Fatalf("expected Codex to be under More, not pinned\n%s", view)
}
wantOrder := []string{"run", "openclaw", "claude", "opencode", "more"}
if diff := compareStrings(integrationSequence(menu.items), wantOrder); diff != "" {
t.Fatalf("unexpected pinned order: %s", diff)
}
}
func TestMenuExpandsOthersFromLastSelection(t *testing.T) {
state := launcherTestState()
state.LastSelection = "pi"
state.LastSelection = "codex"
menu := newModel(state)
if !menu.showOthers {
t.Fatal("expected others section to expand when last selection is in the overflow list")
}
view := menu.View()
if !strings.Contains(view, "Launch Pi") {
if !strings.Contains(view, "Launch Codex") {
t.Fatalf("expected expanded view to contain overflow integration\n%s", view)
}
if strings.Contains(view, "More...") {
t.Fatalf("expected expanded view to replace More... item\n%s", view)
}
wantOrder := []string{"run", "openclaw", "claude", "opencode", "hermes", "codex", "droid", "pi"}
if diff := compareStrings(integrationSequence(menu.items), wantOrder); diff != "" {
t.Fatalf("unexpected expanded order: %s", diff)
}
}
func TestMenuEnterOnRunSelectsRun(t *testing.T) {

View File

@@ -120,6 +120,7 @@
"pages": [
"/integrations/claude-code",
"/integrations/codex",
"/integrations/copilot-cli",
"/integrations/opencode",
"/integrations/droid",
"/integrations/goose",

BIN
docs/images/hermes.png Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 1.4 MiB

View File

@@ -0,0 +1,93 @@
---
title: Copilot CLI
---
GitHub Copilot CLI is GitHub's AI coding agent for the terminal. It can understand your codebase, make edits, run commands, and help you build software faster.
Open models can be used with Copilot CLI through Ollama, enabling you to use models such as `qwen3.5`, `glm-5.1:cloud`, `kimi-k2.5:cloud`.
## Install
Install [Copilot CLI](https://github.com/features/copilot/cli/):
<CodeGroup>
```shell macOS / Linux (Homebrew)
brew install copilot-cli
```
```shell npm (all platforms)
npm install -g @github/copilot
```
```shell macOS / Linux (script)
curl -fsSL https://gh.io/copilot-install | bash
```
```powershell Windows (WinGet)
winget install GitHub.Copilot
```
</CodeGroup>
## Usage with Ollama
### Quick setup
```shell
ollama launch copilot
```
### Run directly with a model
```shell
ollama launch copilot --model kimi-k2.5:cloud
```
## Recommended Models
- `kimi-k2.5:cloud`
- `glm-5:cloud`
- `minimax-m2.7:cloud`
- `qwen3.5:cloud`
- `glm-4.7-flash`
- `qwen3.5`
Cloud models are also available at [ollama.com/search?c=cloud](https://ollama.com/search?c=cloud).
## Non-interactive (headless) mode
Run Copilot CLI without interaction for use in Docker, CI/CD, or scripts:
```shell
ollama launch copilot --model kimi-k2.5:cloud --yes -- -p "how does this repository work?"
```
The `--yes` flag auto-pulls the model, skips selectors, and requires `--model` to be specified. Arguments after `--` are passed directly to Copilot CLI.
## Manual setup
Copilot CLI connects to Ollama using the OpenAI-compatible API via environment variables.
1. Set the environment variables:
```shell
export COPILOT_PROVIDER_BASE_URL=http://localhost:11434/v1
export COPILOT_PROVIDER_API_KEY=
export COPILOT_PROVIDER_WIRE_API=responses
export COPILOT_MODEL=qwen3.5
```
1. Run Copilot CLI:
```shell
copilot
```
Or run with environment variables inline:
```shell
COPILOT_PROVIDER_BASE_URL=http://localhost:11434/v1 COPILOT_PROVIDER_API_KEY= COPILOT_PROVIDER_WIRE_API=responses COPILOT_MODEL=glm-5:cloud copilot
```
**Note:** Copilot requires a large context window. We recommend at least 64k tokens. See the [context length documentation](/context-length) for how to adjust context length in Ollama.

View File

@@ -2,29 +2,66 @@
title: Hermes Agent
---
Hermes Agent is a self-improving AI agent built by Nous Research. It features automatic skill creation, cross-session memory, and connects messaging platforms (Telegram, Discord, Slack, WhatsApp, Signal, Email) to models through a unified gateway.
Hermes Agent is a self-improving AI agent built by Nous Research. It features automatic skill creation, cross-session memory, and 70+ skills that it ships with by default.
![Hermes Agent with Ollama](/images/hermes.png)
## Quick start
### Pull a model
Before running the setup wizard, make sure you have a model available. Hermes will auto-detect models downloaded through Ollama.
```bash
ollama pull kimi-k2.5:cloud
ollama launch hermes
```
See [Recommended models](#recommended-models) for more options.
Ollama handles everything automatically:
### Install
1. **Install** — If Hermes isn't installed, Ollama prompts to install it via the Nous Research install script
2. **Model** — Pick a model from the selector (local or cloud)
3. **Onboarding** — Ollama configures the Ollama provider, points Hermes at `http://127.0.0.1:11434/v1`, and sets your model as the primary
4. **Gateway** — Optionally connects a messaging platform (Telegram, Discord, Slack, WhatsApp, Signal, Email) and launches the Hermes chat
<Note>Hermes on Windows requires WSL2. Install it with `wsl --install` and re-run from inside the WSL shell.</Note>
## Recommended models
**Cloud models**:
- `kimi-k2.5:cloud` — Multimodal reasoning with subagents
- `glm-5.1:cloud` — Reasoning and code generation
- `qwen3.5:cloud` — Reasoning, coding, and agentic tool use with vision
- `minimax-m2.7:cloud` — Fast, efficient coding and real-world productivity
**Local models:**
- `gemma4` — Reasoning and code generation locally (~16 GB VRAM)
- `qwen3.6` — Reasoning, coding, and visual understanding locally (~24 GB VRAM)
More models at [ollama.com/search](https://ollama.com/search?c=cloud).
## Connect messaging apps
Link Telegram, Discord, Slack, WhatsApp, Signal, or Email to chat with your models from anywhere:
```bash
hermes gateway setup
```
## Reconfigure
Re-run the full setup wizard at any time:
```bash
hermes setup
```
## Manual setup
If you'd rather drive Hermes's own wizard instead of `ollama launch hermes`, install it directly:
```bash
curl -fsSL https://raw.githubusercontent.com/NousResearch/hermes-agent/main/scripts/install.sh | bash
```
### Set up
After installation, Hermes launches the setup wizard automatically. Choose **Quick setup**:
Hermes launches the setup wizard automatically. Choose **Quick setup**:
```
How would you like to set up Hermes?
@@ -80,32 +117,3 @@ Connect a messaging platform? (Telegram, Discord, etc.)
Launch hermes chat now? [Y/n]: Y
```
## Recommended models
**Cloud models**:
- `kimi-k2.5:cloud` — Multimodal reasoning with subagents
- `qwen3.5:cloud` — Reasoning, coding, and agentic tool use with vision
- `glm-5.1:cloud` — Reasoning and code generation
- `minimax-m2.7:cloud` — Fast, efficient coding and real-world productivity
**Local models:**
- `gemma4` — Reasoning and code generation locally (~16 GB VRAM)
- `qwen3.5` — Reasoning, coding, and visual understanding locally (~11 GB VRAM)
More models at [ollama.com/search](https://ollama.com/models).
## Configure later
Re-run the setup wizard at any time:
```bash
hermes setup
```
To configure just messaging:
```bash
hermes setup gateway
```

View File

@@ -10,6 +10,7 @@ Coding assistants that can read, modify, and execute code in your projects.
- [Claude Code](/integrations/claude-code)
- [Codex](/integrations/codex)
- [Copilot CLI](/integrations/copilot-cli)
- [OpenCode](/integrations/opencode)
- [Droid](/integrations/droid)
- [Goose](/integrations/goose)

View File

@@ -28,79 +28,4 @@ To configure without launching:
ollama launch opencode --config
```
### Manual setup
Add a configuration block to `~/.config/opencode/opencode.json`:
```json
{
"$schema": "https://opencode.ai/config.json",
"provider": {
"ollama": {
"npm": "@ai-sdk/openai-compatible",
"name": "Ollama",
"options": {
"baseURL": "http://localhost:11434/v1"
},
"models": {
"qwen3-coder": {
"name": "qwen3-coder"
}
}
}
}
}
```
## Cloud Models
`glm-4.7:cloud` is the recommended model for use with OpenCode.
Add the cloud configuration to `~/.config/opencode/opencode.json`:
```json
{
"$schema": "https://opencode.ai/config.json",
"provider": {
"ollama": {
"npm": "@ai-sdk/openai-compatible",
"name": "Ollama",
"options": {
"baseURL": "http://localhost:11434/v1"
},
"models": {
"glm-4.7:cloud": {
"name": "glm-4.7:cloud"
}
}
}
}
}
```
## Connecting to ollama.com
1. Create an [API key](https://ollama.com/settings/keys) from ollama.com and export it as `OLLAMA_API_KEY`.
2. Update `~/.config/opencode/opencode.json` to point to ollama.com:
```json
{
"$schema": "https://opencode.ai/config.json",
"provider": {
"ollama": {
"npm": "@ai-sdk/openai-compatible",
"name": "Ollama Cloud",
"options": {
"baseURL": "https://ollama.com/v1"
},
"models": {
"glm-4.7:cloud": {
"name": "glm-4.7:cloud"
}
}
}
}
}
```
Run `opencode` in a new terminal to load the new settings.
<Note>`ollama launch opencode` passes its configuration to OpenCode inline via the `OPENCODE_CONFIG_CONTENT` environment variable. OpenCode deep-merges its config sources on startup, so anything you declare in `~/.config/opencode/opencode.json` is still respected and available inside OpenCode. Models declared only in `opencode.json` won't appear in `ollama launch`'s model-selection menu.</Note>

2
go.mod
View File

@@ -106,5 +106,5 @@ require (
golang.org/x/term v0.36.0
golang.org/x/text v0.30.0
google.golang.org/protobuf v1.34.1
gopkg.in/yaml.v3 v3.0.1 // indirect
gopkg.in/yaml.v3 v3.0.1
)

View File

@@ -1,6 +1,6 @@
package common
// #cgo CXXFLAGS: -std=c++17
// #cgo CXXFLAGS: -std=c++17 -Wno-deprecated-declarations
// #cgo CPPFLAGS: -I${SRCDIR}/../include -I${SRCDIR}/../vendor
// #cgo CPPFLAGS: -I${SRCDIR}/../../../ml/backend/ggml/ggml/include
import "C"

View File

@@ -6,11 +6,11 @@ Subject: [PATCH] interleave multi rope
since ollama doesn't use mrope for anything else, change it to mean the
interleaved version used for qwen3vl
---
ggml/src/ggml-cpu/ops.cpp | 8 ++++----
ggml/src/ggml-cuda/rope.cu | 8 ++++----
ggml/src/ggml-metal/ggml-metal.metal | 8 ++++----
ggml/src/ggml-vulkan/vulkan-shaders/rope_funcs.glsl | 8 ++++----
4 files changed, 16 insertions(+), 16 deletions(-)
ggml/src/ggml-cpu/ops.cpp | 8 ++++----
ggml/src/ggml-cuda/rope.cu | 8 ++++----
ggml/src/ggml-metal/ggml-metal.metal | 10 +++++-----
ggml/src/ggml-vulkan/vulkan-shaders/rope_funcs.glsl | 8 ++++----
4 files changed, 17 insertions(+), 17 deletions(-)
diff --git a/ggml/src/ggml-cpu/ops.cpp b/ggml/src/ggml-cpu/ops.cpp
index 7d1733adb..f4aae5332 100644
@@ -59,12 +59,15 @@ index 88ed79111..71ca60214 100644
} else {
if (sector < sections.v[0]) {
diff --git a/ggml/src/ggml-metal/ggml-metal.metal b/ggml/src/ggml-metal/ggml-metal.metal
index 236838e9e..c98d269d1 100644
index 236838e9e..18b8bb1b1 100644
--- a/ggml/src/ggml-metal/ggml-metal.metal
+++ b/ggml/src/ggml-metal/ggml-metal.metal
@@ -4242,14 +4242,14 @@ kernel void kernel_rope_multi(
@@ -4240,16 +4240,16 @@ kernel void kernel_rope_multi(
const int sec_w012 = args.sect_0 + args.sect_1 + args.sect_2; // end of section 2
const int sector = ic % sect_dims;
float theta_base;
- float theta_base;
+ float theta_base = 0.0;
if (FC_rope_is_imrope) {
- if (sector % 3 == 1 && sector < 3 * args.sect_1) { // h
+ if (sector % 3 == 1 && sector < 1 + 3 * args.sect_1) { // h

View File

@@ -296,7 +296,7 @@ index e99c1763f..80864f303 100644
const size_t smem = FATTN_SMEM(nsg);
diff --git a/ggml/src/ggml-metal/ggml-metal.metal b/ggml/src/ggml-metal/ggml-metal.metal
index c98d269d1..d33c16079 100644
index 18b8bb1b1..114767785 100644
--- a/ggml/src/ggml-metal/ggml-metal.metal
+++ b/ggml/src/ggml-metal/ggml-metal.metal
@@ -6166,6 +6166,7 @@ kernel void kernel_flash_attn_ext(

View File

@@ -204,7 +204,7 @@ index 902b54452..a475183d3 100644
int ggml_metal_op_norm (ggml_metal_op_t ctx, int idx);
int ggml_metal_op_rope (ggml_metal_op_t ctx, int idx);
diff --git a/ggml/src/ggml-metal/ggml-metal.metal b/ggml/src/ggml-metal/ggml-metal.metal
index d33c16079..c37447a10 100644
index 114767785..876a9eecc 100644
--- a/ggml/src/ggml-metal/ggml-metal.metal
+++ b/ggml/src/ggml-metal/ggml-metal.metal
@@ -3012,6 +3012,66 @@ kernel void kernel_l2_norm_f32(

View File

@@ -24,7 +24,7 @@ index 4ac135603..ac5ad53db 100644
// ref: https://developer.apple.com/metal/Metal-Shading-Language-Specification.pdf (Table 2.5)
//switch (op->src[0]->type) {
diff --git a/ggml/src/ggml-metal/ggml-metal.metal b/ggml/src/ggml-metal/ggml-metal.metal
index c37447a10..4f338aa13 100644
index 876a9eecc..b14a0000c 100644
--- a/ggml/src/ggml-metal/ggml-metal.metal
+++ b/ggml/src/ggml-metal/ggml-metal.metal
@@ -9427,6 +9427,7 @@ template [[host_name("kernel_mul_mm_id_map0_ne20_6" )]] kernel kernel_mul_mm_id_

View File

@@ -342,7 +342,7 @@ index 4e5acfbe5..11457f2b1 100644
return false;
}
diff --git a/ggml/src/ggml-metal/ggml-metal.metal b/ggml/src/ggml-metal/ggml-metal.metal
index 4f338aa13..8be0c1f0c 100644
index b14a0000c..398c80717 100644
--- a/ggml/src/ggml-metal/ggml-metal.metal
+++ b/ggml/src/ggml-metal/ggml-metal.metal
@@ -6276,6 +6276,7 @@ template [[host_name("kernel_flash_attn_ext_f32_dk128_dv128")]] kernel flash_at

View File

@@ -6,6 +6,7 @@ import (
"fmt"
"io"
"os"
"time"
)
type Layer struct {
@@ -60,6 +61,9 @@ func NewLayer(r io.Reader, mediatype string) (Layer, error) {
return Layer{}, err
}
}
if err := touchLayer(blob); err != nil {
return Layer{}, err
}
return Layer{
MediaType: mediatype,
@@ -83,6 +87,9 @@ func NewLayerFromLayer(digest, mediatype, from string) (Layer, error) {
if err != nil {
return Layer{}, err
}
if err := touchLayer(blob); err != nil {
return Layer{}, err
}
return Layer{
MediaType: mediatype,
@@ -93,6 +100,11 @@ func NewLayerFromLayer(digest, mediatype, from string) (Layer, error) {
}, nil
}
func touchLayer(path string) error {
now := time.Now()
return os.Chtimes(path, now, now)
}
func (l *Layer) Open() (io.ReadSeekCloser, error) {
if l.Digest == "" {
return nil, errors.New("opening layer with empty digest")

View File

@@ -7122,7 +7122,7 @@ kernel void kernel_rope_multi(
const int sec_w012 = args.sect_0 + args.sect_1 + args.sect_2; // end of section 2
const int sector = ic % sect_dims;
float theta_base;
float theta_base = 0.0;
if (FC_rope_is_imrope) {
if (sector % 3 == 1 && sector < 1 + 3 * args.sect_1) { // h
theta_base = (float) pos[i2 + args.ne02 * 1];

View File

@@ -4300,7 +4300,7 @@ kernel void kernel_rope_multi(
const int sec_w012 = args.sect_0 + args.sect_1 + args.sect_2; // end of section 2
const int sector = ic % sect_dims;
float theta_base;
float theta_base = 0.0;
if (FC_rope_is_imrope) {
if (sector % 3 == 1 && sector < 1 + 3 * args.sect_1) { // h
theta_base = (float) pos[i2 + args.ne02 * 1];

View File

@@ -12,7 +12,8 @@ import (
// <|turn>/<turn|> markers, <|"|> string delimiters, and <|tool>/
// <|tool_call>/<|tool_response> tags for function calling.
type Gemma4Renderer struct {
useImgTags bool
useImgTags bool
emptyBlockOnNothink bool
}
const (
@@ -124,7 +125,7 @@ func (r *Gemma4Renderer) Render(messages []api.Message, tools []api.Tool, thinkV
// Generation prompt.
if prevMessageType != "tool_response" && prevMessageType != "tool_call" {
sb.WriteString("<|turn>model\n")
if !hasThink {
if r.emptyBlockOnNothink && !hasThink {
sb.WriteString("<|channel>thought\n<channel|>")
}
}

View File

@@ -1,14 +1,18 @@
package renderers
// TestGemma4RendererMatchesReference verifies our renderer matches the HF
// Jinja2 chat template exactly.
// TestGemma4RendererMatchesReference verifies our renderer matches the checked-in
// Gemma 4 reference template.
//
// To regenerate expected values, save gemma4Jinja2Template (below) to
// gemma4_chat_template.jinja2 and run:
// Current upstream Gemma 4 chat templates differ by model size. The checked-in
// reference cases below use the small (e2b/e4b-style) baseline, with large
// (26b/31b-style) checks covered separately in this file.
//
// To regenerate expected values, save the E2B template to
// gemma4_e2b_chat_template.jinja2 and run:
//
// python3 -c "
// from jinja2 import Environment; import json
// tmpl = Environment().from_string(open('gemma4_chat_template.jinja2').read())
// tmpl = Environment().from_string(open('gemma4_e2b_chat_template.jinja2').read())
// msgs = [{'role':'user','content':'Hello'}]
// print(repr(tmpl.render(messages=msgs, bos_token='<bos>', add_generation_prompt=True)))
// "
@@ -26,8 +30,13 @@ import (
"github.com/stretchr/testify/assert"
)
// The full Jinja2 template is committed as testdata/gemma4_chat_template.jinja2.
// Run with VERIFY_JINJA2=1 to verify expected values against the template using uv + Python.
const (
gemma4E2BTemplate = "testdata/gemma4_e2b_chat_template.jinja2"
gemma431BTemplate = "testdata/gemma4_31b_chat_template.jinja2"
)
// The upstream Gemma 4 chat templates are committed by size under testdata/.
// Run with VERIFY_JINJA2=1 to verify expected values against the E2B template using uv + Python.
func bashRefTool() []api.Tool {
return []api.Tool{{
@@ -665,7 +674,7 @@ func TestGemma4RendererMatchesReference(t *testing.T) {
{
name: "user_only",
messages: []api.Message{{Role: "user", Content: "Hello"}},
expected: "<bos><|turn>user\nHello<turn|>\n<|turn>model\n<|channel>thought\n<channel|>",
expected: "<bos><|turn>user\nHello<turn|>\n<|turn>model\n",
},
{
name: "system_user",
@@ -673,7 +682,7 @@ func TestGemma4RendererMatchesReference(t *testing.T) {
{Role: "system", Content: "You are helpful."},
{Role: "user", Content: "Hi"},
},
expected: "<bos><|turn>system\nYou are helpful.<turn|>\n<|turn>user\nHi<turn|>\n<|turn>model\n<|channel>thought\n<channel|>",
expected: "<bos><|turn>system\nYou are helpful.<turn|>\n<|turn>user\nHi<turn|>\n<|turn>model\n",
},
{
name: "developer_user",
@@ -681,13 +690,13 @@ func TestGemma4RendererMatchesReference(t *testing.T) {
{Role: "developer", Content: "You are helpful."},
{Role: "user", Content: "Hi"},
},
expected: "<bos><|turn>system\nYou are helpful.<turn|>\n<|turn>user\nHi<turn|>\n<|turn>model\n<|channel>thought\n<channel|>",
expected: "<bos><|turn>system\nYou are helpful.<turn|>\n<|turn>user\nHi<turn|>\n<|turn>model\n",
},
{
name: "tools_no_system",
messages: []api.Message{{Role: "user", Content: "Hi"}},
tools: bashRefTool(),
expected: "<bos><|turn>system\n" + bashDeclRef + "<turn|>\n<|turn>user\nHi<turn|>\n<|turn>model\n<|channel>thought\n<channel|>",
expected: "<bos><|turn>system\n" + bashDeclRef + "<turn|>\n<|turn>user\nHi<turn|>\n<|turn>model\n",
},
{
name: "system_tools",
@@ -696,7 +705,7 @@ func TestGemma4RendererMatchesReference(t *testing.T) {
{Role: "user", Content: "Hi"},
},
tools: bashRefTool(),
expected: "<bos><|turn>system\nYou are helpful." + bashDeclRef + "<turn|>\n<|turn>user\nHi<turn|>\n<|turn>model\n<|channel>thought\n<channel|>",
expected: "<bos><|turn>system\nYou are helpful." + bashDeclRef + "<turn|>\n<|turn>user\nHi<turn|>\n<|turn>model\n",
},
{
name: "thinking_no_system",
@@ -730,6 +739,12 @@ func TestGemma4RendererMatchesReference(t *testing.T) {
think: thinkTrue(),
expected: "<bos><|turn>system\n<|think|>\nYou are helpful." + bashDeclRef + "<turn|>\n<|turn>user\nHi<turn|>\n<|turn>model\n",
},
{
name: "thinking_explicitly_disabled",
messages: []api.Message{{Role: "user", Content: "Hi"}},
think: thinkFalse(),
expected: "<bos><|turn>user\nHi<turn|>\n<|turn>model\n",
},
// === Message loop paths ===
{
@@ -744,7 +759,7 @@ func TestGemma4RendererMatchesReference(t *testing.T) {
"<|turn>user\nHi<turn|>\n" +
"<|turn>model\nHello!<turn|>\n" +
"<|turn>user\nMore<turn|>\n" +
"<|turn>model\n<|channel>thought\n<channel|>",
"<|turn>model\n",
},
{
// Tool call with structured args → tool response as separate <|turn>tool turn
@@ -806,7 +821,7 @@ func TestGemma4RendererMatchesReference(t *testing.T) {
"<|tool_response>response:bash{value:" + q + "file1.txt\nfile2.txt" + q + "}<tool_response|>" +
"Here are the files.<turn|>\n" +
"<|turn>user\nRead file1.txt<turn|>\n" +
"<|turn>model\n<|channel>thought\n<channel|>",
"<|turn>model\n",
},
{
// Multiple tool calls + multiple tool responses
@@ -841,7 +856,7 @@ func TestGemma4RendererMatchesReference(t *testing.T) {
expected: "<bos><|turn>user\nWhat is 2+2?<turn|>\n" +
"<|turn>model\n4<turn|>\n" +
"<|turn>user\nAnd 3+3?<turn|>\n" +
"<|turn>model\n<|channel>thought\n<channel|>",
"<|turn>model\n",
},
// === Additional edge cases ported from original tests ===
{
@@ -899,17 +914,17 @@ func TestGemma4RendererMatchesReference(t *testing.T) {
messages: []api.Message{{Role: "user", Content: "Test"}},
tools: modeTool(),
expected: "<bos><|turn>system\n" + modeDeclRef + "<turn|>\n" +
"<|turn>user\nTest<turn|>\n<|turn>model\n<|channel>thought\n<channel|>",
"<|turn>user\nTest<turn|>\n<|turn>model\n",
},
{
name: "unicode_content",
messages: []api.Message{{Role: "user", Content: "こんにちは"}},
expected: "<bos><|turn>user\nこんにちは<turn|>\n<|turn>model\n<|channel>thought\n<channel|>",
expected: "<bos><|turn>user\nこんにちは<turn|>\n<|turn>model\n",
},
{
name: "newlines_in_content",
messages: []api.Message{{Role: "user", Content: "Line 1\nLine 2\nLine 3"}},
expected: "<bos><|turn>user\nLine 1\nLine 2\nLine 3<turn|>\n<|turn>model\n<|channel>thought\n<channel|>",
expected: "<bos><|turn>user\nLine 1\nLine 2\nLine 3<turn|>\n<|turn>model\n",
},
{
// Tool response (raw JSON) followed by user message
@@ -928,7 +943,7 @@ func TestGemma4RendererMatchesReference(t *testing.T) {
"<|turn>model\n<|tool_call>call:get_weather{city:" + q + "Tokyo" + q + "}<tool_call|>" +
"<|tool_response>response:get_weather{value:" + q + `{"temperature": 15, "weather": "sunny"}` + q + "}<tool_response|>" +
"<|turn>user\nThanks!<turn|>\n" +
"<|turn>model\n<|channel>thought\n<channel|>",
"<|turn>model\n",
},
// === Ordering and whitespace edge cases ===
{
@@ -951,7 +966,7 @@ func TestGemma4RendererMatchesReference(t *testing.T) {
// User content with whitespace is trimmed
name: "user_content_trimmed",
messages: []api.Message{{Role: "user", Content: " hello "}},
expected: "<bos><|turn>user\nhello<turn|>\n<|turn>model\n<|channel>thought\n<channel|>",
expected: "<bos><|turn>user\nhello<turn|>\n<|turn>model\n",
},
{
// Empty tool call arguments
@@ -975,7 +990,7 @@ func TestGemma4RendererMatchesReference(t *testing.T) {
messages: []api.Message{{Role: "user", Content: "Create"}},
tools: nestedTool(),
expected: "<bos><|turn>system\n" + nestedDeclRef + "<turn|>\n" +
"<|turn>user\nCreate<turn|>\n<|turn>model\n<|channel>thought\n<channel|>",
"<|turn>user\nCreate<turn|>\n<|turn>model\n",
},
{
// Array type in tool declaration
@@ -983,7 +998,7 @@ func TestGemma4RendererMatchesReference(t *testing.T) {
messages: []api.Message{{Role: "user", Content: "Batch"}},
tools: arrayTool(),
expected: "<bos><|turn>system\n" + arrayDeclRef + "<turn|>\n" +
"<|turn>user\nBatch<turn|>\n<|turn>model\n<|channel>thought\n<channel|>",
"<|turn>user\nBatch<turn|>\n<|turn>model\n",
},
{
// Top-level typed union follows the template's odd stringified-list form.
@@ -995,8 +1010,7 @@ func TestGemma4RendererMatchesReference(t *testing.T) {
<|turn>user
Hi<turn|>
<|turn>model
<|channel>thought
<channel|>`,
`,
},
{
// Assistant whitespace is trimmed (strip_thinking includes | trim)
@@ -1009,7 +1023,7 @@ Hi<turn|>
expected: "<bos><|turn>user\nHi<turn|>\n" +
"<|turn>model\nspaced<turn|>\n" +
"<|turn>user\nMore<turn|>\n" +
"<|turn>model\n<|channel>thought\n<channel|>",
"<|turn>model\n",
},
{
// Three sequential tool responses
@@ -1064,7 +1078,7 @@ Hi<turn|>
expected: "<bos><|turn>user\nHi<turn|>\n" +
"<|turn>model\nMiddleDone<turn|>\n" +
"<|turn>user\nMore<turn|>\n" +
"<|turn>model\n<|channel>thought\n<channel|>",
"<|turn>model\n",
},
{
// Property with no description — just type
@@ -1072,7 +1086,7 @@ Hi<turn|>
messages: []api.Message{{Role: "user", Content: "Count"}},
tools: countTool(),
expected: "<bos><|turn>system\n" + countDeclRef + "<turn|>\n" +
"<|turn>user\nCount<turn|>\n<|turn>model\n<|channel>thought\n<channel|>",
"<|turn>user\nCount<turn|>\n<|turn>model\n",
},
{
// System message with leading/trailing whitespace is trimmed
@@ -1082,7 +1096,7 @@ Hi<turn|>
{Role: "user", Content: "Hi"},
},
expected: "<bos><|turn>system\nYou are helpful.<turn|>\n" +
"<|turn>user\nHi<turn|>\n<|turn>model\n<|channel>thought\n<channel|>",
"<|turn>user\nHi<turn|>\n<|turn>model\n",
},
{
// Deeply nested map in tool call arguments (3 levels)
@@ -1144,7 +1158,7 @@ Hi<turn|>
messages: []api.Message{{Role: "user", Content: "Set"}},
tools: enumNoDescTool(),
expected: "<bos><|turn>system\n" + enumNoDescDeclRef + "<turn|>\n" +
"<|turn>user\nSet<turn|>\n<|turn>model\n<|channel>thought\n<channel|>",
"<|turn>user\nSet<turn|>\n<|turn>model\n",
},
{
// System message that is only whitespace (trims to empty)
@@ -1154,7 +1168,7 @@ Hi<turn|>
{Role: "user", Content: "Hi"},
},
expected: "<bos><|turn>system\n<turn|>\n" +
"<|turn>user\nHi<turn|>\n<|turn>model\n<|channel>thought\n<channel|>",
"<|turn>user\nHi<turn|>\n<|turn>model\n",
},
{
// Empty assistant content (empty string, not nil)
@@ -1167,7 +1181,7 @@ Hi<turn|>
expected: "<bos><|turn>user\nHi<turn|>\n" +
"<|turn>model\n<turn|>\n" +
"<|turn>user\nMore<turn|>\n" +
"<|turn>model\n<|channel>thought\n<channel|>",
"<|turn>model\n",
},
{
// Map argument with string keys (keys NOT escaped with <|"|>)
@@ -1193,7 +1207,7 @@ Hi<turn|>
messages: []api.Message{{Role: "user", Content: "Search"}},
tools: searchTool(),
expected: "<bos><|turn>system\n" + searchDeclRef + "<turn|>\n" +
"<|turn>user\nSearch<turn|>\n<|turn>model\n<|channel>thought\n<channel|>",
"<|turn>user\nSearch<turn|>\n<|turn>model\n",
},
// === Round 3 coverage gaps ===
@@ -1221,7 +1235,7 @@ Hi<turn|>
{Role: "user", Content: "Hi"},
},
expected: "<bos><|turn>system\n<turn|>\n" +
"<|turn>user\nHi<turn|>\n<|turn>model\n<|channel>thought\n<channel|>",
"<|turn>user\nHi<turn|>\n<|turn>model\n",
},
{
// Nested OBJECT property with required field
@@ -1229,7 +1243,7 @@ Hi<turn|>
messages: []api.Message{{Role: "user", Content: "Create"}},
tools: nestedRequiredTool(),
expected: "<bos><|turn>system\n" + nestedRequiredDeclRef + "<turn|>\n" +
"<|turn>user\nCreate<turn|>\n<|turn>model\n<|channel>thought\n<channel|>",
"<|turn>user\nCreate<turn|>\n<|turn>model\n",
},
{
// Non-integer float in tool call argument
@@ -1256,7 +1270,7 @@ Hi<turn|>
},
expected: "<bos><|turn>user\nHi<turn|>\n" +
"<|turn>model\nResult<turn|>\n" +
"<|turn>model\n<|channel>thought\n<channel|>",
"<|turn>model\n",
},
{
// Tool content with newlines and leading/trailing whitespace trimmed
@@ -1280,7 +1294,7 @@ Hi<turn|>
messages: []api.Message{{Role: "user", Content: "Raw"}},
tools: rawTool(),
expected: "<bos><|turn>system\n" + rawDeclRef + "<turn|>\n" +
"<|turn>user\nRaw<turn|>\n<|turn>model\n<|channel>thought\n<channel|>",
"<|turn>user\nRaw<turn|>\n<|turn>model\n",
},
{
// Multiple required fields at top level
@@ -1288,7 +1302,7 @@ Hi<turn|>
messages: []api.Message{{Role: "user", Content: "Move"}},
tools: moveTool(),
expected: "<bos><|turn>system\n" + moveDeclRef + "<turn|>\n" +
"<|turn>user\nMove<turn|>\n<|turn>model\n<|channel>thought\n<channel|>",
"<|turn>user\nMove<turn|>\n<|turn>model\n",
},
{
// Assistant content that is ONLY thinking (strips to empty)
@@ -1301,7 +1315,7 @@ Hi<turn|>
expected: "<bos><|turn>user\nHi<turn|>\n" +
"<|turn>model\n<turn|>\n" +
"<|turn>user\nMore<turn|>\n" +
"<|turn>model\n<|channel>thought\n<channel|>",
"<|turn>model\n",
},
// === Round 4: final coverage gaps ===
@@ -1334,7 +1348,7 @@ Hi<turn|>
messages: []api.Message{{Role: "user", Content: "Tag"}},
tools: arrayNoItemsTool(),
expected: "<bos><|turn>system\n" + arrayNoItemsDeclRef + "<turn|>\n" +
"<|turn>user\nTag<turn|>\n<|turn>model\n<|channel>thought\n<channel|>",
"<|turn>user\nTag<turn|>\n<|turn>model\n",
},
{
// OBJECT property without description but with nested properties
@@ -1342,7 +1356,7 @@ Hi<turn|>
messages: []api.Message{{Role: "user", Content: "Update"}},
tools: objectNoDescTool(),
expected: "<bos><|turn>system\n" + objectNoDescDeclRef + "<turn|>\n" +
"<|turn>user\nUpdate<turn|>\n<|turn>model\n<|channel>thought\n<channel|>",
"<|turn>user\nUpdate<turn|>\n<|turn>model\n",
},
// === Round 5: coding agent patterns ===
@@ -1372,7 +1386,7 @@ Hi<turn|>
"<|tool_response>response:bash{value:" + q + q + "}<tool_response|>" +
"Done.<turn|>\n" +
"<|turn>user\nThanks<turn|>\n" +
"<|turn>model\n<|channel>thought\n<channel|>",
"<|turn>model\n",
},
{
// Tool call with thinking that strips to real remaining content
@@ -1392,7 +1406,7 @@ Hi<turn|>
"<|tool_response>response:bash{value:" + q + "main.go\ngo.mod" + q + "}<tool_response|>" +
"Let me list the files.<turn|>\n" +
"<|turn>user\nOK<turn|>\n" +
"<|turn>model\n<|channel>thought\n<channel|>",
"<|turn>model\n",
},
{
// Argument value containing newlines (multi-line script)
@@ -1460,6 +1474,47 @@ Hi<turn|>
}
}
func TestGemma4RendererVariantsMatchExpectedGenerationPrompt(t *testing.T) {
messages := []api.Message{{Role: "user", Content: "Hello"}}
tests := []struct {
name string
rendererName string
expected string
}{
{
name: "legacy_alias",
rendererName: "gemma4",
expected: "<bos><|turn>user\nHello<turn|>\n<|turn>model\n",
},
{
name: "small",
rendererName: "gemma4-small",
expected: "<bos><|turn>user\nHello<turn|>\n<|turn>model\n",
},
{
name: "large",
rendererName: "gemma4-large",
expected: "<bos><|turn>user\nHello<turn|>\n<|turn>model\n<|channel>thought\n<channel|>",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got, err := RenderWithRenderer(tt.rendererName, messages, nil, nil)
assert.NoError(t, err)
assert.Equal(t, tt.expected, got)
})
}
}
func TestGemma4LargeRendererOmitsEmptyThoughtBlockWhenThinkingEnabled(t *testing.T) {
got, err := RenderWithRenderer("gemma4-large", []api.Message{{Role: "user", Content: "Hello"}}, nil, thinkTrue())
assert.NoError(t, err)
assert.Equal(t, "<bos><|turn>system\n<|think|>\n<turn|>\n<|turn>user\nHello<turn|>\n<|turn>model\n", got)
assert.NotContains(t, got, "<|channel>thought\n<channel|>")
}
func TestGemma4RendererMatchesJinja2ExpandedParity(t *testing.T) {
if os.Getenv("VERIFY_JINJA2") == "" {
t.Skip("set VERIFY_JINJA2=1 to run expanded Jinja2 parity checks")
@@ -1602,15 +1657,35 @@ func TestGemma4RendererMatchesJinja2ExpandedParity(t *testing.T) {
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
renderer := &Gemma4Renderer{useImgTags: RenderImgTags}
got, err := renderer.Render(tt.messages, tt.tools, tt.think)
assert.NoError(t, err)
variants := []struct {
name string
renderer *Gemma4Renderer
templateRel string
}{
{
name: "small",
renderer: &Gemma4Renderer{useImgTags: RenderImgTags},
templateRel: gemma4E2BTemplate,
},
{
name: "large",
renderer: &Gemma4Renderer{useImgTags: RenderImgTags, emptyBlockOnNothink: true},
templateRel: gemma431BTemplate,
},
}
jinja2Output := renderWithJinja2(t, tt.messages, tt.tools, tt.think)
assert.Equal(t, jinja2Output, got,
"renderer output doesn't match Jinja2 template output")
for _, variant := range variants {
t.Run(variant.name, func(t *testing.T) {
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got, err := variant.renderer.Render(tt.messages, tt.tools, tt.think)
assert.NoError(t, err)
jinja2Output := renderWithJinja2Template(t, variant.templateRel, tt.messages, tt.tools, tt.think)
assert.Equal(t, jinja2Output, got,
"renderer output doesn't match Jinja2 template output")
})
}
})
}
}
@@ -1720,12 +1795,35 @@ func TestGemma4RendererToolResponseWithoutNameOrIDUsesUnknown(t *testing.T) {
assert.NotContains(t, got, `response:read{value:<|"|>payload<|"|>}`)
}
func TestGemma4SizeTemplateFixturesDifferAtGenerationPrompt(t *testing.T) {
e2b, err := os.ReadFile(gemma4E2BTemplate)
if err != nil {
t.Fatalf("failed to read %s: %v", gemma4E2BTemplate, err)
}
thirtyOneB, err := os.ReadFile(gemma431BTemplate)
if err != nil {
t.Fatalf("failed to read %s: %v", gemma431BTemplate, err)
}
assert.Contains(t, string(e2b), "{{- '<|turn>model\\n' -}}")
assert.NotContains(t, string(e2b), "{{- '<|channel>thought\\n<channel|>' -}}")
assert.Contains(t, string(thirtyOneB), "{{- '<|turn>model\\n' -}}")
assert.Contains(t, string(thirtyOneB), "{{- '<|channel>thought\\n<channel|>' -}}")
}
// renderWithJinja2 shells out to uv + Python to render messages through the
// Jinja2 chat template. Returns the rendered string.
// E2B Jinja2 chat template. Returns the rendered string.
func renderWithJinja2(t *testing.T, messages []api.Message, tools []api.Tool, think *api.ThinkValue) string {
return renderWithJinja2Template(t, gemma4E2BTemplate, messages, tools, think)
}
// renderWithJinja2Template shells out to uv + Python to render messages through
// the named Jinja2 chat template. Returns the rendered string.
func renderWithJinja2Template(t *testing.T, templateRelPath string, messages []api.Message, tools []api.Tool, think *api.ThinkValue) string {
t.Helper()
templatePath, err := filepath.Abs("testdata/gemma4_chat_template.jinja2")
templatePath, err := filepath.Abs(templateRelPath)
if err != nil {
t.Fatalf("failed to get template path: %v", err)
}
@@ -1814,3 +1912,7 @@ print(tmpl.render(**kwargs), end="")
func thinkTrue() *api.ThinkValue {
return &api.ThinkValue{Value: true}
}
func thinkFalse() *api.ThinkValue {
return &api.ThinkValue{Value: false}
}

View File

@@ -81,8 +81,10 @@ func rendererForName(name string) Renderer {
return renderer
case "nemotron-3-nano":
return &Nemotron3NanoRenderer{}
case "gemma4":
case "gemma4", "gemma4-small":
return &Gemma4Renderer{useImgTags: RenderImgTags}
case "gemma4-large":
return &Gemma4Renderer{useImgTags: RenderImgTags, emptyBlockOnNothink: true}
case "functiongemma":
return &FunctionGemmaRenderer{}
case "glm-4.7":

View File

@@ -0,0 +1,344 @@
{%- macro format_parameters(properties, required) -%}
{%- set standard_keys = ['description', 'type', 'properties', 'required', 'nullable'] -%}
{%- set ns = namespace(found_first=false) -%}
{%- for key, value in properties | dictsort -%}
{%- set add_comma = false -%}
{%- if key not in standard_keys -%}
{%- if ns.found_first %},{% endif -%}
{%- set ns.found_first = true -%}
{{ key }}:{
{%- if value['description'] -%}
description:<|"|>{{ value['description'] }}<|"|>
{%- set add_comma = true -%}
{%- endif -%}
{%- if value['type'] | upper == 'STRING' -%}
{%- if value['enum'] -%}
{%- if add_comma %},{%- else -%} {%- set add_comma = true -%} {% endif -%}
enum:{{ format_argument(value['enum']) }}
{%- endif -%}
{%- elif value['type'] | upper == 'ARRAY' -%}
{%- if value['items'] is mapping and value['items'] -%}
{%- if add_comma %},{%- else -%} {%- set add_comma = true -%} {% endif -%}
items:{
{%- set ns_items = namespace(found_first=false) -%}
{%- for item_key, item_value in value['items'] | dictsort -%}
{%- if item_value is not none -%}
{%- if ns_items.found_first %},{% endif -%}
{%- set ns_items.found_first = true -%}
{%- if item_key == 'properties' -%}
properties:{
{%- if item_value is mapping -%}
{{- format_parameters(item_value, value['items']['required'] | default([])) -}}
{%- endif -%}
}
{%- elif item_key == 'required' -%}
required:[
{%- for req_item in item_value -%}
<|"|>{{- req_item -}}<|"|>
{%- if not loop.last %},{% endif -%}
{%- endfor -%}
]
{%- elif item_key == 'type' -%}
{%- if item_value is string -%}
type:{{ format_argument(item_value | upper) }}
{%- else -%}
type:{{ format_argument(item_value | map('upper') | list) }}
{%- endif -%}
{%- else -%}
{{ item_key }}:{{ format_argument(item_value) }}
{%- endif -%}
{%- endif -%}
{%- endfor -%}
}
{%- endif -%}
{%- endif -%}
{%- if value['nullable'] %}
{%- if add_comma %},{%- else -%} {%- set add_comma = true -%} {% endif -%}
nullable:true
{%- endif -%}
{%- if value['type'] | upper == 'OBJECT' -%}
{%- if value['properties'] is defined and value['properties'] is mapping -%}
{%- if add_comma %},{%- else -%} {%- set add_comma = true -%} {% endif -%}
properties:{
{{- format_parameters(value['properties'], value['required'] | default([])) -}}
}
{%- elif value is mapping -%}
{%- if add_comma %},{%- else -%} {%- set add_comma = true -%} {% endif -%}
properties:{
{{- format_parameters(value, value['required'] | default([])) -}}
}
{%- endif -%}
{%- if value['required'] -%}
{%- if add_comma %},{%- else -%} {%- set add_comma = true -%} {% endif -%}
required:[
{%- for item in value['required'] | default([]) -%}
<|"|>{{- item -}}<|"|>
{%- if not loop.last %},{% endif -%}
{%- endfor -%}
]
{%- endif -%}
{%- endif -%}
{%- if add_comma %},{%- else -%} {%- set add_comma = true -%} {% endif -%}
type:<|"|>{{ value['type'] | upper }}<|"|>}
{%- endif -%}
{%- endfor -%}
{%- endmacro -%}
{%- macro format_function_declaration(tool_data) -%}
declaration:{{- tool_data['function']['name'] -}}{description:<|"|>{{- tool_data['function']['description'] -}}<|"|>
{%- set params = tool_data['function']['parameters'] -%}
{%- if params -%}
,parameters:{
{%- if params['properties'] -%}
properties:{ {{- format_parameters(params['properties'], params['required']) -}} },
{%- endif -%}
{%- if params['required'] -%}
required:[
{%- for item in params['required'] -%}
<|"|>{{- item -}}<|"|>
{{- ',' if not loop.last -}}
{%- endfor -%}
],
{%- endif -%}
{%- if params['type'] -%}
type:<|"|>{{- params['type'] | upper -}}<|"|>}
{%- endif -%}
{%- endif -%}
{%- if 'response' in tool_data['function'] -%}
{%- set response_declaration = tool_data['function']['response'] -%}
,response:{
{%- if response_declaration['description'] -%}
description:<|"|>{{- response_declaration['description'] -}}<|"|>,
{%- endif -%}
{%- if response_declaration['type'] | upper == 'OBJECT' -%}
type:<|"|>{{- response_declaration['type'] | upper -}}<|"|>}
{%- endif -%}
{%- endif -%}
}
{%- endmacro -%}
{%- macro format_argument(argument, escape_keys=True) -%}
{%- if argument is string -%}
{{- '<|"|>' + argument + '<|"|>' -}}
{%- elif argument is boolean -%}
{{- 'true' if argument else 'false' -}}
{%- elif argument is mapping -%}
{{- '{' -}}
{%- set ns = namespace(found_first=false) -%}
{%- for key, value in argument | dictsort -%}
{%- if ns.found_first %},{% endif -%}
{%- set ns.found_first = true -%}
{%- if escape_keys -%}
{{- '<|"|>' + key + '<|"|>' -}}
{%- else -%}
{{- key -}}
{%- endif -%}
:{{- format_argument(value, escape_keys=escape_keys) -}}
{%- endfor -%}
{{- '}' -}}
{%- elif argument is sequence -%}
{{- '[' -}}
{%- for item in argument -%}
{{- format_argument(item, escape_keys=escape_keys) -}}
{%- if not loop.last %},{% endif -%}
{%- endfor -%}
{{- ']' -}}
{%- else -%}
{{- argument -}}
{%- endif -%}
{%- endmacro -%}
{%- macro strip_thinking(text) -%}
{%- set ns = namespace(result='') -%}
{%- for part in text.split('<channel|>') -%}
{%- if '<|channel>' in part -%}
{%- set ns.result = ns.result + part.split('<|channel>')[0] -%}
{%- else -%}
{%- set ns.result = ns.result + part -%}
{%- endif -%}
{%- endfor -%}
{{- ns.result | trim -}}
{%- endmacro -%}
{%- macro format_tool_response_block(tool_name, response) -%}
{{- '<|tool_response>' -}}
{%- if response is mapping -%}
{{- 'response:' + tool_name + '{' -}}
{%- for key, value in response | dictsort -%}
{{- key -}}:{{- format_argument(value, escape_keys=False) -}}
{%- if not loop.last %},{% endif -%}
{%- endfor -%}
{{- '}' -}}
{%- else -%}
{{- 'response:' + tool_name + '{value:' + format_argument(response, escape_keys=False) + '}' -}}
{%- endif -%}
{{- '<tool_response|>' -}}
{%- endmacro -%}
{%- set ns = namespace(prev_message_type=None) -%}
{%- set loop_messages = messages -%}
{{- bos_token -}}
{#- Handle System/Tool Definitions Block -#}
{%- if (enable_thinking is defined and enable_thinking) or tools or messages[0]['role'] in ['system', 'developer'] -%}
{{- '<|turn>system\n' -}}
{#- Inject Thinking token at the very top of the FIRST system turn -#}
{%- if enable_thinking is defined and enable_thinking -%}
{{- '<|think|>\n' -}}
{%- set ns.prev_message_type = 'think' -%}
{%- endif -%}
{%- if messages[0]['role'] in ['system', 'developer'] -%}
{{- messages[0]['content'] | trim -}}
{%- set loop_messages = messages[1:] -%}
{%- endif -%}
{%- if tools -%}
{%- for tool in tools %}
{{- '<|tool>' -}}
{{- format_function_declaration(tool) | trim -}}
{{- '<tool|>' -}}
{%- endfor %}
{%- set ns.prev_message_type = 'tool' -%}
{%- endif -%}
{{- '<turn|>\n' -}}
{%- endif %}
{#- Pre-scan: find last user message index for reasoning guard -#}
{%- set ns_turn = namespace(last_user_idx=-1) -%}
{%- for i in range(loop_messages | length) -%}
{%- if loop_messages[i]['role'] == 'user' -%}
{%- set ns_turn.last_user_idx = i -%}
{%- endif -%}
{%- endfor -%}
{#- Loop through messages -#}
{%- for message in loop_messages -%}
{%- if message['role'] != 'tool' -%}
{%- set ns.prev_message_type = None -%}
{%- set role = 'model' if message['role'] == 'assistant' else message['role'] -%}
{#- Detect continuation: suppress duplicate <|turn>model when previous non-tool message was also assistant -#}
{%- set prev_nt = namespace(role=None, found=false) -%}
{%- if loop.index0 > 0 -%}
{%- for j in range(loop.index0 - 1, -1, -1) -%}
{%- if not prev_nt.found -%}
{%- if loop_messages[j]['role'] != 'tool' -%}
{%- set prev_nt.role = loop_messages[j]['role'] -%}
{%- set prev_nt.found = true -%}
{%- endif -%}
{%- endif -%}
{%- endfor -%}
{%- endif -%}
{%- set continue_same_model_turn = (role == 'model' and prev_nt.role == 'assistant') -%}
{%- if not continue_same_model_turn -%}
{{- '<|turn>' + role + '\n' }}
{%- endif -%}
{#- Render reasoning/reasoning_content as thinking channel -#}
{%- set thinking_text = message.get('reasoning') or message.get('reasoning_content') -%}
{%- if thinking_text and loop.index0 > ns_turn.last_user_idx and message.get('tool_calls') -%}
{{- '<|channel>thought\n' + thinking_text + '\n<channel|>' -}}
{%- endif -%}
{%- if message['tool_calls'] -%}
{%- for tool_call in message['tool_calls'] -%}
{%- set function = tool_call['function'] -%}
{{- '<|tool_call>call:' + function['name'] + '{' -}}
{%- if function['arguments'] is mapping -%}
{%- set ns_args = namespace(found_first=false) -%}
{%- for key, value in function['arguments'] | dictsort -%}
{%- if ns_args.found_first %},{% endif -%}
{%- set ns_args.found_first = true -%}
{{- key -}}:{{- format_argument(value, escape_keys=False) -}}
{%- endfor -%}
{%- elif function['arguments'] is string -%}
{{- function['arguments'] -}}
{%- endif -%}
{{- '}<tool_call|>' -}}
{%- endfor -%}
{%- set ns.prev_message_type = 'tool_call' -%}
{%- endif -%}
{%- set ns_tr_out = namespace(flag=false) -%}
{%- if message.get('tool_responses') -%}
{#- Legacy: tool_responses embedded on the assistant message (Google/Gemma native) -#}
{%- for tool_response in message['tool_responses'] -%}
{{- format_tool_response_block(tool_response['name'] | default('unknown'), tool_response['response']) -}}
{%- set ns_tr_out.flag = true -%}
{%- set ns.prev_message_type = 'tool_response' -%}
{%- endfor -%}
{%- elif message.get('tool_calls') -%}
{#- OpenAI Chat Completions: forward-scan consecutive role:tool messages -#}
{%- set ns_tool_scan = namespace(stopped=false) -%}
{%- for k in range(loop.index0 + 1, loop_messages | length) -%}
{%- if ns_tool_scan.stopped -%}
{%- elif loop_messages[k]['role'] != 'tool' -%}
{%- set ns_tool_scan.stopped = true -%}
{%- else -%}
{%- set follow = loop_messages[k] -%}
{#- Resolve tool_call_id to function name -#}
{%- set ns_tname = namespace(name=follow.get('name') | default('unknown')) -%}
{%- for tc in message['tool_calls'] -%}
{%- if tc.get('id') == follow.get('tool_call_id') -%}
{%- set ns_tname.name = tc['function']['name'] -%}
{%- endif -%}
{%- endfor -%}
{#- Handle content as string or content-parts array -#}
{%- set tool_body = follow.get('content') -%}
{%- if tool_body is string -%}
{{- format_tool_response_block(ns_tname.name, tool_body) -}}
{%- elif tool_body is sequence and tool_body is not string -%}
{%- set ns_txt = namespace(s='') -%}
{%- for part in tool_body -%}
{%- if part.get('type') == 'text' -%}
{%- set ns_txt.s = ns_txt.s + (part.get('text') | default('')) -%}
{%- endif -%}
{%- endfor -%}
{{- format_tool_response_block(ns_tname.name, ns_txt.s) -}}
{%- else -%}
{{- format_tool_response_block(ns_tname.name, tool_body) -}}
{%- endif -%}
{%- set ns_tr_out.flag = true -%}
{%- set ns.prev_message_type = 'tool_response' -%}
{%- endif -%}
{%- endfor -%}
{%- endif -%}
{%- if message['content'] is string -%}
{%- if role == 'model' -%}
{{- strip_thinking(message['content']) -}}
{%- else -%}
{{- message['content'] | trim -}}
{%- endif -%}
{%- elif message['content'] is sequence -%}
{%- for item in message['content'] -%}
{%- if item['type'] == 'text' -%}
{%- if role == 'model' -%}
{{- strip_thinking(item['text']) -}}
{%- else -%}
{{- item['text'] | trim -}}
{%- endif -%}
{%- elif item['type'] == 'image' -%}
{{- '<|image|>' -}}
{%- set ns.prev_message_type = 'image' -%}
{%- elif item['type'] == 'audio' -%}
{{- '<|audio|>' -}}
{%- set ns.prev_message_type = 'audio' -%}
{%- elif item['type'] == 'video' -%}
{{- '<|video|>' -}}
{%- set ns.prev_message_type = 'video' -%}
{%- endif -%}
{%- endfor -%}
{%- endif -%}
{%- if ns.prev_message_type == 'tool_call' and not ns_tr_out.flag -%}
{{- '<|tool_response>' -}}
{%- elif not (ns_tr_out.flag and not message.get('content')) -%}
{{- '<turn|>\n' -}}
{%- endif -%}
{%- endif -%}
{%- endfor -%}
{%- if add_generation_prompt -%}
{%- if ns.prev_message_type != 'tool_response' and ns.prev_message_type != 'tool_call' -%}
{{- '<|turn>model\n' -}}
{%- endif -%}
{%- endif -%}

View File

@@ -523,7 +523,7 @@ func createModel(r api.CreateRequest, name model.Name, baseLayers []*layerGGML,
arch := layer.GGML.KV().Architecture()
switch arch {
case "gemma4":
config.Renderer = cmp.Or(config.Renderer, "gemma4")
config.Renderer = cmp.Or(config.Renderer, gemma4RendererLegacy)
config.Parser = cmp.Or(config.Parser, "gemma4")
if _, ok := r.Parameters["stop"]; !ok {
if r.Parameters == nil {

78
server/gemma4_test.go Normal file
View File

@@ -0,0 +1,78 @@
package server
import "testing"
func TestResolveGemma4Renderer(t *testing.T) {
tests := []struct {
name string
model *Model
want string
}{
{
name: "nil model falls back to legacy alias",
model: nil,
want: gemma4RendererLegacy,
},
{
name: "explicit small passes through",
model: &Model{
Config: testConfigWithRenderer(gemma4RendererSmall),
},
want: gemma4RendererSmall,
},
{
name: "explicit large passes through",
model: &Model{
Config: testConfigWithRenderer(gemma4RendererLarge),
},
want: gemma4RendererLarge,
},
{
name: "legacy e4b tag resolves small",
model: &Model{
Name: "gemma4:e4b",
ShortName: "gemma4:e4b",
Config: testConfigWithRenderer(gemma4RendererLegacy),
},
want: gemma4RendererSmall,
},
{
name: "legacy 31b tag resolves large",
model: &Model{
Name: "gemma4:31b-cloud",
ShortName: "gemma4:31b-cloud",
Config: testConfigWithRenderer(gemma4RendererLegacy),
},
want: gemma4RendererLarge,
},
{
name: "legacy model type resolves small",
model: &Model{
Config: testConfigWithRendererAndType(gemma4RendererLegacy, "4.3B"),
},
want: gemma4RendererSmall,
},
{
name: "legacy model type resolves large",
model: &Model{
Config: testConfigWithRendererAndType(gemma4RendererLegacy, "25.2B"),
},
want: gemma4RendererLarge,
},
{
name: "legacy unknown defaults small",
model: &Model{
Config: testConfigWithRenderer(gemma4RendererLegacy),
},
want: gemma4RendererSmall,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
if got := resolveGemma4Renderer(tt.model); got != tt.want {
t.Fatalf("resolveGemma4Renderer() = %q, want %q", got, tt.want)
}
})
}
}

View File

@@ -19,6 +19,7 @@ import (
"slices"
"strconv"
"strings"
"time"
"github.com/ollama/ollama/api"
"github.com/ollama/ollama/envconfig"
@@ -33,6 +34,10 @@ import (
"github.com/ollama/ollama/x/imagegen/transfer"
)
// Blobs newer than this may belong to another process that has not written its
// manifest yet. They become eligible for the normal mark-and-sweep pass later.
const layerPruneGracePeriod = time.Hour
var (
errCapabilities = errors.New("does not support")
errCapabilityCompletion = errors.New("completion")
@@ -156,7 +161,7 @@ func (m *Model) Capabilities() []model.Capability {
// Temporary workaround — suppress vision/audio for gemma4 MLX models
// until multimodal runtime pipeline lands. Remove when imageproc.go is wired up.
if m.Config.ModelFormat == "safetensors" && m.Config.Renderer == "gemma4" {
if m.Config.ModelFormat == "safetensors" && isGemma4Renderer(m.Config.Renderer) {
capabilities = slices.DeleteFunc(capabilities, func(c model.Capability) bool {
return c == model.CapabilityVision || c == "audio"
})
@@ -478,10 +483,23 @@ func PruneLayers() error {
}
for _, blob := range blobs {
if blob.IsDir() {
continue
}
info, err := blob.Info()
if err != nil {
slog.Error("couldn't stat blob", "blob", blob.Name(), "error", err)
continue
}
if time.Since(info.ModTime()) < layerPruneGracePeriod {
continue
}
name := blob.Name()
name = strings.ReplaceAll(name, "-", ":")
_, err := manifest.BlobsPath(name)
_, err = manifest.BlobsPath(name)
if err != nil {
if errors.Is(err, manifest.ErrInvalidDigestFormat) {
// remove invalid blobs (e.g. partial downloads)

View File

@@ -5,14 +5,58 @@ import (
"fmt"
"net/http"
"net/http/httptest"
"os"
"strings"
"testing"
"time"
"github.com/ollama/ollama/fs/ggml"
"github.com/ollama/ollama/manifest"
"github.com/ollama/ollama/template"
"github.com/ollama/ollama/types/model"
)
func TestPruneLayersSkipsRecentOrphans(t *testing.T) {
t.Setenv("OLLAMA_MODELS", t.TempDir())
recentDigest := "sha256:0000000000000000000000000000000000000000000000000000000000000001"
oldDigest := "sha256:0000000000000000000000000000000000000000000000000000000000000002"
for _, digest := range []string{recentDigest, oldDigest} {
p, err := manifest.BlobsPath(digest)
if err != nil {
t.Fatal(err)
}
if err := os.WriteFile(p, nil, 0o644); err != nil {
t.Fatal(err)
}
}
oldPath, err := manifest.BlobsPath(oldDigest)
if err != nil {
t.Fatal(err)
}
oldTime := time.Now().Add(-layerPruneGracePeriod - time.Hour)
if err := os.Chtimes(oldPath, oldTime, oldTime); err != nil {
t.Fatal(err)
}
if err := PruneLayers(); err != nil {
t.Fatal(err)
}
recentPath, err := manifest.BlobsPath(recentDigest)
if err != nil {
t.Fatal(err)
}
if _, err := os.Stat(recentPath); err != nil {
t.Fatalf("recent orphan was pruned: %v", err)
}
if _, err := os.Stat(oldPath); !os.IsNotExist(err) {
t.Fatalf("old orphan still exists: %v", err)
}
}
func TestModelCapabilities(t *testing.T) {
// Create completion model (llama architecture without vision)
completionModelPath, _ := createBinFile(t, ggml.KV{
@@ -118,6 +162,39 @@ func TestModelCapabilities(t *testing.T) {
},
expectedCaps: []model.Capability{model.CapabilityEmbedding},
},
{
name: "gemma4 small safetensors suppresses vision and audio",
model: Model{
Config: model.ConfigV2{
ModelFormat: "safetensors",
Renderer: gemma4RendererSmall,
Capabilities: []string{"vision", "audio"},
},
Template: chatTemplate,
},
},
{
name: "gemma4 large safetensors suppresses vision and audio",
model: Model{
Config: model.ConfigV2{
ModelFormat: "safetensors",
Renderer: gemma4RendererLarge,
Capabilities: []string{"vision", "audio"},
},
Template: chatTemplate,
},
},
{
name: "legacy gemma4 safetensors suppresses vision and audio",
model: Model{
Config: model.ConfigV2{
ModelFormat: "safetensors",
Renderer: gemma4RendererLegacy,
Capabilities: []string{"vision", "audio"},
},
Template: chatTemplate,
},
},
}
// compare two slices of model.Capability regardless of order

View File

@@ -115,7 +115,8 @@ func chatPrompt(ctx context.Context, m *Model, tokenize tokenizeFunc, opts *api.
func renderPrompt(m *Model, msgs []api.Message, tools []api.Tool, think *api.ThinkValue) (string, error) {
if m.Config.Renderer != "" {
rendered, err := renderers.RenderWithRenderer(m.Config.Renderer, msgs, tools, think)
rendererName := resolveRendererName(m)
rendered, err := renderers.RenderWithRenderer(rendererName, msgs, tools, think)
if err != nil {
return "", err
}

View File

@@ -13,6 +13,14 @@ import (
"github.com/ollama/ollama/types/model"
)
func testConfigWithRenderer(renderer string) model.ConfigV2 {
return model.ConfigV2{Renderer: renderer}
}
func testConfigWithRendererAndType(renderer, modelType string) model.ConfigV2 {
return model.ConfigV2{Renderer: renderer, ModelType: modelType}
}
func TestChatPrompt(t *testing.T) {
type expect struct {
prompt string
@@ -397,3 +405,43 @@ func TestChatPromptGLMOcrRendererAddsImageTags(t *testing.T) {
t.Fatalf("prompt missing glm-ocr image tags, got: %q", prompt)
}
}
func TestRenderPromptResolvesDynamicGemma4Renderer(t *testing.T) {
msgs := []api.Message{{Role: "user", Content: "Hello"}}
tests := []struct {
name string
model Model
want string
}{
{
name: "small from name",
model: Model{
Name: "gemma4:e4b",
ShortName: "gemma4:e4b",
Config: testConfigWithRenderer(gemma4RendererLegacy),
},
want: "<bos><|turn>user\nHello<turn|>\n<|turn>model\n",
},
{
name: "large from model type",
model: Model{
Config: testConfigWithRendererAndType(gemma4RendererLegacy, "25.2B"),
},
want: "<bos><|turn>user\nHello<turn|>\n<|turn>model\n<|channel>thought\n<channel|>",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got, err := renderPrompt(&tt.model, msgs, nil, nil)
if err != nil {
t.Fatal(err)
}
if diff := cmp.Diff(got, tt.want); diff != "" {
t.Fatalf("rendered prompt mismatch (-got +want):\n%s", diff)
}
})
}
}

View File

@@ -0,0 +1,110 @@
package server
import (
"strconv"
"strings"
"github.com/ollama/ollama/format"
)
const (
gemma4RendererLegacy = "gemma4"
gemma4RendererSmall = "gemma4-small"
gemma4RendererLarge = "gemma4-large"
// Gemma 4 small templates cover the e2b/e4b family, while 26b/31b use the
// large template. Default to the small prompt unless the model is clearly in
// the large range.
gemma4LargeMinParameterCount = 16_000_000_000
)
func resolveRendererName(m *Model) string {
if m == nil || m.Config.Renderer == "" {
return ""
}
switch m.Config.Renderer {
case gemma4RendererLegacy:
return resolveGemma4Renderer(m)
default:
return m.Config.Renderer
}
}
func resolveGemma4Renderer(m *Model) string {
if m == nil || m.Config.Renderer != gemma4RendererLegacy {
if m == nil {
return gemma4RendererLegacy
}
return m.Config.Renderer
}
if renderer, ok := gemma4RendererFromName(m.ShortName); ok {
return renderer
}
if renderer, ok := gemma4RendererFromName(m.Name); ok {
return renderer
}
if parameterCount, ok := parseHumanParameterCount(m.Config.ModelType); ok {
return gemma4RendererForParameterCount(parameterCount)
}
return gemma4RendererSmall
}
func gemma4RendererForParameterCount(parameterCount uint64) string {
if parameterCount >= gemma4LargeMinParameterCount {
return gemma4RendererLarge
}
return gemma4RendererSmall
}
func gemma4RendererFromName(name string) (string, bool) {
lower := strings.ToLower(name)
switch {
case strings.Contains(lower, "e2b"), strings.Contains(lower, "e4b"):
return gemma4RendererSmall, true
case strings.Contains(lower, "26b"), strings.Contains(lower, "31b"):
return gemma4RendererLarge, true
default:
return "", false
}
}
func parseHumanParameterCount(s string) (uint64, bool) {
if s == "" {
return 0, false
}
unit := strings.ToUpper(s[len(s)-1:])
var multiplier float64
switch unit {
case "B":
multiplier = float64(format.Billion)
case "M":
multiplier = float64(format.Million)
case "K":
multiplier = float64(format.Thousand)
default:
return 0, false
}
value, err := strconv.ParseFloat(s[:len(s)-1], 64)
if err != nil {
return 0, false
}
return uint64(value * multiplier), true
}
func isGemma4Renderer(renderer string) bool {
switch renderer {
case gemma4RendererLegacy, gemma4RendererSmall, gemma4RendererLarge:
return true
default:
return false
}
}

View File

@@ -928,6 +928,59 @@ func TestCreateDetectTemplate(t *testing.T) {
})
}
func TestCreateGemma4KeepsDynamicRendererAlias(t *testing.T) {
gin.SetMode(gin.TestMode)
p := t.TempDir()
t.Setenv("OLLAMA_MODELS", p)
var s Server
_, digest := createBinFile(t, ggml.KV{
"general.architecture": "gemma4",
"general.parameter_count": uint64(25_200_000_000),
}, nil)
w := createRequest(t, s.CreateHandler, api.CreateRequest{
Name: "test",
Files: map[string]string{"test.gguf": digest},
Stream: &stream,
})
if w.Code != http.StatusOK {
t.Fatalf("expected status code 200, actual %d", w.Code)
}
mf, err := manifest.ParseNamedManifest(model.ParseName("test"))
if err != nil {
t.Fatalf("parse manifest: %v", err)
}
if mf.Config.Digest == "" {
t.Fatalf("unexpected empty config digest for manifest")
}
configPath, err := manifest.BlobsPath(mf.Config.Digest)
if err != nil {
t.Fatalf("config blob path: %v", err)
}
cfgFile, err := os.Open(configPath)
if err != nil {
t.Fatalf("open config blob: %v", err)
}
defer cfgFile.Close()
var cfg model.ConfigV2
if err := json.NewDecoder(cfgFile).Decode(&cfg); err != nil {
t.Fatalf("decode config: %v", err)
}
if cfg.Renderer != gemma4RendererLegacy {
t.Fatalf("expected renderer %q, got %q", gemma4RendererLegacy, cfg.Renderer)
}
if cfg.Parser != "gemma4" {
t.Fatalf("expected parser %q, got %q", "gemma4", cfg.Parser)
}
}
func TestDetectModelTypeFromFiles(t *testing.T) {
t.Run("gguf file", func(t *testing.T) {
_, digest := createBinFile(t, nil, nil)

View File

@@ -191,6 +191,10 @@ func inferSafetensorsCapabilities(modelDir string) []string {
capabilities = append(capabilities, "vision")
}
if supportsAudio(modelDir) {
capabilities = append(capabilities, "audio")
}
if supportsThinking(modelDir) {
capabilities = append(capabilities, "thinking")
}
@@ -496,32 +500,38 @@ func supportsThinking(modelDir string) bool {
return false
}
// supportsVision checks if the model supports image input based on its architecture.
// Qwen3.5 multimodal checkpoints are published as ConditionalGeneration architectures.
// supportsVision checks if the model has a vision encoder by looking for
// vision_config in config.json.
func supportsVision(modelDir string) bool {
configPath := filepath.Join(modelDir, "config.json")
data, err := os.ReadFile(configPath)
data, err := os.ReadFile(filepath.Join(modelDir, "config.json"))
if err != nil {
return false
}
var cfg struct {
Architectures []string `json:"architectures"`
ModelType string `json:"model_type"`
VisionConfig *map[string]any `json:"vision_config"`
}
if err := json.Unmarshal(data, &cfg); err != nil {
return false
}
for _, arch := range cfg.Architectures {
archLower := strings.ToLower(arch)
if strings.Contains(archLower, "qwen3") && strings.Contains(archLower, "conditionalgeneration") {
return true
}
return cfg.VisionConfig != nil
}
func supportsAudio(modelDir string) bool {
data, err := os.ReadFile(filepath.Join(modelDir, "config.json"))
if err != nil {
return false
}
typeLower := strings.ToLower(cfg.ModelType)
return strings.Contains(typeLower, "qwen3") && strings.Contains(typeLower, "conditionalgeneration")
var cfg struct {
AudioConfig *map[string]any `json:"audio_config"`
}
if err := json.Unmarshal(data, &cfg); err != nil {
return false
}
return cfg.AudioConfig != nil
}
// getParserName returns the parser name for a model based on its architecture.
@@ -550,6 +560,9 @@ func getParserName(modelDir string) string {
if strings.Contains(archLower, "deepseek") {
return "deepseek3"
}
if strings.Contains(archLower, "gemma4") {
return "gemma4"
}
if strings.Contains(archLower, "qwen3") {
return "qwen3"
}
@@ -564,6 +577,9 @@ func getParserName(modelDir string) string {
if strings.Contains(typeLower, "deepseek") {
return "deepseek3"
}
if strings.Contains(typeLower, "gemma4") {
return "gemma4"
}
if strings.Contains(typeLower, "qwen3") {
return "qwen3"
}
@@ -592,6 +608,9 @@ func getRendererName(modelDir string) string {
// Check architectures for known renderers
for _, arch := range cfg.Architectures {
archLower := strings.ToLower(arch)
if strings.Contains(archLower, "gemma4") {
return "gemma4"
}
if strings.Contains(archLower, "glm4") || strings.Contains(archLower, "glm-4") {
return "glm-4.7"
}
@@ -606,6 +625,9 @@ func getRendererName(modelDir string) string {
// Also check model_type
if cfg.ModelType != "" {
typeLower := strings.ToLower(cfg.ModelType)
if strings.Contains(typeLower, "gemma4") {
return "gemma4"
}
if strings.Contains(typeLower, "glm4") || strings.Contains(typeLower, "glm-4") {
return "glm-4.7"
}

View File

@@ -311,10 +311,30 @@ func TestInferSafetensorsCapabilities(t *testing.T) {
name: "qwen3.5 multimodal model",
configJSON: `{
"architectures": ["Qwen3_5ForConditionalGeneration"],
"model_type": "qwen3"
"model_type": "qwen3",
"vision_config": {"hidden_size": 1024}
}`,
want: []string{"completion", "vision", "thinking"},
},
{
name: "model with audio config",
configJSON: `{
"architectures": ["Gemma4ForConditionalGeneration"],
"model_type": "gemma4",
"vision_config": {"hidden_size": 1024},
"audio_config": {"num_mel_bins": 128}
}`,
want: []string{"completion", "vision", "audio"},
},
{
name: "model with audio but no vision",
configJSON: `{
"architectures": ["SomeAudioModel"],
"model_type": "other",
"audio_config": {"num_mel_bins": 128}
}`,
want: []string{"completion", "audio"},
},
{
name: "non-qwen conditional generation model",
configJSON: `{
@@ -339,6 +359,74 @@ func TestInferSafetensorsCapabilities(t *testing.T) {
}
}
func TestParsePerExpertInputs(t *testing.T) {
makeInput := func(name, quantize string) create.PackedTensorInput {
return create.PackedTensorInput{Name: name, Quantize: quantize}
}
t.Run("uniform quant across projections", func(t *testing.T) {
inputs := []create.PackedTensorInput{
makeInput("layer.moe.experts.0.gate_proj.weight", "int4"),
makeInput("layer.moe.experts.1.gate_proj.weight", "int4"),
makeInput("layer.moe.experts.0.down_proj.weight", "int4"),
makeInput("layer.moe.experts.1.down_proj.weight", "int4"),
}
groups, projQ := parsePerExpertInputs("layer.moe.experts", inputs)
if groups == nil {
t.Fatal("expected non-nil groups")
}
if len(groups) != 2 {
t.Fatalf("expected 2 projection groups, got %d", len(groups))
}
if projQ["gate_proj.weight"] != "int4" {
t.Errorf("gate_proj quant = %q, want int4", projQ["gate_proj.weight"])
}
if projQ["down_proj.weight"] != "int4" {
t.Errorf("down_proj quant = %q, want int4", projQ["down_proj.weight"])
}
})
t.Run("mixed quant across projections", func(t *testing.T) {
inputs := []create.PackedTensorInput{
makeInput("layer.moe.experts.0.gate_proj.weight", "int4"),
makeInput("layer.moe.experts.1.gate_proj.weight", "int4"),
makeInput("layer.moe.experts.0.down_proj.weight", "int8"),
makeInput("layer.moe.experts.1.down_proj.weight", "int8"),
}
groups, projQ := parsePerExpertInputs("layer.moe.experts", inputs)
if groups == nil {
t.Fatal("expected non-nil groups for mixed cross-projection quant")
}
if projQ["gate_proj.weight"] != "int4" {
t.Errorf("gate_proj quant = %q, want int4", projQ["gate_proj.weight"])
}
if projQ["down_proj.weight"] != "int8" {
t.Errorf("down_proj quant = %q, want int8", projQ["down_proj.weight"])
}
})
t.Run("mixed quant within same projection rejected", func(t *testing.T) {
inputs := []create.PackedTensorInput{
makeInput("layer.moe.experts.0.down_proj.weight", "int4"),
makeInput("layer.moe.experts.1.down_proj.weight", "int8"),
}
groups, _ := parsePerExpertInputs("layer.moe.experts", inputs)
if groups != nil {
t.Fatal("expected nil for mixed quant within same projection")
}
})
t.Run("non-experts group rejected", func(t *testing.T) {
inputs := []create.PackedTensorInput{
makeInput("layer.mlp.gate_proj.weight", "int4"),
}
groups, _ := parsePerExpertInputs("layer.mlp", inputs)
if groups != nil {
t.Fatal("expected nil for non-experts group")
}
})
}
func TestQuantizeSupported(t *testing.T) {
// This just verifies the function exists and returns a boolean
// The actual value depends on build tags (mlx vs non-mlx)

View File

@@ -97,6 +97,20 @@ func loadAndQuantizeArray(r io.Reader, name, quantize string, arrays map[string]
groupSize, bits, mode := model.QuantizationParams(quantize)
qweight, scales, qbiases := mlx.Quantize(arr, groupSize, bits, mode)
// Validate quantization produced non-empty output. MLX quantize may return
// empty arrays for unsupported mode/bits combinations without raising an error.
mlx.Eval(qweight, scales)
if len(qweight.Dims()) == 0 || qweight.Dims()[0] == 0 {
st.Free()
return tmpPath, nil, nil, fmt.Errorf("mlx.Quantize produced empty weight for %s (quantize=%s, groupSize=%d, bits=%d, mode=%s)",
name, quantize, groupSize, bits, mode)
}
if len(scales.Dims()) == 0 || scales.Dims()[0] == 0 {
st.Free()
return tmpPath, nil, nil, fmt.Errorf("mlx.Quantize produced empty scales for %s (quantize=%s, groupSize=%d, bits=%d, mode=%s)",
name, quantize, groupSize, bits, mode)
}
qweight = mlx.Contiguous(qweight, false)
scales = mlx.Contiguous(scales, false)
arrays[name] = qweight
@@ -174,8 +188,8 @@ func quantizeTensor(r io.Reader, tensorName, dtype string, shape []int32, quanti
// Returns the blob bytes.
func quantizePackedGroup(groupName string, inputs []create.PackedTensorInput) ([]byte, error) {
// Check if inputs are per-expert tensors that should be stacked into 3D
if projGroups, quantize := parsePerExpertInputs(groupName, inputs); projGroups != nil {
return stackAndQuantizeExpertGroup(groupName, projGroups, quantize)
if projGroups, projQuantize := parsePerExpertInputs(groupName, inputs); projGroups != nil {
return stackAndQuantizeExpertGroup(groupName, projGroups, projQuantize)
}
allArrays := make(map[string]*mlx.Array)
@@ -224,6 +238,17 @@ func quantizePackedGroup(groupName string, inputs []create.PackedTensorInput) ([
mlx.Pin(finalArrays...)
pinned = append(pinned, finalArrays...)
// Record per-tensor quant type so the model can resolve params at load time.
if input.Quantize != "" {
if groupSize, _, _ := model.QuantizationParams(input.Quantize); groupSize > 0 {
if metadata == nil {
metadata = make(map[string]string)
}
metadata[input.Name+".quant_type"] = input.Quantize
metadata[input.Name+".group_size"] = strconv.Itoa(groupSize)
}
}
if st != nil {
st.Free()
}
@@ -279,57 +304,60 @@ type expertTensorInfo struct {
}
// parsePerExpertInputs groups per-expert 2D tensor inputs by projection type
// and returns the uniform quantization type shared by all inputs.
// Returns nil if the inputs are not per-expert tensors (e.g., already stacked 3D)
// or if the inputs have mixed quantization types.
// and returns per-projection quantization types. Different projections may use
// different quant types (e.g., gate_up=int4, down=int8) but all experts within
// a projection must share the same type.
// Returns nil if the inputs are not per-expert tensors (e.g., already stacked 3D).
// Only handles ".experts" groups; ".shared_experts" groups are left unpacked.
func parsePerExpertInputs(groupName string, inputs []create.PackedTensorInput) (map[string][]expertTensorInfo, string) {
func parsePerExpertInputs(groupName string, inputs []create.PackedTensorInput) (map[string][]expertTensorInfo, map[string]string) {
if !strings.HasSuffix(groupName, ".experts") {
return nil, ""
return nil, nil
}
quantize := inputs[0].Quantize
groups := make(map[string][]expertTensorInfo)
projQuantize := make(map[string]string) // projection -> quant type
for _, input := range inputs {
if input.Quantize != quantize {
return nil, "" // mixed quantization types
}
suffix := strings.TrimPrefix(input.Name, groupName)
m := perExpertSuffix.FindStringSubmatch(suffix)
if m == nil {
return nil, "" // not a per-expert pattern
return nil, nil // not a per-expert pattern
}
index, err := strconv.Atoi(m[1])
if err != nil {
return nil, ""
return nil, nil
}
groups[m[2]] = append(groups[m[2]], expertTensorInfo{
proj := m[2]
if existing, ok := projQuantize[proj]; ok {
if input.Quantize != existing {
return nil, nil // mixed quant within same projection
}
} else {
projQuantize[proj] = input.Quantize
}
groups[proj] = append(groups[proj], expertTensorInfo{
index: index,
proj: m[2],
proj: proj,
input: input,
})
}
if len(groups) == 0 {
return nil, ""
return nil, nil
}
return groups, quantize
return groups, projQuantize
}
// stackAndQuantizeExpertGroup decodes per-expert tensors, stacks them into 3D
// switch_mlp tensors, quantizes, and returns the combined safetensors blob.
func stackAndQuantizeExpertGroup(groupName string, projGroups map[string][]expertTensorInfo, quantize string) ([]byte, error) {
// projQuantize maps projection name to its quantization type (may differ per projection).
func stackAndQuantizeExpertGroup(groupName string, projGroups map[string][]expertTensorInfo, projQuantize map[string]string) ([]byte, error) {
groupBase := strings.TrimSuffix(groupName, ".experts")
allArrays := make(map[string]*mlx.Array)
var pinned []*mlx.Array
var metadata map[string]string
if groupSize, _, _ := model.QuantizationParams(quantize); groupSize > 0 && quantize != "" {
metadata = map[string]string{
"quant_type": quantize,
"group_size": strconv.Itoa(groupSize),
}
}
// Build metadata: if all projections use the same quant type, set global metadata.
// Otherwise record per-tensor quant info.
metadata := make(map[string]string)
// Sort projection names for deterministic output
projNames := make([]string, 0, len(projGroups))
@@ -339,7 +367,11 @@ func stackAndQuantizeExpertGroup(groupName string, projGroups map[string][]exper
sort.Strings(projNames)
cleanup := func() {
mlx.Unpin(pinned...)
for _, p := range pinned {
if p != nil {
mlx.Unpin(p)
}
}
mlx.Sweep()
}
@@ -382,11 +414,27 @@ func stackAndQuantizeExpertGroup(groupName string, projGroups map[string][]exper
mlx.Pin(stacked)
pinned = append(pinned, stacked)
// Free individual decoded arrays
// Free individual decoded arrays (remove from pinned to avoid double-unpin in cleanup)
for i, p := range pinned {
for _, d := range decoded {
if p == d {
pinned[i] = nil
}
}
}
mlx.Unpin(decoded...)
mlx.Sweep()
stackedName := groupBase + ".switch_mlp." + proj
quantize := projQuantize[proj]
// Record per-tensor quant metadata so the model can resolve params at load time.
if quantize != "" {
if groupSize, _, _ := model.QuantizationParams(quantize); groupSize > 0 {
metadata[stackedName+".quant_type"] = quantize
metadata[stackedName+".group_size"] = strconv.Itoa(groupSize)
}
}
// Quantize the stacked tensor
if quantize != "" {
@@ -394,6 +442,14 @@ func stackAndQuantizeExpertGroup(groupName string, projGroups map[string][]exper
qweight, scales, qbiases := mlx.Quantize(stacked, groupSize, bits, mode)
// Validate quantization produced non-empty output.
mlx.Eval(qweight, scales)
if len(qweight.Dims()) == 0 || qweight.Dims()[0] == 0 {
cleanup()
return nil, fmt.Errorf("mlx.Quantize produced empty weight for %s (quantize=%s, groupSize=%d, bits=%d, mode=%s)",
stackedName, quantize, groupSize, bits, mode)
}
qweight = mlx.Contiguous(qweight, false)
scales = mlx.Contiguous(scales, false)
allArrays[stackedName] = qweight
@@ -409,12 +465,19 @@ func stackAndQuantizeExpertGroup(groupName string, projGroups map[string][]exper
mlx.Pin(toEval...)
pinned = append(pinned, toEval...)
// Free stacked source array
// Free stacked source array (remove from pinned to avoid double-unpin in cleanup)
for i, p := range pinned {
if p == stacked {
pinned[i] = nil
}
}
mlx.Unpin(stacked)
mlx.Sweep()
} else {
stacked = mlx.Contiguous(stacked, false)
mlx.Eval(stacked)
mlx.Pin(stacked)
pinned = append(pinned, stacked)
allArrays[stackedName] = stacked
}
}
@@ -529,7 +592,7 @@ func decodeSourceFP8Tensor(weight, scaleInv *mlx.Array) (*mlx.Array, error) {
padBottom := blockRows*scaleShape[0] - rows
padSide := blockCols*scaleShape[1] - cols
if padBottom > 0 || padSide > 0 {
decoded = mlx.Pad(decoded, []int32{0, int32(padBottom), 0, int32(padSide)})
decoded = mlx.PadConstant(decoded, []int{0, 1}, []int{0, 0}, []int{padBottom, padSide})
}
decoded = mlx.Reshape(decoded, int32(scaleShape[0]), int32(blockRows), int32(scaleShape[1]), int32(blockCols))

View File

@@ -246,6 +246,11 @@ func ShouldQuantize(name, component string) bool {
return false
}
// Skip audio encoder tensors (highly sensitive to quantization)
if strings.Contains(name, "audio_tower") || strings.Contains(name, "embed_audio") {
return false
}
// Skip embeddings
if strings.Contains(name, "embed") {
return false
@@ -291,6 +296,22 @@ func normalizeQuantType(quantize string) string {
}
}
// isAligned checks if a tensor's last dimension is divisible by the
// group size required for the given quantization type.
func isAligned(shape []int32, quantType string) bool {
if len(shape) == 0 {
return false
}
groupSize := int32(32)
switch normalizeQuantType(quantType) {
case "nvfp4":
groupSize = 16
case "int4", "int8":
groupSize = 64
}
return shape[len(shape)-1]%groupSize == 0
}
func isStackedExpertWeight(name string) bool {
// Combined/stacked expert tensors may be emitted either as "...proj.weight" (per-expert)
// or "...proj" (pre-stacked packed tensor).
@@ -300,16 +321,16 @@ func isStackedExpertWeight(name string) bool {
return strings.Contains(name, ".mlp.switch_mlp.") ||
strings.Contains(name, ".mlp.experts.") ||
strings.Contains(name, ".mlp.shared_experts.")
strings.Contains(name, ".mlp.shared_experts.") ||
strings.Contains(name, ".moe.experts.")
}
// GetTensorQuantization returns the appropriate quantization type for a tensor.
// Returns "" if the tensor should not be quantized.
// This implements mixed-precision quantization:
// - Attention MLA weights (q_a, q_b, kv_a, kv_b): unquantized (most sensitive)
// - Output projection, gate/up weights: int4 (less sensitive)
// - Down projection weights: int8 (more sensitive, would be Q6 in GGML but no MLX kernel)
// - v_proj, k_proj, down_proj: promoted to INT8 when base is INT4
// - Norms, embeddings, biases, routing gates: no quantization
// - All other eligible weights: use requested quantization type
func GetTensorQuantization(name string, shape []int32, quantize string) string {
stackedExpert := isStackedExpertWeight(name)
@@ -336,60 +357,35 @@ func GetTensorQuantization(name string, shape []int32, quantize string) string {
// Normalize quantization type to canonical form
quantNorm := normalizeQuantType(quantize)
// MLX quantization requires last dimension to be divisible by group size
// nvfp4: 16, mxfp4/mxfp8: 32, int4/int8: 64
groupSize := int32(32)
switch quantNorm {
case "nvfp4":
groupSize = 16
case "int4", "int8":
groupSize = 64
}
if shape[len(shape)-1]%groupSize != 0 {
return ""
}
// Skip routing gate weights (should stay high precision)
// In safetensors these are: mlp.gate.weight (not mlp.gate_proj.weight)
if strings.Contains(name, "mlp.gate.weight") && !strings.Contains(name, "_proj") {
return ""
}
// MLX quantization requires last dimension to be divisible by group size.
if !isAligned(shape, quantNorm) {
return ""
}
// For non-affine modes, use the same quantization for all eligible tensors.
if quantNorm == "nvfp4" || quantNorm == "mxfp4" || quantNorm == "mxfp8" {
return quantNorm
}
// Attention MLA weights - keep unquantized (bf16)
// These are highly sensitive: errors accumulate in the KV cache over time
// q_a_proj, q_b_proj, kv_a_proj_with_mqa, kv_b_proj
if strings.Contains(name, "q_a_proj") ||
strings.Contains(name, "q_b_proj") ||
strings.Contains(name, "kv_a_proj") ||
strings.Contains(name, "kv_b_proj") {
return "" // No quantization - keep bf16
// Value projection weights directly determine attention output quality.
// Down projection weights feed directly into the residual stream where
// errors accumulate across layers. Both benefit from higher precision.
// Promote to INT8 when base is INT4 (same affine mode, compatible with
// GatherQMM for MoE expert tensors).
if quantNorm == "int4" {
if strings.Contains(name, ".v_proj") || strings.Contains(name, ".k_proj") || strings.Contains(name, "down_proj") {
if isAligned(shape, "int8") {
return "int8"
}
}
}
// Down projection weights - use INT8 (would be Q6_K in GGML, but MLX has no Q6 kernel)
// mlp.down_proj, mlp.experts.X.down_proj, mlp.shared_experts.down_proj
if strings.Contains(name, "down_proj") {
return "int8"
}
// Output projection, gate/up weights - use requested quantization (INT4)
// o_proj, gate_proj, up_proj
if strings.Contains(name, "o_proj") ||
strings.Contains(name, "gate_proj") ||
strings.Contains(name, "up_proj") {
return quantNorm
}
// LM head - use requested quantization
if strings.Contains(name, "lm_head") {
return quantNorm
}
// Default to requested quantization for other weights
return quantNorm
}
@@ -411,6 +407,7 @@ func ExpertGroupPrefix(tensorName string) string {
".mlp.experts.",
".mlp.shared_experts.",
".mlp.switch_mlp.",
".moe.experts.",
} {
idx := strings.Index(tensorName, marker)
if idx == -1 {
@@ -637,6 +634,8 @@ var tensorImportTransformRegistry = map[string]tensorImportTransformFactory{
"Qwen3_5MoeForConditionalGeneration": newQwen35ImportTransform,
"Qwen3NextMoeForCausalLM": newQwen35ImportTransform,
"Qwen3NextMoeForConditionalGeneration": newQwen35ImportTransform,
"Gemma4ForCausalLM": newGemma4ImportTransform,
"Gemma4ForConditionalGeneration": newGemma4ImportTransform,
}
func newTensorImportTransform(modelDir string, cfg sourceModelConfig) (tensorImportTransform, error) {

View File

@@ -1169,6 +1169,11 @@ func TestShouldQuantize(t *testing.T) {
{"ln prefix", "ln_1.weight", "", false},
{"layernorm in name", "input_layernorm.weight", "", false},
// Audio encoder tensors should not be quantized
{"audio tower weight", "model.audio_tower.layers.0.weight", "", false},
{"audio tower norm", "model.audio_tower.norm.weight", "", false},
{"embed audio weight", "embed_audio.weight", "", false},
// Biases should not be quantized
{"bias tensor", "attention.bias", "", false},
{"proj bias", "o_proj.bias", "", false},
@@ -1262,6 +1267,11 @@ func TestExpertGroupPrefix(t *testing.T) {
{"model.layers.1.mlp.experts.63.gate_proj.weight", "model.layers.1.mlp.experts"},
{"model.layers.0.mlp.experts.0.up_proj.weight", "model.layers.0.mlp.experts"},
// MoE expert tensors (Gemma-style .moe.experts.)
{"model.layers.0.moe.experts.0.gate_proj.weight", "model.layers.0.moe.experts"},
{"model.layers.1.moe.experts.42.down_proj.weight", "model.layers.1.moe.experts"},
{"language_model.model.layers.2.moe.experts.127.up_proj.weight", "language_model.model.layers.2.moe.experts"},
// Expert tensors with language_model prefix should also match
{"language_model.model.layers.0.mlp.experts.0.gate_proj.weight", "language_model.model.layers.0.mlp.experts"},
{"language_model.model.layers.1.mlp.experts.255.down_proj.weight", "language_model.model.layers.1.mlp.experts"},
@@ -1369,6 +1379,94 @@ func TestGetTensorQuantization_StackedExpert3D(t *testing.T) {
}
}
func TestIsAligned(t *testing.T) {
tests := []struct {
name string
shape []int32
quantType string
want bool
}{
// int4/int8: group_size=64
{"int4 aligned", []int32{1024, 4096}, "int4", true},
{"int4 unaligned", []int32{1024, 48}, "int4", false},
{"int8 aligned", []int32{1024, 128}, "int8", true},
{"int8 unaligned", []int32{1024, 32}, "int8", false},
// nvfp4: group_size=16
{"nvfp4 aligned", []int32{1024, 48}, "nvfp4", true},
{"nvfp4 unaligned", []int32{1024, 24}, "nvfp4", false},
{"nvfp4 aligned 16", []int32{1024, 16}, "nvfp4", true},
// mxfp4/mxfp8: group_size=32
{"mxfp4 aligned", []int32{1024, 64}, "mxfp4", true},
{"mxfp4 unaligned", []int32{1024, 48}, "mxfp4", false},
{"mxfp8 aligned", []int32{1024, 32}, "mxfp8", true},
{"mxfp8 unaligned", []int32{1024, 24}, "mxfp8", false},
// Edge cases
{"empty shape", []int32{}, "int4", false},
{"1D tensor", []int32{4096}, "int4", true},
{"3D stacked expert", []int32{128, 4096, 2816}, "int4", true},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got := isAligned(tt.shape, tt.quantType)
if got != tt.want {
t.Errorf("isAligned(%v, %q) = %v, want %v", tt.shape, tt.quantType, got, tt.want)
}
})
}
}
func TestGetTensorQuantization_MixedPrecisionPromotion(t *testing.T) {
aligned := []int32{4096, 4096} // divisible by 64
tests := []struct {
name string
tensor string
shape []int32
quantize string
want string
}{
// int4 → int8 promotion for sensitive tensors
{"v_proj int4 promoted", "model.layers.0.self_attn.v_proj.weight", aligned, "int4", "int8"},
{"k_proj int4 promoted", "model.layers.0.self_attn.k_proj.weight", aligned, "int4", "int8"},
{"down_proj int4 promoted", "model.layers.0.mlp.down_proj.weight", aligned, "int4", "int8"},
// Non-sensitive int4 tensors stay int4
{"q_proj int4 stays", "model.layers.0.self_attn.q_proj.weight", aligned, "int4", "int4"},
{"o_proj int4 stays", "model.layers.0.self_attn.o_proj.weight", aligned, "int4", "int4"},
{"gate_proj int4 stays", "model.layers.0.mlp.gate_proj.weight", aligned, "int4", "int4"},
{"up_proj int4 stays", "model.layers.0.mlp.up_proj.weight", aligned, "int4", "int4"},
// nvfp4/mxfp4/mxfp8: no promotion (uniform quantization)
{"v_proj nvfp4 uniform", "model.layers.0.self_attn.v_proj.weight", aligned, "nvfp4", "nvfp4"},
{"down_proj mxfp4 uniform", "model.layers.0.mlp.down_proj.weight", aligned, "mxfp4", "mxfp4"},
{"v_proj mxfp8 uniform", "model.layers.0.self_attn.v_proj.weight", aligned, "mxfp8", "mxfp8"},
// int8: already 8-bit, no promotion
{"v_proj int8 stays", "model.layers.0.self_attn.v_proj.weight", aligned, "int8", "int8"},
// Expert tensors: down_proj also promoted for int4
{"expert down_proj int4", "model.layers.0.mlp.experts.down_proj.weight", []int32{128, 4096, 2816}, "int4", "int8"},
{"moe expert down_proj int4", "model.layers.0.moe.experts.down_proj.weight", []int32{128, 4096, 2816}, "int4", "int8"},
// Unaligned: falls back to bf16 (empty string)
{"v_proj int4 unaligned", "model.layers.0.self_attn.v_proj.weight", []int32{1024, 48}, "int4", ""},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got := GetTensorQuantization(tt.tensor, tt.shape, tt.quantize)
if got != tt.want {
t.Errorf("GetTensorQuantization(%q, %v, %q) = %q, want %q",
tt.tensor, tt.shape, tt.quantize, got, tt.want)
}
})
}
}
func TestCreateSafetensorsModel_Qwen35NVFP4PacksSwitchMLPExperts(t *testing.T) {
dir := t.TempDir()

277
x/create/gemma4.go Normal file
View File

@@ -0,0 +1,277 @@
package create
import (
"encoding/json"
"fmt"
"io"
"os"
"path/filepath"
"regexp"
"strconv"
"strings"
"github.com/ollama/ollama/x/safetensors"
)
type gemma4ImportTransform struct {
numLayers int
numExperts int
}
// gemma4Config is a minimal subset of the Gemma 4 config.json used for quant decisions.
type gemma4Config struct {
NumHiddenLayers int `json:"num_hidden_layers"`
NumExperts int `json:"num_experts"`
TextConfig struct {
NumHiddenLayers int `json:"num_hidden_layers"`
NumExperts int `json:"num_experts"`
} `json:"text_config"`
}
func newGemma4ImportTransform(modelDir string, _ sourceModelConfig) (tensorImportTransform, error) {
data, err := os.ReadFile(filepath.Join(modelDir, "config.json"))
if err != nil {
return gemma4ImportTransform{}, nil //nolint:nilerr // fallback to no heuristic
}
var cfg gemma4Config
if err := json.Unmarshal(data, &cfg); err != nil {
return gemma4ImportTransform{}, nil //nolint:nilerr // fallback to no heuristic
}
numLayers := cfg.NumHiddenLayers
if numLayers == 0 {
numLayers = cfg.TextConfig.NumHiddenLayers
}
numExperts := cfg.NumExperts
if numExperts == 0 {
numExperts = cfg.TextConfig.NumExperts
}
return gemma4ImportTransform{numLayers: numLayers, numExperts: numExperts}, nil
}
func (t gemma4ImportTransform) skipTensor(name string) bool {
return false
}
// layerIndexRe extracts the layer index from tensor names like
// "model.language_model.layers.5.self_attn.v_proj.weight" or
// "model.language_model.layers.5.moe.experts.42.down_proj.weight"
var layerIndexRe = regexp.MustCompile(`\.layers\.(\d+)\.`)
// useMoreBits returns true for layers where quantization-sensitive tensors
// should use higher precision: the first and last 1/8 of layers (which handle
// input grounding and final output refinement), plus every 3rd layer in between
// to limit error accumulation through the residual stream.
func useMoreBits(layerIdx, numLayers int) bool {
return layerIdx < numLayers/8 ||
layerIdx >= 7*numLayers/8 ||
(layerIdx-numLayers/8)%3 == 2
}
func (t gemma4ImportTransform) quantizationType(name string, shape []int32, quantize string) string {
quantNorm := normalizeQuantType(quantize)
// Embedding: quantize to 8-bit variant for bandwidth efficiency.
// The embedding serves double duty: lookup (via QuantizedEmbedding) and
// lm_head projection (via AsLinear). Using 8-bit matches GGUF Q6_K quality
// (strictly higher at 8 bpw vs 6.5 bpw) while saving ~2.8 GB on 31B vs bf16.
if isEmbedTokensWeight(name) {
switch quantNorm {
case "int4", "int8":
if isAligned(shape, "int8") {
return "int8"
}
case "mxfp4", "nvfp4", "mxfp8":
if isAligned(shape, "mxfp8") {
return "mxfp8"
}
}
if isAligned(shape, quantNorm) {
return quantNorm
}
return ""
}
// MoE router logits choose the top-k expert set. Quantization noise here
// can flip expert selection, after which downstream activations diverge
// sharply. The tensor is small, so leave it in source precision.
if isGemma4RouterProjection(name) {
return ""
}
// Mixed-precision quantization: sensitive tensors get higher precision.
//
// Value projections (v_proj) directly determine attention output quality.
// Down projections (down_proj) are the final MLP output and errors there
// propagate directly to the residual stream. Both benefit from higher
// precision at early layers, late layers, and periodically in between
// (the "useMoreBits" heuristic).
//
// For int4: promote → int8 (same affine family, GatherQMM compatible).
// For mxfp4/nvfp4: promote → mxfp8. MLX quantized_matmul handles mixed
// nvfp4+mxfp8 modes within the same model — each tensor carries its own
// quant metadata and the kernel dispatches per-tensor.
if t.numLayers > 0 {
layerIdx := -1
if m := layerIndexRe.FindStringSubmatch(name); m != nil {
if idx, err := strconv.Atoi(m[1]); err == nil {
layerIdx = idx
}
}
// Determine promotion target for sensitive tensors.
// "int8" = int4 base → int8 (affine family)
// "mxfp8" = mxfp4/nvfp4 base → mxfp8
// "" = no promotion (int8/mxfp8, already 8-bit)
promote := ""
switch quantNorm {
case "int4":
promote = "int8"
case "mxfp4", "nvfp4":
promote = "mxfp8"
}
// Only apply to language model tensors — audio/vision tower tensors
// should pass through to GetTensorQuantization which skips them.
isModelTensor := !strings.Contains(name, "audio_tower") &&
!strings.Contains(name, "vision_tower")
isSensitive := isModelTensor &&
(strings.Contains(name, ".v_proj") || strings.Contains(name, "down_proj"))
isSensitiveK := isModelTensor && strings.Contains(name, "k_proj")
if promote != "" && (isSensitive || isSensitiveK) {
shouldPromote := false
// 8-expert models: v_proj and k_proj share very few KV heads,
// so quantization errors are amplified. Always promote.
if t.numExperts == 8 && (strings.Contains(name, ".v_proj") || isSensitiveK) {
shouldPromote = true
}
// Layer-position heuristic for v_proj and down_proj.
if isSensitive && layerIdx >= 0 && useMoreBits(layerIdx, t.numLayers) {
shouldPromote = true
}
if shouldPromote && isAligned(shape, promote) {
return promote
}
// Sensitive tensor at a non-promoted layer: use base quant type.
// Return directly to bypass GetTensorQuantization's uniform
// promotion — the layer-position heuristic is authoritative here.
if !isAligned(shape, quantNorm) {
return ""
}
return quantNorm
}
}
return GetTensorQuantization(name, shape, quantize)
}
// isEmbedTokensWeight returns true for the main token embedding weight.
func isEmbedTokensWeight(name string) bool {
return strings.HasSuffix(name, "embed_tokens.weight") &&
!strings.Contains(name, "per_layer")
}
func isGemma4RouterProjection(name string) bool {
return strings.HasSuffix(name, ".router.proj.weight") &&
!strings.Contains(name, "audio_tower") &&
!strings.Contains(name, "vision_tower")
}
func (t gemma4ImportTransform) transformTensor(td *safetensors.TensorData) ([]*safetensors.TensorData, error) {
if td == nil {
return nil, nil
}
// Split pre-stacked MoE expert tensors [N, out, in] into per-expert
// [out, in] tensors so they go through the standard expert packing and
// quantization flow (ExpertGroupPrefix matching, per-expert quantize).
if isGemma4StackedMoETensor(td.Name, td.Shape) {
return splitStackedMoETensor(td)
}
return []*safetensors.TensorData{td}, nil
}
// isGemma4StackedMoETensor checks if this is a pre-stacked MoE expert weight.
// Gemma 4 HF weights come in two layouts depending on the model version:
// - Older: model.language_model.layers.N.moe.{gate,up,down}_proj [experts, dim1, dim2]
// - Newer: model.language_model.layers.N.experts.{gate_up,down}_proj [experts, dim1, dim2]
//
// The newer layout has gate+up already fused. We keep it fused (no splitting)
// so the tensors flow through the standard expert packing and quantization path.
func isGemma4StackedMoETensor(name string, shape []int32) bool {
if len(shape) != 3 {
return false
}
if strings.Contains(name, ".moe.") || strings.Contains(name, ".experts.") {
return strings.HasSuffix(name, "_proj") || strings.HasSuffix(name, "_proj.weight")
}
return false
}
// splitStackedMoETensor splits a [N, out, in] stacked expert tensor into
// N individual [out, in] tensors named with the per-expert convention that
// ExpertGroupPrefix expects: prefix.moe.experts.{E}.{proj}.weight
func splitStackedMoETensor(td *safetensors.TensorData) ([]*safetensors.TensorData, error) {
raw, err := io.ReadAll(td.Reader())
if err != nil {
return nil, fmt.Errorf("failed to read tensor %s: %w", td.Name, err)
}
numExperts := int(td.Shape[0])
rows := int(td.Shape[1]) // out_features in HF layout
cols := int(td.Shape[2]) // in_features in HF layout
elemSize, err := DTypeSize(td.Dtype)
if err != nil {
return nil, fmt.Errorf("failed to get dtype size for %s: %w", td.Dtype, err)
}
perExpertBytes := rows * cols * elemSize
if len(raw) != numExperts*perExpertBytes {
return nil, fmt.Errorf("tensor %s: raw byte length %d does not match shape %v and dtype %s",
td.Name, len(raw), td.Shape, td.Dtype)
}
// Determine the per-expert name pattern.
// Two source layouts:
// Old: model.language_model.layers.N.moe.gate_proj
// -> model.language_model.layers.N.moe.experts.E.gate_proj.weight
// New: model.language_model.layers.N.experts.gate_up_proj
// -> model.language_model.layers.N.moe.experts.E.gate_up_proj.weight
baseName := td.Name
baseName = strings.TrimSuffix(baseName, ".weight")
lastDot := strings.LastIndex(baseName, ".")
if lastDot < 0 {
return nil, fmt.Errorf("tensor %s: unexpected name format", td.Name)
}
parentPrefix := baseName[:lastDot] // "...layers.N.moe" or "...layers.N.experts"
projName := baseName[lastDot+1:] // "gate_proj" or "gate_up_proj"
// Normalize: if parent already ends with ".experts", use the grandparent + ".moe"
// so we get a consistent "layers.N.moe.experts.E" pattern.
var moePrefix string
if cut, ok := strings.CutSuffix(parentPrefix, ".experts"); ok {
moePrefix = cut + ".moe"
} else {
moePrefix = parentPrefix
}
transposedShape := []int32{td.Shape[1], td.Shape[2]}
results := make([]*safetensors.TensorData, numExperts)
for e := range numExperts {
expertName := fmt.Sprintf("%s.experts.%d.%s.weight", moePrefix, e, projName)
start := e * perExpertBytes
end := start + perExpertBytes
results[e] = safetensors.NewTensorDataFromBytes(expertName, td.Dtype, transposedShape, raw[start:end])
}
return results, nil
}

196
x/create/gemma4_test.go Normal file
View File

@@ -0,0 +1,196 @@
package create
import (
"testing"
)
func TestGemma4QuantizationType(t *testing.T) {
// 26B MoE: 30 layers, 128 experts
transform26B := gemma4ImportTransform{numLayers: 30, numExperts: 128}
// 8-expert model (hypothetical)
transform8E := gemma4ImportTransform{numLayers: 30, numExperts: 8}
aligned := []int32{2816, 2816} // divisible by 64 (int4/int8 group size) and 16 (nvfp4)
tests := []struct {
name string
transform gemma4ImportTransform
tensor string
shape []int32
quantize string
want string
}{
// === embed_tokens: quantize to 8-bit variant (serves as both embed and lm_head) ===
{"embed_tokens int4", transform26B, "model.embed_tokens.weight", aligned, "int4", "int8"},
{"embed_tokens nvfp4", transform26B, "model.embed_tokens.weight", aligned, "nvfp4", "mxfp8"},
{"embed_tokens mxfp4", transform26B, "model.embed_tokens.weight", aligned, "mxfp4", "mxfp8"},
{"embed_tokens int8", transform26B, "model.embed_tokens.weight", aligned, "int8", "int8"},
{"embed_tokens mxfp8", transform26B, "model.embed_tokens.weight", aligned, "mxfp8", "mxfp8"},
// === v_proj: layer-position heuristic for int4/nvfp4 ===
// Layer 0 is in first 1/8 (30/8=3) → promoted
{"v_proj int4 promoted layer", transform26B, "model.layers.0.self_attn.v_proj.weight", aligned, "int4", "int8"},
// Layer 4 is NOT in useMoreBits → base quant
{"v_proj int4 non-promoted layer", transform26B, "model.layers.4.self_attn.v_proj.weight", aligned, "int4", "int4"},
// Layer 29 is in last 1/8 → promoted
{"v_proj int4 last layer promoted", transform26B, "model.layers.29.self_attn.v_proj.weight", aligned, "int4", "int8"},
// nvfp4: promote to mxfp8 (cross-family, validated by MLX quantized_matmul)
{"v_proj nvfp4 promoted layer", transform26B, "model.layers.0.self_attn.v_proj.weight", aligned, "nvfp4", "mxfp8"},
{"v_proj nvfp4 non-promoted layer", transform26B, "model.layers.4.self_attn.v_proj.weight", aligned, "nvfp4", "nvfp4"},
// mxfp4: promoted to mxfp8 at promoted layers (same mxfp family)
{"v_proj mxfp4 promoted layer", transform26B, "model.layers.0.self_attn.v_proj.weight", aligned, "mxfp4", "mxfp8"},
{"v_proj mxfp4 non-promoted layer", transform26B, "model.layers.4.self_attn.v_proj.weight", aligned, "mxfp4", "mxfp4"},
// int8/mxfp8: no promotion (already 8-bit)
{"v_proj int8 base", transform26B, "model.layers.0.self_attn.v_proj.weight", aligned, "int8", "int8"},
{"v_proj mxfp8 base", transform26B, "model.layers.0.self_attn.v_proj.weight", aligned, "mxfp8", "mxfp8"},
// === down_proj (dense MLP): same heuristic as v_proj ===
{"dense down_proj int4 promoted", transform26B, "model.layers.0.mlp.down_proj.weight", aligned, "int4", "int8"},
{"dense down_proj int4 non-promoted", transform26B, "model.layers.4.mlp.down_proj.weight", aligned, "int4", "int4"},
{"dense down_proj nvfp4 promoted", transform26B, "model.layers.0.mlp.down_proj.weight", aligned, "nvfp4", "mxfp8"},
{"dense down_proj nvfp4 non-promoted", transform26B, "model.layers.4.mlp.down_proj.weight", aligned, "nvfp4", "nvfp4"},
{"dense down_proj mxfp4 promoted", transform26B, "model.layers.0.mlp.down_proj.weight", aligned, "mxfp4", "mxfp8"},
{"dense down_proj mxfp4 non-promoted", transform26B, "model.layers.4.mlp.down_proj.weight", aligned, "mxfp4", "mxfp4"},
// === Expert down_proj: int4→int8, nvfp4→nvfp8 at promoted layers ===
{"expert down_proj int4 promoted", transform26B, "model.layers.0.moe.experts.42.down_proj.weight", aligned, "int4", "int8"},
{"expert down_proj int4 non-promoted", transform26B, "model.layers.4.moe.experts.42.down_proj.weight", aligned, "int4", "int4"},
// nvfp4 experts: promote to mxfp8 (all experts at a layer get same treatment,
// so GatherQMM sees uniform quant per projection per layer)
{"expert down_proj nvfp4 promoted layer", transform26B, "model.layers.0.moe.experts.42.down_proj.weight", aligned, "nvfp4", "mxfp8"},
{"expert down_proj nvfp4 non-promoted layer", transform26B, "model.layers.4.moe.experts.42.down_proj.weight", aligned, "nvfp4", "nvfp4"},
// mxfp4 experts: promote to mxfp8 (same mxfp family, GatherQMM compatible)
{"expert down_proj mxfp4 promoted layer", transform26B, "model.layers.0.moe.experts.42.down_proj.weight", aligned, "mxfp4", "mxfp8"},
{"expert down_proj mxfp4 non-promoted layer", transform26B, "model.layers.4.moe.experts.42.down_proj.weight", aligned, "mxfp4", "mxfp4"},
// === Expert gate_up_proj: always base quant (not a sensitive tensor) ===
{"expert gate_up int4", transform26B, "model.layers.0.moe.experts.42.gate_up_proj.weight", aligned, "int4", "int4"},
{"expert gate_up nvfp4", transform26B, "model.layers.0.moe.experts.42.gate_up_proj.weight", aligned, "nvfp4", "nvfp4"},
{"expert gate_up mxfp4", transform26B, "model.layers.0.moe.experts.42.gate_up_proj.weight", aligned, "mxfp4", "mxfp4"},
// === Router projection: expert selection is sensitive; keep source precision ===
{"router proj int4", transform26B, "model.layers.0.router.proj.weight", aligned, "int4", ""},
{"router proj nvfp4", transform26B, "model.layers.0.router.proj.weight", aligned, "nvfp4", ""},
{"router proj mxfp4", transform26B, "model.layers.0.router.proj.weight", aligned, "mxfp4", ""},
// === k_proj: promoted only for 8-expert models ===
{"k_proj 128 experts int4", transform26B, "model.layers.0.self_attn.k_proj.weight", aligned, "int4", "int4"},
{"k_proj 8 experts int4", transform8E, "model.layers.0.self_attn.k_proj.weight", aligned, "int4", "int8"},
{"k_proj 8 experts nvfp4", transform8E, "model.layers.0.self_attn.k_proj.weight", aligned, "nvfp4", "mxfp8"},
{"k_proj 8 experts mxfp4", transform8E, "model.layers.0.self_attn.k_proj.weight", aligned, "mxfp4", "mxfp8"},
// === q_proj, o_proj, gate_proj, up_proj: always base quant ===
{"q_proj int4", transform26B, "model.layers.0.self_attn.q_proj.weight", aligned, "int4", "int4"},
{"o_proj int4", transform26B, "model.layers.0.self_attn.o_proj.weight", aligned, "int4", "int4"},
{"gate_proj int4", transform26B, "model.layers.0.mlp.gate_proj.weight", aligned, "int4", "int4"},
{"up_proj int4", transform26B, "model.layers.0.mlp.up_proj.weight", aligned, "int4", "int4"},
// === Non-quantizable tensors: always bf16 ===
{"embed_tokens per_layer skip", transform26B, "model.embed_tokens_per_layer.weight", aligned, "int4", ""},
{"norm", transform26B, "model.layers.0.input_layernorm.weight", []int32{2816}, "int4", ""},
{"router scale", transform26B, "model.layers.0.router.scale", []int32{2816}, "int4", ""},
// === Audio/vision tower tensors: must pass through unquantized for all quant types ===
// These contain .v_proj and down_proj but should NOT be intercepted by
// the sensitive-tensor promotion logic.
{"audio norm int4", transform26B, "model.audio_tower.subsample_conv_projection.layer0.norm.weight", []int32{128}, "int4", ""},
{"audio norm nvfp4", transform26B, "model.audio_tower.subsample_conv_projection.layer0.norm.weight", []int32{128}, "nvfp4", ""},
{"audio norm int8", transform26B, "model.audio_tower.subsample_conv_projection.layer0.norm.weight", []int32{128}, "int8", ""},
{"audio norm mxfp8", transform26B, "model.audio_tower.subsample_conv_projection.layer0.norm.weight", []int32{128}, "mxfp8", ""},
{"audio conv int4", transform26B, "model.audio_tower.subsample_conv_projection.layer0.conv.weight", []int32{128, 1, 3, 3}, "int4", ""},
{"audio conv nvfp4", transform26B, "model.audio_tower.subsample_conv_projection.layer0.conv.weight", []int32{128, 1, 3, 3}, "nvfp4", ""},
{"audio linear int4", transform26B, "model.audio_tower.subsample_conv_projection.input_proj_linear.weight", aligned, "int4", ""},
{"audio linear nvfp4", transform26B, "model.audio_tower.subsample_conv_projection.input_proj_linear.weight", aligned, "nvfp4", ""},
// Audio tower v_proj — must NOT be promoted despite containing .v_proj
{"audio v_proj int4", transform26B, "model.audio_tower.layers.0.self_attn.v_proj.linear.weight", aligned, "int4", ""},
{"audio v_proj nvfp4", transform26B, "model.audio_tower.layers.0.self_attn.v_proj.linear.weight", aligned, "nvfp4", ""},
// Vision tower v_proj — vision tower IS quantized (unlike audio tower),
// but not intercepted by gemma4's layer-position heuristic.
// Falls through to GetTensorQuantization which applies uniform promotion.
{"vision v_proj int4", transform26B, "model.vision_tower.encoder.layers.0.self_attn.v_proj.linear.weight", aligned, "int4", "int8"},
{"vision v_proj nvfp4", transform26B, "model.vision_tower.encoder.layers.0.self_attn.v_proj.linear.weight", aligned, "nvfp4", "nvfp4"},
// Audio tower down_proj
{"audio down_proj int4", transform26B, "model.audio_tower.layers.0.mlp.down_proj.linear.weight", aligned, "int4", ""},
{"audio down_proj nvfp4", transform26B, "model.audio_tower.layers.0.mlp.down_proj.linear.weight", aligned, "nvfp4", ""},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got := tt.transform.quantizationType(tt.tensor, tt.shape, tt.quantize)
if got != tt.want {
t.Errorf("quantizationType(%q, %v, %q) = %q, want %q",
tt.tensor, tt.shape, tt.quantize, got, tt.want)
}
})
}
}
func TestUseMoreBits(t *testing.T) {
// 30 layers: first 1/8 = layers 0-2, last 1/8 = layers 27-29
// In between: every 3rd from offset (i - n/8) % 3 == 2
n := 30
promoted := map[int]bool{}
for i := range n {
if useMoreBits(i, n) {
promoted[i] = true
}
}
// First 1/8 (30/8 = 3): layers 0, 1, 2
for _, i := range []int{0, 1, 2} {
if !promoted[i] {
t.Errorf("layer %d should be promoted (first 1/8)", i)
}
}
// Last 1/8: layers 26, 27, 28, 29 (>= 7*30/8 = 26)
for _, i := range []int{26, 27, 28, 29} {
if !promoted[i] {
t.Errorf("layer %d should be promoted (last 1/8)", i)
}
}
// Some middle layers should NOT be promoted
for _, i := range []int{3, 4, 6, 7} {
if promoted[i] {
t.Errorf("layer %d should NOT be promoted", i)
}
}
// Layer 5 should be promoted: (5 - 3) % 3 == 2
if !promoted[5] {
t.Errorf("layer 5 should be promoted (periodic)")
}
}
func TestIsGemma4StackedMoETensor(t *testing.T) {
tests := []struct {
label string
tensorName string
shape []int32
want bool
}{
// New-style: .experts.gate_up_proj
{"experts gate_up_proj 3D", "model.layers.0.experts.gate_up_proj", []int32{128, 1408, 2816}, true},
{"experts down_proj 3D", "model.layers.0.experts.down_proj", []int32{128, 2816, 704}, true},
// Old-style: .moe.gate_proj
{"moe gate_proj 3D", "model.layers.0.moe.gate_proj", []int32{128, 2112, 2816}, true},
{"moe down_proj 3D", "model.layers.0.moe.down_proj.weight", []int32{128, 2816, 2112}, true},
// Not stacked: 2D
{"2D weight", "model.layers.0.experts.gate_up_proj", []int32{1408, 2816}, false},
// Not expert
{"non-expert 3D", "model.layers.0.mlp.gate_proj", []int32{3, 2816, 2816}, false},
// Not a projection
{"expert non-proj", "model.layers.0.experts.scale", []int32{128, 1, 1}, false},
}
for _, tt := range tests {
t.Run(tt.label, func(t *testing.T) {
got := isGemma4StackedMoETensor(tt.tensorName, tt.shape)
if got != tt.want {
t.Errorf("isGemma4StackedMoETensor(%q, %v) = %v, want %v",
tt.tensorName, tt.shape, got, tt.want)
}
})
}
}

View File

@@ -115,36 +115,7 @@ func (s *Server) Load(ctx context.Context, _ ml.SystemInfo, gpus []ml.DeviceInfo
// Spawn subprocess: ollama runner --imagegen-engine --model <path> --port <port>
cmd := exec.Command(exe, "runner", "--imagegen-engine", "--model", s.modelName, "--port", strconv.Itoa(port))
cmd.Env = os.Environ()
// On Linux, set LD_LIBRARY_PATH to include MLX library directories
if runtime.GOOS == "linux" {
// Build library paths: start with LibOllamaPath, then add any mlx_* subdirectories
libraryPaths := []string{ml.LibOllamaPath}
if mlxDirs, err := filepath.Glob(filepath.Join(ml.LibOllamaPath, "mlx_*")); err == nil {
libraryPaths = append(libraryPaths, mlxDirs...)
}
// Append existing LD_LIBRARY_PATH if set
if existingPath, ok := os.LookupEnv("LD_LIBRARY_PATH"); ok {
libraryPaths = append(libraryPaths, filepath.SplitList(existingPath)...)
}
pathEnvVal := strings.Join(libraryPaths, string(filepath.ListSeparator))
// Update or add LD_LIBRARY_PATH in cmd.Env
found := false
for i := range cmd.Env {
if strings.HasPrefix(cmd.Env[i], "LD_LIBRARY_PATH=") {
cmd.Env[i] = "LD_LIBRARY_PATH=" + pathEnvVal
found = true
break
}
}
if !found {
cmd.Env = append(cmd.Env, "LD_LIBRARY_PATH="+pathEnvVal)
}
slog.Debug("mlx subprocess library path", "LD_LIBRARY_PATH", pathEnvVal)
}
configureMLXSubprocessEnv(cmd, ml.LibraryPaths(gpus))
s.cmd = cmd
@@ -200,6 +171,53 @@ func (s *Server) Ping(ctx context.Context) error {
return nil
}
func mlxLibraryPathEnv() string {
switch runtime.GOOS {
case "windows":
return "PATH"
case "darwin":
return "DYLD_LIBRARY_PATH"
default:
return "LD_LIBRARY_PATH"
}
}
func configureMLXSubprocessEnv(cmd *exec.Cmd, libraryPaths []string) {
if len(libraryPaths) == 0 {
return
}
// Search order for the imagegen runner is:
// 1. bundled lib/ollama root
// 2. backend-specific library dirs selected during GPU discovery
// 3. any existing caller-provided library path values
pathEnv := mlxLibraryPathEnv()
pathEnvPaths := append([]string{}, libraryPaths...)
if existingPath, ok := os.LookupEnv(pathEnv); ok {
pathEnvPaths = append(pathEnvPaths, filepath.SplitList(existingPath)...)
}
setSubprocessEnv(cmd, pathEnv, strings.Join(pathEnvPaths, string(filepath.ListSeparator)))
slog.Debug("mlx subprocess library path", pathEnv, strings.Join(pathEnvPaths, string(filepath.ListSeparator)))
ollamaLibraryPaths := append([]string{}, libraryPaths...)
if existingPath, ok := os.LookupEnv("OLLAMA_LIBRARY_PATH"); ok {
ollamaLibraryPaths = append(ollamaLibraryPaths, filepath.SplitList(existingPath)...)
}
setSubprocessEnv(cmd, "OLLAMA_LIBRARY_PATH", strings.Join(ollamaLibraryPaths, string(filepath.ListSeparator)))
slog.Debug("mlx subprocess library path", "OLLAMA_LIBRARY_PATH", strings.Join(ollamaLibraryPaths, string(filepath.ListSeparator)))
}
func setSubprocessEnv(cmd *exec.Cmd, key, value string) {
for i := range cmd.Env {
name, _, ok := strings.Cut(cmd.Env[i], "=")
if ok && strings.EqualFold(name, key) {
cmd.Env[i] = key + "=" + value
return
}
}
cmd.Env = append(cmd.Env, key+"="+value)
}
// getLastErr returns the last stderr line.
func (s *Server) getLastErr() string {
s.lastErrLock.Lock()

View File

@@ -254,8 +254,23 @@ func (c *RotatingKVCache) concat(keys, values *mlx.Array) (newK *mlx.Array, newV
mlx.Pin(c.keys, c.values)
} else {
if c.idx < c.keys.Dim(2) {
c.keys.Set(c.keys.Slice(mlx.Slice(), mlx.Slice(), mlx.Slice(0, c.idx), mlx.Slice()))
c.values.Set(c.values.Slice(mlx.Slice(), mlx.Slice(), mlx.Slice(0, c.idx), mlx.Slice()))
if c.offset <= c.maxSize {
// Not yet wrapped: slots [c.idx, Dim) are grow padding
// or stale post-rewind data, not live window content.
c.keys.Set(c.keys.Slice(mlx.Slice(), mlx.Slice(), mlx.Slice(0, c.idx), mlx.Slice()))
c.values.Set(c.values.Slice(mlx.Slice(), mlx.Slice(), mlx.Slice(0, c.idx), mlx.Slice()))
} else {
// Wrapped: logical order is slots[idx..Dim) then slots[0..idx).
// Linearize so the trim + concat below operate on contiguous
// positions and preserve the last (maxSize - 1) old tokens.
tailK := c.keys.Slice(mlx.Slice(), mlx.Slice(), mlx.Slice(c.idx, c.keys.Dim(2)), mlx.Slice())
tailV := c.values.Slice(mlx.Slice(), mlx.Slice(), mlx.Slice(c.idx, c.values.Dim(2)), mlx.Slice())
headK := c.keys.Slice(mlx.Slice(), mlx.Slice(), mlx.Slice(0, c.idx), mlx.Slice())
headV := c.values.Slice(mlx.Slice(), mlx.Slice(), mlx.Slice(0, c.idx), mlx.Slice())
c.keys.Set(tailK.Concatenate(2, headK))
c.values.Set(tailV.Concatenate(2, headV))
c.idx = c.keys.Dim(2)
}
}
// Trim to max_size to maintain sliding window
@@ -322,9 +337,10 @@ func (c *RotatingKVCache) State() []*mlx.Array {
if c.keys == nil || c.values == nil {
return nil
}
liveLen := min(c.offset, c.keys.Dim(2))
return []*mlx.Array{
c.keys.Slice(mlx.Slice(), mlx.Slice(), mlx.Slice(0, c.offset), mlx.Slice()),
c.values.Slice(mlx.Slice(), mlx.Slice(), mlx.Slice(0, c.offset), mlx.Slice()),
c.keys.Slice(mlx.Slice(), mlx.Slice(), mlx.Slice(0, liveLen), mlx.Slice()),
c.values.Slice(mlx.Slice(), mlx.Slice(), mlx.Slice(0, liveLen), mlx.Slice()),
}
}

View File

@@ -0,0 +1,338 @@
package cache
import (
"testing"
"github.com/ollama/ollama/x/mlxrunner/mlx"
)
// singleTokenKV and multiTokenKV fabricate [B=1, H=1, L, D=2] key/value
// tensors whose channel value is the token id, so stateIDs can recover
// which ids survived in the cache.
func singleTokenKV(id float32) (*mlx.Array, *mlx.Array) {
k := mlx.FromValues([]float32{id, id}, 1, 1, 1, 2)
v := mlx.FromValues([]float32{id, id}, 1, 1, 1, 2)
return k, v
}
func multiTokenKV(ids []float32) (*mlx.Array, *mlx.Array) {
data := make([]float32, 0, 2*len(ids))
for _, id := range ids {
data = append(data, id, id)
}
k := mlx.FromValues(data, 1, 1, len(ids), 2)
v := mlx.FromValues(data, 1, 1, len(ids), 2)
return k, v
}
// stateIDs returns the ids currently in the cache in slot order (logical
// after a concat, physical/rotated after a single-token update).
func stateIDs(t *testing.T, c *RotatingKVCache) []float32 {
t.Helper()
state := c.State()
if state == nil {
return nil
}
mlx.Eval(state[0])
flat := state[0].Floats()
n := state[0].Dim(2)
out := make([]float32, n)
for i := range n {
out[i] = flat[i*2]
}
return out
}
func equalSlice(a, b []float32) bool {
if len(a) != len(b) {
return false
}
for i := range a {
if a[i] != b[i] {
return false
}
}
return true
}
func feedMulti(c *RotatingKVCache, startID float32, n int) float32 {
ids := make([]float32, n)
for i := range ids {
ids[i] = startID + float32(i)
}
k, v := multiTokenKV(ids)
c.Update(k, v)
return startID + float32(n)
}
func feedSingle(c *RotatingKVCache, id float32) {
k, v := singleTokenKV(id)
c.Update(k, v)
}
// TestRotatingKVCacheConcatMidRotationPreservesContext: after the buffer
// has wrapped, a multi-token concat must keep the (maxSize-1) most recent
// pre-existing tokens in logical order so the first Q of the new batch
// has a full sliding window.
func TestRotatingKVCacheConcatMidRotationPreservesContext(t *testing.T) {
skipIfNoMLX(t)
const window = 4
c := NewRotatingKVCache(window)
nextID := feedMulti(c, 1, 3)
for range 6 {
feedSingle(c, nextID)
nextID++
}
if c.Offset() != 9 {
t.Fatalf("setup: offset=%d want 9", c.Offset())
}
if c.idx >= c.maxSize {
t.Fatalf("setup: expected mid-rotation idx (<%d), got %d", c.maxSize, c.idx)
}
feedMulti(c, 10, 2)
got := stateIDs(t, c)
want := []float32{7, 8, 9, 10, 11}
if !equalSlice(got, want) {
t.Fatalf("post-concat window=%v want %v", got, want)
}
if c.Offset() != 11 {
t.Fatalf("offset=%d want 11", c.Offset())
}
}
// TestRotatingKVCacheConcatAlignedInvariant: with an aligned buffer
// (c.idx == Dim), an L>1 concat keeps the last (maxSize-1) pre-existing
// tokens plus the full new batch. This is the chunked-prefill contract
// x/mlxrunner/pipeline.go relies on.
func TestRotatingKVCacheConcatAlignedInvariant(t *testing.T) {
skipIfNoMLX(t)
const window = 4
c := NewRotatingKVCache(window)
// Chunk 1 fills past maxSize, leaving Dim == maxSize aligned.
feedMulti(c, 1, 6)
// Chunk 2: the buffer is intentionally oversized to (maxSize-1) + L
// so the first new Q has its full window in scope for this forward.
feedMulti(c, 7, 3)
got := stateIDs(t, c)
want := []float32{4, 5, 6, 7, 8, 9}
if !equalSlice(got, want) {
t.Fatalf("post-chunk-2 buffer=%v want %v", got, want)
}
// The next decode trims oversize back to maxSize; order may be
// physical (rotated), so check as a set.
feedSingle(c, 10)
got = stateIDs(t, c)
if len(got) != window {
t.Fatalf("post-decode Dim=%d want %d", len(got), window)
}
seen := map[float32]bool{}
for _, v := range got {
seen[v] = true
}
for _, w := range []float32{7, 8, 9, 10} {
if !seen[w] {
t.Fatalf("post-decode window missing %v (got %v)", w, got)
}
}
}
// TestRotatingKVCacheConcatAfterDecodeGrowsBuffer: update() grows the
// underlying buffer by `step` slots via mlx.Zeros before writing, so
// after one decode on a short prefill c.idx < Dim even though the cache
// has not wrapped. Those trailing slots are zero padding and must not
// be pulled back into the live window on the next concat.
func TestRotatingKVCacheConcatAfterDecodeGrowsBuffer(t *testing.T) {
skipIfNoMLX(t)
const window = 512
c := NewRotatingKVCache(window)
feedMulti(c, 1, 3)
feedSingle(c, 4)
feedMulti(c, 5, 3)
got := stateIDs(t, c)
want := []float32{1, 2, 3, 4, 5, 6, 7}
if !equalSlice(got, want) {
t.Fatalf("growing-buffer concat=%v want %v", got, want)
}
}
// TestRotatingKVCacheConcatAfterLiveRewind: x/mlxrunner/cache.go calls
// Restore(nil, target) between conversation turns to rewind the cache to
// the matched prefix. Restore moves c.offset/c.idx without trimming the
// underlying buffer, so slots [c.idx, Dim) still hold stale pre-rewind
// tokens. A subsequent concat must drop those, not treat them as wrapped
// window content.
func TestRotatingKVCacheConcatAfterLiveRewind(t *testing.T) {
skipIfNoMLX(t)
const window = 8
c := NewRotatingKVCache(window)
// Grow the buffer to exactly maxSize without wrapping.
feedMulti(c, 1, 2)
for id := float32(3); id <= 8; id++ {
feedSingle(c, id)
}
if c.Offset() != window {
t.Fatalf("setup: offset=%d want %d", c.Offset(), window)
}
if !c.Restore(nil, 2) {
t.Fatalf("live rewind to 2 failed")
}
if c.Offset() != 2 {
t.Fatalf("post-rewind offset=%d want 2", c.Offset())
}
feedMulti(c, 9, 3)
got := stateIDs(t, c)
want := []float32{1, 2, 9, 10, 11}
if !equalSlice(got, want) {
t.Fatalf("post-rewind concat=%v want %v", got, want)
}
if c.Offset() != 5 {
t.Fatalf("offset=%d want 5", c.Offset())
}
}
// TestRotatingKVCacheConcatGrowingBuffer: when oldLen < maxSize the trim
// formula drops to non-positive and all pre-existing tokens are kept.
func TestRotatingKVCacheConcatGrowingBuffer(t *testing.T) {
skipIfNoMLX(t)
const window = 4
c := NewRotatingKVCache(window)
feedMulti(c, 1, 2)
feedMulti(c, 3, 2)
got := stateIDs(t, c)
want := []float32{1, 2, 3, 4}
if !equalSlice(got, want) {
t.Fatalf("growing buffer=%v want %v", got, want)
}
}
// TestRotatingKVCacheRunnerChunkedPrefill mirrors the
// x/mlxrunner/pipeline.go prefill loop: a long prompt fed through
// repeated L>1 Update() calls on a single cache. Scaled-down proxy for
// the Gemma 4 26B case (sliding_window=1024, prefillChunkSize=2048).
func TestRotatingKVCacheRunnerChunkedPrefill(t *testing.T) {
skipIfNoMLX(t)
const window = 4
c := NewRotatingKVCache(window)
feedMulti(c, 1, 8)
if c.Offset() != 8 {
t.Fatalf("chunk 1: offset=%d want 8", c.Offset())
}
feedMulti(c, 9, 8)
got := stateIDs(t, c)
want := []float32{6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16}
if !equalSlice(got, want) {
t.Fatalf("chunk 2: buffer=%v want %v", got, want)
}
feedMulti(c, 17, 4)
got = stateIDs(t, c)
want = []float32{14, 15, 16, 17, 18, 19, 20}
if !equalSlice(got, want) {
t.Fatalf("chunk 3: buffer=%v want %v", got, want)
}
// Decode trims oversize back to maxSize; order may be physical.
feedSingle(c, 21)
got = stateIDs(t, c)
if len(got) != window {
t.Fatalf("post-decode Dim=%d want %d", len(got), window)
}
seen := map[float32]bool{}
for _, v := range got {
seen[v] = true
}
for _, w := range []float32{18, 19, 20, 21} {
if !seen[w] {
t.Fatalf("post-decode window missing %v (got %v)", w, got)
}
}
}
// TestRotatingKVCacheMultiTurnChatSimulation walks a prefill → decode →
// prefill sequence and checks that each new prefill retains the last
// (maxSize-1) pre-existing tokens in logical order.
func TestRotatingKVCacheMultiTurnChatSimulation(t *testing.T) {
skipIfNoMLX(t)
const window = 4
c := NewRotatingKVCache(window)
nextID := feedMulti(c, 1, 2)
for range 5 {
feedSingle(c, nextID)
nextID++
}
if c.Offset() != 7 {
t.Fatalf("turn 1: offset=%d want 7", c.Offset())
}
feedMulti(c, nextID, 3)
nextID += 3
got := stateIDs(t, c)
want := []float32{5, 6, 7, 8, 9, 10}
if !equalSlice(got, want) {
t.Fatalf("turn 2 prefill buffer=%v want %v", got, want)
}
for range 4 {
feedSingle(c, nextID)
nextID++
}
if c.Offset() != 14 {
t.Fatalf("turn 2 decode: offset=%d want 14", c.Offset())
}
feedMulti(c, nextID, 2)
got = stateIDs(t, c)
want = []float32{12, 13, 14, 15, 16}
if !equalSlice(got, want) {
t.Fatalf("turn 3 prefill buffer=%v want %v", got, want)
}
}
// TestRotatingKVCacheOffsetTracking: Offset() is the monotonic logical
// token count through any mix of Update() calls — Gemma 4 uses
// donorEntry.Offset - L for the consumer's RoPE offset.
func TestRotatingKVCacheOffsetTracking(t *testing.T) {
skipIfNoMLX(t)
c := NewRotatingKVCache(4)
nextID := feedMulti(c, 1, 3)
if c.Offset() != 3 {
t.Fatalf("after prefill 3: offset=%d want 3", c.Offset())
}
for i := range 5 {
feedSingle(c, nextID)
nextID++
if c.Offset() != 3+i+1 {
t.Fatalf("after decode %d: offset=%d want %d", i, c.Offset(), 3+i+1)
}
}
nextID = feedMulti(c, nextID, 2)
if c.Offset() != 10 {
t.Fatalf("after turn-2 prefill: offset=%d want 10", c.Offset())
}
// L > maxSize concat.
feedMulti(c, nextID, 7)
if c.Offset() != 17 {
t.Fatalf("after large prefill: offset=%d want 17", c.Offset())
}
}

View File

@@ -2,6 +2,7 @@ package mlxrunner
import (
_ "github.com/ollama/ollama/x/models/gemma3"
_ "github.com/ollama/ollama/x/models/gemma4"
_ "github.com/ollama/ollama/x/models/glm4_moe_lite"
_ "github.com/ollama/ollama/x/models/llama"
_ "github.com/ollama/ollama/x/models/qwen3"

View File

@@ -1,21 +1,64 @@
package mlx
// #include "generated.h"
import "C"
import "math"
func GELUApprox(t *Array) *Array {
return t.Multiply(
FromValue[float32](0.5),
).Multiply(
t.Add(
t.Power(FromValue[float32](3.0)).Multiply(FromValue[float32](0.044715)),
).Multiply(
FromValue(float32(math.Sqrt(2 / math.Pi))),
).Tanh().Add(FromValue[float32](1.0)),
).AsType(t.DType())
}
var geluCoeff = float32(math.Sqrt(2 / math.Pi))
func SILU(t *Array) *Array {
return t.Multiply(t.Sigmoid()).AsType(t.DType())
}
// GELUApprox returns 0.5 * x * (1 + tanh(sqrt(2/pi) * (x + 0.044715 * x^3)))
// as a fused kernel.
var GELUApprox = Compile1(
"GELUApprox",
func(x *Array) *Array {
// Dtype-matched scalars avoid implicit upcasts on bf16 inputs.
dt := x.DType()
half := FromValue[float32](0.5).AsType(dt)
coeff := FromValue(geluCoeff).AsType(dt)
c := FromValue[float32](0.044715).AsType(dt)
one := FromValue[float32](1.0).AsType(dt)
// x^3 via x*x*x (avoids general Power which is slower).
x3 := x.Multiply(x).Multiply(x)
inner := x.Add(c.Multiply(x3))
tanh := coeff.Multiply(inner).Tanh()
return half.Multiply(x).Multiply(one.Add(tanh))
},
Shapeless(),
)
// SiLU returns a * sigmoid(a) as a fused kernel.
var SiLU = Compile1(
"SiLU",
func(a *Array) *Array {
return a.Multiply(a.Sigmoid())
},
Shapeless(),
)
// SwiGLU returns silu(gate) * up as a fused kernel.
var SwiGLU = Compile2(
"SwiGLU",
func(gate, up *Array) *Array {
return SiLU(gate).Multiply(up)
},
Shapeless(),
)
// GeGLU returns gelu_approx(gate) * up as a fused kernel. Matches mlx_lm's
// geglu, used by Gemma-family MLP and MoE paths.
var GeGLU = Compile2(
"GeGLU",
func(gate, up *Array) *Array {
return GELUApprox(gate).Multiply(up)
},
Shapeless(),
)
// LogitSoftcap returns tanh(x / cap) * cap as a fused kernel. Matches
// mlx_lm's logit_softcap. cap must have the same dtype as x.
var LogitSoftcap = Compile2(
"LogitSoftcap",
func(x, cap *Array) *Array {
return x.Divide(cap).Tanh().Multiply(cap)
},
Shapeless(),
)

View File

@@ -27,7 +27,11 @@ var arrays []*Array
func New(name string) *Array {
t := &Array{name: name}
arrays = append(arrays, t)
if tracing {
traceScratch = append(traceScratch, t)
} else {
arrays = append(arrays, t)
}
return t
}

192
x/mlxrunner/mlx/compile.go Normal file
View File

@@ -0,0 +1,192 @@
package mlx
// #include <stdlib.h>
// #include "generated.h"
//
// extern int closureCallback(mlx_vector_array* res, mlx_vector_array input, void* payload);
// extern void closureDestructor(void* payload);
import "C"
import (
"log/slog"
"runtime/cgo"
"sync"
"unsafe"
)
// CompileFunc is the signature of a function that can be compiled.
type CompileFunc func(inputs ...*Array) []*Array
// CompileOption configures Compile behavior.
type CompileOption func(*compileConfig)
type compileConfig struct {
shapeless bool
}
// Shapeless traces the function once against symbolic shapes so the compiled
// graph accepts any input shape afterwards. Without this option, MLX re-traces
// on each new (shape, dtype) combination and caches each specialization.
func Shapeless() CompileOption {
return func(c *compileConfig) { c.shapeless = true }
}
// Compile returns a compiled version of fn. When called during another
// compile's trace, fn is inlined directly so outer compiles can fuse through
// inner ones.
//
// Compiled functions must not have side effects outside of the function. Do
// not access data other than the arguments passed in (either Go data or MLX
// arrays) unless it is a constant.
func Compile(name string, fn CompileFunc, opts ...CompileOption) CompileFunc {
var cfg compileConfig
for _, o := range opts {
o(&cfg)
}
var closure C.mlx_closure
var once sync.Once
return func(inputs ...*Array) []*Array {
if tracing {
return fn(inputs...)
}
once.Do(func() {
payload := (*cgo.Handle)(C.malloc(C.size_t(unsafe.Sizeof(cgo.Handle(0)))))
*payload = cgo.NewHandle(fn)
src := C.mlx_closure_new_func_payload(
(*[0]byte)(C.closureCallback),
unsafe.Pointer(payload),
(*[0]byte)(C.closureDestructor),
)
defer C.mlx_closure_free(src)
closure = C.mlx_closure_new()
mlxCheck(name+": compile failed", func() C.int {
return C.mlx_compile(&closure, src, C.bool(cfg.shapeless))
})
})
inVec := C.mlx_vector_array_new()
defer C.mlx_vector_array_free(inVec)
for _, in := range inputs {
C.mlx_vector_array_append_value(inVec, in.ctx)
}
outVec := C.mlx_vector_array_new()
defer C.mlx_vector_array_free(outVec)
mlxCheck(name+": closure apply failed", func() C.int {
return C.mlx_closure_apply(&outVec, closure, inVec)
})
n := int(C.mlx_vector_array_size(outVec))
outputs := make([]*Array, n)
for i := range n {
outputs[i] = New(name)
C.mlx_vector_array_get(&outputs[i].ctx, outVec, C.size_t(i))
}
return outputs
}
}
// Compile1 compiles a unary function. See Compile.
func Compile1(name string, fn func(*Array) *Array, opts ...CompileOption) func(*Array) *Array {
cf := Compile(name, func(in ...*Array) []*Array {
return []*Array{fn(in[0])}
}, opts...)
return func(a *Array) *Array {
return cf(a)[0]
}
}
// Compile2 compiles a binary function. See Compile.
func Compile2(name string, fn func(*Array, *Array) *Array, opts ...CompileOption) func(*Array, *Array) *Array {
cf := Compile(name, func(in ...*Array) []*Array {
return []*Array{fn(in[0], in[1])}
}, opts...)
return func(a, b *Array) *Array {
return cf(a, b)[0]
}
}
// Compile3 compiles a ternary function. See Compile.
func Compile3(name string, fn func(*Array, *Array, *Array) *Array, opts ...CompileOption) func(*Array, *Array, *Array) *Array {
cf := Compile(name, func(in ...*Array) []*Array {
return []*Array{fn(in[0], in[1], in[2])}
}, opts...)
return func(a, b, c *Array) *Array {
return cf(a, b, c)[0]
}
}
// tracing is true while a compile callback is running. Since MLX is
// single-threaded at this level a plain Go bool suffices.
var tracing bool
// traceScratch collects arrays created during a compile trace so they can be
// freed as a group when the callback returns.
var traceScratch []*Array
//export closureCallback
func closureCallback(res *C.mlx_vector_array, input C.mlx_vector_array, payload unsafe.Pointer) (rc C.int) {
defer func() {
if r := recover(); r != nil {
slog.Error("mlx closure callback panicked", "panic", r)
rc = 1
}
}()
handle := *(*cgo.Handle)(payload)
fn := handle.Value().(CompileFunc)
// When tracing, we track all of the intermediates that are created and free them separately at the end of
// the process. This will give the effect of a single op - inputs are owned by the original caller (via
// the MLX layer) and outputs are transferred back to MLX to create a new Go side tensor.
if tracing {
panic("mlx: nested compile trace")
}
tracing = true
traceScratch = nil
defer func() {
for _, a := range traceScratch {
if a.pinned > 0 {
panic("mlx: traced array was pinned during compilation")
}
if a.Valid() {
C.mlx_array_free(a.ctx)
a.ctx.ctx = nil
}
}
tracing = false
traceScratch = nil
}()
n := int(C.mlx_vector_array_size(input))
inputs := make([]*Array, n)
for i := range n {
a := New("")
C.mlx_vector_array_get(&a.ctx, input, C.size_t(i))
inputs[i] = a
}
outputs := fn(inputs...)
var arrPtr *C.mlx_array
if len(outputs) > 0 {
handles := make([]C.mlx_array, len(outputs))
for i, out := range outputs {
handles[i] = out.ctx
}
arrPtr = &handles[0]
}
C.mlx_vector_array_set_data(res, arrPtr, C.size_t(len(outputs)))
return 0
}
//export closureDestructor
func closureDestructor(payload unsafe.Pointer) {
handle := *(*cgo.Handle)(payload)
handle.Delete()
C.free(payload)
}

View File

@@ -0,0 +1,147 @@
package mlx
import (
"testing"
)
func TestCompileFusion(t *testing.T) {
skipIfNoMLX(t)
// Compile fuses the ops inside a function body into a single kernel,
// eliminating intermediate buffers. Use a diamond-shaped graph where
// two branches must be materialized simultaneously without fusion,
// then compare peak memory against the compiled version which fuses
// everything into one kernel with no intermediates.
const n = 1024 * 1024 // 4MB per float32 array
data := make([]float32, n)
for i := range data {
data[i] = float32(i + 1)
}
// Diamond: both a*b and a+b must be live for the final multiply.
// Without fusion: peak includes both intermediates (~8MB extra).
// With fusion: single kernel, no intermediates.
body := func(a, b *Array) *Array {
return a.Multiply(b).Multiply(a.Add(b))
}
a := FromValues(data, n)
b := FromValues(data, n)
Pin(a, b)
defer Unpin(a, b)
// Compiled: ops fused into a single kernel.
EnableCompile()
fn := Compile2("diamond", body, Shapeless())
warm := fn(a, b)
Eval(warm)
Sweep()
ClearCache()
ResetPeakMemory()
y := fn(a, b)
Eval(y)
compiledPeak := PeakMemory()
Sweep()
// Uncompiled: ops evaluated individually, intermediates materialized.
ClearCache()
ResetPeakMemory()
z := body(a, b)
Eval(z)
uncompiledPeak := PeakMemory()
Sweep()
if compiledPeak == 0 && uncompiledPeak == 0 {
t.Skip("peak memory tracking not available")
}
t.Logf("peak memory: compiled=%d uncompiled=%d", compiledPeak, uncompiledPeak)
if compiledPeak >= uncompiledPeak {
t.Fatalf("compilation did not reduce peak memory: compiled=%d uncompiled=%d", compiledPeak, uncompiledPeak)
}
}
func TestCompileNested(t *testing.T) {
skipIfNoMLX(t)
// A compiled function that calls another compiled function should
// produce correct results. The inner function inlines via isTracing()
// during the outer's trace.
inner := Compile1("silu", func(a *Array) *Array {
return a.Multiply(a.Sigmoid())
}, Shapeless())
outer := Compile2("swiglu", func(gate, up *Array) *Array {
return inner(gate).Multiply(up)
}, Shapeless())
gate := FromValues([]float32{0, 1, 2}, 3)
up := FromValues([]float32{1, 1, 1}, 3)
Pin(gate, up)
defer Unpin(gate, up)
y := outer(gate, up)
Eval(y)
// silu(x) = x * sigmoid(x); for x=0 → 0, x=1 → ~0.7311, x=2 → ~1.7616
got := y.Floats()
want := []float32{0, 0.7310586, 1.7615942}
for i, v := range got {
if v-want[i] > 1e-4 || want[i]-v > 1e-4 {
t.Fatalf("got[%d]=%v want %v", i, v, want[i])
}
}
}
func TestCompileCallbackPanicRecovers(t *testing.T) {
skipIfNoMLX(t)
boom := Compile1("boom", func(a *Array) *Array {
panic("intentional test panic")
})
x := FromValues([]float32{1}, 1)
Pin(x)
defer Unpin(x)
defer func() {
r := recover()
if r == nil {
t.Fatal("expected panic from Call, got none")
}
if _, ok := r.(string); !ok {
t.Fatalf("expected string panic, got %T: %v", r, r)
}
}()
boom(x)
}
func TestCompileNoTrackingGrowth(t *testing.T) {
skipIfNoMLX(t)
// Repeated invocations of a compiled kernel should not grow the
// tracked-arrays list — the callback's traceScratch collects
// intermediates during tracing and frees them when the callback returns.
fn := Compile2("mul_add", func(a, b *Array) *Array {
return a.Multiply(b).Add(b)
})
a := FromValues([]float32{1, 2}, 2)
b := FromValues([]float32{3, 4}, 2)
Pin(a, b)
defer Unpin(a, b)
Sweep()
before := len(arrays)
for range 100 {
_ = fn(a, b)
Sweep()
}
after := len(arrays)
if after > before+2 {
t.Fatalf("tracked arrays grew from %d to %d across 100 calls (includes initial trace)", before, after)
}
}

View File

@@ -9,8 +9,8 @@ package mlx
// #include "generated.h"
// #include <string.h>
//
// static char _mlx_last_error_msg[1024] = {0};
// static int _mlx_last_error_flag = 0;
// static __thread char _mlx_last_error_msg[1024] = {0};
// static __thread int _mlx_last_error_flag = 0;
//
// static void _mlx_capture_error_handler(const char* msg, void* data) {
// (void)data;
@@ -30,15 +30,13 @@ package mlx
// _mlx_last_error_msg[0] = '\0';
// }
//
// static int mlx_had_last_error(void) {
// return _mlx_last_error_flag;
// }
//
// static const char* mlx_get_last_error(void) {
// return _mlx_last_error_flag ? _mlx_last_error_msg : NULL;
// return _mlx_last_error_flag ? _mlx_last_error_msg : "";
// }
import "C"
import "runtime"
func init() {
// Replace the default exit(-1) error handler with one that captures
// the error message so we can surface it in Go.
@@ -53,6 +51,24 @@ func Version() string {
return C.GoString(C.mlx_string_data(str))
}
// mlxCheck locks the goroutine to its OS thread, clears the captured error
// state, calls fn, and panics with the captured message if fn returns non-zero.
// The thread lock ensures the thread-local error state is read from the same
// thread that executed the call.
func mlxCheck(fallback string, fn func() C.int) {
runtime.LockOSThread()
defer runtime.UnlockOSThread()
C.mlx_clear_last_error()
if fn() != 0 {
msg := C.GoString(C.mlx_get_last_error())
if msg == "" {
msg = fallback
}
panic("mlx: " + msg)
}
}
func doEval(outputs []*Array, async bool) {
if len(outputs) == 0 {
return
@@ -67,20 +83,12 @@ func doEval(outputs []*Array, async bool) {
}
}
C.mlx_clear_last_error()
var rc C.int
if async {
rc = C.mlx_async_eval(vector)
} else {
rc = C.mlx_eval(vector)
}
if rc != 0 {
msg := "mlx eval failed"
if C.mlx_had_last_error() != 0 {
msg = C.GoString(C.mlx_get_last_error())
mlxCheck("eval failed", func() C.int {
if async {
return C.mlx_async_eval(vector)
}
panic("mlx: " + msg)
}
return C.mlx_eval(vector)
})
}
func AsyncEval(outputs ...*Array) {
@@ -90,3 +98,10 @@ func AsyncEval(outputs ...*Array) {
func Eval(outputs ...*Array) {
doEval(outputs, false)
}
// MetalIsAvailable returns true if a Metal GPU is available.
func MetalIsAvailable() bool {
var available C._Bool
C.mlx_metal_is_available(&available)
return bool(available)
}

View File

@@ -149,45 +149,132 @@ func Contiguous(a *Array, allowColMajor bool) *Array {
return out
}
func Pad(a *Array, paddings []int32) *Array {
numAxes := len(paddings) / 2
axes := make([]C.int, numAxes)
lowPad := make([]C.int, numAxes)
highPad := make([]C.int, numAxes)
for i := range numAxes {
axes[i] = C.int(i)
lowPad[i] = C.int(paddings[i*2])
highPad[i] = C.int(paddings[i*2+1])
// Conv2d performs 2D convolution: x [N,H,W,C_in], weight [C_out,kH,kW,C_in].
// MLX uses NHWC layout.
func Conv2d(x, weight *Array, strideH, strideW, padH, padW, dilationH, dilationW, groups int32) *Array {
out := New("CONV2D")
C.mlx_conv2d(
&out.ctx,
x.ctx,
weight.ctx,
C.int(strideH), C.int(strideW),
C.int(padH), C.int(padW),
C.int(dilationH), C.int(dilationW),
C.int(groups),
DefaultStream().ctx,
)
return out
}
// Pad pads array a along the given axes with specified low/high pad sizes.
// mode should be "constant", "edge", or "reflect".
func Pad(a *Array, axes []int, lowPad, highPad []int, padValue *Array, mode string) *Array {
cAxes := make([]C.int, len(axes))
cLow := make([]C.int, len(lowPad))
cHigh := make([]C.int, len(highPad))
for i := range axes {
cAxes[i] = C.int(axes[i])
cLow[i] = C.int(lowPad[i])
cHigh[i] = C.int(highPad[i])
}
padValue := C.mlx_array_new_float(C.float(0))
defer C.mlx_array_free(padValue)
cMode := C.CString("constant")
cMode := C.CString(mode)
defer C.free(unsafe.Pointer(cMode))
out := New("PAD")
C.mlx_pad(
&out.ctx,
a.ctx,
unsafe.SliceData(axes),
C.size_t(len(axes)),
unsafe.SliceData(lowPad),
C.size_t(len(lowPad)),
unsafe.SliceData(highPad),
C.size_t(len(highPad)),
padValue,
unsafe.SliceData(cAxes), C.size_t(len(cAxes)),
unsafe.SliceData(cLow), C.size_t(len(cLow)),
unsafe.SliceData(cHigh), C.size_t(len(cHigh)),
padValue.ctx,
cMode,
DefaultStream().ctx,
)
return out
}
// PadConstant pads with zeros along the given axes.
func PadConstant(a *Array, axes []int, lowPad, highPad []int) *Array {
zero := NewScalarArray(float32(0))
return Pad(a, axes, lowPad, highPad, zero, "constant")
}
func DepthwiseConv1d(x, weight *Array, bias *Array) *Array {
groups := int32(x.Dim(x.NumDims() - 1))
return Conv1d(x, weight, bias, 1, 0, 1, groups)
}
// Maximum returns element-wise maximum of two arrays.
func Maximum(a, b *Array) *Array {
out := New("MAXIMUM")
C.mlx_maximum(&out.ctx, a.ctx, b.ctx, DefaultStream().ctx)
return out
}
// Minimum returns element-wise minimum of two arrays.
func Minimum(a, b *Array) *Array {
out := New("MINIMUM")
C.mlx_minimum(&out.ctx, a.ctx, b.ctx, DefaultStream().ctx)
return out
}
// Softplus computes log(1 + exp(x)) using logaddexp for numerical stability.
func Softplus(a *Array) *Array {
return Logaddexp(a, Zeros(a.DType(), a.Dims()...))
}
// ReLU computes max(0, x).
func ReLU(a *Array) *Array {
return Maximum(a, NewScalarArray(float32(0)))
}
// GLU applies Gated Linear Unit: splits x along last dim into two halves,
// returns first * sigmoid(second).
func GLU(a *Array) *Array {
lastDim := a.NumDims() - 1
halfSize := a.Dim(lastDim) / 2
first := SliceStartStop(a,
make([]int32, lastDim+1), // all zeros for start
appendDims(a, lastDim, int32(halfSize)),
)
second := SliceStartStop(a,
appendDimsStart(a, lastDim, int32(halfSize)),
appendDims(a, lastDim, int32(a.Dim(lastDim))),
)
return first.Multiply(second.Sigmoid())
}
// helper: builds stop array for SliceStartStop where the target axis = val
func appendDims(a *Array, targetAxis int, val int32) []int32 {
n := a.NumDims()
out := make([]int32, n)
for i := range n {
if i == targetAxis {
out[i] = val
} else {
out[i] = int32(a.Dim(i))
}
}
return out
}
// helper: builds start array for SliceStartStop where the target axis = val
func appendDimsStart(a *Array, targetAxis int, val int32) []int32 {
n := a.NumDims()
out := make([]int32, n)
for i := range n {
if i == targetAxis {
out[i] = val
}
}
return out
}
// Clamp clamps array values to [min, max].
func Clamp(a *Array, minVal, maxVal float32) *Array {
return Minimum(Maximum(a, NewScalarArray(minVal)), NewScalarArray(maxVal))
}
// Convenience wrappers (function-style for the model code)
func Stack(arrays []*Array, axis int) *Array {
@@ -317,26 +404,38 @@ func GatherMM(a, b *Array, lhsIndices, rhsIndices *Array, sortedIndices bool) *A
return a.GatherMM(b, lhsIndices, rhsIndices, sortedIndices)
}
func SiLU(a *Array) *Array {
sig := a.Sigmoid()
return a.Multiply(sig)
func RoPEWithBase(x *Array, dims int, traditional bool, base, scale float32, offset int) *Array {
return RoPEWithFreqs(x, dims, traditional, base, scale, offset, nil)
}
func RoPEWithBase(x *Array, dims int, traditional bool, base, scale float32, offset int) *Array {
freqs := New("")
// RoPEWithFreqs applies RoPE with optional custom frequencies.
// When freqs is non-nil, it is used instead of computing from base.
// Note: MLX takes reciprocal(freqs) internally to get inv_freq, so pass
// the actual frequencies (base^(2i/dim)), not the inverse frequencies.
func RoPEWithFreqs(x *Array, dims int, traditional bool, base, scale float32, offset int, freqs *Array) *Array {
var freqsCtx C.mlx_array
var optBase C.mlx_optional_float
if freqs != nil {
freqsCtx = freqs.ctx
optBase = C.mlx_optional_float{has_value: C.bool(false)}
} else {
empty := New("")
freqsCtx = empty.ctx
optBase = C.mlx_optional_float{
value: C.float(base),
has_value: C.bool(func() bool { return base != 0 }()),
}
}
out := New("FAST_ROPE")
C.mlx_fast_rope(
&out.ctx,
x.ctx,
C.int(dims),
C.bool(traditional),
C.mlx_optional_float{
value: C.float(base),
has_value: C.bool(func() bool { return base != 0 }()),
},
optBase,
C.float(scale),
C.int(offset),
freqs.ctx,
freqsCtx,
DefaultStream().ctx,
)
return out
@@ -358,6 +457,24 @@ func Log(a *Array) *Array {
return out
}
func Sin(a *Array) *Array {
out := New("SIN")
C.mlx_sin(&out.ctx, a.ctx, DefaultStream().ctx)
return out
}
func Cos(a *Array) *Array {
out := New("COS")
C.mlx_cos(&out.ctx, a.ctx, DefaultStream().ctx)
return out
}
func Clip(a, aMin, aMax *Array) *Array {
out := New("CLIP")
C.mlx_clip(&out.ctx, a.ctx, aMin.ctx, aMax.ctx, DefaultStream().ctx)
return out
}
func Logaddexp(a, b *Array) *Array {
out := New("LOGADDEXP")
C.mlx_logaddexp(&out.ctx, a.ctx, b.ctx, DefaultStream().ctx)
@@ -385,6 +502,20 @@ func ScaledDotProductAttentionCausal(q, k, v *Array, scale float32, causalMask b
return out
}
// ScaledDotProductAttentionMasked runs the fast SDPA kernel with an explicit
// additive mask. The mask is broadcast to [B, H, Q, K] and added to scores
// before softmax. Pass mode="array" so MLX actually consults mask_arr; the
// empty string is "no mask" and silently ignores the array argument.
func ScaledDotProductAttentionMasked(q, k, v *Array, scale float32, mask *Array) *Array {
sinks := New("")
cMode := C.CString("array")
defer C.free(unsafe.Pointer(cMode))
out := New("FAST_SDPA")
C.mlx_fast_scaled_dot_product_attention(&out.ctx, q.ctx, k.ctx, v.ctx, C.float(scale), cMode, mask.ctx, sinks.ctx, DefaultStream().ctx)
return out
}
func LayerNormFn(x, weight, bias *Array, eps float32) *Array {
out := New("FAST_LAYERNORM")
var w, b C.mlx_array

View File

@@ -131,6 +131,12 @@ func readBlobTensorQuantInfo(path string) (map[string]*TensorQuantInfo, string,
globalQuantType, globalGroupSize := parseGlobalQuantMetadata(header)
globalQuantType = strings.ToUpper(globalQuantType)
// Parse full metadata for per-tensor quant info
var metaMap map[string]string
if metaRaw, ok := header["__metadata__"]; ok {
json.Unmarshal(metaRaw, &metaMap)
}
mainNames := mainTensorNames(header)
infos := make(map[string]*TensorQuantInfo)
for _, name := range mainNames {
@@ -141,6 +147,18 @@ func readBlobTensorQuantInfo(path string) (map[string]*TensorQuantInfo, string,
quantType := globalQuantType
groupSize := globalGroupSize
// Check per-tensor metadata (e.g. from packed expert blobs with mixed precision)
if metaMap != nil {
if qt, ok := metaMap[name+".quant_type"]; ok && qt != "" {
quantType = strings.ToUpper(qt)
}
if gs, ok := metaMap[name+".group_size"]; ok && gs != "" {
if v, err := strconv.Atoi(gs); err == nil {
groupSize = v
}
}
}
inferredType, inferredGroup := inferQuantTypeFromShapes(header, name, quantType)
if quantType == "" {
quantType = inferredType

View File

@@ -23,15 +23,6 @@ func (r *Runner) TextGenerationPipeline(request Request) error {
return errors.New("model not loaded")
}
enableCompile := true
if modelCompile, ok := r.Model.(interface{ EnableCompile() bool }); ok {
enableCompile = modelCompile.EnableCompile()
}
if enableCompile {
mlx.EnableCompile()
} else {
mlx.DisableCompile()
}
mlx.ResetPeakMemory()
ctx := request.Ctx
var (

View File

@@ -79,6 +79,8 @@ func (r *Runner) Load(modelName string) error {
r.Model = m
r.Tokenizer = m.Tokenizer()
r.contextLength = m.MaxContextLength()
mlx.EnableCompile()
return nil
}

1508
x/models/gemma4/gemma4.go Normal file

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,228 @@
package gemma4
import (
"testing"
"github.com/ollama/ollama/x/mlxrunner/mlx"
)
// onesLike creates a tensor of the given shape filled with a small constant.
func onesLike(shape ...int) *mlx.Array {
return mlx.AddScalar(mlx.Zeros(mlx.DTypeBFloat16, shape...), 0.01)
}
func TestMoEForward(t *testing.T) {
skipIfNoMLX(t)
// Small config matching 26b architecture pattern.
cfg := &TextConfig{
HiddenSize: 16, // tiny for testing
NumAttentionHeads: 2,
NumKeyValueHeads: 1,
NumGlobalKeyValueHeads: 1,
HeadDim: 8,
GlobalHeadDim: 8,
NumExperts: 4,
TopKExperts: 2,
ExpertIntermediateSize: 8,
EnableMoeBlock: true,
AttentionKEqV: false,
RMSNormEps: 1e-6,
SlidingScale: 1.0,
FullScale: 1.0,
}
B, L := int32(1), int32(3)
x := onesLike(int(B), int(L), int(cfg.HiddenSize))
// Test Router.Forward.
router := &Router{
Proj: linearFromWeight(onesLike(int(cfg.NumExperts), int(cfg.HiddenSize))),
Scale: onesLike(int(cfg.HiddenSize)),
}
t.Run("Router", func(t *testing.T) {
scores, inds := router.Forward(x, cfg)
mlx.Eval(scores, inds)
sDims := scores.Dims()
iDims := inds.Dims()
t.Logf("scores shape: %v, inds shape: %v", sDims, iDims)
if len(sDims) != 2 || sDims[0] != int(B*L) || sDims[1] != int(cfg.TopKExperts) {
t.Errorf("scores shape = %v, want [%d, %d]", sDims, B*L, cfg.TopKExperts)
}
if len(iDims) != 2 || iDims[0] != int(B*L) || iDims[1] != int(cfg.TopKExperts) {
t.Errorf("inds shape = %v, want [%d, %d]", iDims, B*L, cfg.TopKExperts)
}
})
// Test MoEBlock.Forward.
moe := &MoEBlock{
GateWeight: onesLike(int(cfg.NumExperts), int(cfg.HiddenSize), int(cfg.ExpertIntermediateSize)),
UpWeight: onesLike(int(cfg.NumExperts), int(cfg.HiddenSize), int(cfg.ExpertIntermediateSize)),
DownWeight: onesLike(int(cfg.NumExperts), int(cfg.ExpertIntermediateSize), int(cfg.HiddenSize)),
PerExpertScale: onesLike(int(cfg.NumExperts)),
}
t.Run("MoEBlock", func(t *testing.T) {
scores, inds := router.Forward(x, cfg)
mlx.Eval(scores, inds)
out := moe.Forward(x, scores, inds, cfg)
mlx.Eval(out)
outDims := out.Dims()
t.Logf("MoE output shape: %v", outDims)
if len(outDims) != 3 || outDims[0] != int(B) || outDims[1] != int(L) || outDims[2] != int(cfg.HiddenSize) {
t.Errorf("output shape = %v, want [%d, %d, %d]", outDims, B, L, cfg.HiddenSize)
}
})
// Test with larger batch to exercise the sorted GatherMM path (B*L >= 64).
t.Run("MoEBlock_sorted", func(t *testing.T) {
bigB, bigL := int32(1), int32(128)
bigX := onesLike(int(bigB), int(bigL), int(cfg.HiddenSize))
scores, inds := router.Forward(bigX, cfg)
mlx.Eval(scores, inds)
out := moe.Forward(bigX, scores, inds, cfg)
mlx.Eval(out)
outDims := out.Dims()
t.Logf("MoE sorted output shape: %v", outDims)
if len(outDims) != 3 || outDims[0] != int(bigB) || outDims[1] != int(bigL) || outDims[2] != int(cfg.HiddenSize) {
t.Errorf("output shape = %v, want [%d, %d, %d]", outDims, bigB, bigL, cfg.HiddenSize)
}
})
}
// TestRouterForwardMatchesLegacy verifies the optimized Router.Forward —
// which takes the top-k of the raw logits and softmaxes only the selected
// values — produces the same indices and (within tolerance) the same
// normalized scores as the legacy path that softmaxes over every expert
// first, gathers the top-k probabilities, then renormalizes.
func TestRouterForwardMatchesLegacy(t *testing.T) {
skipIfNoMLX(t)
cfg := &TextConfig{
HiddenSize: 8,
NumExperts: 4,
TopKExperts: 2,
RMSNormEps: 1e-6,
RouterScale: 0.5,
}
// Distinct per-expert weight rows so top-k has a well-defined ordering
// (tied scores would let argpartition pick either tied expert and make
// the index comparison below flaky).
projWeight := mlx.FromValues([]float32{
0.10, 0.11, 0.12, 0.13, 0.14, 0.15, 0.16, 0.17, // expert 0
0.30, 0.29, 0.28, 0.27, 0.26, 0.25, 0.24, 0.23, // expert 1
-0.05, -0.06, -0.07, -0.08, -0.09, -0.10, -0.11, -0.12, // expert 2
0.50, 0.48, 0.46, 0.44, 0.42, 0.40, 0.38, 0.36, // expert 3
}, int(cfg.NumExperts), int(cfg.HiddenSize))
scale := mlx.FromValues([]float32{
1.0, 0.9, 1.1, 1.0, 1.2, 0.8, 1.0, 1.05,
}, int(cfg.HiddenSize))
r := &Router{
Proj: linearFromWeight(projWeight),
Scale: scale,
}
// Varied x so different positions potentially hit different top-k.
x := mlx.FromValues([]float32{
0.2, -0.1, 0.3, 0.0, 0.4, -0.2, 0.1, 0.05,
-0.3, 0.2, -0.1, 0.4, -0.05, 0.3, 0.0, 0.2,
0.5, 0.4, -0.2, 0.1, -0.3, 0.0, 0.3, -0.1,
}, 1, 3, int(cfg.HiddenSize))
gotScores, gotInds := r.Forward(x, cfg)
wantScores, wantInds := legacyRouterForward(r, x, cfg)
mlx.Eval(gotScores, gotInds, wantScores, wantInds)
if got, want := gotInds.Ints(), wantInds.Ints(); !intSlicesEqual(got, want) {
t.Fatalf("indices mismatch:\n got %v\n want %v", got, want)
}
if got, want := gotScores.Floats(), wantScores.Floats(); !floatSlicesClose(got, want, 1e-5) {
t.Fatalf("scores mismatch:\n got %v\n want %v", got, want)
}
}
// legacyRouterForward implements the pre-optimization router: full softmax
// over every expert, gather the top-k probabilities, then renormalize them
// to sum to 1. Algebraically identical to the fused form in Router.Forward.
func legacyRouterForward(r *Router, x *mlx.Array, cfg *TextConfig) (*mlx.Array, *mlx.Array) {
dims := x.Dims()
BL := int32(dims[0]) * int32(dims[1])
xFlat := mlx.Reshape(x, BL, cfg.HiddenSize)
normed := mlx.RMSNormFn(xFlat, nil, cfg.RMSNormEps)
normed = mlx.MulScalar(normed, cfg.RouterScale)
normed = mlx.Mul(normed, r.Scale)
expertScores := r.Proj.Forward(normed)
probs := mlx.SoftmaxAxis(expertScores, -1, true)
neg := mlx.Neg(expertScores)
inds := mlx.Argpartition(neg, int(cfg.TopKExperts)-1, -1)
inds = mlx.SliceStartStop(inds,
[]int32{0, 0},
[]int32{BL, cfg.TopKExperts},
)
scores := mlx.TakeAlongAxis(probs, inds, -1)
sumScores := mlx.Sum(scores, -1, true)
scores = mlx.Div(scores, sumScores)
return scores, inds
}
func intSlicesEqual(a, b []int) bool {
if len(a) != len(b) {
return false
}
for i := range a {
if a[i] != b[i] {
return false
}
}
return true
}
func floatSlicesClose(a, b []float32, tol float32) bool {
if len(a) != len(b) {
return false
}
for i := range a {
d := a[i] - b[i]
if d < 0 {
d = -d
}
if d > tol {
return false
}
}
return true
}
// linearFromWeight creates a simple nn.LinearLayer from a weight tensor (no bias).
func linearFromWeight(w *mlx.Array) *simpleLinear {
return &simpleLinear{weight: w}
}
type simpleLinear struct {
weight *mlx.Array
}
func (l *simpleLinear) Forward(x *mlx.Array) *mlx.Array {
return x.Matmul(mlx.Transpose(l.weight, 1, 0))
}
func (l *simpleLinear) OutputDim() int32 {
return int32(l.weight.Dims()[0])
}

View File

@@ -0,0 +1,503 @@
package gemma4
import (
"testing"
"github.com/ollama/ollama/x/mlxrunner/mlx"
)
func TestParseTextConfigE2B(t *testing.T) {
skipIfNoMLX(t)
data := []byte(`{
"architectures": ["Gemma4ForConditionalGeneration"],
"text_config": {
"hidden_size": 1536,
"num_hidden_layers": 35,
"intermediate_size": 6144,
"num_attention_heads": 8,
"num_key_value_heads": 1,
"head_dim": 256,
"global_head_dim": 512,
"vocab_size": 262144,
"rms_norm_eps": 1e-6,
"max_position_embeddings": 131072,
"sliding_window": 512,
"sliding_window_pattern": 5,
"final_logit_softcapping": 30.0,
"use_double_wide_mlp": true,
"num_kv_shared_layers": 20,
"hidden_size_per_layer_input": 256,
"vocab_size_per_layer_input": 262144,
"attention_k_eq_v": false,
"tie_word_embeddings": true,
"layer_types": [
"sliding_attention","sliding_attention","sliding_attention","sliding_attention","full_attention",
"sliding_attention","sliding_attention","sliding_attention","sliding_attention","full_attention",
"sliding_attention","sliding_attention","sliding_attention","sliding_attention","full_attention",
"sliding_attention","sliding_attention","sliding_attention","sliding_attention","full_attention",
"sliding_attention","sliding_attention","sliding_attention","sliding_attention","full_attention",
"sliding_attention","sliding_attention","sliding_attention","sliding_attention","full_attention",
"sliding_attention","sliding_attention","sliding_attention","sliding_attention","full_attention"
],
"rope_parameters": {
"full_attention": {
"partial_rotary_factor": 0.25,
"rope_theta": 1000000.0,
"rope_type": "proportional"
},
"sliding_attention": {
"rope_theta": 10000.0,
"rope_type": "default"
}
}
}
}`)
cfg, err := parseTextConfig(data)
if err != nil {
t.Fatalf("parseTextConfig failed: %v", err)
}
// Basic fields.
if cfg.HiddenSize != 1536 {
t.Errorf("HiddenSize = %d, want 1536", cfg.HiddenSize)
}
if cfg.NumHiddenLayers != 35 {
t.Errorf("NumHiddenLayers = %d, want 35", cfg.NumHiddenLayers)
}
if cfg.GlobalHeadDim != 512 {
t.Errorf("GlobalHeadDim = %d, want 512", cfg.GlobalHeadDim)
}
if cfg.FinalLogitSoftcapping != 30.0 {
t.Errorf("FinalLogitSoftcapping = %f, want 30.0", cfg.FinalLogitSoftcapping)
}
if cfg.NumKVSharedLayers != 20 {
t.Errorf("NumKVSharedLayers = %d, want 20", cfg.NumKVSharedLayers)
}
if cfg.HiddenSizePerLayer != 256 {
t.Errorf("HiddenSizePerLayer = %d, want 256", cfg.HiddenSizePerLayer)
}
// RoPE settings.
if cfg.SlidingRopeDims != 256 {
t.Errorf("SlidingRopeDims = %d, want 256", cfg.SlidingRopeDims)
}
if cfg.FullRopeDims != 512 {
t.Errorf("FullRopeDims = %d, want 512 (GlobalHeadDim, partial rotation handled via custom freqs)", cfg.FullRopeDims)
}
if cfg.SlidingRopeBase != 10000 {
t.Errorf("SlidingRopeBase = %f, want 10000", cfg.SlidingRopeBase)
}
if cfg.FullRopeBase != 1000000 {
t.Errorf("FullRopeBase = %f, want 1000000", cfg.FullRopeBase)
}
// Attention scale.
if cfg.SlidingScale == 0 || cfg.FullScale == 0 {
t.Error("attention scales should be non-zero")
}
// KV sharing map.
// First shared layer is 35 - 20 = 15.
if donor, ok := cfg.KVShareMap[15]; !ok || donor != 13 {
t.Errorf("KVShareMap[15] = %d, ok=%v; want 13, true", donor, ok)
}
if donor, ok := cfg.KVShareMap[19]; !ok || donor != 14 {
t.Errorf("KVShareMap[19] = %d, ok=%v; want 14, true (full attn donor)", donor, ok)
}
if donor, ok := cfg.KVShareMap[34]; !ok || donor != 14 {
t.Errorf("KVShareMap[34] = %d, ok=%v; want 14, true (full attn donor)", donor, ok)
}
// Layer 14 should not be shared.
if _, ok := cfg.KVShareMap[14]; ok {
t.Error("layer 14 should not be in KVShareMap (non-shared)")
}
// Donors.
if !cfg.KVDonors[13] {
t.Error("layer 13 should be a KV donor")
}
if !cfg.KVDonors[14] {
t.Error("layer 14 should be a KV donor")
}
}
func TestParseTextConfig26B(t *testing.T) {
skipIfNoMLX(t)
data := []byte(`{
"architectures": ["Gemma4ForConditionalGeneration"],
"text_config": {
"hidden_size": 2816,
"num_hidden_layers": 30,
"intermediate_size": 2112,
"num_attention_heads": 16,
"num_key_value_heads": 8,
"num_global_key_value_heads": 2,
"head_dim": 256,
"global_head_dim": 512,
"vocab_size": 262144,
"rms_norm_eps": 1e-6,
"max_position_embeddings": 131072,
"sliding_window": 1024,
"final_logit_softcapping": 30.0,
"use_double_wide_mlp": false,
"num_kv_shared_layers": 0,
"hidden_size_per_layer_input": null,
"attention_k_eq_v": true,
"enable_moe_block": true,
"num_experts": 128,
"top_k_experts": 8,
"moe_intermediate_size": 704,
"tie_word_embeddings": true,
"layer_types": [
"sliding_attention","sliding_attention","sliding_attention","sliding_attention","sliding_attention","full_attention",
"sliding_attention","sliding_attention","sliding_attention","sliding_attention","sliding_attention","full_attention",
"sliding_attention","sliding_attention","sliding_attention","sliding_attention","sliding_attention","full_attention",
"sliding_attention","sliding_attention","sliding_attention","sliding_attention","sliding_attention","full_attention",
"sliding_attention","sliding_attention","sliding_attention","sliding_attention","sliding_attention","full_attention"
],
"rope_parameters": {
"full_attention": {
"partial_rotary_factor": 0.25,
"rope_theta": 1000000.0,
"rope_type": "proportional"
},
"sliding_attention": {
"rope_theta": 10000.0,
"rope_type": "default"
}
}
}
}`)
cfg, err := parseTextConfig(data)
if err != nil {
t.Fatalf("parseTextConfig failed: %v", err)
}
if cfg.HiddenSize != 2816 {
t.Errorf("HiddenSize = %d, want 2816", cfg.HiddenSize)
}
if !cfg.AttentionKEqV {
t.Error("AttentionKEqV should be true")
}
if cfg.NumGlobalKeyValueHeads != 2 {
t.Errorf("NumGlobalKeyValueHeads = %d, want 2", cfg.NumGlobalKeyValueHeads)
}
if !cfg.EnableMoeBlock {
t.Error("EnableMoeBlock should be true")
}
if cfg.NumExperts != 128 {
t.Errorf("NumExperts = %d, want 128", cfg.NumExperts)
}
if cfg.TopKExperts != 8 {
t.Errorf("TopKExperts = %d, want 8", cfg.TopKExperts)
}
if cfg.ExpertIntermediateSize != 704 {
t.Errorf("ExpertIntermediateSize = %d, want 704", cfg.ExpertIntermediateSize)
}
if cfg.HiddenSizePerLayer != 0 {
t.Errorf("HiddenSizePerLayer = %d, want 0 (no PLE)", cfg.HiddenSizePerLayer)
}
}
func TestParseTextConfig31B(t *testing.T) {
skipIfNoMLX(t)
data := []byte(`{
"architectures": ["Gemma4ForConditionalGeneration"],
"text_config": {
"hidden_size": 5376,
"num_hidden_layers": 60,
"intermediate_size": 21504,
"num_attention_heads": 32,
"num_key_value_heads": 16,
"num_global_key_value_heads": 4,
"head_dim": 256,
"global_head_dim": 512,
"vocab_size": 262144,
"rms_norm_eps": 1e-6,
"max_position_embeddings": 131072,
"sliding_window": 1024,
"final_logit_softcapping": 30.0,
"use_double_wide_mlp": false,
"num_kv_shared_layers": 0,
"hidden_size_per_layer_input": null,
"attention_k_eq_v": true,
"tie_word_embeddings": true,
"layer_types": [
"sliding_attention","sliding_attention","sliding_attention","sliding_attention","sliding_attention","full_attention",
"sliding_attention","sliding_attention","sliding_attention","sliding_attention","sliding_attention","full_attention",
"sliding_attention","sliding_attention","sliding_attention","sliding_attention","sliding_attention","full_attention",
"sliding_attention","sliding_attention","sliding_attention","sliding_attention","sliding_attention","full_attention",
"sliding_attention","sliding_attention","sliding_attention","sliding_attention","sliding_attention","full_attention",
"sliding_attention","sliding_attention","sliding_attention","sliding_attention","sliding_attention","full_attention",
"sliding_attention","sliding_attention","sliding_attention","sliding_attention","sliding_attention","full_attention",
"sliding_attention","sliding_attention","sliding_attention","sliding_attention","sliding_attention","full_attention",
"sliding_attention","sliding_attention","sliding_attention","sliding_attention","sliding_attention","full_attention",
"sliding_attention","sliding_attention","sliding_attention","sliding_attention","sliding_attention","full_attention"
],
"rope_parameters": {
"full_attention": {
"partial_rotary_factor": 0.25,
"rope_theta": 1000000.0,
"rope_type": "proportional"
},
"sliding_attention": {
"rope_theta": 10000.0,
"rope_type": "default"
}
}
}
}`)
cfg, err := parseTextConfig(data)
if err != nil {
t.Fatalf("parseTextConfig failed: %v", err)
}
if cfg.HiddenSize != 5376 {
t.Errorf("HiddenSize = %d, want 5376", cfg.HiddenSize)
}
if cfg.NumHiddenLayers != 60 {
t.Errorf("NumHiddenLayers = %d, want 60", cfg.NumHiddenLayers)
}
if !cfg.AttentionKEqV {
t.Error("AttentionKEqV should be true")
}
if cfg.NumGlobalKeyValueHeads != 4 {
t.Errorf("NumGlobalKeyValueHeads = %d, want 4", cfg.NumGlobalKeyValueHeads)
}
if cfg.NumKeyValueHeads != 16 {
t.Errorf("NumKeyValueHeads = %d, want 16", cfg.NumKeyValueHeads)
}
if cfg.NumKVSharedLayers != 0 {
t.Errorf("NumKVSharedLayers = %d, want 0", cfg.NumKVSharedLayers)
}
if cfg.HiddenSizePerLayer != 0 {
t.Errorf("HiddenSizePerLayer = %d, want 0 (no PLE)", cfg.HiddenSizePerLayer)
}
if cfg.SlidingWindow != 1024 {
t.Errorf("SlidingWindow = %d, want 1024", cfg.SlidingWindow)
}
// KV sharing should be empty (no shared layers).
if len(cfg.KVShareMap) != 0 {
t.Errorf("KVShareMap should be empty, got %d entries", len(cfg.KVShareMap))
}
// Layer types: pattern is 5 sliding + 1 full, repeating 10 times.
if !isLayerSliding(0, &cfg) {
t.Error("layer 0 should be sliding")
}
if isLayerSliding(5, &cfg) {
t.Error("layer 5 should be full attention")
}
if !isLayerSliding(6, &cfg) {
t.Error("layer 6 should be sliding")
}
if isLayerSliding(59, &cfg) {
t.Error("layer 59 should be full attention")
}
}
func TestParseTextConfigE4B(t *testing.T) {
skipIfNoMLX(t)
data := []byte(`{
"architectures": ["Gemma4ForConditionalGeneration"],
"text_config": {
"hidden_size": 2560,
"num_hidden_layers": 42,
"intermediate_size": 10240,
"num_attention_heads": 8,
"num_key_value_heads": 2,
"head_dim": 256,
"global_head_dim": 512,
"vocab_size": 262144,
"rms_norm_eps": 1e-6,
"max_position_embeddings": 131072,
"sliding_window": 512,
"final_logit_softcapping": 30.0,
"use_double_wide_mlp": false,
"num_kv_shared_layers": 18,
"hidden_size_per_layer_input": 256,
"vocab_size_per_layer_input": 262144,
"attention_k_eq_v": false,
"enable_moe_block": false,
"tie_word_embeddings": true,
"layer_types": [
"sliding_attention","sliding_attention","sliding_attention","sliding_attention","sliding_attention","full_attention",
"sliding_attention","sliding_attention","sliding_attention","sliding_attention","sliding_attention","full_attention",
"sliding_attention","sliding_attention","sliding_attention","sliding_attention","sliding_attention","full_attention",
"sliding_attention","sliding_attention","sliding_attention","sliding_attention","sliding_attention","full_attention",
"sliding_attention","sliding_attention","sliding_attention","sliding_attention","sliding_attention","full_attention",
"sliding_attention","sliding_attention","sliding_attention","sliding_attention","sliding_attention","full_attention",
"sliding_attention","sliding_attention","sliding_attention","sliding_attention","sliding_attention","full_attention"
],
"rope_parameters": {
"full_attention": {
"partial_rotary_factor": 0.25,
"rope_theta": 1000000.0,
"rope_type": "proportional"
},
"sliding_attention": {
"rope_theta": 10000.0,
"rope_type": "default"
}
}
}
}`)
cfg, err := parseTextConfig(data)
if err != nil {
t.Fatalf("parseTextConfig failed: %v", err)
}
if cfg.HiddenSize != 2560 {
t.Errorf("HiddenSize = %d, want 2560", cfg.HiddenSize)
}
if cfg.NumHiddenLayers != 42 {
t.Errorf("NumHiddenLayers = %d, want 42", cfg.NumHiddenLayers)
}
if cfg.IntermediateSize != 10240 {
t.Errorf("IntermediateSize = %d, want 10240", cfg.IntermediateSize)
}
if cfg.NumKeyValueHeads != 2 {
t.Errorf("NumKeyValueHeads = %d, want 2", cfg.NumKeyValueHeads)
}
if cfg.UseDoubleWideMLP {
t.Error("UseDoubleWideMLP should be false")
}
if cfg.NumKVSharedLayers != 18 {
t.Errorf("NumKVSharedLayers = %d, want 18", cfg.NumKVSharedLayers)
}
if cfg.HiddenSizePerLayer != 256 {
t.Errorf("HiddenSizePerLayer = %d, want 256 (has PLE)", cfg.HiddenSizePerLayer)
}
if cfg.AttentionKEqV {
t.Error("AttentionKEqV should be false")
}
if cfg.EnableMoeBlock {
t.Error("EnableMoeBlock should be false")
}
if cfg.SlidingWindow != 512 {
t.Errorf("SlidingWindow = %d, want 512", cfg.SlidingWindow)
}
// Layer types: pattern is 5 sliding + 1 full, repeating 7 times = 42 layers.
if !isLayerSliding(0, &cfg) {
t.Error("layer 0 should be sliding")
}
if isLayerSliding(5, &cfg) {
t.Error("layer 5 should be full attention")
}
if !isLayerSliding(6, &cfg) {
t.Error("layer 6 should be sliding")
}
if isLayerSliding(41, &cfg) {
t.Error("layer 41 should be full attention")
}
// KV sharing: first shared = 42 - 18 = 24.
// Layer 24 is sliding, its donor should be the last non-shared sliding layer.
// Non-shared layers: 0-23. Last sliding in 0-23 is layer 22 (23=full).
if donor, ok := cfg.KVShareMap[24]; !ok {
t.Error("layer 24 should be in KVShareMap")
} else {
t.Logf("layer 24 donor = %d", donor)
}
// Layer 29 is full_attention (5th full), donor should be the last non-shared full layer.
// Non-shared full layers: 5, 11, 17, 23.
if donor, ok := cfg.KVShareMap[29]; !ok || donor != 23 {
t.Errorf("KVShareMap[29] = %d, ok=%v; want 23, true (full attn donor)", donor, ok)
}
// Layer 23 should NOT be shared (it's the last non-shared layer).
if _, ok := cfg.KVShareMap[23]; ok {
t.Error("layer 23 should not be in KVShareMap (non-shared)")
}
}
func TestLayerTypeDetection(t *testing.T) {
cfg := &TextConfig{
LayerTypes: []string{
"sliding_attention", "sliding_attention", "sliding_attention", "sliding_attention", "full_attention",
},
}
if !isLayerSliding(0, cfg) {
t.Error("layer 0 should be sliding")
}
if !isLayerSliding(3, cfg) {
t.Error("layer 3 should be sliding")
}
if isLayerSliding(4, cfg) {
t.Error("layer 4 should be full attention")
}
}
func TestNewCachesOmitsSharedKVLayers(t *testing.T) {
m := &Model{
Layers: []*DecoderLayer{
{IsSliding: true, KVShareDonor: -1},
{IsSliding: false, KVShareDonor: -1},
{IsSliding: true, KVShareDonor: 0},
{IsSliding: false, KVShareDonor: 1},
},
TextConfig: &TextConfig{SlidingWindow: 512},
}
caches := m.NewCaches()
if got, want := len(caches), 2; got != want {
t.Fatalf("len(NewCaches()) = %d, want %d", got, want)
}
}
func TestNewCachesIncludesAllNonSharedLayers(t *testing.T) {
m := &Model{
Layers: []*DecoderLayer{
{IsSliding: true, KVShareDonor: -1},
{IsSliding: false, KVShareDonor: -1},
{IsSliding: true, KVShareDonor: -1},
},
TextConfig: &TextConfig{SlidingWindow: 512},
}
caches := m.NewCaches()
if got, want := len(caches), len(m.Layers); got != want {
t.Fatalf("len(NewCaches()) = %d, want %d", got, want)
}
}
func TestResolveWeightPrefix(t *testing.T) {
if err := mlx.CheckInit(); err != nil {
t.Skipf("MLX not available: %v", err)
}
tests := []struct {
name string
key string
wantPfx string
}{
{"bare", "embed_tokens.weight", ""},
{"language_model", "model.language_model.embed_tokens.weight", "model.language_model."},
{"with_model", "model.embed_tokens.weight", "model."},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
dummy := mlx.FromValue(float32(1.0))
mlx.Eval(dummy)
tensors := map[string]*mlx.Array{tt.key: dummy}
got := resolveWeightPrefix(tensors)
if got != tt.wantPfx {
t.Errorf("resolveWeightPrefix(%q) = %q, want %q", tt.key, got, tt.wantPfx)
}
})
}
}
func skipIfNoMLX(t *testing.T) {
t.Helper()
if err := mlx.CheckInit(); err != nil {
t.Skipf("MLX not available: %v", err)
}
}

View File

@@ -148,9 +148,7 @@ type DenseMLP struct {
// Forward applies the SwiGLU MLP
func (m *DenseMLP) Forward(x *mlx.Array) *mlx.Array {
gate := mlx.SiLU(m.GateProj.Forward(x))
up := m.UpProj.Forward(x)
return m.DownProj.Forward(mlx.Mul(gate, up))
return m.DownProj.Forward(mlx.SwiGLU(m.GateProj.Forward(x), m.UpProj.Forward(x)))
}
// MoEGate implements the expert gating mechanism
@@ -242,7 +240,7 @@ func (s *SwitchMLP) Forward(x *mlx.Array, indices *mlx.Array, cfg *Config) *mlx.
up = mlx.GatherQMM(xFlat, s.UpWeightQ, s.UpScales, s.UpBiases,
nil, idxFlat, true, s.UpGroupSize, s.UpBits, cfg.QuantMode, doSort)
hidden = mlx.Mul(mlx.SiLU(gate), up)
hidden = mlx.SwiGLU(gate, up)
down = mlx.GatherQMM(hidden, s.DownWeightQ, s.DownScales, s.DownBiases,
nil, idxFlat, true, s.DownGroupSize, s.DownBits, cfg.QuantMode, doSort)
@@ -250,7 +248,7 @@ func (s *SwitchMLP) Forward(x *mlx.Array, indices *mlx.Array, cfg *Config) *mlx.
gate = mlx.GatherMM(xFlat, mlx.Transpose(s.GateWeight, 0, 2, 1), nil, idxFlat, doSort)
up = mlx.GatherMM(xFlat, mlx.Transpose(s.UpWeight, 0, 2, 1), nil, idxFlat, doSort)
hidden = mlx.Mul(mlx.SiLU(gate), up)
hidden = mlx.SwiGLU(gate, up)
down = mlx.GatherMM(hidden, mlx.Transpose(s.DownWeight, 0, 2, 1), nil, idxFlat, doSort)
}
@@ -273,9 +271,7 @@ type SharedExperts struct {
// Forward applies the shared expert MLP
func (s *SharedExperts) Forward(x *mlx.Array) *mlx.Array {
gate := mlx.SiLU(s.GateProj.Forward(x))
up := s.UpProj.Forward(x)
return s.DownProj.Forward(mlx.Mul(gate, up))
return s.DownProj.Forward(mlx.SwiGLU(s.GateProj.Forward(x), s.UpProj.Forward(x)))
}
// MoE implements the full Mixture of Experts layer

View File

@@ -314,5 +314,5 @@ func (a *Attention) Forward(x *mlx.Array, c cache.Cache, B, L int32, cfg *Config
}
func (m *MLP) Forward(x *mlx.Array) *mlx.Array {
return m.DownProj.Forward(mlx.Mul(mlx.SiLU(m.GateProj.Forward(x)), m.UpProj.Forward(x)))
return m.DownProj.Forward(mlx.SwiGLU(m.GateProj.Forward(x), m.UpProj.Forward(x)))
}

View File

@@ -333,5 +333,5 @@ func (a *Attention) Forward(x *mlx.Array, c cache.Cache, B, L int32, cfg *Config
}
func (m *MLP) Forward(x *mlx.Array) *mlx.Array {
return m.DownProj.Forward(mlx.Mul(mlx.SiLU(m.GateProj.Forward(x)), m.UpProj.Forward(x)))
return m.DownProj.Forward(mlx.SwiGLU(m.GateProj.Forward(x), m.UpProj.Forward(x)))
}

View File

@@ -1253,7 +1253,7 @@ func (g *GatedDeltaNet) Forward(x *mlx.Array, c cache.Cache, B, L int32, cfg *Co
}
func (m *DenseMLP) Forward(x *mlx.Array, _ *Config) *mlx.Array {
return m.DownProj.Forward(mlx.Mul(mlx.SiLU(m.GateProj.Forward(x)), m.UpProj.Forward(x)))
return m.DownProj.Forward(mlx.SwiGLU(m.GateProj.Forward(x), m.UpProj.Forward(x)))
}
func (s *SwitchMLP) Forward(x *mlx.Array, indices *mlx.Array, cfg *Config) *mlx.Array {
@@ -1283,13 +1283,13 @@ func (s *SwitchMLP) Forward(x *mlx.Array, indices *mlx.Array, cfg *Config) *mlx.
nil, idxFlat, true, s.GateGroupSize, s.GateBits, cfg.QuantMode, doSort)
up = mlx.GatherQMM(xFlat, s.UpWeightQ, s.UpScales, s.UpBiases,
nil, idxFlat, true, s.UpGroupSize, s.UpBits, cfg.QuantMode, doSort)
hidden = mlx.Mul(mlx.SiLU(gate), up)
hidden = mlx.SwiGLU(gate, up)
down = mlx.GatherQMM(hidden, s.DownWeightQ, s.DownScales, s.DownBiases,
nil, idxFlat, true, s.DownGroupSize, s.DownBits, cfg.QuantMode, doSort)
} else {
gate = mlx.GatherMM(xFlat, s.GateWeight, nil, idxFlat, doSort)
up = mlx.GatherMM(xFlat, s.UpWeight, nil, idxFlat, doSort)
hidden = mlx.Mul(mlx.SiLU(gate), up)
hidden = mlx.SwiGLU(gate, up)
down = mlx.GatherMM(hidden, s.DownWeight, nil, idxFlat, doSort)
}