mirror of
https://github.com/ollama/ollama.git
synced 2026-04-18 09:03:35 -04:00
Compare commits
35 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
123b300af6 | ||
|
|
57653b8e42 | ||
|
|
a50ce61c54 | ||
|
|
2bb7ea00d2 | ||
|
|
55fa80d07a | ||
|
|
b9cb535407 | ||
|
|
031baef094 | ||
|
|
7d271e6dc9 | ||
|
|
c88dae2d6b | ||
|
|
9e3618d663 | ||
|
|
5d920cc6bc | ||
|
|
e585ecd11f | ||
|
|
cdddea0592 | ||
|
|
43f90def04 | ||
|
|
06ae6367bd | ||
|
|
48ad7085c4 | ||
|
|
e1e3cec8d0 | ||
|
|
d3e67e305c | ||
|
|
698e04a14b | ||
|
|
1d9537bc33 | ||
|
|
120424d832 | ||
|
|
5818001610 | ||
|
|
2cba7756c5 | ||
|
|
bf2a421727 | ||
|
|
f3cf6b75fb | ||
|
|
5dfac387a6 | ||
|
|
a99e5d9c22 | ||
|
|
0abf3aca36 | ||
|
|
ee0266462a | ||
|
|
c88fb286ec | ||
|
|
d3da29cbfc | ||
|
|
1b70bb8a10 | ||
|
|
ec29ce4ce3 | ||
|
|
4d75f5da03 | ||
|
|
798fd09bfe |
2
.github/workflows/test.yaml
vendored
2
.github/workflows/test.yaml
vendored
@@ -51,7 +51,7 @@ jobs:
|
||||
container: nvidia/cuda:13.0.0-devel-ubuntu22.04
|
||||
flags: '-DCMAKE_CUDA_ARCHITECTURES=87'
|
||||
- preset: ROCm
|
||||
container: rocm/dev-ubuntu-22.04:7.2
|
||||
container: rocm/dev-ubuntu-22.04:7.2.1
|
||||
extra-packages: rocm-libs
|
||||
flags: '-DAMDGPU_TARGETS=gfx1010 -DCMAKE_PREFIX_PATH=/opt/rocm'
|
||||
- preset: Vulkan
|
||||
|
||||
@@ -2,7 +2,7 @@
|
||||
|
||||
ARG FLAVOR=${TARGETARCH}
|
||||
|
||||
ARG ROCMVERSION=7.2
|
||||
ARG ROCMVERSION=7.2.1
|
||||
ARG JETPACK5VERSION=r35.4.1
|
||||
ARG JETPACK6VERSION=r36.4.0
|
||||
ARG CMAKEVERSION=3.31.2
|
||||
|
||||
@@ -55,7 +55,7 @@ The official [Ollama Docker image](https://hub.docker.com/r/ollama/ollama) `olla
|
||||
ollama
|
||||
```
|
||||
|
||||
You'll be prompted to run a model or connect Ollama to your existing agents or applications such as `claude`, `codex`, `openclaw` and more.
|
||||
You'll be prompted to run a model or connect Ollama to your existing agents or applications such as `Claude Code`, `OpenClaw`, `OpenCode` , `Codex`, `Copilot`, and more.
|
||||
|
||||
### Coding
|
||||
|
||||
@@ -65,7 +65,7 @@ To launch a specific integration:
|
||||
ollama launch claude
|
||||
```
|
||||
|
||||
Supported integrations include [Claude Code](https://docs.ollama.com/integrations/claude-code), [Codex](https://docs.ollama.com/integrations/codex), [Droid](https://docs.ollama.com/integrations/droid), and [OpenCode](https://docs.ollama.com/integrations/opencode).
|
||||
Supported integrations include [Claude Code](https://docs.ollama.com/integrations/claude-code), [Codex](https://docs.ollama.com/integrations/codex), [Copilot CLI](https://docs.ollama.com/integrations/copilot-cli), [Droid](https://docs.ollama.com/integrations/droid), and [OpenCode](https://docs.ollama.com/integrations/opencode).
|
||||
|
||||
### AI assistant
|
||||
|
||||
|
||||
@@ -58,6 +58,9 @@ func TestLaunchCmd(t *testing.T) {
|
||||
if cmd.Long == "" {
|
||||
t.Error("Long description should not be empty")
|
||||
}
|
||||
if !strings.Contains(cmd.Long, "hermes") {
|
||||
t.Error("Long description should mention hermes")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("flags exist", func(t *testing.T) {
|
||||
|
||||
76
cmd/launch/copilot.go
Normal file
76
cmd/launch/copilot.go
Normal file
@@ -0,0 +1,76 @@
|
||||
package launch
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"os"
|
||||
"os/exec"
|
||||
"path/filepath"
|
||||
"runtime"
|
||||
|
||||
"github.com/ollama/ollama/envconfig"
|
||||
)
|
||||
|
||||
// Copilot implements Runner for GitHub Copilot CLI integration.
|
||||
type Copilot struct{}
|
||||
|
||||
func (c *Copilot) String() string { return "Copilot CLI" }
|
||||
|
||||
func (c *Copilot) args(model string, extra []string) []string {
|
||||
var args []string
|
||||
if model != "" {
|
||||
args = append(args, "--model", model)
|
||||
}
|
||||
args = append(args, extra...)
|
||||
return args
|
||||
}
|
||||
|
||||
func (c *Copilot) findPath() (string, error) {
|
||||
if p, err := exec.LookPath("copilot"); err == nil {
|
||||
return p, nil
|
||||
}
|
||||
home, err := os.UserHomeDir()
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
name := "copilot"
|
||||
if runtime.GOOS == "windows" {
|
||||
name = "copilot.exe"
|
||||
}
|
||||
fallback := filepath.Join(home, ".local", "bin", name)
|
||||
if _, err := os.Stat(fallback); err != nil {
|
||||
return "", err
|
||||
}
|
||||
return fallback, nil
|
||||
}
|
||||
|
||||
func (c *Copilot) Run(model string, args []string) error {
|
||||
copilotPath, err := c.findPath()
|
||||
if err != nil {
|
||||
return fmt.Errorf("copilot is not installed, install from https://docs.github.com/en/copilot/how-tos/set-up/install-copilot-cli")
|
||||
}
|
||||
|
||||
cmd := exec.Command(copilotPath, c.args(model, args)...)
|
||||
cmd.Stdin = os.Stdin
|
||||
cmd.Stdout = os.Stdout
|
||||
cmd.Stderr = os.Stderr
|
||||
|
||||
cmd.Env = append(os.Environ(), c.envVars(model)...)
|
||||
|
||||
return cmd.Run()
|
||||
}
|
||||
|
||||
// envVars returns the environment variables that configure Copilot CLI
|
||||
// to use Ollama as its model provider.
|
||||
func (c *Copilot) envVars(model string) []string {
|
||||
env := []string{
|
||||
"COPILOT_PROVIDER_BASE_URL=" + envconfig.Host().String() + "/v1",
|
||||
"COPILOT_PROVIDER_API_KEY=",
|
||||
"COPILOT_PROVIDER_WIRE_API=responses",
|
||||
}
|
||||
|
||||
if model != "" {
|
||||
env = append(env, "COPILOT_MODEL="+model)
|
||||
}
|
||||
|
||||
return env
|
||||
}
|
||||
161
cmd/launch/copilot_test.go
Normal file
161
cmd/launch/copilot_test.go
Normal file
@@ -0,0 +1,161 @@
|
||||
package launch
|
||||
|
||||
import (
|
||||
"os"
|
||||
"path/filepath"
|
||||
"runtime"
|
||||
"slices"
|
||||
"strings"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestCopilotIntegration(t *testing.T) {
|
||||
c := &Copilot{}
|
||||
|
||||
t.Run("String", func(t *testing.T) {
|
||||
if got := c.String(); got != "Copilot CLI" {
|
||||
t.Errorf("String() = %q, want %q", got, "Copilot CLI")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("implements Runner", func(t *testing.T) {
|
||||
var _ Runner = c
|
||||
})
|
||||
}
|
||||
|
||||
func TestCopilotFindPath(t *testing.T) {
|
||||
c := &Copilot{}
|
||||
|
||||
t.Run("finds copilot in PATH", func(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
name := "copilot"
|
||||
if runtime.GOOS == "windows" {
|
||||
name = "copilot.exe"
|
||||
}
|
||||
fakeBin := filepath.Join(tmpDir, name)
|
||||
os.WriteFile(fakeBin, []byte("#!/bin/sh\n"), 0o755)
|
||||
t.Setenv("PATH", tmpDir)
|
||||
|
||||
got, err := c.findPath()
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
if got != fakeBin {
|
||||
t.Errorf("findPath() = %q, want %q", got, fakeBin)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("returns error when not in PATH", func(t *testing.T) {
|
||||
t.Setenv("PATH", t.TempDir()) // empty dir, no copilot binary
|
||||
|
||||
_, err := c.findPath()
|
||||
if err == nil {
|
||||
t.Fatal("expected error, got nil")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("falls back to ~/.local/bin/copilot", func(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
setTestHome(t, tmpDir)
|
||||
t.Setenv("PATH", t.TempDir()) // empty dir, no copilot binary
|
||||
|
||||
name := "copilot"
|
||||
if runtime.GOOS == "windows" {
|
||||
name = "copilot.exe"
|
||||
}
|
||||
fallback := filepath.Join(tmpDir, ".local", "bin", name)
|
||||
os.MkdirAll(filepath.Dir(fallback), 0o755)
|
||||
os.WriteFile(fallback, []byte("#!/bin/sh\n"), 0o755)
|
||||
|
||||
got, err := c.findPath()
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
if got != fallback {
|
||||
t.Errorf("findPath() = %q, want %q", got, fallback)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("returns error when neither PATH nor fallback exists", func(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
setTestHome(t, tmpDir)
|
||||
t.Setenv("PATH", t.TempDir()) // empty dir, no copilot binary
|
||||
|
||||
_, err := c.findPath()
|
||||
if err == nil {
|
||||
t.Fatal("expected error, got nil")
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestCopilotArgs(t *testing.T) {
|
||||
c := &Copilot{}
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
model string
|
||||
args []string
|
||||
want []string
|
||||
}{
|
||||
{"with model", "llama3.2", nil, []string{"--model", "llama3.2"}},
|
||||
{"empty model", "", nil, nil},
|
||||
{"with model and extra", "llama3.2", []string{"--verbose"}, []string{"--model", "llama3.2", "--verbose"}},
|
||||
{"empty model with help", "", []string{"--help"}, []string{"--help"}},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
got := c.args(tt.model, tt.args)
|
||||
if !slices.Equal(got, tt.want) {
|
||||
t.Errorf("args(%q, %v) = %v, want %v", tt.model, tt.args, got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestCopilotEnvVars(t *testing.T) {
|
||||
c := &Copilot{}
|
||||
|
||||
envMap := func(envs []string) map[string]string {
|
||||
m := make(map[string]string)
|
||||
for _, e := range envs {
|
||||
k, v, _ := strings.Cut(e, "=")
|
||||
m[k] = v
|
||||
}
|
||||
return m
|
||||
}
|
||||
|
||||
t.Run("sets required provider env vars with model", func(t *testing.T) {
|
||||
got := envMap(c.envVars("llama3.2"))
|
||||
if got["COPILOT_PROVIDER_BASE_URL"] == "" {
|
||||
t.Error("COPILOT_PROVIDER_BASE_URL should be set")
|
||||
}
|
||||
if !strings.HasSuffix(got["COPILOT_PROVIDER_BASE_URL"], "/v1") {
|
||||
t.Errorf("COPILOT_PROVIDER_BASE_URL = %q, want /v1 suffix", got["COPILOT_PROVIDER_BASE_URL"])
|
||||
}
|
||||
if _, ok := got["COPILOT_PROVIDER_API_KEY"]; !ok {
|
||||
t.Error("COPILOT_PROVIDER_API_KEY should be set (empty)")
|
||||
}
|
||||
if got["COPILOT_PROVIDER_WIRE_API"] != "responses" {
|
||||
t.Errorf("COPILOT_PROVIDER_WIRE_API = %q, want %q", got["COPILOT_PROVIDER_WIRE_API"], "responses")
|
||||
}
|
||||
if got["COPILOT_MODEL"] != "llama3.2" {
|
||||
t.Errorf("COPILOT_MODEL = %q, want %q", got["COPILOT_MODEL"], "llama3.2")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("omits COPILOT_MODEL when model is empty", func(t *testing.T) {
|
||||
got := envMap(c.envVars(""))
|
||||
if _, ok := got["COPILOT_MODEL"]; ok {
|
||||
t.Errorf("COPILOT_MODEL should not be set for empty model, got %q", got["COPILOT_MODEL"])
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("uses custom OLLAMA_HOST", func(t *testing.T) {
|
||||
t.Setenv("OLLAMA_HOST", "http://myhost:9999")
|
||||
got := envMap(c.envVars("test"))
|
||||
if !strings.Contains(got["COPILOT_PROVIDER_BASE_URL"], "myhost:9999") {
|
||||
t.Errorf("COPILOT_PROVIDER_BASE_URL = %q, want custom host", got["COPILOT_PROVIDER_BASE_URL"])
|
||||
}
|
||||
})
|
||||
}
|
||||
679
cmd/launch/hermes.go
Normal file
679
cmd/launch/hermes.go
Normal file
@@ -0,0 +1,679 @@
|
||||
package launch
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"bytes"
|
||||
"context"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"os"
|
||||
"os/exec"
|
||||
"path/filepath"
|
||||
"runtime"
|
||||
"slices"
|
||||
"strconv"
|
||||
"strings"
|
||||
|
||||
"gopkg.in/yaml.v3"
|
||||
|
||||
"github.com/ollama/ollama/api"
|
||||
"github.com/ollama/ollama/cmd/config"
|
||||
"github.com/ollama/ollama/cmd/internal/fileutil"
|
||||
"github.com/ollama/ollama/envconfig"
|
||||
)
|
||||
|
||||
const (
|
||||
hermesInstallScript = "curl -fsSL https://raw.githubusercontent.com/NousResearch/hermes-agent/main/scripts/install.sh | bash -s -- --skip-setup"
|
||||
hermesProviderName = "Ollama"
|
||||
hermesProviderKey = "ollama-launch"
|
||||
hermesLegacyKey = "ollama"
|
||||
hermesPlaceholderKey = "ollama"
|
||||
hermesGatewaySetupHint = "hermes gateway setup"
|
||||
hermesGatewaySetupTitle = "Connect a messaging app now?"
|
||||
)
|
||||
|
||||
var (
|
||||
hermesGOOS = runtime.GOOS
|
||||
hermesLookPath = exec.LookPath
|
||||
hermesCommand = exec.Command
|
||||
hermesUserHome = os.UserHomeDir
|
||||
hermesOllamaURL = envconfig.ConnectableHost
|
||||
)
|
||||
|
||||
var hermesMessagingEnvGroups = [][]string{
|
||||
{"TELEGRAM_BOT_TOKEN"},
|
||||
{"DISCORD_BOT_TOKEN"},
|
||||
{"SLACK_BOT_TOKEN"},
|
||||
{"SIGNAL_ACCOUNT"},
|
||||
{"EMAIL_ADDRESS"},
|
||||
{"TWILIO_ACCOUNT_SID"},
|
||||
{"MATRIX_ACCESS_TOKEN", "MATRIX_PASSWORD"},
|
||||
{"MATTERMOST_TOKEN"},
|
||||
{"WHATSAPP_PHONE_NUMBER_ID"},
|
||||
{"DINGTALK_CLIENT_ID"},
|
||||
{"FEISHU_APP_ID"},
|
||||
{"WECOM_BOT_ID"},
|
||||
{"WEIXIN_ACCOUNT_ID"},
|
||||
{"BLUEBUBBLES_SERVER_URL"},
|
||||
{"WEBHOOK_ENABLED"},
|
||||
}
|
||||
|
||||
// Hermes is intentionally not an Editor integration: launch owns one primary
|
||||
// model and the local Ollama endpoint, while Hermes keeps its own discovery and
|
||||
// switching UX after startup.
|
||||
type Hermes struct{}
|
||||
|
||||
func (h *Hermes) String() string { return "Hermes Agent" }
|
||||
|
||||
func (h *Hermes) Run(_ string, args []string) error {
|
||||
// Hermes reads its primary model from config.yaml. launch configures that
|
||||
// default model ahead of time so we can keep runtime invocation simple and
|
||||
// still let Hermes discover additional models later via its own UX.
|
||||
bin, err := h.binary()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if err := h.runGatewaySetupPreflight(args, func() error {
|
||||
return hermesAttachedCommand(bin, "gateway", "setup").Run()
|
||||
}); err != nil {
|
||||
return err
|
||||
}
|
||||
return hermesAttachedCommand(bin, args...).Run()
|
||||
}
|
||||
|
||||
func (h *Hermes) Paths() []string {
|
||||
configPath, err := hermesConfigPath()
|
||||
if err != nil {
|
||||
return nil
|
||||
}
|
||||
return []string{configPath}
|
||||
}
|
||||
|
||||
func (h *Hermes) Configure(model string) error {
|
||||
configPath, err := hermesConfigPath()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
cfg := map[string]any{}
|
||||
if data, err := os.ReadFile(configPath); err == nil {
|
||||
if err := yaml.Unmarshal(data, &cfg); err != nil {
|
||||
return fmt.Errorf("parse hermes config: %w", err)
|
||||
}
|
||||
} else if !os.IsNotExist(err) {
|
||||
return err
|
||||
}
|
||||
|
||||
modelSection, _ := cfg["model"].(map[string]any)
|
||||
if modelSection == nil {
|
||||
modelSection = make(map[string]any)
|
||||
}
|
||||
models := h.listModels(model)
|
||||
applyHermesManagedProviders(cfg, hermesBaseURL(), model, models)
|
||||
|
||||
// launch writes the minimum provider/default-model settings needed to
|
||||
// bootstrap Hermes against Ollama. The active provider stays on a
|
||||
// launch-owned key so /model stays aligned with the launcher-managed entry,
|
||||
// and the Ollama endpoint lives in providers: so the picker shows one row.
|
||||
modelSection["provider"] = hermesProviderKey
|
||||
modelSection["default"] = model
|
||||
modelSection["base_url"] = hermesBaseURL()
|
||||
modelSection["api_key"] = hermesPlaceholderKey
|
||||
cfg["model"] = modelSection
|
||||
|
||||
// use Hermes' built-in web toolset for now.
|
||||
// TODO(parthsareen): move this to using Ollama web search
|
||||
cfg["toolsets"] = mergeHermesToolsets(cfg["toolsets"])
|
||||
|
||||
data, err := yaml.Marshal(cfg)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if err := os.MkdirAll(filepath.Dir(configPath), 0o755); err != nil {
|
||||
return err
|
||||
}
|
||||
return fileutil.WriteWithBackup(configPath, data)
|
||||
}
|
||||
|
||||
func (h *Hermes) CurrentModel() string {
|
||||
configPath, err := hermesConfigPath()
|
||||
if err != nil {
|
||||
return ""
|
||||
}
|
||||
data, err := os.ReadFile(configPath)
|
||||
if err != nil {
|
||||
return ""
|
||||
}
|
||||
|
||||
cfg := map[string]any{}
|
||||
if yaml.Unmarshal(data, &cfg) != nil {
|
||||
return ""
|
||||
}
|
||||
return hermesManagedCurrentModel(cfg, hermesBaseURL())
|
||||
}
|
||||
|
||||
func (h *Hermes) Onboard() error {
|
||||
return config.MarkIntegrationOnboarded("hermes")
|
||||
}
|
||||
|
||||
func (h *Hermes) RequiresInteractiveOnboarding() bool {
|
||||
return false
|
||||
}
|
||||
|
||||
func (h *Hermes) RefreshRuntimeAfterConfigure() error {
|
||||
running, err := h.gatewayRunning()
|
||||
if err != nil {
|
||||
return fmt.Errorf("check Hermes gateway status: %w", err)
|
||||
}
|
||||
if !running {
|
||||
return nil
|
||||
}
|
||||
|
||||
fmt.Fprintf(os.Stderr, "%sRefreshing Hermes messaging gateway...%s\n", ansiGray, ansiReset)
|
||||
if err := h.restartGateway(); err != nil {
|
||||
return fmt.Errorf("restart Hermes gateway: %w", err)
|
||||
}
|
||||
fmt.Fprintln(os.Stderr)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (h *Hermes) installed() bool {
|
||||
_, err := h.binary()
|
||||
return err == nil
|
||||
}
|
||||
|
||||
func (h *Hermes) ensureInstalled() error {
|
||||
if h.installed() {
|
||||
return nil
|
||||
}
|
||||
|
||||
if hermesGOOS == "windows" {
|
||||
return hermesWindowsHint()
|
||||
}
|
||||
|
||||
var missing []string
|
||||
for _, dep := range []string{"bash", "curl", "git"} {
|
||||
if _, err := hermesLookPath(dep); err != nil {
|
||||
missing = append(missing, dep)
|
||||
}
|
||||
}
|
||||
if len(missing) > 0 {
|
||||
return fmt.Errorf("Hermes is not installed and required dependencies are missing\n\nInstall the following first:\n %s\n\nThen re-run:\n ollama launch hermes", strings.Join(missing, "\n "))
|
||||
}
|
||||
|
||||
ok, err := ConfirmPrompt("Hermes is not installed. Install now?")
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if !ok {
|
||||
return fmt.Errorf("hermes installation cancelled")
|
||||
}
|
||||
|
||||
fmt.Fprintf(os.Stderr, "\nInstalling Hermes...\n")
|
||||
if err := hermesAttachedCommand("bash", "-lc", hermesInstallScript).Run(); err != nil {
|
||||
return fmt.Errorf("failed to install hermes: %w", err)
|
||||
}
|
||||
|
||||
if !h.installed() {
|
||||
return fmt.Errorf("hermes was installed but the binary was not found on PATH\n\nYou may need to restart your shell")
|
||||
}
|
||||
|
||||
fmt.Fprintf(os.Stderr, "%sHermes installed successfully%s\n\n", ansiGreen, ansiReset)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (h *Hermes) listModels(defaultModel string) []string {
|
||||
client := hermesOllamaClient()
|
||||
resp, err := client.List(context.Background())
|
||||
if err != nil {
|
||||
return []string{defaultModel}
|
||||
}
|
||||
|
||||
models := make([]string, 0, len(resp.Models)+1)
|
||||
seen := make(map[string]struct{}, len(resp.Models)+1)
|
||||
add := func(name string) {
|
||||
name = strings.TrimSpace(name)
|
||||
if name == "" {
|
||||
return
|
||||
}
|
||||
if _, ok := seen[name]; ok {
|
||||
return
|
||||
}
|
||||
seen[name] = struct{}{}
|
||||
models = append(models, name)
|
||||
}
|
||||
|
||||
add(defaultModel)
|
||||
for _, entry := range resp.Models {
|
||||
add(entry.Name)
|
||||
}
|
||||
if len(models) == 0 {
|
||||
return []string{defaultModel}
|
||||
}
|
||||
return models
|
||||
}
|
||||
|
||||
func (h *Hermes) binary() (string, error) {
|
||||
if path, err := hermesLookPath("hermes"); err == nil {
|
||||
return path, nil
|
||||
}
|
||||
|
||||
if hermesGOOS == "windows" {
|
||||
return "", hermesWindowsHint()
|
||||
}
|
||||
|
||||
home, err := hermesUserHome()
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
fallback := filepath.Join(home, ".local", "bin", "hermes")
|
||||
if _, err := os.Stat(fallback); err == nil {
|
||||
return fallback, nil
|
||||
}
|
||||
|
||||
return "", fmt.Errorf("hermes is not installed")
|
||||
}
|
||||
|
||||
func hermesConfigPath() (string, error) {
|
||||
home, err := hermesUserHome()
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
return filepath.Join(home, ".hermes", "config.yaml"), nil
|
||||
}
|
||||
|
||||
func hermesBaseURL() string {
|
||||
return strings.TrimRight(hermesOllamaURL().String(), "/") + "/v1"
|
||||
}
|
||||
|
||||
func hermesEnvPath() (string, error) {
|
||||
home, err := hermesUserHome()
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
return filepath.Join(home, ".hermes", ".env"), nil
|
||||
}
|
||||
|
||||
func (h *Hermes) runGatewaySetupPreflight(args []string, runSetup func() error) error {
|
||||
if len(args) > 0 || !isInteractiveSession() || currentLaunchConfirmPolicy.yes || currentLaunchConfirmPolicy.requireYesMessage {
|
||||
return nil
|
||||
}
|
||||
if h.messagingConfigured() {
|
||||
return nil
|
||||
}
|
||||
|
||||
fmt.Fprintf(os.Stderr, "\nHermes can message you on Telegram, Discord, Slack, and more.\n\n")
|
||||
ok, err := ConfirmPromptWithOptions(hermesGatewaySetupTitle, ConfirmOptions{
|
||||
YesLabel: "Yes",
|
||||
NoLabel: "Set up later",
|
||||
})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if !ok {
|
||||
return nil
|
||||
}
|
||||
if err := runSetup(); err != nil {
|
||||
return fmt.Errorf("hermes messaging setup failed: %w\n\nTry running: %s", err, hermesGatewaySetupHint)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (h *Hermes) messagingConfigured() bool {
|
||||
envVars, err := h.gatewayEnvVars()
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
for _, group := range hermesMessagingEnvGroups {
|
||||
for _, key := range group {
|
||||
if strings.TrimSpace(envVars[key]) != "" {
|
||||
return true
|
||||
}
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func (h *Hermes) gatewayEnvVars() (map[string]string, error) {
|
||||
envVars := make(map[string]string)
|
||||
|
||||
envFilePath, err := hermesEnvPath()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
switch data, err := os.ReadFile(envFilePath); {
|
||||
case err == nil:
|
||||
for key, value := range hermesParseEnvFile(data) {
|
||||
envVars[key] = value
|
||||
}
|
||||
case os.IsNotExist(err):
|
||||
// nothing persisted yet
|
||||
default:
|
||||
return nil, err
|
||||
}
|
||||
|
||||
for _, group := range hermesMessagingEnvGroups {
|
||||
for _, key := range group {
|
||||
if value, ok := os.LookupEnv(key); ok {
|
||||
envVars[key] = value
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return envVars, nil
|
||||
}
|
||||
|
||||
func (h *Hermes) gatewayRunning() (bool, error) {
|
||||
status, err := h.gatewayStatusOutput()
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
return hermesGatewayStatusRunning(status), nil
|
||||
}
|
||||
|
||||
func (h *Hermes) gatewayStatusOutput() (string, error) {
|
||||
bin, err := h.binary()
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
out, err := hermesCommand(bin, "gateway", "status").CombinedOutput()
|
||||
return string(out), err
|
||||
}
|
||||
|
||||
func (h *Hermes) restartGateway() error {
|
||||
bin, err := h.binary()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return hermesAttachedCommand(bin, "gateway", "restart").Run()
|
||||
}
|
||||
|
||||
func hermesGatewayStatusRunning(output string) bool {
|
||||
status := strings.ToLower(output)
|
||||
switch {
|
||||
case strings.Contains(status, "gateway is not running"):
|
||||
return false
|
||||
case strings.Contains(status, "gateway service is stopped"):
|
||||
return false
|
||||
case strings.Contains(status, "gateway service is not loaded"):
|
||||
return false
|
||||
case strings.Contains(status, "gateway is running"):
|
||||
return true
|
||||
case strings.Contains(status, "gateway service is running"):
|
||||
return true
|
||||
case strings.Contains(status, "gateway service is loaded"):
|
||||
return true
|
||||
default:
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
func hermesParseEnvFile(data []byte) map[string]string {
|
||||
out := make(map[string]string)
|
||||
scanner := bufio.NewScanner(bytes.NewReader(data))
|
||||
for scanner.Scan() {
|
||||
line := strings.TrimSpace(strings.TrimPrefix(scanner.Text(), "\ufeff"))
|
||||
if line == "" || strings.HasPrefix(line, "#") {
|
||||
continue
|
||||
}
|
||||
if strings.HasPrefix(line, "export ") {
|
||||
line = strings.TrimSpace(strings.TrimPrefix(line, "export "))
|
||||
}
|
||||
|
||||
key, value, ok := strings.Cut(line, "=")
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
|
||||
key = strings.TrimSpace(key)
|
||||
if key == "" {
|
||||
continue
|
||||
}
|
||||
|
||||
value = strings.TrimSpace(value)
|
||||
if len(value) >= 2 {
|
||||
switch {
|
||||
case value[0] == '"' && value[len(value)-1] == '"':
|
||||
if unquoted, err := strconv.Unquote(value); err == nil {
|
||||
value = unquoted
|
||||
}
|
||||
case value[0] == '\'' && value[len(value)-1] == '\'':
|
||||
value = value[1 : len(value)-1]
|
||||
}
|
||||
}
|
||||
|
||||
out[key] = value
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
func hermesOllamaClient() *api.Client {
|
||||
// Hermes queries the same launch-resolved Ollama host that launch writes
|
||||
// into config, so model discovery follows the configured endpoint.
|
||||
return api.NewClient(hermesOllamaURL(), http.DefaultClient)
|
||||
}
|
||||
|
||||
func applyHermesManagedProviders(cfg map[string]any, baseURL string, model string, models []string) {
|
||||
providers := hermesUserProviders(cfg["providers"])
|
||||
entry := hermesManagedProviderEntry(providers)
|
||||
if entry == nil {
|
||||
entry = make(map[string]any)
|
||||
}
|
||||
entry["name"] = hermesProviderName
|
||||
entry["api"] = baseURL
|
||||
entry["default_model"] = model
|
||||
entry["models"] = hermesStringListAny(models)
|
||||
providers[hermesProviderKey] = entry
|
||||
delete(providers, hermesLegacyKey)
|
||||
cfg["providers"] = providers
|
||||
|
||||
customProviders := hermesWithoutManagedCustomProviders(cfg["custom_providers"])
|
||||
if len(customProviders) == 0 {
|
||||
delete(cfg, "custom_providers")
|
||||
return
|
||||
}
|
||||
cfg["custom_providers"] = customProviders
|
||||
}
|
||||
|
||||
func hermesManagedCurrentModel(cfg map[string]any, baseURL string) string {
|
||||
modelCfg, _ := cfg["model"].(map[string]any)
|
||||
if modelCfg == nil {
|
||||
return ""
|
||||
}
|
||||
|
||||
provider, _ := modelCfg["provider"].(string)
|
||||
if strings.TrimSpace(strings.ToLower(provider)) != hermesProviderKey {
|
||||
return ""
|
||||
}
|
||||
|
||||
configBaseURL, _ := modelCfg["base_url"].(string)
|
||||
if hermesNormalizeURL(configBaseURL) != hermesNormalizeURL(baseURL) {
|
||||
return ""
|
||||
}
|
||||
|
||||
current, _ := modelCfg["default"].(string)
|
||||
current = strings.TrimSpace(current)
|
||||
if current == "" {
|
||||
return ""
|
||||
}
|
||||
|
||||
providers := hermesUserProviders(cfg["providers"])
|
||||
entry, _ := providers[hermesProviderKey].(map[string]any)
|
||||
if entry == nil {
|
||||
return ""
|
||||
}
|
||||
if hermesHasManagedCustomProvider(cfg["custom_providers"]) {
|
||||
return ""
|
||||
}
|
||||
|
||||
apiURL, _ := entry["api"].(string)
|
||||
if hermesNormalizeURL(apiURL) != hermesNormalizeURL(baseURL) {
|
||||
return ""
|
||||
}
|
||||
|
||||
defaultModel, _ := entry["default_model"].(string)
|
||||
if strings.TrimSpace(defaultModel) != current {
|
||||
return ""
|
||||
}
|
||||
|
||||
return current
|
||||
}
|
||||
|
||||
func hermesUserProviders(current any) map[string]any {
|
||||
switch existing := current.(type) {
|
||||
case map[string]any:
|
||||
out := make(map[string]any, len(existing))
|
||||
for key, value := range existing {
|
||||
out[key] = value
|
||||
}
|
||||
return out
|
||||
case map[any]any:
|
||||
out := make(map[string]any, len(existing))
|
||||
for key, value := range existing {
|
||||
if s, ok := key.(string); ok {
|
||||
out[s] = value
|
||||
}
|
||||
}
|
||||
return out
|
||||
default:
|
||||
return make(map[string]any)
|
||||
}
|
||||
}
|
||||
|
||||
func hermesCustomProviders(current any) []any {
|
||||
switch existing := current.(type) {
|
||||
case []any:
|
||||
return append([]any(nil), existing...)
|
||||
case []map[string]any:
|
||||
out := make([]any, 0, len(existing))
|
||||
for _, entry := range existing {
|
||||
out = append(out, entry)
|
||||
}
|
||||
return out
|
||||
default:
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
func hermesManagedProviderEntry(providers map[string]any) map[string]any {
|
||||
for _, key := range []string{hermesProviderKey, hermesLegacyKey} {
|
||||
if entry, _ := providers[key].(map[string]any); entry != nil {
|
||||
return entry
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func hermesWithoutManagedCustomProviders(current any) []any {
|
||||
customProviders := hermesCustomProviders(current)
|
||||
preserved := make([]any, 0, len(customProviders))
|
||||
|
||||
for _, item := range customProviders {
|
||||
entry, _ := item.(map[string]any)
|
||||
if entry == nil {
|
||||
preserved = append(preserved, item)
|
||||
continue
|
||||
}
|
||||
if hermesManagedCustomProvider(entry) {
|
||||
continue
|
||||
}
|
||||
preserved = append(preserved, entry)
|
||||
}
|
||||
|
||||
return preserved
|
||||
}
|
||||
|
||||
func hermesHasManagedCustomProvider(current any) bool {
|
||||
for _, item := range hermesCustomProviders(current) {
|
||||
entry, _ := item.(map[string]any)
|
||||
if entry != nil && hermesManagedCustomProvider(entry) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func hermesManagedCustomProvider(entry map[string]any) bool {
|
||||
name, _ := entry["name"].(string)
|
||||
return strings.EqualFold(strings.TrimSpace(name), hermesProviderName)
|
||||
}
|
||||
|
||||
func hermesNormalizeURL(raw string) string {
|
||||
return strings.TrimRight(strings.TrimSpace(raw), "/")
|
||||
}
|
||||
|
||||
func hermesStringListAny(models []string) []any {
|
||||
out := make([]any, 0, len(models))
|
||||
for _, model := range dedupeModelList(models) {
|
||||
model = strings.TrimSpace(model)
|
||||
if model == "" {
|
||||
continue
|
||||
}
|
||||
out = append(out, model)
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
func mergeHermesToolsets(current any) any {
|
||||
added := false
|
||||
switch existing := current.(type) {
|
||||
case []any:
|
||||
out := make([]any, 0, len(existing)+1)
|
||||
for _, item := range existing {
|
||||
out = append(out, item)
|
||||
if s, _ := item.(string); s == "web" {
|
||||
added = true
|
||||
}
|
||||
}
|
||||
if !added {
|
||||
out = append(out, "web")
|
||||
}
|
||||
return out
|
||||
case []string:
|
||||
out := append([]string(nil), existing...)
|
||||
if !slices.Contains(out, "web") {
|
||||
out = append(out, "web")
|
||||
}
|
||||
asAny := make([]any, 0, len(out))
|
||||
for _, item := range out {
|
||||
asAny = append(asAny, item)
|
||||
}
|
||||
return asAny
|
||||
case string:
|
||||
if strings.TrimSpace(existing) == "" {
|
||||
return []any{"hermes-cli", "web"}
|
||||
}
|
||||
parts := strings.Split(existing, ",")
|
||||
out := make([]any, 0, len(parts)+1)
|
||||
for _, part := range parts {
|
||||
part = strings.TrimSpace(part)
|
||||
if part == "" {
|
||||
continue
|
||||
}
|
||||
if part == "web" {
|
||||
added = true
|
||||
}
|
||||
out = append(out, part)
|
||||
}
|
||||
if !added {
|
||||
out = append(out, "web")
|
||||
}
|
||||
return out
|
||||
default:
|
||||
return []any{"hermes-cli", "web"}
|
||||
}
|
||||
}
|
||||
|
||||
func hermesAttachedCommand(name string, args ...string) *exec.Cmd {
|
||||
cmd := hermesCommand(name, args...)
|
||||
cmd.Stdin = os.Stdin
|
||||
cmd.Stdout = os.Stdout
|
||||
cmd.Stderr = os.Stderr
|
||||
return cmd
|
||||
}
|
||||
|
||||
func hermesWindowsHint() error {
|
||||
return fmt.Errorf("Hermes on Windows requires WSL2. Install WSL with: wsl --install\n" +
|
||||
"Then run 'ollama launch hermes' from inside your WSL shell.\n" +
|
||||
"Docs: https://hermes-agent.nousresearch.com/docs/getting-started/installation/")
|
||||
}
|
||||
1110
cmd/launch/hermes_test.go
Normal file
1110
cmd/launch/hermes_test.go
Normal file
File diff suppressed because it is too large
Load Diff
@@ -14,6 +14,7 @@ import (
|
||||
|
||||
"github.com/google/go-cmp/cmp"
|
||||
"github.com/ollama/ollama/api"
|
||||
"github.com/ollama/ollama/cmd/config"
|
||||
)
|
||||
|
||||
type stubEditorRunner struct {
|
||||
@@ -73,7 +74,7 @@ func TestIntegrationLookup(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestIntegrationRegistry(t *testing.T) {
|
||||
expectedIntegrations := []string{"claude", "codex", "droid", "opencode"}
|
||||
expectedIntegrations := []string{"claude", "codex", "droid", "opencode", "hermes"}
|
||||
|
||||
for _, name := range expectedIntegrations {
|
||||
t.Run(name, func(t *testing.T) {
|
||||
@@ -328,7 +329,7 @@ func TestBuildModelList_NoExistingModels(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestBuildModelList_OnlyLocalModels_CloudRecsAtBottom(t *testing.T) {
|
||||
func TestBuildModelList_OnlyLocalModels_CloudRecsStillFirst(t *testing.T) {
|
||||
existing := []modelInfo{
|
||||
{Name: "llama3.2:latest", Remote: false},
|
||||
{Name: "qwen2.5:latest", Remote: false},
|
||||
@@ -337,10 +338,11 @@ func TestBuildModelList_OnlyLocalModels_CloudRecsAtBottom(t *testing.T) {
|
||||
items, _, _, _ := buildModelList(existing, nil, "")
|
||||
got := names(items)
|
||||
|
||||
// Recommended pinned at top (local recs first, then cloud recs when only-local), then installed non-recs
|
||||
want := []string{"gemma4", "qwen3.5", "kimi-k2.5:cloud", "qwen3.5:cloud", "glm-5.1:cloud", "minimax-m2.7:cloud", "llama3.2", "qwen2.5"}
|
||||
// Cloud recs always come first among recommended, regardless of installed inventory.
|
||||
// Cloud disablement is handled upstream in loadSelectableModels via filterCloudItems.
|
||||
want := []string{"kimi-k2.5:cloud", "qwen3.5:cloud", "glm-5.1:cloud", "minimax-m2.7:cloud", "gemma4", "qwen3.5", "llama3.2", "qwen2.5"}
|
||||
if diff := cmp.Diff(want, got); diff != "" {
|
||||
t.Errorf("recs pinned at top, local recs before cloud recs (-want +got):\n%s", diff)
|
||||
t.Errorf("cloud recs pinned first even when no cloud models installed (-want +got):\n%s", diff)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -587,7 +589,7 @@ func TestBuildModelList_MixedCase_CloudRecsFirst(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestBuildModelList_OnlyLocal_LocalRecsFirst(t *testing.T) {
|
||||
func TestBuildModelList_OnlyLocal_CloudRecsStillFirst(t *testing.T) {
|
||||
existing := []modelInfo{
|
||||
{Name: "llama3.2:latest", Remote: false},
|
||||
}
|
||||
@@ -595,11 +597,11 @@ func TestBuildModelList_OnlyLocal_LocalRecsFirst(t *testing.T) {
|
||||
items, _, _, _ := buildModelList(existing, nil, "")
|
||||
got := names(items)
|
||||
|
||||
// Local recs should sort before cloud recs in only-local case
|
||||
// Cloud recs sort before local recs regardless of installed inventory.
|
||||
localIdx := slices.Index(got, "gemma4")
|
||||
cloudIdx := slices.Index(got, "glm-5.1:cloud")
|
||||
if localIdx > cloudIdx {
|
||||
t.Errorf("local recs should be before cloud recs in only-local case, got %v", got)
|
||||
if cloudIdx > localIdx {
|
||||
t.Errorf("cloud recs should be before local recs even when only local models installed, got %v", got)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -722,6 +724,59 @@ func TestLauncherClientFilterDisabledCloudModels_ChecksStatusOncePerInvocation(t
|
||||
}
|
||||
}
|
||||
|
||||
func TestSavedMatchesModels(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
saved *config.IntegrationConfig
|
||||
models []string
|
||||
want bool
|
||||
}{
|
||||
{
|
||||
name: "nil saved",
|
||||
saved: nil,
|
||||
models: []string{"llama3.2"},
|
||||
want: false,
|
||||
},
|
||||
{
|
||||
name: "identical order",
|
||||
saved: &config.IntegrationConfig{Models: []string{"llama3.2", "qwen3:8b"}},
|
||||
models: []string{"llama3.2", "qwen3:8b"},
|
||||
want: true,
|
||||
},
|
||||
{
|
||||
name: "different order",
|
||||
saved: &config.IntegrationConfig{Models: []string{"llama3.2", "qwen3:8b"}},
|
||||
models: []string{"qwen3:8b", "llama3.2"},
|
||||
want: false,
|
||||
},
|
||||
{
|
||||
name: "subset",
|
||||
saved: &config.IntegrationConfig{Models: []string{"llama3.2", "qwen3:8b"}},
|
||||
models: []string{"llama3.2"},
|
||||
want: false,
|
||||
},
|
||||
{
|
||||
name: "nil models in saved with non-nil models",
|
||||
saved: &config.IntegrationConfig{Models: nil},
|
||||
models: []string{"llama3.2"},
|
||||
want: false,
|
||||
},
|
||||
{
|
||||
name: "empty both",
|
||||
saved: &config.IntegrationConfig{Models: nil},
|
||||
models: nil,
|
||||
want: true,
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
if got := savedMatchesModels(tt.saved, tt.models); got != tt.want {
|
||||
t.Fatalf("savedMatchesModels = %v, want %v", got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestPrepareEditorIntegration_SavesOnlyAfterSuccessfulEdit(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
setTestHome(t, tmpDir)
|
||||
@@ -1455,27 +1510,13 @@ func TestListIntegrationInfos(t *testing.T) {
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("sorted with custom order at end", func(t *testing.T) {
|
||||
// integrationOrder entries (cline, opencode) should appear last, in that order.
|
||||
// All other entries should be sorted alphabetically before them.
|
||||
orderRank := make(map[string]int)
|
||||
for i, name := range integrationOrder {
|
||||
orderRank[name] = i + 1
|
||||
t.Run("follows launcher order", func(t *testing.T) {
|
||||
got := make([]string, 0, len(infos))
|
||||
for _, info := range infos {
|
||||
got = append(got, info.Name)
|
||||
}
|
||||
for i := 1; i < len(infos); i++ {
|
||||
aRank, bRank := orderRank[infos[i-1].Name], orderRank[infos[i].Name]
|
||||
switch {
|
||||
case aRank == 0 && bRank == 0:
|
||||
if infos[i-1].Name >= infos[i].Name {
|
||||
t.Errorf("non-ordered items not sorted: %q >= %q", infos[i-1].Name, infos[i].Name)
|
||||
}
|
||||
case aRank > 0 && bRank == 0:
|
||||
t.Errorf("ordered item %q should come after non-ordered %q", infos[i-1].Name, infos[i].Name)
|
||||
case aRank > 0 && bRank > 0:
|
||||
if aRank >= bRank {
|
||||
t.Errorf("ordered items wrong: %q (rank %d) before %q (rank %d)", infos[i-1].Name, aRank, infos[i].Name, bRank)
|
||||
}
|
||||
}
|
||||
if diff := compareStrings(got, integrationOrder); diff != "" {
|
||||
t.Fatalf("launcher integration order mismatch: %s", diff)
|
||||
}
|
||||
})
|
||||
|
||||
@@ -1503,6 +1544,28 @@ func TestListIntegrationInfos(t *testing.T) {
|
||||
}
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("includes hermes", func(t *testing.T) {
|
||||
for _, info := range infos {
|
||||
if info.Name == "hermes" {
|
||||
return
|
||||
}
|
||||
}
|
||||
t.Fatal("expected hermes to be included in ListIntegrationInfos")
|
||||
})
|
||||
|
||||
t.Run("hermes still resolves explicitly", func(t *testing.T) {
|
||||
name, runner, err := LookupIntegration("hermes")
|
||||
if err != nil {
|
||||
t.Fatalf("expected explicit hermes integration lookup to work, got %v", err)
|
||||
}
|
||||
if name != "hermes" {
|
||||
t.Fatalf("expected canonical name hermes, got %q", name)
|
||||
}
|
||||
if runner.String() == "" {
|
||||
t.Fatal("expected hermes integration runner to be present")
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestBuildModelList_Descriptions(t *testing.T) {
|
||||
@@ -1591,6 +1654,7 @@ func TestIntegration_AutoInstallable(t *testing.T) {
|
||||
}{
|
||||
{"openclaw", true},
|
||||
{"pi", true},
|
||||
{"hermes", true},
|
||||
{"claude", false},
|
||||
{"codex", false},
|
||||
{"opencode", false},
|
||||
|
||||
@@ -6,6 +6,7 @@ import (
|
||||
"fmt"
|
||||
"net/http"
|
||||
"os"
|
||||
"slices"
|
||||
"strings"
|
||||
|
||||
"github.com/ollama/ollama/api"
|
||||
@@ -140,6 +141,36 @@ type Editor interface {
|
||||
Models() []string
|
||||
}
|
||||
|
||||
// ManagedSingleModel is the narrow launch-owned config path for integrations
|
||||
// like Hermes that have one primary model selected by launcher, need launcher
|
||||
// to persist minimal config, and still keep their own model discovery and
|
||||
// onboarding UX. This stays separate from Runner-only integrations and the
|
||||
// multi-model Editor flow so Hermes-specific behavior stays scoped to one path.
|
||||
type ManagedSingleModel interface {
|
||||
Paths() []string
|
||||
Configure(model string) error
|
||||
CurrentModel() string
|
||||
Onboard() error
|
||||
}
|
||||
|
||||
// ManagedRuntimeRefresher lets managed integrations refresh any long-lived
|
||||
// background runtime after launch rewrites their config.
|
||||
type ManagedRuntimeRefresher interface {
|
||||
RefreshRuntimeAfterConfigure() error
|
||||
}
|
||||
|
||||
// ManagedOnboardingValidator lets managed integrations re-check saved
|
||||
// onboarding state when launcher needs a stronger live readiness signal.
|
||||
type ManagedOnboardingValidator interface {
|
||||
OnboardingComplete() bool
|
||||
}
|
||||
|
||||
// ManagedInteractiveOnboarding lets a managed integration declare whether its
|
||||
// onboarding step really requires an interactive terminal. Hermes does not.
|
||||
type ManagedInteractiveOnboarding interface {
|
||||
RequiresInteractiveOnboarding() bool
|
||||
}
|
||||
|
||||
type modelInfo struct {
|
||||
Name string
|
||||
Remote bool
|
||||
@@ -175,7 +206,9 @@ Supported integrations:
|
||||
claude Claude Code
|
||||
cline Cline
|
||||
codex Codex
|
||||
copilot Copilot CLI (aliases: copilot-cli)
|
||||
droid Droid
|
||||
hermes Hermes Agent
|
||||
opencode OpenCode
|
||||
openclaw OpenClaw (aliases: clawdbot, moltbot)
|
||||
pi Pi
|
||||
@@ -185,6 +218,7 @@ Examples:
|
||||
ollama launch
|
||||
ollama launch claude
|
||||
ollama launch claude --model <model>
|
||||
ollama launch hermes
|
||||
ollama launch droid --config (does not auto-launch)
|
||||
ollama launch codex -- -p myprofile (pass extra args to integration)
|
||||
ollama launch codex -- --sandbox workspace-write`,
|
||||
@@ -307,36 +341,54 @@ func LaunchIntegration(ctx context.Context, req IntegrationLaunchRequest) error
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
policy := launchIntegrationPolicy(req)
|
||||
if policy.Confirm == LaunchConfirmAutoApprove && !isInteractiveSession() && req.ModelOverride == "" {
|
||||
return fmt.Errorf("headless --yes launch for %s requires --model <model>", name)
|
||||
}
|
||||
|
||||
launchClient, saved, err := prepareIntegrationLaunch(name, policy)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if managed, ok := runner.(ManagedSingleModel); ok {
|
||||
if err := EnsureIntegrationInstalled(name, runner); err != nil {
|
||||
return err
|
||||
}
|
||||
return launchClient.launchManagedSingleIntegration(ctx, name, runner, managed, saved, req)
|
||||
}
|
||||
|
||||
if !req.ConfigureOnly {
|
||||
if err := EnsureIntegrationInstalled(name, runner); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
var policy LaunchPolicy
|
||||
// TUI does not set a policy, whereas ollama launch <app> does as it can have flags which change the behavior
|
||||
if req.Policy == nil {
|
||||
policy = defaultLaunchPolicy(isInteractiveSession(), false)
|
||||
} else {
|
||||
policy = *req.Policy
|
||||
}
|
||||
|
||||
launchClient, err := newLauncherClient(policy)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
saved, _ := loadStoredIntegrationConfig(name)
|
||||
// In headless --yes mode we cannot prompt, so require an explicit --model.
|
||||
if policy.Confirm == LaunchConfirmAutoApprove && !isInteractiveSession() && req.ModelOverride == "" {
|
||||
return fmt.Errorf("headless --yes launch for %s requires --model <model>", name)
|
||||
}
|
||||
|
||||
if editor, ok := runner.(Editor); ok {
|
||||
return launchClient.launchEditorIntegration(ctx, name, runner, editor, saved, req)
|
||||
}
|
||||
return launchClient.launchSingleIntegration(ctx, name, runner, saved, req)
|
||||
}
|
||||
|
||||
func launchIntegrationPolicy(req IntegrationLaunchRequest) LaunchPolicy {
|
||||
// TUI does not set a policy, whereas ollama launch <app> does as it can
|
||||
// have flags which change the behavior.
|
||||
if req.Policy != nil {
|
||||
return *req.Policy
|
||||
}
|
||||
return defaultLaunchPolicy(isInteractiveSession(), false)
|
||||
}
|
||||
|
||||
func prepareIntegrationLaunch(name string, policy LaunchPolicy) (*launcherClient, *config.IntegrationConfig, error) {
|
||||
launchClient, err := newLauncherClient(policy)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
saved, _ := loadStoredIntegrationConfig(name)
|
||||
return launchClient, saved, nil
|
||||
}
|
||||
|
||||
func (c *launcherClient) buildLauncherState(ctx context.Context) (*LauncherState, error) {
|
||||
_ = c.loadModelInventoryOnce(ctx)
|
||||
|
||||
@@ -367,9 +419,18 @@ func (c *launcherClient) buildLauncherIntegrationState(ctx context.Context, info
|
||||
if err != nil {
|
||||
return LauncherIntegrationState{}, err
|
||||
}
|
||||
currentModel, usable, err := c.launcherModelState(ctx, info.Name, integration.editor)
|
||||
if err != nil {
|
||||
return LauncherIntegrationState{}, err
|
||||
var currentModel string
|
||||
var usable bool
|
||||
if managed, ok := integration.spec.Runner.(ManagedSingleModel); ok {
|
||||
currentModel, usable, err = c.launcherManagedModelState(ctx, info.Name, managed)
|
||||
if err != nil {
|
||||
return LauncherIntegrationState{}, err
|
||||
}
|
||||
} else {
|
||||
currentModel, usable, err = c.launcherModelState(ctx, info.Name, integration.editor)
|
||||
if err != nil {
|
||||
return LauncherIntegrationState{}, err
|
||||
}
|
||||
}
|
||||
|
||||
return LauncherIntegrationState{
|
||||
@@ -407,6 +468,28 @@ func (c *launcherClient) launcherModelState(ctx context.Context, name string, is
|
||||
return model, usableErr == nil && usable, nil
|
||||
}
|
||||
|
||||
func (c *launcherClient) launcherManagedModelState(ctx context.Context, name string, managed ManagedSingleModel) (string, bool, error) {
|
||||
current := managed.CurrentModel()
|
||||
if current == "" {
|
||||
cfg, loadErr := loadStoredIntegrationConfig(name)
|
||||
if loadErr == nil {
|
||||
current = primaryModelFromConfig(cfg)
|
||||
}
|
||||
if current != "" {
|
||||
return current, false, nil
|
||||
}
|
||||
}
|
||||
if current == "" {
|
||||
return "", false, nil
|
||||
}
|
||||
|
||||
usable, err := c.savedModelUsable(ctx, current)
|
||||
if err != nil {
|
||||
return current, false, err
|
||||
}
|
||||
return current, usable, nil
|
||||
}
|
||||
|
||||
func (c *launcherClient) resolveRunModel(ctx context.Context, req RunModelRequest) (string, error) {
|
||||
current := config.LastModel()
|
||||
if !req.ForcePicker && current != "" && c.policy.Confirm == LaunchConfirmAutoApprove && !isInteractiveSession() {
|
||||
@@ -443,35 +526,15 @@ func (c *launcherClient) resolveRunModel(ctx context.Context, req RunModelReques
|
||||
}
|
||||
|
||||
func (c *launcherClient) launchSingleIntegration(ctx context.Context, name string, runner Runner, saved *config.IntegrationConfig, req IntegrationLaunchRequest) error {
|
||||
current := primaryModelFromConfig(saved)
|
||||
target := req.ModelOverride
|
||||
needsConfigure := req.ForceConfigure
|
||||
|
||||
if target == "" {
|
||||
target = current
|
||||
usable, err := c.savedModelUsable(ctx, target)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if !usable {
|
||||
needsConfigure = true
|
||||
}
|
||||
}
|
||||
|
||||
if needsConfigure {
|
||||
selected, err := c.selectSingleModelWithSelector(ctx, fmt.Sprintf("Select model for %s:", runner), target, DefaultSingleSelector)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
target = selected
|
||||
} else if err := c.ensureModelsReady(ctx, []string{target}); err != nil {
|
||||
target, _, err := c.resolveSingleIntegrationTarget(ctx, runner, primaryModelFromConfig(saved), req)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if target == "" {
|
||||
return nil
|
||||
}
|
||||
|
||||
current := primaryModelFromConfig(saved)
|
||||
if target != current {
|
||||
if err := config.SaveIntegration(name, []string{target}); err != nil {
|
||||
return fmt.Errorf("failed to save: %w", err)
|
||||
@@ -500,7 +563,7 @@ func (c *launcherClient) launchEditorIntegration(ctx context.Context, name strin
|
||||
return nil
|
||||
}
|
||||
|
||||
if needsConfigure || req.ModelOverride != "" {
|
||||
if (needsConfigure || req.ModelOverride != "") && !savedMatchesModels(saved, models) {
|
||||
if err := prepareEditorIntegration(name, runner, editor, models); err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -509,6 +572,102 @@ func (c *launcherClient) launchEditorIntegration(ctx context.Context, name strin
|
||||
return launchAfterConfiguration(name, runner, models[0], req)
|
||||
}
|
||||
|
||||
func (c *launcherClient) launchManagedSingleIntegration(ctx context.Context, name string, runner Runner, managed ManagedSingleModel, saved *config.IntegrationConfig, req IntegrationLaunchRequest) error {
|
||||
current := managed.CurrentModel()
|
||||
selectionCurrent := current
|
||||
if selectionCurrent == "" {
|
||||
selectionCurrent = primaryModelFromConfig(saved)
|
||||
}
|
||||
|
||||
target, needsConfigure, err := c.resolveSingleIntegrationTarget(ctx, runner, selectionCurrent, req)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if target == "" {
|
||||
return nil
|
||||
}
|
||||
|
||||
if (current == "" || needsConfigure || req.ModelOverride != "" || target != current) && !savedMatchesModels(saved, []string{target}) {
|
||||
if err := prepareManagedSingleIntegration(name, runner, managed, target); err != nil {
|
||||
return err
|
||||
}
|
||||
if refresher, ok := managed.(ManagedRuntimeRefresher); ok {
|
||||
if err := refresher.RefreshRuntimeAfterConfigure(); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if !managedIntegrationOnboarded(saved, managed) {
|
||||
if !isInteractiveSession() && managedRequiresInteractiveOnboarding(managed) {
|
||||
return fmt.Errorf("%s still needs interactive gateway setup; run 'ollama launch %s' in a terminal to finish onboarding", runner, name)
|
||||
}
|
||||
if err := managed.Onboard(); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
if req.ConfigureOnly {
|
||||
return nil
|
||||
}
|
||||
|
||||
return runIntegration(runner, target, req.ExtraArgs)
|
||||
}
|
||||
|
||||
func (c *launcherClient) resolveSingleIntegrationTarget(ctx context.Context, runner Runner, current string, req IntegrationLaunchRequest) (string, bool, error) {
|
||||
target := req.ModelOverride
|
||||
needsConfigure := req.ForceConfigure
|
||||
|
||||
if target == "" {
|
||||
target = current
|
||||
usable, err := c.savedModelUsable(ctx, target)
|
||||
if err != nil {
|
||||
return "", false, err
|
||||
}
|
||||
if !usable {
|
||||
needsConfigure = true
|
||||
}
|
||||
}
|
||||
|
||||
if needsConfigure {
|
||||
selected, err := c.selectSingleModelWithSelector(ctx, fmt.Sprintf("Select model for %s:", runner), target, DefaultSingleSelector)
|
||||
if err != nil {
|
||||
return "", false, err
|
||||
}
|
||||
target = selected
|
||||
} else if err := c.ensureModelsReady(ctx, []string{target}); err != nil {
|
||||
return "", false, err
|
||||
}
|
||||
|
||||
return target, needsConfigure, nil
|
||||
}
|
||||
|
||||
func savedIntegrationOnboarded(saved *config.IntegrationConfig) bool {
|
||||
return saved != nil && saved.Onboarded
|
||||
}
|
||||
|
||||
func managedIntegrationOnboarded(saved *config.IntegrationConfig, managed ManagedSingleModel) bool {
|
||||
if !savedIntegrationOnboarded(saved) {
|
||||
return false
|
||||
}
|
||||
validator, ok := managed.(ManagedOnboardingValidator)
|
||||
if !ok {
|
||||
return true
|
||||
}
|
||||
return validator.OnboardingComplete()
|
||||
}
|
||||
|
||||
// Most managed integrations treat onboarding as an interactive terminal step.
|
||||
// Hermes opts out because its launch-owned onboarding is just bookkeeping, so
|
||||
// headless launches should not be blocked once config is already prepared.
|
||||
func managedRequiresInteractiveOnboarding(managed ManagedSingleModel) bool {
|
||||
onboarding, ok := managed.(ManagedInteractiveOnboarding)
|
||||
if !ok {
|
||||
return true
|
||||
}
|
||||
return onboarding.RequiresInteractiveOnboarding()
|
||||
}
|
||||
|
||||
func (c *launcherClient) selectSingleModelWithSelector(ctx context.Context, title, current string, selector SingleSelector) (string, error) {
|
||||
if selector == nil {
|
||||
return "", fmt.Errorf("no selector configured")
|
||||
@@ -846,6 +1005,13 @@ func firstModel(models []string) string {
|
||||
return models[0]
|
||||
}
|
||||
|
||||
func savedMatchesModels(saved *config.IntegrationConfig, models []string) bool {
|
||||
if saved == nil {
|
||||
return false
|
||||
}
|
||||
return slices.Equal(saved.Models, models)
|
||||
}
|
||||
|
||||
func editorPreCheckedModels(saved *config.IntegrationConfig, override string) []string {
|
||||
if override == "" {
|
||||
if saved == nil {
|
||||
|
||||
@@ -49,6 +49,55 @@ func (r *launcherSingleRunner) Run(model string, args []string) error {
|
||||
|
||||
func (r *launcherSingleRunner) String() string { return "StubSingle" }
|
||||
|
||||
type launcherManagedRunner struct {
|
||||
paths []string
|
||||
currentModel string
|
||||
configured []string
|
||||
ranModel string
|
||||
onboarded bool
|
||||
onboardCalls int
|
||||
onboardingComplete bool
|
||||
refreshCalls int
|
||||
refreshErr error
|
||||
}
|
||||
|
||||
func (r *launcherManagedRunner) Run(model string, args []string) error {
|
||||
r.ranModel = model
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *launcherManagedRunner) String() string { return "StubManaged" }
|
||||
|
||||
func (r *launcherManagedRunner) Paths() []string { return r.paths }
|
||||
|
||||
func (r *launcherManagedRunner) Configure(model string) error {
|
||||
r.configured = append(r.configured, model)
|
||||
r.currentModel = model
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *launcherManagedRunner) CurrentModel() string { return r.currentModel }
|
||||
|
||||
func (r *launcherManagedRunner) Onboard() error {
|
||||
r.onboardCalls++
|
||||
r.onboarded = true
|
||||
r.onboardingComplete = true
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *launcherManagedRunner) OnboardingComplete() bool { return r.onboardingComplete }
|
||||
|
||||
func (r *launcherManagedRunner) RefreshRuntimeAfterConfigure() error {
|
||||
r.refreshCalls++
|
||||
return r.refreshErr
|
||||
}
|
||||
|
||||
type launcherHeadlessManagedRunner struct {
|
||||
launcherManagedRunner
|
||||
}
|
||||
|
||||
func (r *launcherHeadlessManagedRunner) RequiresInteractiveOnboarding() bool { return false }
|
||||
|
||||
func setLaunchTestHome(t *testing.T, dir string) {
|
||||
t.Helper()
|
||||
t.Setenv("HOME", dir)
|
||||
@@ -141,6 +190,451 @@ func TestDefaultLaunchPolicy(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestBuildLauncherState_ManagedSingleIntegrationUsesCurrentModel(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
setLaunchTestHome(t, tmpDir)
|
||||
|
||||
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
switch r.URL.Path {
|
||||
case "/api/tags":
|
||||
fmt.Fprint(w, `{"models":[{"name":"gemma4"}]}`)
|
||||
case "/api/show":
|
||||
fmt.Fprint(w, `{"model_info":{"general.context_length":131072}}`)
|
||||
default:
|
||||
http.NotFound(w, r)
|
||||
}
|
||||
}))
|
||||
defer srv.Close()
|
||||
t.Setenv("OLLAMA_HOST", srv.URL)
|
||||
|
||||
runner := &launcherManagedRunner{currentModel: "gemma4"}
|
||||
withIntegrationOverride(t, "pi", runner)
|
||||
|
||||
state, err := BuildLauncherState(context.Background())
|
||||
if err != nil {
|
||||
t.Fatalf("BuildLauncherState returned error: %v", err)
|
||||
}
|
||||
|
||||
if state.Integrations["pi"].CurrentModel != "gemma4" {
|
||||
t.Fatalf("expected managed current model from integration config, got %q", state.Integrations["pi"].CurrentModel)
|
||||
}
|
||||
if !state.Integrations["pi"].ModelUsable {
|
||||
t.Fatal("expected managed current model to be usable")
|
||||
}
|
||||
}
|
||||
|
||||
func TestBuildLauncherState_ManagedSingleIntegrationShowsSavedModelWhenLiveConfigMissing(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
setLaunchTestHome(t, tmpDir)
|
||||
|
||||
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
switch r.URL.Path {
|
||||
case "/api/tags":
|
||||
fmt.Fprint(w, `{"models":[{"name":"gemma4"}]}`)
|
||||
case "/api/show":
|
||||
fmt.Fprint(w, `{"model_info":{"general.context_length":131072}}`)
|
||||
default:
|
||||
http.NotFound(w, r)
|
||||
}
|
||||
}))
|
||||
defer srv.Close()
|
||||
t.Setenv("OLLAMA_HOST", srv.URL)
|
||||
|
||||
if err := config.SaveIntegration("pi", []string{"gemma4"}); err != nil {
|
||||
t.Fatalf("failed to save managed integration config: %v", err)
|
||||
}
|
||||
|
||||
runner := &launcherManagedRunner{}
|
||||
withIntegrationOverride(t, "pi", runner)
|
||||
|
||||
state, err := BuildLauncherState(context.Background())
|
||||
if err != nil {
|
||||
t.Fatalf("BuildLauncherState returned error: %v", err)
|
||||
}
|
||||
|
||||
if state.Integrations["pi"].CurrentModel != "gemma4" {
|
||||
t.Fatalf("expected saved model to remain visible, got %q", state.Integrations["pi"].CurrentModel)
|
||||
}
|
||||
if state.Integrations["pi"].ModelUsable {
|
||||
t.Fatal("expected missing live config to mark managed model unusable")
|
||||
}
|
||||
}
|
||||
|
||||
func TestLaunchIntegration_ManagedSingleIntegrationConfiguresOnboardsAndRuns(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
setLaunchTestHome(t, tmpDir)
|
||||
withInteractiveSession(t, true)
|
||||
withLauncherHooks(t)
|
||||
|
||||
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
switch r.URL.Path {
|
||||
case "/api/tags":
|
||||
fmt.Fprint(w, `{"models":[{"name":"gemma4"}]}`)
|
||||
case "/api/show":
|
||||
fmt.Fprint(w, `{"model_info":{"general.context_length":131072}}`)
|
||||
default:
|
||||
http.NotFound(w, r)
|
||||
}
|
||||
}))
|
||||
defer srv.Close()
|
||||
t.Setenv("OLLAMA_HOST", srv.URL)
|
||||
|
||||
runner := &launcherManagedRunner{
|
||||
paths: nil,
|
||||
}
|
||||
withIntegrationOverride(t, "stubmanaged", runner)
|
||||
|
||||
DefaultSingleSelector = func(title string, items []ModelItem, current string) (string, error) {
|
||||
return "gemma4", nil
|
||||
}
|
||||
DefaultConfirmPrompt = func(prompt string, options ConfirmOptions) (bool, error) {
|
||||
return true, nil
|
||||
}
|
||||
|
||||
if err := LaunchIntegration(context.Background(), IntegrationLaunchRequest{Name: "stubmanaged"}); err != nil {
|
||||
t.Fatalf("LaunchIntegration returned error: %v", err)
|
||||
}
|
||||
|
||||
if diff := compareStrings(runner.configured, []string{"gemma4"}); diff != "" {
|
||||
t.Fatalf("configured models mismatch: %s", diff)
|
||||
}
|
||||
if runner.refreshCalls != 1 {
|
||||
t.Fatalf("expected runtime refresh once after configure, got %d", runner.refreshCalls)
|
||||
}
|
||||
if runner.onboardCalls != 1 {
|
||||
t.Fatalf("expected onboarding to run once, got %d", runner.onboardCalls)
|
||||
}
|
||||
if runner.ranModel != "gemma4" {
|
||||
t.Fatalf("expected launch to run configured model, got %q", runner.ranModel)
|
||||
}
|
||||
|
||||
saved, err := config.LoadIntegration("stubmanaged")
|
||||
if err != nil {
|
||||
t.Fatalf("failed to reload managed integration config: %v", err)
|
||||
}
|
||||
if diff := compareStrings(saved.Models, []string{"gemma4"}); diff != "" {
|
||||
t.Fatalf("saved models mismatch: %s", diff)
|
||||
}
|
||||
}
|
||||
|
||||
func TestLaunchIntegration_ManagedSingleIntegrationReOnboardsWhenSavedFlagIsStale(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
setLaunchTestHome(t, tmpDir)
|
||||
withInteractiveSession(t, true)
|
||||
withLauncherHooks(t)
|
||||
|
||||
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
switch r.URL.Path {
|
||||
case "/api/tags":
|
||||
fmt.Fprint(w, `{"models":[{"name":"gemma4"}]}`)
|
||||
case "/api/show":
|
||||
fmt.Fprint(w, `{"model_info":{"general.context_length":131072}}`)
|
||||
default:
|
||||
http.NotFound(w, r)
|
||||
}
|
||||
}))
|
||||
defer srv.Close()
|
||||
t.Setenv("OLLAMA_HOST", srv.URL)
|
||||
|
||||
runner := &launcherManagedRunner{
|
||||
currentModel: "gemma4",
|
||||
onboardingComplete: false,
|
||||
}
|
||||
withIntegrationOverride(t, "stubmanaged", runner)
|
||||
|
||||
if err := config.SaveIntegration("stubmanaged", []string{"gemma4"}); err != nil {
|
||||
t.Fatalf("failed to save managed integration config: %v", err)
|
||||
}
|
||||
if err := config.MarkIntegrationOnboarded("stubmanaged"); err != nil {
|
||||
t.Fatalf("failed to mark managed integration onboarded: %v", err)
|
||||
}
|
||||
|
||||
if err := LaunchIntegration(context.Background(), IntegrationLaunchRequest{Name: "stubmanaged"}); err != nil {
|
||||
t.Fatalf("LaunchIntegration returned error: %v", err)
|
||||
}
|
||||
|
||||
if runner.onboardCalls != 1 {
|
||||
t.Fatalf("expected stale onboarded flag to trigger onboarding, got %d calls", runner.onboardCalls)
|
||||
}
|
||||
if runner.refreshCalls != 0 {
|
||||
t.Fatalf("expected no runtime refresh when config is unchanged, got %d", runner.refreshCalls)
|
||||
}
|
||||
if runner.ranModel != "gemma4" {
|
||||
t.Fatalf("expected launch to run saved model after onboarding repair, got %q", runner.ranModel)
|
||||
}
|
||||
}
|
||||
|
||||
func TestLaunchIntegration_ManagedSingleIntegrationConfigOnlySkipsFinalRun(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
setLaunchTestHome(t, tmpDir)
|
||||
withInteractiveSession(t, true)
|
||||
withLauncherHooks(t)
|
||||
|
||||
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
switch r.URL.Path {
|
||||
case "/api/show":
|
||||
fmt.Fprint(w, `{"model_info":{"general.context_length":131072}}`)
|
||||
default:
|
||||
http.NotFound(w, r)
|
||||
}
|
||||
}))
|
||||
defer srv.Close()
|
||||
t.Setenv("OLLAMA_HOST", srv.URL)
|
||||
|
||||
runner := &launcherManagedRunner{
|
||||
paths: nil,
|
||||
}
|
||||
withIntegrationOverride(t, "stubmanaged", runner)
|
||||
|
||||
DefaultConfirmPrompt = func(prompt string, options ConfirmOptions) (bool, error) {
|
||||
return true, nil
|
||||
}
|
||||
|
||||
if err := LaunchIntegration(context.Background(), IntegrationLaunchRequest{
|
||||
Name: "stubmanaged",
|
||||
ModelOverride: "gemma4",
|
||||
ConfigureOnly: true,
|
||||
}); err != nil {
|
||||
t.Fatalf("LaunchIntegration returned error: %v", err)
|
||||
}
|
||||
|
||||
if runner.ranModel != "" {
|
||||
t.Fatalf("expected configure-only flow to skip final launch, got %q", runner.ranModel)
|
||||
}
|
||||
if runner.refreshCalls != 1 {
|
||||
t.Fatalf("expected configure-only flow to refresh runtime once, got %d", runner.refreshCalls)
|
||||
}
|
||||
if runner.onboardCalls != 1 {
|
||||
t.Fatalf("expected configure-only flow to onboard once, got %d", runner.onboardCalls)
|
||||
}
|
||||
}
|
||||
|
||||
func TestLaunchIntegration_ManagedSingleIntegrationSkipsRewriteWhenSavedMatches(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
setLaunchTestHome(t, tmpDir)
|
||||
withInteractiveSession(t, true)
|
||||
withLauncherHooks(t)
|
||||
|
||||
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
switch r.URL.Path {
|
||||
case "/api/tags":
|
||||
fmt.Fprint(w, `{"models":[{"name":"gemma4"}]}`)
|
||||
case "/api/show":
|
||||
fmt.Fprint(w, `{"model_info":{"general.context_length":131072}}`)
|
||||
default:
|
||||
http.NotFound(w, r)
|
||||
}
|
||||
}))
|
||||
defer srv.Close()
|
||||
t.Setenv("OLLAMA_HOST", srv.URL)
|
||||
|
||||
if err := config.SaveIntegration("stubmanaged", []string{"gemma4"}); err != nil {
|
||||
t.Fatalf("failed to save managed integration config: %v", err)
|
||||
}
|
||||
|
||||
runner := &launcherManagedRunner{}
|
||||
withIntegrationOverride(t, "stubmanaged", runner)
|
||||
|
||||
DefaultSingleSelector = func(title string, items []ModelItem, current string) (string, error) {
|
||||
t.Fatal("selector should not be called when saved model matches target")
|
||||
return "", nil
|
||||
}
|
||||
DefaultConfirmPrompt = func(prompt string, options ConfirmOptions) (bool, error) {
|
||||
t.Fatal("confirm prompt should not run when saved model matches target")
|
||||
return false, nil
|
||||
}
|
||||
|
||||
if err := LaunchIntegration(context.Background(), IntegrationLaunchRequest{Name: "stubmanaged"}); err != nil {
|
||||
t.Fatalf("LaunchIntegration returned error: %v", err)
|
||||
}
|
||||
|
||||
if len(runner.configured) != 0 {
|
||||
t.Fatalf("expected Configure to be skipped when saved matches, got %v", runner.configured)
|
||||
}
|
||||
if runner.refreshCalls != 0 {
|
||||
t.Fatalf("expected no runtime refresh when config is unchanged, got %d", runner.refreshCalls)
|
||||
}
|
||||
if runner.ranModel != "gemma4" {
|
||||
t.Fatalf("expected launch to run saved model, got %q", runner.ranModel)
|
||||
}
|
||||
}
|
||||
|
||||
func TestLaunchIntegration_ManagedSingleIntegrationRewritesWhenSavedDiffers(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
setLaunchTestHome(t, tmpDir)
|
||||
withInteractiveSession(t, true)
|
||||
withLauncherHooks(t)
|
||||
|
||||
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
switch r.URL.Path {
|
||||
case "/api/tags":
|
||||
fmt.Fprint(w, `{"models":[{"name":"gemma4"}]}`)
|
||||
case "/api/show":
|
||||
fmt.Fprint(w, `{"model_info":{"general.context_length":131072}}`)
|
||||
default:
|
||||
http.NotFound(w, r)
|
||||
}
|
||||
}))
|
||||
defer srv.Close()
|
||||
t.Setenv("OLLAMA_HOST", srv.URL)
|
||||
|
||||
if err := config.SaveIntegration("stubmanaged", []string{"old-model"}); err != nil {
|
||||
t.Fatalf("failed to save managed integration config: %v", err)
|
||||
}
|
||||
|
||||
runner := &launcherManagedRunner{}
|
||||
withIntegrationOverride(t, "stubmanaged", runner)
|
||||
|
||||
DefaultSingleSelector = func(title string, items []ModelItem, current string) (string, error) {
|
||||
t.Fatal("selector should not be called when model override is provided")
|
||||
return "", nil
|
||||
}
|
||||
DefaultConfirmPrompt = func(prompt string, options ConfirmOptions) (bool, error) {
|
||||
return true, nil
|
||||
}
|
||||
|
||||
if err := LaunchIntegration(context.Background(), IntegrationLaunchRequest{
|
||||
Name: "stubmanaged",
|
||||
ModelOverride: "gemma4",
|
||||
}); err != nil {
|
||||
t.Fatalf("LaunchIntegration returned error: %v", err)
|
||||
}
|
||||
|
||||
if diff := compareStrings(runner.configured, []string{"gemma4"}); diff != "" {
|
||||
t.Fatalf("expected Configure to run when saved differs from target: %s", diff)
|
||||
}
|
||||
if runner.refreshCalls != 1 {
|
||||
t.Fatalf("expected runtime refresh once after configure, got %d", runner.refreshCalls)
|
||||
}
|
||||
if runner.ranModel != "gemma4" {
|
||||
t.Fatalf("expected launch to run configured model, got %q", runner.ranModel)
|
||||
}
|
||||
}
|
||||
|
||||
func TestLaunchIntegration_ManagedSingleIntegrationStopsWhenRuntimeRefreshFails(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
setLaunchTestHome(t, tmpDir)
|
||||
withInteractiveSession(t, true)
|
||||
withLauncherHooks(t)
|
||||
|
||||
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
switch r.URL.Path {
|
||||
case "/api/show":
|
||||
fmt.Fprint(w, `{"model_info":{"general.context_length":131072}}`)
|
||||
default:
|
||||
http.NotFound(w, r)
|
||||
}
|
||||
}))
|
||||
defer srv.Close()
|
||||
t.Setenv("OLLAMA_HOST", srv.URL)
|
||||
|
||||
runner := &launcherManagedRunner{
|
||||
refreshErr: fmt.Errorf("boom"),
|
||||
}
|
||||
withIntegrationOverride(t, "stubmanaged", runner)
|
||||
|
||||
DefaultConfirmPrompt = func(prompt string, options ConfirmOptions) (bool, error) {
|
||||
return true, nil
|
||||
}
|
||||
|
||||
err := LaunchIntegration(context.Background(), IntegrationLaunchRequest{
|
||||
Name: "stubmanaged",
|
||||
ModelOverride: "gemma4",
|
||||
})
|
||||
if err == nil || !strings.Contains(err.Error(), "boom") {
|
||||
t.Fatalf("expected runtime refresh error, got %v", err)
|
||||
}
|
||||
if runner.ranModel != "" {
|
||||
t.Fatalf("expected final launch to stop on runtime refresh failure, got %q", runner.ranModel)
|
||||
}
|
||||
if runner.refreshCalls != 1 {
|
||||
t.Fatalf("expected one runtime refresh attempt, got %d", runner.refreshCalls)
|
||||
}
|
||||
if runner.onboardCalls != 0 {
|
||||
t.Fatalf("expected onboarding to stop after refresh failure, got %d", runner.onboardCalls)
|
||||
}
|
||||
}
|
||||
|
||||
func TestLaunchIntegration_ManagedSingleIntegrationHeadlessNeedsInteractiveOnboarding(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
setLaunchTestHome(t, tmpDir)
|
||||
withInteractiveSession(t, false)
|
||||
withLauncherHooks(t)
|
||||
|
||||
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
switch r.URL.Path {
|
||||
case "/api/show":
|
||||
fmt.Fprint(w, `{"model_info":{"general.context_length":131072}}`)
|
||||
default:
|
||||
http.NotFound(w, r)
|
||||
}
|
||||
}))
|
||||
defer srv.Close()
|
||||
t.Setenv("OLLAMA_HOST", srv.URL)
|
||||
|
||||
runner := &launcherManagedRunner{
|
||||
paths: nil,
|
||||
}
|
||||
withIntegrationOverride(t, "stubmanaged", runner)
|
||||
|
||||
err := LaunchIntegration(context.Background(), IntegrationLaunchRequest{
|
||||
Name: "stubmanaged",
|
||||
ModelOverride: "gemma4",
|
||||
Policy: &LaunchPolicy{Confirm: LaunchConfirmAutoApprove, MissingModel: LaunchMissingModelAutoPull},
|
||||
})
|
||||
if err == nil {
|
||||
t.Fatal("expected headless onboarding requirement to fail")
|
||||
}
|
||||
if !strings.Contains(err.Error(), "interactive gateway setup") {
|
||||
t.Fatalf("expected interactive onboarding guidance, got %v", err)
|
||||
}
|
||||
if runner.ranModel != "" {
|
||||
t.Fatalf("expected no final launch when onboarding is still required, got %q", runner.ranModel)
|
||||
}
|
||||
if runner.onboardCalls != 0 {
|
||||
t.Fatalf("expected no onboarding attempts in headless mode, got %d", runner.onboardCalls)
|
||||
}
|
||||
}
|
||||
|
||||
func TestLaunchIntegration_ManagedSingleIntegrationHeadlessAllowsNonInteractiveOnboarding(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
setLaunchTestHome(t, tmpDir)
|
||||
withInteractiveSession(t, false)
|
||||
withLauncherHooks(t)
|
||||
|
||||
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
switch r.URL.Path {
|
||||
case "/api/show":
|
||||
fmt.Fprint(w, `{"model_info":{"general.context_length":131072}}`)
|
||||
default:
|
||||
http.NotFound(w, r)
|
||||
}
|
||||
}))
|
||||
defer srv.Close()
|
||||
t.Setenv("OLLAMA_HOST", srv.URL)
|
||||
|
||||
runner := &launcherHeadlessManagedRunner{}
|
||||
withIntegrationOverride(t, "stubmanaged", runner)
|
||||
|
||||
err := LaunchIntegration(context.Background(), IntegrationLaunchRequest{
|
||||
Name: "stubmanaged",
|
||||
ModelOverride: "gemma4",
|
||||
Policy: &LaunchPolicy{Confirm: LaunchConfirmAutoApprove, MissingModel: LaunchMissingModelAutoPull},
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("expected non-interactive onboarding to succeed headlessly, got %v", err)
|
||||
}
|
||||
if diff := compareStrings(runner.configured, []string{"gemma4"}); diff != "" {
|
||||
t.Fatalf("configured models mismatch: %s", diff)
|
||||
}
|
||||
if runner.onboardCalls != 1 {
|
||||
t.Fatalf("expected onboarding to run once, got %d", runner.onboardCalls)
|
||||
}
|
||||
if runner.ranModel != "gemma4" {
|
||||
t.Fatalf("expected launch to run configured model, got %q", runner.ranModel)
|
||||
}
|
||||
}
|
||||
|
||||
func TestBuildLauncherState_InstalledAndCloudDisabled(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
setLaunchTestHome(t, tmpDir)
|
||||
|
||||
@@ -230,7 +230,7 @@ func pullMissingModel(ctx context.Context, client *api.Client, model string) err
|
||||
|
||||
// prepareEditorIntegration persists models and applies editor-managed config files.
|
||||
func prepareEditorIntegration(name string, runner Runner, editor Editor, models []string) error {
|
||||
if ok, err := confirmEditorEdit(runner, editor); err != nil {
|
||||
if ok, err := confirmConfigEdit(runner, editor.Paths()); err != nil {
|
||||
return err
|
||||
} else if !ok {
|
||||
return errCancelled
|
||||
@@ -244,8 +244,22 @@ func prepareEditorIntegration(name string, runner Runner, editor Editor, models
|
||||
return nil
|
||||
}
|
||||
|
||||
func confirmEditorEdit(runner Runner, editor Editor) (bool, error) {
|
||||
paths := editor.Paths()
|
||||
func prepareManagedSingleIntegration(name string, runner Runner, managed ManagedSingleModel, model string) error {
|
||||
if ok, err := confirmConfigEdit(runner, managed.Paths()); err != nil {
|
||||
return err
|
||||
} else if !ok {
|
||||
return errCancelled
|
||||
}
|
||||
if err := managed.Configure(model); err != nil {
|
||||
return fmt.Errorf("setup failed: %w", err)
|
||||
}
|
||||
if err := config.SaveIntegration(name, []string{model}); err != nil {
|
||||
return fmt.Errorf("failed to save: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func confirmConfigEdit(runner Runner, paths []string) (bool, error) {
|
||||
if len(paths) == 0 {
|
||||
return true, nil
|
||||
}
|
||||
@@ -345,8 +359,6 @@ func buildModelList(existing []modelInfo, preChecked []string, current string) (
|
||||
recRank[rec.Name] = i + 1
|
||||
}
|
||||
|
||||
onlyLocal := hasLocalModel && !hasCloudModel
|
||||
|
||||
if hasLocalModel || hasCloudModel {
|
||||
slices.SortStableFunc(items, func(a, b ModelItem) int {
|
||||
ac, bc := checked[a.Name], checked[b.Name]
|
||||
@@ -368,12 +380,6 @@ func buildModelList(existing []modelInfo, preChecked []string, current string) (
|
||||
}
|
||||
if aRec && bRec {
|
||||
if aCloud != bCloud {
|
||||
if onlyLocal {
|
||||
if aCloud {
|
||||
return 1
|
||||
}
|
||||
return -1
|
||||
}
|
||||
if aCloud {
|
||||
return -1
|
||||
}
|
||||
|
||||
@@ -186,6 +186,11 @@ func (c *Openclaw) runChannelSetupPreflight(bin string) error {
|
||||
if !isInteractiveSession() {
|
||||
return nil
|
||||
}
|
||||
// --yes is headless; channel setup spawns an interactive picker we can't
|
||||
// auto-answer, so skip it. Users can run `openclaw channels add` later.
|
||||
if currentLaunchConfirmPolicy.yes {
|
||||
return nil
|
||||
}
|
||||
|
||||
for {
|
||||
if c.channelsConfigured() {
|
||||
|
||||
@@ -1304,6 +1304,46 @@ func TestOpenclawChannelSetupPreflight(t *testing.T) {
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("--yes skips preflight without channels configured", func(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
setTestHome(t, tmpDir)
|
||||
t.Setenv("PATH", tmpDir)
|
||||
configDir := filepath.Join(tmpDir, ".openclaw")
|
||||
if err := os.MkdirAll(configDir, 0o755); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
// Empty config = no channels configured. Without the --yes skip, the
|
||||
// preflight would prompt and (on confirm) spawn `openclaw channels add`.
|
||||
if err := os.WriteFile(filepath.Join(configDir, "openclaw.json"), []byte(`{}`), 0o644); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
bin := filepath.Join(tmpDir, "openclaw")
|
||||
if err := os.WriteFile(bin, []byte("#!/bin/sh\nprintf '%s\\n' \"$*\" >> \"$HOME/invocations.log\"\n"), 0o755); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
oldInteractive := isInteractiveSession
|
||||
isInteractiveSession = func() bool { return true }
|
||||
defer func() { isInteractiveSession = oldInteractive }()
|
||||
|
||||
restore := withLaunchConfirmPolicy(launchConfirmPolicy{yes: true})
|
||||
defer restore()
|
||||
|
||||
oldConfirmPrompt := DefaultConfirmPrompt
|
||||
DefaultConfirmPrompt = func(prompt string, options ConfirmOptions) (bool, error) {
|
||||
t.Fatalf("did not expect prompt in --yes mode: %s", prompt)
|
||||
return false, nil
|
||||
}
|
||||
defer func() { DefaultConfirmPrompt = oldConfirmPrompt }()
|
||||
|
||||
if err := c.runChannelSetupPreflight("openclaw"); err != nil {
|
||||
t.Fatalf("runChannelSetupPreflight() error = %v", err)
|
||||
}
|
||||
if _, err := os.Stat(filepath.Join(tmpDir, "invocations.log")); !os.IsNotExist(err) {
|
||||
t.Fatalf("expected no channels add invocation in --yes mode, got err=%v", err)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("set up later prompts once and exits", func(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
setTestHome(t, tmpDir)
|
||||
|
||||
@@ -3,20 +3,22 @@ package launch
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"maps"
|
||||
"os"
|
||||
"os/exec"
|
||||
"path/filepath"
|
||||
"runtime"
|
||||
"slices"
|
||||
"strings"
|
||||
|
||||
"github.com/ollama/ollama/cmd/internal/fileutil"
|
||||
"github.com/ollama/ollama/envconfig"
|
||||
)
|
||||
|
||||
// OpenCode implements Runner and Editor for OpenCode integration
|
||||
type OpenCode struct{}
|
||||
// OpenCode implements Runner and Editor for OpenCode integration.
|
||||
// Config is passed via OPENCODE_CONFIG_CONTENT env var at launch time
|
||||
// instead of writing to opencode's config files.
|
||||
type OpenCode struct {
|
||||
configContent string // JSON config built by Edit, passed to Run via env var
|
||||
}
|
||||
|
||||
func (o *OpenCode) String() string { return "OpenCode" }
|
||||
|
||||
@@ -51,25 +53,51 @@ func (o *OpenCode) Run(model string, args []string) error {
|
||||
cmd.Stdin = os.Stdin
|
||||
cmd.Stdout = os.Stdout
|
||||
cmd.Stderr = os.Stderr
|
||||
cmd.Env = os.Environ()
|
||||
if content := o.resolveContent(model); content != "" {
|
||||
cmd.Env = append(cmd.Env, "OPENCODE_CONFIG_CONTENT="+content)
|
||||
}
|
||||
return cmd.Run()
|
||||
}
|
||||
|
||||
// resolveContent returns the inline config to send via OPENCODE_CONFIG_CONTENT.
|
||||
// Returns content built by Edit if available, otherwise builds from model.json
|
||||
// with the requested model as primary (e.g. re-launch with saved config).
|
||||
func (o *OpenCode) resolveContent(model string) string {
|
||||
if o.configContent != "" {
|
||||
return o.configContent
|
||||
}
|
||||
models := readModelJSONModels()
|
||||
if !slices.Contains(models, model) {
|
||||
models = append([]string{model}, models...)
|
||||
}
|
||||
content, err := buildInlineConfig(model, models)
|
||||
if err != nil {
|
||||
return ""
|
||||
}
|
||||
return content
|
||||
}
|
||||
|
||||
func (o *OpenCode) Paths() []string {
|
||||
home, err := os.UserHomeDir()
|
||||
sp, err := openCodeStatePath()
|
||||
if err != nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
var paths []string
|
||||
p := filepath.Join(home, ".config", "opencode", "opencode.json")
|
||||
if _, err := os.Stat(p); err == nil {
|
||||
paths = append(paths, p)
|
||||
}
|
||||
sp := filepath.Join(home, ".local", "state", "opencode", "model.json")
|
||||
if _, err := os.Stat(sp); err == nil {
|
||||
paths = append(paths, sp)
|
||||
return []string{sp}
|
||||
}
|
||||
return paths
|
||||
return nil
|
||||
}
|
||||
|
||||
// openCodeStatePath returns the path to opencode's model state file.
|
||||
// TODO: this hardcodes the Linux/macOS XDG path. On Windows, opencode stores
|
||||
// state under %LOCALAPPDATA% (or similar) — verify and branch on runtime.GOOS.
|
||||
func openCodeStatePath() (string, error) {
|
||||
home, err := os.UserHomeDir()
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
return filepath.Join(home, ".local", "state", "opencode", "model.json"), nil
|
||||
}
|
||||
|
||||
func (o *OpenCode) Edit(modelList []string) error {
|
||||
@@ -77,110 +105,17 @@ func (o *OpenCode) Edit(modelList []string) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
home, err := os.UserHomeDir()
|
||||
content, err := buildInlineConfig(modelList[0], modelList)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
o.configContent = content
|
||||
|
||||
configPath := filepath.Join(home, ".config", "opencode", "opencode.json")
|
||||
if err := os.MkdirAll(filepath.Dir(configPath), 0o755); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
config := make(map[string]any)
|
||||
if data, err := os.ReadFile(configPath); err == nil {
|
||||
_ = json.Unmarshal(data, &config) // Ignore parse errors; treat missing/corrupt files as empty
|
||||
}
|
||||
|
||||
config["$schema"] = "https://opencode.ai/config.json"
|
||||
|
||||
provider, ok := config["provider"].(map[string]any)
|
||||
if !ok {
|
||||
provider = make(map[string]any)
|
||||
}
|
||||
|
||||
ollama, ok := provider["ollama"].(map[string]any)
|
||||
if !ok {
|
||||
ollama = map[string]any{
|
||||
"npm": "@ai-sdk/openai-compatible",
|
||||
"name": "Ollama",
|
||||
"options": map[string]any{
|
||||
"baseURL": envconfig.Host().String() + "/v1",
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
// Migrate legacy provider name
|
||||
if name, _ := ollama["name"].(string); name == "Ollama (local)" {
|
||||
ollama["name"] = "Ollama"
|
||||
}
|
||||
|
||||
models, ok := ollama["models"].(map[string]any)
|
||||
if !ok {
|
||||
models = make(map[string]any)
|
||||
}
|
||||
|
||||
selectedSet := make(map[string]bool)
|
||||
for _, m := range modelList {
|
||||
selectedSet[m] = true
|
||||
}
|
||||
|
||||
for name, cfg := range models {
|
||||
if cfgMap, ok := cfg.(map[string]any); ok {
|
||||
if isOllamaModel(cfgMap) && !selectedSet[name] {
|
||||
delete(models, name)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
for _, model := range modelList {
|
||||
if existing, ok := models[model].(map[string]any); ok {
|
||||
// migrate existing models without _launch marker
|
||||
if isOllamaModel(existing) {
|
||||
existing["_launch"] = true
|
||||
if name, ok := existing["name"].(string); ok {
|
||||
existing["name"] = strings.TrimSuffix(name, " [Ollama]")
|
||||
}
|
||||
}
|
||||
if isCloudModelName(model) {
|
||||
if l, ok := lookupCloudModelLimit(model); ok {
|
||||
existing["limit"] = map[string]any{
|
||||
"context": l.Context,
|
||||
"output": l.Output,
|
||||
}
|
||||
}
|
||||
}
|
||||
continue
|
||||
}
|
||||
entry := map[string]any{
|
||||
"name": model,
|
||||
"_launch": true,
|
||||
}
|
||||
if isCloudModelName(model) {
|
||||
if l, ok := lookupCloudModelLimit(model); ok {
|
||||
entry["limit"] = map[string]any{
|
||||
"context": l.Context,
|
||||
"output": l.Output,
|
||||
}
|
||||
}
|
||||
}
|
||||
models[model] = entry
|
||||
}
|
||||
|
||||
ollama["models"] = models
|
||||
provider["ollama"] = ollama
|
||||
config["provider"] = provider
|
||||
config["model"] = "ollama/" + modelList[0]
|
||||
|
||||
configData, err := json.MarshalIndent(config, "", " ")
|
||||
// Write model state file so models appear in OpenCode's model picker
|
||||
statePath, err := openCodeStatePath()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if err := fileutil.WriteWithBackup(configPath, configData); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
statePath := filepath.Join(home, ".local", "state", "opencode", "model.json")
|
||||
if err := os.MkdirAll(filepath.Dir(statePath), 0o755); err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -232,33 +167,82 @@ func (o *OpenCode) Edit(modelList []string) error {
|
||||
}
|
||||
|
||||
func (o *OpenCode) Models() []string {
|
||||
home, err := os.UserHomeDir()
|
||||
if err != nil {
|
||||
return nil
|
||||
}
|
||||
config, err := fileutil.ReadJSON(filepath.Join(home, ".config", "opencode", "opencode.json"))
|
||||
if err != nil {
|
||||
return nil
|
||||
}
|
||||
provider, _ := config["provider"].(map[string]any)
|
||||
ollama, _ := provider["ollama"].(map[string]any)
|
||||
models, _ := ollama["models"].(map[string]any)
|
||||
if len(models) == 0 {
|
||||
return nil
|
||||
}
|
||||
keys := slices.Collect(maps.Keys(models))
|
||||
slices.Sort(keys)
|
||||
return keys
|
||||
return nil
|
||||
}
|
||||
|
||||
// isOllamaModel reports whether a model config entry is managed by us
|
||||
func isOllamaModel(cfg map[string]any) bool {
|
||||
if v, ok := cfg["_launch"].(bool); ok && v {
|
||||
return true
|
||||
// buildInlineConfig produces the JSON string for OPENCODE_CONFIG_CONTENT.
|
||||
// primary is the model to launch with, models is the full list of available models.
|
||||
func buildInlineConfig(primary string, models []string) (string, error) {
|
||||
if primary == "" || len(models) == 0 {
|
||||
return "", fmt.Errorf("buildInlineConfig: primary and models are required")
|
||||
}
|
||||
// previously used [Ollama] as a suffix for the model managed by ollama launch
|
||||
if name, ok := cfg["name"].(string); ok {
|
||||
return strings.HasSuffix(name, "[Ollama]")
|
||||
config := map[string]any{
|
||||
"$schema": "https://opencode.ai/config.json",
|
||||
"provider": map[string]any{
|
||||
"ollama": map[string]any{
|
||||
"npm": "@ai-sdk/openai-compatible",
|
||||
"name": "Ollama",
|
||||
"options": map[string]any{
|
||||
"baseURL": envconfig.Host().String() + "/v1",
|
||||
},
|
||||
"models": buildModelEntries(models),
|
||||
},
|
||||
},
|
||||
"model": "ollama/" + primary,
|
||||
}
|
||||
return false
|
||||
data, err := json.Marshal(config)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
return string(data), nil
|
||||
}
|
||||
|
||||
// readModelJSONModels reads ollama model IDs from the opencode model.json state file
|
||||
func readModelJSONModels() []string {
|
||||
statePath, err := openCodeStatePath()
|
||||
if err != nil {
|
||||
return nil
|
||||
}
|
||||
data, err := os.ReadFile(statePath)
|
||||
if err != nil {
|
||||
return nil
|
||||
}
|
||||
var state map[string]any
|
||||
if err := json.Unmarshal(data, &state); err != nil {
|
||||
return nil
|
||||
}
|
||||
recent, _ := state["recent"].([]any)
|
||||
var models []string
|
||||
for _, entry := range recent {
|
||||
e, ok := entry.(map[string]any)
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
if e["providerID"] != "ollama" {
|
||||
continue
|
||||
}
|
||||
if id, ok := e["modelID"].(string); ok && id != "" {
|
||||
models = append(models, id)
|
||||
}
|
||||
}
|
||||
return models
|
||||
}
|
||||
|
||||
func buildModelEntries(modelList []string) map[string]any {
|
||||
models := make(map[string]any)
|
||||
for _, model := range modelList {
|
||||
entry := map[string]any{
|
||||
"name": model,
|
||||
}
|
||||
if isCloudModelName(model) {
|
||||
if l, ok := lookupCloudModelLimit(model); ok {
|
||||
entry["limit"] = map[string]any{
|
||||
"context": l.Context,
|
||||
"output": l.Output,
|
||||
}
|
||||
}
|
||||
}
|
||||
models[model] = entry
|
||||
}
|
||||
return models
|
||||
}
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -33,7 +33,7 @@ type IntegrationInfo struct {
|
||||
Description string
|
||||
}
|
||||
|
||||
var launcherIntegrationOrder = []string{"opencode", "droid", "pi"}
|
||||
var launcherIntegrationOrder = []string{"openclaw", "claude", "opencode", "hermes", "codex", "copilot", "droid", "pi"}
|
||||
|
||||
var integrationSpecs = []*IntegrationSpec{
|
||||
{
|
||||
@@ -74,6 +74,19 @@ var integrationSpecs = []*IntegrationSpec{
|
||||
Command: []string{"npm", "install", "-g", "@openai/codex"},
|
||||
},
|
||||
},
|
||||
{
|
||||
Name: "copilot",
|
||||
Runner: &Copilot{},
|
||||
Aliases: []string{"copilot-cli"},
|
||||
Description: "GitHub's AI coding agent for the terminal",
|
||||
Install: IntegrationInstallSpec{
|
||||
CheckInstalled: func() bool {
|
||||
_, err := (&Copilot{}).findPath()
|
||||
return err == nil
|
||||
},
|
||||
URL: "https://github.com/features/copilot/cli/",
|
||||
},
|
||||
},
|
||||
{
|
||||
Name: "droid",
|
||||
Runner: &Droid{},
|
||||
@@ -136,6 +149,20 @@ var integrationSpecs = []*IntegrationSpec{
|
||||
Command: []string{"npm", "install", "-g", "@mariozechner/pi-coding-agent@latest"},
|
||||
},
|
||||
},
|
||||
{
|
||||
Name: "hermes",
|
||||
Runner: &Hermes{},
|
||||
Description: "Self-improving AI agent built by Nous Research",
|
||||
Install: IntegrationInstallSpec{
|
||||
CheckInstalled: func() bool {
|
||||
return (&Hermes{}).installed()
|
||||
},
|
||||
EnsureInstalled: func() error {
|
||||
return (&Hermes{}).ensureInstalled()
|
||||
},
|
||||
URL: "https://hermes-agent.nousresearch.com/docs/getting-started/installation/",
|
||||
},
|
||||
},
|
||||
{
|
||||
Name: "vscode",
|
||||
Runner: &VSCode{},
|
||||
@@ -255,10 +282,10 @@ func ListVisibleIntegrationSpecs() []IntegrationSpec {
|
||||
return aRank - bRank
|
||||
}
|
||||
if aRank > 0 {
|
||||
return 1
|
||||
return -1
|
||||
}
|
||||
if bRank > 0 {
|
||||
return -1
|
||||
return 1
|
||||
}
|
||||
return strings.Compare(a.Name, b.Name)
|
||||
})
|
||||
|
||||
@@ -26,7 +26,7 @@ func TestEditorRunsDoNotRewriteConfig(t *testing.T) {
|
||||
binary: "opencode",
|
||||
runner: &OpenCode{},
|
||||
checkPath: func(home string) string {
|
||||
return filepath.Join(home, ".config", "opencode", "opencode.json")
|
||||
return filepath.Join(home, ".local", "state", "opencode", "model.json")
|
||||
},
|
||||
},
|
||||
{
|
||||
|
||||
@@ -45,21 +45,12 @@ type menuItem struct {
|
||||
isOthers bool
|
||||
}
|
||||
|
||||
var mainMenuItems = []menuItem{
|
||||
{
|
||||
title: "Chat with a model",
|
||||
description: "Start an interactive chat with a model",
|
||||
isRunModel: true,
|
||||
},
|
||||
{
|
||||
integration: "openclaw",
|
||||
},
|
||||
{
|
||||
integration: "claude",
|
||||
},
|
||||
{
|
||||
integration: "opencode",
|
||||
},
|
||||
const pinnedIntegrationCount = 3
|
||||
|
||||
var runModelMenuItem = menuItem{
|
||||
title: "Chat with a model",
|
||||
description: "Start an interactive chat with a model",
|
||||
isRunModel: true,
|
||||
}
|
||||
|
||||
var othersMenuItem = menuItem{
|
||||
@@ -102,20 +93,14 @@ func shouldExpandOthers(state *launch.LauncherState) bool {
|
||||
}
|
||||
|
||||
func buildMenuItems(state *launch.LauncherState, showOthers bool) []menuItem {
|
||||
items := make([]menuItem, 0, len(mainMenuItems)+1)
|
||||
for _, item := range mainMenuItems {
|
||||
if item.integration == "" {
|
||||
items = append(items, item)
|
||||
continue
|
||||
}
|
||||
if integrationState, ok := state.Integrations[item.integration]; ok {
|
||||
items = append(items, integrationMenuItem(integrationState))
|
||||
}
|
||||
}
|
||||
items := []menuItem{runModelMenuItem}
|
||||
items = append(items, pinnedIntegrationItems(state)...)
|
||||
|
||||
if showOthers {
|
||||
items = append(items, otherIntegrationItems(state)...)
|
||||
} else {
|
||||
otherItems := otherIntegrationItems(state)
|
||||
switch {
|
||||
case showOthers:
|
||||
items = append(items, otherItems...)
|
||||
case len(otherItems) > 0:
|
||||
items = append(items, othersMenuItem)
|
||||
}
|
||||
|
||||
@@ -135,17 +120,28 @@ func integrationMenuItem(state launch.LauncherIntegrationState) menuItem {
|
||||
}
|
||||
|
||||
func otherIntegrationItems(state *launch.LauncherState) []menuItem {
|
||||
pinned := map[string]bool{
|
||||
"openclaw": true,
|
||||
"claude": true,
|
||||
"opencode": true,
|
||||
ordered := orderedIntegrationItems(state)
|
||||
if len(ordered) <= pinnedIntegrationCount {
|
||||
return nil
|
||||
}
|
||||
return ordered[pinnedIntegrationCount:]
|
||||
}
|
||||
|
||||
func pinnedIntegrationItems(state *launch.LauncherState) []menuItem {
|
||||
ordered := orderedIntegrationItems(state)
|
||||
if len(ordered) <= pinnedIntegrationCount {
|
||||
return ordered
|
||||
}
|
||||
return ordered[:pinnedIntegrationCount]
|
||||
}
|
||||
|
||||
func orderedIntegrationItems(state *launch.LauncherState) []menuItem {
|
||||
if state == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
var items []menuItem
|
||||
items := make([]menuItem, 0, len(state.Integrations))
|
||||
for _, info := range launch.ListIntegrationInfos() {
|
||||
if pinned[info.Name] {
|
||||
continue
|
||||
}
|
||||
integrationState, ok := state.Integrations[info.Name]
|
||||
if !ok {
|
||||
continue
|
||||
@@ -155,6 +151,10 @@ func otherIntegrationItems(state *launch.LauncherState) []menuItem {
|
||||
return items
|
||||
}
|
||||
|
||||
func primaryMenuItemCount(state *launch.LauncherState) int {
|
||||
return 1 + len(pinnedIntegrationItems(state))
|
||||
}
|
||||
|
||||
func initialCursor(state *launch.LauncherState, items []menuItem) int {
|
||||
if state == nil || state.LastSelection == "" {
|
||||
return 0
|
||||
@@ -190,7 +190,7 @@ func (m model) Update(msg tea.Msg) (tea.Model, tea.Cmd) {
|
||||
if m.cursor > 0 {
|
||||
m.cursor--
|
||||
}
|
||||
if m.showOthers && m.cursor < len(mainMenuItems) {
|
||||
if m.showOthers && m.cursor < primaryMenuItemCount(m.state) {
|
||||
m.showOthers = false
|
||||
m.items = buildMenuItems(m.state, false)
|
||||
m.cursor = min(m.cursor, len(m.items)-1)
|
||||
|
||||
@@ -5,6 +5,7 @@ import (
|
||||
"testing"
|
||||
|
||||
tea "github.com/charmbracelet/bubbletea"
|
||||
"github.com/google/go-cmp/cmp"
|
||||
"github.com/ollama/ollama/cmd/launch"
|
||||
)
|
||||
|
||||
@@ -43,6 +44,13 @@ func launcherTestState() *launch.LauncherState {
|
||||
Selectable: true,
|
||||
Changeable: true,
|
||||
},
|
||||
"hermes": {
|
||||
Name: "hermes",
|
||||
DisplayName: "Hermes Agent",
|
||||
Description: "Self-improving AI agent built by Nous Research",
|
||||
Selectable: true,
|
||||
Changeable: true,
|
||||
},
|
||||
"droid": {
|
||||
Name: "droid",
|
||||
DisplayName: "Droid",
|
||||
@@ -70,8 +78,28 @@ func findMenuCursorByIntegration(items []menuItem, name string) int {
|
||||
return -1
|
||||
}
|
||||
|
||||
func integrationSequence(items []menuItem) []string {
|
||||
sequence := make([]string, 0, len(items))
|
||||
for _, item := range items {
|
||||
switch {
|
||||
case item.isRunModel:
|
||||
sequence = append(sequence, "run")
|
||||
case item.isOthers:
|
||||
sequence = append(sequence, "more")
|
||||
case item.integration != "":
|
||||
sequence = append(sequence, item.integration)
|
||||
}
|
||||
}
|
||||
return sequence
|
||||
}
|
||||
|
||||
func compareStrings(got, want []string) string {
|
||||
return cmp.Diff(want, got)
|
||||
}
|
||||
|
||||
func TestMenuRendersPinnedItemsAndMore(t *testing.T) {
|
||||
view := newModel(launcherTestState()).View()
|
||||
menu := newModel(launcherTestState())
|
||||
view := menu.View()
|
||||
for _, want := range []string{"Chat with a model", "Launch OpenClaw", "Launch Claude Code", "Launch OpenCode", "More..."} {
|
||||
if !strings.Contains(view, want) {
|
||||
t.Fatalf("expected menu view to contain %q\n%s", want, view)
|
||||
@@ -80,23 +108,31 @@ func TestMenuRendersPinnedItemsAndMore(t *testing.T) {
|
||||
if strings.Contains(view, "Launch Codex") {
|
||||
t.Fatalf("expected Codex to be under More, not pinned\n%s", view)
|
||||
}
|
||||
wantOrder := []string{"run", "openclaw", "claude", "opencode", "more"}
|
||||
if diff := compareStrings(integrationSequence(menu.items), wantOrder); diff != "" {
|
||||
t.Fatalf("unexpected pinned order: %s", diff)
|
||||
}
|
||||
}
|
||||
|
||||
func TestMenuExpandsOthersFromLastSelection(t *testing.T) {
|
||||
state := launcherTestState()
|
||||
state.LastSelection = "pi"
|
||||
state.LastSelection = "codex"
|
||||
|
||||
menu := newModel(state)
|
||||
if !menu.showOthers {
|
||||
t.Fatal("expected others section to expand when last selection is in the overflow list")
|
||||
}
|
||||
view := menu.View()
|
||||
if !strings.Contains(view, "Launch Pi") {
|
||||
if !strings.Contains(view, "Launch Codex") {
|
||||
t.Fatalf("expected expanded view to contain overflow integration\n%s", view)
|
||||
}
|
||||
if strings.Contains(view, "More...") {
|
||||
t.Fatalf("expected expanded view to replace More... item\n%s", view)
|
||||
}
|
||||
wantOrder := []string{"run", "openclaw", "claude", "opencode", "hermes", "codex", "droid", "pi"}
|
||||
if diff := compareStrings(integrationSequence(menu.items), wantOrder); diff != "" {
|
||||
t.Fatalf("unexpected expanded order: %s", diff)
|
||||
}
|
||||
}
|
||||
|
||||
func TestMenuEnterOnRunSelectsRun(t *testing.T) {
|
||||
|
||||
@@ -120,6 +120,7 @@
|
||||
"pages": [
|
||||
"/integrations/claude-code",
|
||||
"/integrations/codex",
|
||||
"/integrations/copilot-cli",
|
||||
"/integrations/opencode",
|
||||
"/integrations/droid",
|
||||
"/integrations/goose",
|
||||
|
||||
BIN
docs/images/hermes.png
Normal file
BIN
docs/images/hermes.png
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 1.4 MiB |
93
docs/integrations/copilot-cli.mdx
Normal file
93
docs/integrations/copilot-cli.mdx
Normal file
@@ -0,0 +1,93 @@
|
||||
---
|
||||
title: Copilot CLI
|
||||
---
|
||||
|
||||
GitHub Copilot CLI is GitHub's AI coding agent for the terminal. It can understand your codebase, make edits, run commands, and help you build software faster.
|
||||
|
||||
Open models can be used with Copilot CLI through Ollama, enabling you to use models such as `qwen3.5`, `glm-5.1:cloud`, `kimi-k2.5:cloud`.
|
||||
|
||||
## Install
|
||||
|
||||
Install [Copilot CLI](https://github.com/features/copilot/cli/):
|
||||
|
||||
<CodeGroup>
|
||||
|
||||
```shell macOS / Linux (Homebrew)
|
||||
brew install copilot-cli
|
||||
```
|
||||
|
||||
```shell npm (all platforms)
|
||||
npm install -g @github/copilot
|
||||
```
|
||||
|
||||
```shell macOS / Linux (script)
|
||||
curl -fsSL https://gh.io/copilot-install | bash
|
||||
```
|
||||
|
||||
```powershell Windows (WinGet)
|
||||
winget install GitHub.Copilot
|
||||
```
|
||||
|
||||
</CodeGroup>
|
||||
|
||||
## Usage with Ollama
|
||||
|
||||
### Quick setup
|
||||
|
||||
```shell
|
||||
ollama launch copilot
|
||||
```
|
||||
|
||||
### Run directly with a model
|
||||
|
||||
```shell
|
||||
ollama launch copilot --model kimi-k2.5:cloud
|
||||
```
|
||||
|
||||
## Recommended Models
|
||||
|
||||
- `kimi-k2.5:cloud`
|
||||
- `glm-5:cloud`
|
||||
- `minimax-m2.7:cloud`
|
||||
- `qwen3.5:cloud`
|
||||
- `glm-4.7-flash`
|
||||
- `qwen3.5`
|
||||
|
||||
Cloud models are also available at [ollama.com/search?c=cloud](https://ollama.com/search?c=cloud).
|
||||
|
||||
## Non-interactive (headless) mode
|
||||
|
||||
Run Copilot CLI without interaction for use in Docker, CI/CD, or scripts:
|
||||
|
||||
```shell
|
||||
ollama launch copilot --model kimi-k2.5:cloud --yes -- -p "how does this repository work?"
|
||||
```
|
||||
|
||||
The `--yes` flag auto-pulls the model, skips selectors, and requires `--model` to be specified. Arguments after `--` are passed directly to Copilot CLI.
|
||||
|
||||
## Manual setup
|
||||
|
||||
Copilot CLI connects to Ollama using the OpenAI-compatible API via environment variables.
|
||||
|
||||
1. Set the environment variables:
|
||||
|
||||
```shell
|
||||
export COPILOT_PROVIDER_BASE_URL=http://localhost:11434/v1
|
||||
export COPILOT_PROVIDER_API_KEY=
|
||||
export COPILOT_PROVIDER_WIRE_API=responses
|
||||
export COPILOT_MODEL=qwen3.5
|
||||
```
|
||||
|
||||
1. Run Copilot CLI:
|
||||
|
||||
```shell
|
||||
copilot
|
||||
```
|
||||
|
||||
Or run with environment variables inline:
|
||||
|
||||
```shell
|
||||
COPILOT_PROVIDER_BASE_URL=http://localhost:11434/v1 COPILOT_PROVIDER_API_KEY= COPILOT_PROVIDER_WIRE_API=responses COPILOT_MODEL=glm-5:cloud copilot
|
||||
```
|
||||
|
||||
**Note:** Copilot requires a large context window. We recommend at least 64k tokens. See the [context length documentation](/context-length) for how to adjust context length in Ollama.
|
||||
@@ -2,29 +2,66 @@
|
||||
title: Hermes Agent
|
||||
---
|
||||
|
||||
Hermes Agent is a self-improving AI agent built by Nous Research. It features automatic skill creation, cross-session memory, and connects messaging platforms (Telegram, Discord, Slack, WhatsApp, Signal, Email) to models through a unified gateway.
|
||||
Hermes Agent is a self-improving AI agent built by Nous Research. It features automatic skill creation, cross-session memory, and 70+ skills that it ships with by default.
|
||||
|
||||

|
||||
|
||||
## Quick start
|
||||
|
||||
### Pull a model
|
||||
|
||||
Before running the setup wizard, make sure you have a model available. Hermes will auto-detect models downloaded through Ollama.
|
||||
|
||||
```bash
|
||||
ollama pull kimi-k2.5:cloud
|
||||
ollama launch hermes
|
||||
```
|
||||
|
||||
See [Recommended models](#recommended-models) for more options.
|
||||
Ollama handles everything automatically:
|
||||
|
||||
### Install
|
||||
1. **Install** — If Hermes isn't installed, Ollama prompts to install it via the Nous Research install script
|
||||
2. **Model** — Pick a model from the selector (local or cloud)
|
||||
3. **Onboarding** — Ollama configures the Ollama provider, points Hermes at `http://127.0.0.1:11434/v1`, and sets your model as the primary
|
||||
4. **Gateway** — Optionally connects a messaging platform (Telegram, Discord, Slack, WhatsApp, Signal, Email) and launches the Hermes chat
|
||||
|
||||
<Note>Hermes on Windows requires WSL2. Install it with `wsl --install` and re-run from inside the WSL shell.</Note>
|
||||
|
||||
## Recommended models
|
||||
|
||||
**Cloud models**:
|
||||
|
||||
- `kimi-k2.5:cloud` — Multimodal reasoning with subagents
|
||||
- `glm-5.1:cloud` — Reasoning and code generation
|
||||
- `qwen3.5:cloud` — Reasoning, coding, and agentic tool use with vision
|
||||
- `minimax-m2.7:cloud` — Fast, efficient coding and real-world productivity
|
||||
|
||||
**Local models:**
|
||||
|
||||
- `gemma4` — Reasoning and code generation locally (~16 GB VRAM)
|
||||
- `qwen3.6` — Reasoning, coding, and visual understanding locally (~24 GB VRAM)
|
||||
|
||||
More models at [ollama.com/search](https://ollama.com/search?c=cloud).
|
||||
|
||||
## Connect messaging apps
|
||||
|
||||
Link Telegram, Discord, Slack, WhatsApp, Signal, or Email to chat with your models from anywhere:
|
||||
|
||||
```bash
|
||||
hermes gateway setup
|
||||
```
|
||||
|
||||
## Reconfigure
|
||||
|
||||
Re-run the full setup wizard at any time:
|
||||
|
||||
```bash
|
||||
hermes setup
|
||||
```
|
||||
|
||||
## Manual setup
|
||||
|
||||
If you'd rather drive Hermes's own wizard instead of `ollama launch hermes`, install it directly:
|
||||
|
||||
```bash
|
||||
curl -fsSL https://raw.githubusercontent.com/NousResearch/hermes-agent/main/scripts/install.sh | bash
|
||||
```
|
||||
|
||||
### Set up
|
||||
|
||||
After installation, Hermes launches the setup wizard automatically. Choose **Quick setup**:
|
||||
Hermes launches the setup wizard automatically. Choose **Quick setup**:
|
||||
|
||||
```
|
||||
How would you like to set up Hermes?
|
||||
@@ -80,32 +117,3 @@ Connect a messaging platform? (Telegram, Discord, etc.)
|
||||
Launch hermes chat now? [Y/n]: Y
|
||||
```
|
||||
|
||||
## Recommended models
|
||||
|
||||
**Cloud models**:
|
||||
|
||||
- `kimi-k2.5:cloud` — Multimodal reasoning with subagents
|
||||
- `qwen3.5:cloud` — Reasoning, coding, and agentic tool use with vision
|
||||
- `glm-5.1:cloud` — Reasoning and code generation
|
||||
- `minimax-m2.7:cloud` — Fast, efficient coding and real-world productivity
|
||||
|
||||
**Local models:**
|
||||
|
||||
- `gemma4` — Reasoning and code generation locally (~16 GB VRAM)
|
||||
- `qwen3.5` — Reasoning, coding, and visual understanding locally (~11 GB VRAM)
|
||||
|
||||
More models at [ollama.com/search](https://ollama.com/models).
|
||||
|
||||
## Configure later
|
||||
|
||||
Re-run the setup wizard at any time:
|
||||
|
||||
```bash
|
||||
hermes setup
|
||||
```
|
||||
|
||||
To configure just messaging:
|
||||
|
||||
```bash
|
||||
hermes setup gateway
|
||||
```
|
||||
|
||||
@@ -10,6 +10,7 @@ Coding assistants that can read, modify, and execute code in your projects.
|
||||
|
||||
- [Claude Code](/integrations/claude-code)
|
||||
- [Codex](/integrations/codex)
|
||||
- [Copilot CLI](/integrations/copilot-cli)
|
||||
- [OpenCode](/integrations/opencode)
|
||||
- [Droid](/integrations/droid)
|
||||
- [Goose](/integrations/goose)
|
||||
|
||||
@@ -28,79 +28,4 @@ To configure without launching:
|
||||
ollama launch opencode --config
|
||||
```
|
||||
|
||||
### Manual setup
|
||||
|
||||
Add a configuration block to `~/.config/opencode/opencode.json`:
|
||||
|
||||
```json
|
||||
{
|
||||
"$schema": "https://opencode.ai/config.json",
|
||||
"provider": {
|
||||
"ollama": {
|
||||
"npm": "@ai-sdk/openai-compatible",
|
||||
"name": "Ollama",
|
||||
"options": {
|
||||
"baseURL": "http://localhost:11434/v1"
|
||||
},
|
||||
"models": {
|
||||
"qwen3-coder": {
|
||||
"name": "qwen3-coder"
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
## Cloud Models
|
||||
|
||||
`glm-4.7:cloud` is the recommended model for use with OpenCode.
|
||||
|
||||
Add the cloud configuration to `~/.config/opencode/opencode.json`:
|
||||
|
||||
```json
|
||||
{
|
||||
"$schema": "https://opencode.ai/config.json",
|
||||
"provider": {
|
||||
"ollama": {
|
||||
"npm": "@ai-sdk/openai-compatible",
|
||||
"name": "Ollama",
|
||||
"options": {
|
||||
"baseURL": "http://localhost:11434/v1"
|
||||
},
|
||||
"models": {
|
||||
"glm-4.7:cloud": {
|
||||
"name": "glm-4.7:cloud"
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
## Connecting to ollama.com
|
||||
|
||||
1. Create an [API key](https://ollama.com/settings/keys) from ollama.com and export it as `OLLAMA_API_KEY`.
|
||||
2. Update `~/.config/opencode/opencode.json` to point to ollama.com:
|
||||
|
||||
```json
|
||||
{
|
||||
"$schema": "https://opencode.ai/config.json",
|
||||
"provider": {
|
||||
"ollama": {
|
||||
"npm": "@ai-sdk/openai-compatible",
|
||||
"name": "Ollama Cloud",
|
||||
"options": {
|
||||
"baseURL": "https://ollama.com/v1"
|
||||
},
|
||||
"models": {
|
||||
"glm-4.7:cloud": {
|
||||
"name": "glm-4.7:cloud"
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
Run `opencode` in a new terminal to load the new settings.
|
||||
<Note>`ollama launch opencode` passes its configuration to OpenCode inline via the `OPENCODE_CONFIG_CONTENT` environment variable. OpenCode deep-merges its config sources on startup, so anything you declare in `~/.config/opencode/opencode.json` is still respected and available inside OpenCode. Models declared only in `opencode.json` won't appear in `ollama launch`'s model-selection menu.</Note>
|
||||
|
||||
2
go.mod
2
go.mod
@@ -106,5 +106,5 @@ require (
|
||||
golang.org/x/term v0.36.0
|
||||
golang.org/x/text v0.30.0
|
||||
google.golang.org/protobuf v1.34.1
|
||||
gopkg.in/yaml.v3 v3.0.1 // indirect
|
||||
gopkg.in/yaml.v3 v3.0.1
|
||||
)
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
package common
|
||||
|
||||
// #cgo CXXFLAGS: -std=c++17
|
||||
// #cgo CXXFLAGS: -std=c++17 -Wno-deprecated-declarations
|
||||
// #cgo CPPFLAGS: -I${SRCDIR}/../include -I${SRCDIR}/../vendor
|
||||
// #cgo CPPFLAGS: -I${SRCDIR}/../../../ml/backend/ggml/ggml/include
|
||||
import "C"
|
||||
|
||||
@@ -6,11 +6,11 @@ Subject: [PATCH] interleave multi rope
|
||||
since ollama doesn't use mrope for anything else, change it to mean the
|
||||
interleaved version used for qwen3vl
|
||||
---
|
||||
ggml/src/ggml-cpu/ops.cpp | 8 ++++----
|
||||
ggml/src/ggml-cuda/rope.cu | 8 ++++----
|
||||
ggml/src/ggml-metal/ggml-metal.metal | 8 ++++----
|
||||
ggml/src/ggml-vulkan/vulkan-shaders/rope_funcs.glsl | 8 ++++----
|
||||
4 files changed, 16 insertions(+), 16 deletions(-)
|
||||
ggml/src/ggml-cpu/ops.cpp | 8 ++++----
|
||||
ggml/src/ggml-cuda/rope.cu | 8 ++++----
|
||||
ggml/src/ggml-metal/ggml-metal.metal | 10 +++++-----
|
||||
ggml/src/ggml-vulkan/vulkan-shaders/rope_funcs.glsl | 8 ++++----
|
||||
4 files changed, 17 insertions(+), 17 deletions(-)
|
||||
|
||||
diff --git a/ggml/src/ggml-cpu/ops.cpp b/ggml/src/ggml-cpu/ops.cpp
|
||||
index 7d1733adb..f4aae5332 100644
|
||||
@@ -59,12 +59,15 @@ index 88ed79111..71ca60214 100644
|
||||
} else {
|
||||
if (sector < sections.v[0]) {
|
||||
diff --git a/ggml/src/ggml-metal/ggml-metal.metal b/ggml/src/ggml-metal/ggml-metal.metal
|
||||
index 236838e9e..c98d269d1 100644
|
||||
index 236838e9e..18b8bb1b1 100644
|
||||
--- a/ggml/src/ggml-metal/ggml-metal.metal
|
||||
+++ b/ggml/src/ggml-metal/ggml-metal.metal
|
||||
@@ -4242,14 +4242,14 @@ kernel void kernel_rope_multi(
|
||||
@@ -4240,16 +4240,16 @@ kernel void kernel_rope_multi(
|
||||
const int sec_w012 = args.sect_0 + args.sect_1 + args.sect_2; // end of section 2
|
||||
const int sector = ic % sect_dims;
|
||||
|
||||
float theta_base;
|
||||
- float theta_base;
|
||||
+ float theta_base = 0.0;
|
||||
if (FC_rope_is_imrope) {
|
||||
- if (sector % 3 == 1 && sector < 3 * args.sect_1) { // h
|
||||
+ if (sector % 3 == 1 && sector < 1 + 3 * args.sect_1) { // h
|
||||
|
||||
@@ -296,7 +296,7 @@ index e99c1763f..80864f303 100644
|
||||
const size_t smem = FATTN_SMEM(nsg);
|
||||
|
||||
diff --git a/ggml/src/ggml-metal/ggml-metal.metal b/ggml/src/ggml-metal/ggml-metal.metal
|
||||
index c98d269d1..d33c16079 100644
|
||||
index 18b8bb1b1..114767785 100644
|
||||
--- a/ggml/src/ggml-metal/ggml-metal.metal
|
||||
+++ b/ggml/src/ggml-metal/ggml-metal.metal
|
||||
@@ -6166,6 +6166,7 @@ kernel void kernel_flash_attn_ext(
|
||||
|
||||
@@ -204,7 +204,7 @@ index 902b54452..a475183d3 100644
|
||||
int ggml_metal_op_norm (ggml_metal_op_t ctx, int idx);
|
||||
int ggml_metal_op_rope (ggml_metal_op_t ctx, int idx);
|
||||
diff --git a/ggml/src/ggml-metal/ggml-metal.metal b/ggml/src/ggml-metal/ggml-metal.metal
|
||||
index d33c16079..c37447a10 100644
|
||||
index 114767785..876a9eecc 100644
|
||||
--- a/ggml/src/ggml-metal/ggml-metal.metal
|
||||
+++ b/ggml/src/ggml-metal/ggml-metal.metal
|
||||
@@ -3012,6 +3012,66 @@ kernel void kernel_l2_norm_f32(
|
||||
|
||||
@@ -24,7 +24,7 @@ index 4ac135603..ac5ad53db 100644
|
||||
// ref: https://developer.apple.com/metal/Metal-Shading-Language-Specification.pdf (Table 2.5)
|
||||
//switch (op->src[0]->type) {
|
||||
diff --git a/ggml/src/ggml-metal/ggml-metal.metal b/ggml/src/ggml-metal/ggml-metal.metal
|
||||
index c37447a10..4f338aa13 100644
|
||||
index 876a9eecc..b14a0000c 100644
|
||||
--- a/ggml/src/ggml-metal/ggml-metal.metal
|
||||
+++ b/ggml/src/ggml-metal/ggml-metal.metal
|
||||
@@ -9427,6 +9427,7 @@ template [[host_name("kernel_mul_mm_id_map0_ne20_6" )]] kernel kernel_mul_mm_id_
|
||||
|
||||
@@ -342,7 +342,7 @@ index 4e5acfbe5..11457f2b1 100644
|
||||
return false;
|
||||
}
|
||||
diff --git a/ggml/src/ggml-metal/ggml-metal.metal b/ggml/src/ggml-metal/ggml-metal.metal
|
||||
index 4f338aa13..8be0c1f0c 100644
|
||||
index b14a0000c..398c80717 100644
|
||||
--- a/ggml/src/ggml-metal/ggml-metal.metal
|
||||
+++ b/ggml/src/ggml-metal/ggml-metal.metal
|
||||
@@ -6276,6 +6276,7 @@ template [[host_name("kernel_flash_attn_ext_f32_dk128_dv128")]] kernel flash_at
|
||||
|
||||
@@ -6,6 +6,7 @@ import (
|
||||
"fmt"
|
||||
"io"
|
||||
"os"
|
||||
"time"
|
||||
)
|
||||
|
||||
type Layer struct {
|
||||
@@ -60,6 +61,9 @@ func NewLayer(r io.Reader, mediatype string) (Layer, error) {
|
||||
return Layer{}, err
|
||||
}
|
||||
}
|
||||
if err := touchLayer(blob); err != nil {
|
||||
return Layer{}, err
|
||||
}
|
||||
|
||||
return Layer{
|
||||
MediaType: mediatype,
|
||||
@@ -83,6 +87,9 @@ func NewLayerFromLayer(digest, mediatype, from string) (Layer, error) {
|
||||
if err != nil {
|
||||
return Layer{}, err
|
||||
}
|
||||
if err := touchLayer(blob); err != nil {
|
||||
return Layer{}, err
|
||||
}
|
||||
|
||||
return Layer{
|
||||
MediaType: mediatype,
|
||||
@@ -93,6 +100,11 @@ func NewLayerFromLayer(digest, mediatype, from string) (Layer, error) {
|
||||
}, nil
|
||||
}
|
||||
|
||||
func touchLayer(path string) error {
|
||||
now := time.Now()
|
||||
return os.Chtimes(path, now, now)
|
||||
}
|
||||
|
||||
func (l *Layer) Open() (io.ReadSeekCloser, error) {
|
||||
if l.Digest == "" {
|
||||
return nil, errors.New("opening layer with empty digest")
|
||||
|
||||
@@ -7122,7 +7122,7 @@ kernel void kernel_rope_multi(
|
||||
const int sec_w012 = args.sect_0 + args.sect_1 + args.sect_2; // end of section 2
|
||||
const int sector = ic % sect_dims;
|
||||
|
||||
float theta_base;
|
||||
float theta_base = 0.0;
|
||||
if (FC_rope_is_imrope) {
|
||||
if (sector % 3 == 1 && sector < 1 + 3 * args.sect_1) { // h
|
||||
theta_base = (float) pos[i2 + args.ne02 * 1];
|
||||
|
||||
@@ -4300,7 +4300,7 @@ kernel void kernel_rope_multi(
|
||||
const int sec_w012 = args.sect_0 + args.sect_1 + args.sect_2; // end of section 2
|
||||
const int sector = ic % sect_dims;
|
||||
|
||||
float theta_base;
|
||||
float theta_base = 0.0;
|
||||
if (FC_rope_is_imrope) {
|
||||
if (sector % 3 == 1 && sector < 1 + 3 * args.sect_1) { // h
|
||||
theta_base = (float) pos[i2 + args.ne02 * 1];
|
||||
|
||||
@@ -12,7 +12,8 @@ import (
|
||||
// <|turn>/<turn|> markers, <|"|> string delimiters, and <|tool>/
|
||||
// <|tool_call>/<|tool_response> tags for function calling.
|
||||
type Gemma4Renderer struct {
|
||||
useImgTags bool
|
||||
useImgTags bool
|
||||
emptyBlockOnNothink bool
|
||||
}
|
||||
|
||||
const (
|
||||
@@ -124,7 +125,7 @@ func (r *Gemma4Renderer) Render(messages []api.Message, tools []api.Tool, thinkV
|
||||
// Generation prompt.
|
||||
if prevMessageType != "tool_response" && prevMessageType != "tool_call" {
|
||||
sb.WriteString("<|turn>model\n")
|
||||
if !hasThink {
|
||||
if r.emptyBlockOnNothink && !hasThink {
|
||||
sb.WriteString("<|channel>thought\n<channel|>")
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,14 +1,18 @@
|
||||
package renderers
|
||||
|
||||
// TestGemma4RendererMatchesReference verifies our renderer matches the HF
|
||||
// Jinja2 chat template exactly.
|
||||
// TestGemma4RendererMatchesReference verifies our renderer matches the checked-in
|
||||
// Gemma 4 reference template.
|
||||
//
|
||||
// To regenerate expected values, save gemma4Jinja2Template (below) to
|
||||
// gemma4_chat_template.jinja2 and run:
|
||||
// Current upstream Gemma 4 chat templates differ by model size. The checked-in
|
||||
// reference cases below use the small (e2b/e4b-style) baseline, with large
|
||||
// (26b/31b-style) checks covered separately in this file.
|
||||
//
|
||||
// To regenerate expected values, save the E2B template to
|
||||
// gemma4_e2b_chat_template.jinja2 and run:
|
||||
//
|
||||
// python3 -c "
|
||||
// from jinja2 import Environment; import json
|
||||
// tmpl = Environment().from_string(open('gemma4_chat_template.jinja2').read())
|
||||
// tmpl = Environment().from_string(open('gemma4_e2b_chat_template.jinja2').read())
|
||||
// msgs = [{'role':'user','content':'Hello'}]
|
||||
// print(repr(tmpl.render(messages=msgs, bos_token='<bos>', add_generation_prompt=True)))
|
||||
// "
|
||||
@@ -26,8 +30,13 @@ import (
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
// The full Jinja2 template is committed as testdata/gemma4_chat_template.jinja2.
|
||||
// Run with VERIFY_JINJA2=1 to verify expected values against the template using uv + Python.
|
||||
const (
|
||||
gemma4E2BTemplate = "testdata/gemma4_e2b_chat_template.jinja2"
|
||||
gemma431BTemplate = "testdata/gemma4_31b_chat_template.jinja2"
|
||||
)
|
||||
|
||||
// The upstream Gemma 4 chat templates are committed by size under testdata/.
|
||||
// Run with VERIFY_JINJA2=1 to verify expected values against the E2B template using uv + Python.
|
||||
|
||||
func bashRefTool() []api.Tool {
|
||||
return []api.Tool{{
|
||||
@@ -665,7 +674,7 @@ func TestGemma4RendererMatchesReference(t *testing.T) {
|
||||
{
|
||||
name: "user_only",
|
||||
messages: []api.Message{{Role: "user", Content: "Hello"}},
|
||||
expected: "<bos><|turn>user\nHello<turn|>\n<|turn>model\n<|channel>thought\n<channel|>",
|
||||
expected: "<bos><|turn>user\nHello<turn|>\n<|turn>model\n",
|
||||
},
|
||||
{
|
||||
name: "system_user",
|
||||
@@ -673,7 +682,7 @@ func TestGemma4RendererMatchesReference(t *testing.T) {
|
||||
{Role: "system", Content: "You are helpful."},
|
||||
{Role: "user", Content: "Hi"},
|
||||
},
|
||||
expected: "<bos><|turn>system\nYou are helpful.<turn|>\n<|turn>user\nHi<turn|>\n<|turn>model\n<|channel>thought\n<channel|>",
|
||||
expected: "<bos><|turn>system\nYou are helpful.<turn|>\n<|turn>user\nHi<turn|>\n<|turn>model\n",
|
||||
},
|
||||
{
|
||||
name: "developer_user",
|
||||
@@ -681,13 +690,13 @@ func TestGemma4RendererMatchesReference(t *testing.T) {
|
||||
{Role: "developer", Content: "You are helpful."},
|
||||
{Role: "user", Content: "Hi"},
|
||||
},
|
||||
expected: "<bos><|turn>system\nYou are helpful.<turn|>\n<|turn>user\nHi<turn|>\n<|turn>model\n<|channel>thought\n<channel|>",
|
||||
expected: "<bos><|turn>system\nYou are helpful.<turn|>\n<|turn>user\nHi<turn|>\n<|turn>model\n",
|
||||
},
|
||||
{
|
||||
name: "tools_no_system",
|
||||
messages: []api.Message{{Role: "user", Content: "Hi"}},
|
||||
tools: bashRefTool(),
|
||||
expected: "<bos><|turn>system\n" + bashDeclRef + "<turn|>\n<|turn>user\nHi<turn|>\n<|turn>model\n<|channel>thought\n<channel|>",
|
||||
expected: "<bos><|turn>system\n" + bashDeclRef + "<turn|>\n<|turn>user\nHi<turn|>\n<|turn>model\n",
|
||||
},
|
||||
{
|
||||
name: "system_tools",
|
||||
@@ -696,7 +705,7 @@ func TestGemma4RendererMatchesReference(t *testing.T) {
|
||||
{Role: "user", Content: "Hi"},
|
||||
},
|
||||
tools: bashRefTool(),
|
||||
expected: "<bos><|turn>system\nYou are helpful." + bashDeclRef + "<turn|>\n<|turn>user\nHi<turn|>\n<|turn>model\n<|channel>thought\n<channel|>",
|
||||
expected: "<bos><|turn>system\nYou are helpful." + bashDeclRef + "<turn|>\n<|turn>user\nHi<turn|>\n<|turn>model\n",
|
||||
},
|
||||
{
|
||||
name: "thinking_no_system",
|
||||
@@ -730,6 +739,12 @@ func TestGemma4RendererMatchesReference(t *testing.T) {
|
||||
think: thinkTrue(),
|
||||
expected: "<bos><|turn>system\n<|think|>\nYou are helpful." + bashDeclRef + "<turn|>\n<|turn>user\nHi<turn|>\n<|turn>model\n",
|
||||
},
|
||||
{
|
||||
name: "thinking_explicitly_disabled",
|
||||
messages: []api.Message{{Role: "user", Content: "Hi"}},
|
||||
think: thinkFalse(),
|
||||
expected: "<bos><|turn>user\nHi<turn|>\n<|turn>model\n",
|
||||
},
|
||||
|
||||
// === Message loop paths ===
|
||||
{
|
||||
@@ -744,7 +759,7 @@ func TestGemma4RendererMatchesReference(t *testing.T) {
|
||||
"<|turn>user\nHi<turn|>\n" +
|
||||
"<|turn>model\nHello!<turn|>\n" +
|
||||
"<|turn>user\nMore<turn|>\n" +
|
||||
"<|turn>model\n<|channel>thought\n<channel|>",
|
||||
"<|turn>model\n",
|
||||
},
|
||||
{
|
||||
// Tool call with structured args → tool response as separate <|turn>tool turn
|
||||
@@ -806,7 +821,7 @@ func TestGemma4RendererMatchesReference(t *testing.T) {
|
||||
"<|tool_response>response:bash{value:" + q + "file1.txt\nfile2.txt" + q + "}<tool_response|>" +
|
||||
"Here are the files.<turn|>\n" +
|
||||
"<|turn>user\nRead file1.txt<turn|>\n" +
|
||||
"<|turn>model\n<|channel>thought\n<channel|>",
|
||||
"<|turn>model\n",
|
||||
},
|
||||
{
|
||||
// Multiple tool calls + multiple tool responses
|
||||
@@ -841,7 +856,7 @@ func TestGemma4RendererMatchesReference(t *testing.T) {
|
||||
expected: "<bos><|turn>user\nWhat is 2+2?<turn|>\n" +
|
||||
"<|turn>model\n4<turn|>\n" +
|
||||
"<|turn>user\nAnd 3+3?<turn|>\n" +
|
||||
"<|turn>model\n<|channel>thought\n<channel|>",
|
||||
"<|turn>model\n",
|
||||
},
|
||||
// === Additional edge cases ported from original tests ===
|
||||
{
|
||||
@@ -899,17 +914,17 @@ func TestGemma4RendererMatchesReference(t *testing.T) {
|
||||
messages: []api.Message{{Role: "user", Content: "Test"}},
|
||||
tools: modeTool(),
|
||||
expected: "<bos><|turn>system\n" + modeDeclRef + "<turn|>\n" +
|
||||
"<|turn>user\nTest<turn|>\n<|turn>model\n<|channel>thought\n<channel|>",
|
||||
"<|turn>user\nTest<turn|>\n<|turn>model\n",
|
||||
},
|
||||
{
|
||||
name: "unicode_content",
|
||||
messages: []api.Message{{Role: "user", Content: "こんにちは"}},
|
||||
expected: "<bos><|turn>user\nこんにちは<turn|>\n<|turn>model\n<|channel>thought\n<channel|>",
|
||||
expected: "<bos><|turn>user\nこんにちは<turn|>\n<|turn>model\n",
|
||||
},
|
||||
{
|
||||
name: "newlines_in_content",
|
||||
messages: []api.Message{{Role: "user", Content: "Line 1\nLine 2\nLine 3"}},
|
||||
expected: "<bos><|turn>user\nLine 1\nLine 2\nLine 3<turn|>\n<|turn>model\n<|channel>thought\n<channel|>",
|
||||
expected: "<bos><|turn>user\nLine 1\nLine 2\nLine 3<turn|>\n<|turn>model\n",
|
||||
},
|
||||
{
|
||||
// Tool response (raw JSON) followed by user message
|
||||
@@ -928,7 +943,7 @@ func TestGemma4RendererMatchesReference(t *testing.T) {
|
||||
"<|turn>model\n<|tool_call>call:get_weather{city:" + q + "Tokyo" + q + "}<tool_call|>" +
|
||||
"<|tool_response>response:get_weather{value:" + q + `{"temperature": 15, "weather": "sunny"}` + q + "}<tool_response|>" +
|
||||
"<|turn>user\nThanks!<turn|>\n" +
|
||||
"<|turn>model\n<|channel>thought\n<channel|>",
|
||||
"<|turn>model\n",
|
||||
},
|
||||
// === Ordering and whitespace edge cases ===
|
||||
{
|
||||
@@ -951,7 +966,7 @@ func TestGemma4RendererMatchesReference(t *testing.T) {
|
||||
// User content with whitespace is trimmed
|
||||
name: "user_content_trimmed",
|
||||
messages: []api.Message{{Role: "user", Content: " hello "}},
|
||||
expected: "<bos><|turn>user\nhello<turn|>\n<|turn>model\n<|channel>thought\n<channel|>",
|
||||
expected: "<bos><|turn>user\nhello<turn|>\n<|turn>model\n",
|
||||
},
|
||||
{
|
||||
// Empty tool call arguments
|
||||
@@ -975,7 +990,7 @@ func TestGemma4RendererMatchesReference(t *testing.T) {
|
||||
messages: []api.Message{{Role: "user", Content: "Create"}},
|
||||
tools: nestedTool(),
|
||||
expected: "<bos><|turn>system\n" + nestedDeclRef + "<turn|>\n" +
|
||||
"<|turn>user\nCreate<turn|>\n<|turn>model\n<|channel>thought\n<channel|>",
|
||||
"<|turn>user\nCreate<turn|>\n<|turn>model\n",
|
||||
},
|
||||
{
|
||||
// Array type in tool declaration
|
||||
@@ -983,7 +998,7 @@ func TestGemma4RendererMatchesReference(t *testing.T) {
|
||||
messages: []api.Message{{Role: "user", Content: "Batch"}},
|
||||
tools: arrayTool(),
|
||||
expected: "<bos><|turn>system\n" + arrayDeclRef + "<turn|>\n" +
|
||||
"<|turn>user\nBatch<turn|>\n<|turn>model\n<|channel>thought\n<channel|>",
|
||||
"<|turn>user\nBatch<turn|>\n<|turn>model\n",
|
||||
},
|
||||
{
|
||||
// Top-level typed union follows the template's odd stringified-list form.
|
||||
@@ -995,8 +1010,7 @@ func TestGemma4RendererMatchesReference(t *testing.T) {
|
||||
<|turn>user
|
||||
Hi<turn|>
|
||||
<|turn>model
|
||||
<|channel>thought
|
||||
<channel|>`,
|
||||
`,
|
||||
},
|
||||
{
|
||||
// Assistant whitespace is trimmed (strip_thinking includes | trim)
|
||||
@@ -1009,7 +1023,7 @@ Hi<turn|>
|
||||
expected: "<bos><|turn>user\nHi<turn|>\n" +
|
||||
"<|turn>model\nspaced<turn|>\n" +
|
||||
"<|turn>user\nMore<turn|>\n" +
|
||||
"<|turn>model\n<|channel>thought\n<channel|>",
|
||||
"<|turn>model\n",
|
||||
},
|
||||
{
|
||||
// Three sequential tool responses
|
||||
@@ -1064,7 +1078,7 @@ Hi<turn|>
|
||||
expected: "<bos><|turn>user\nHi<turn|>\n" +
|
||||
"<|turn>model\nMiddleDone<turn|>\n" +
|
||||
"<|turn>user\nMore<turn|>\n" +
|
||||
"<|turn>model\n<|channel>thought\n<channel|>",
|
||||
"<|turn>model\n",
|
||||
},
|
||||
{
|
||||
// Property with no description — just type
|
||||
@@ -1072,7 +1086,7 @@ Hi<turn|>
|
||||
messages: []api.Message{{Role: "user", Content: "Count"}},
|
||||
tools: countTool(),
|
||||
expected: "<bos><|turn>system\n" + countDeclRef + "<turn|>\n" +
|
||||
"<|turn>user\nCount<turn|>\n<|turn>model\n<|channel>thought\n<channel|>",
|
||||
"<|turn>user\nCount<turn|>\n<|turn>model\n",
|
||||
},
|
||||
{
|
||||
// System message with leading/trailing whitespace is trimmed
|
||||
@@ -1082,7 +1096,7 @@ Hi<turn|>
|
||||
{Role: "user", Content: "Hi"},
|
||||
},
|
||||
expected: "<bos><|turn>system\nYou are helpful.<turn|>\n" +
|
||||
"<|turn>user\nHi<turn|>\n<|turn>model\n<|channel>thought\n<channel|>",
|
||||
"<|turn>user\nHi<turn|>\n<|turn>model\n",
|
||||
},
|
||||
{
|
||||
// Deeply nested map in tool call arguments (3 levels)
|
||||
@@ -1144,7 +1158,7 @@ Hi<turn|>
|
||||
messages: []api.Message{{Role: "user", Content: "Set"}},
|
||||
tools: enumNoDescTool(),
|
||||
expected: "<bos><|turn>system\n" + enumNoDescDeclRef + "<turn|>\n" +
|
||||
"<|turn>user\nSet<turn|>\n<|turn>model\n<|channel>thought\n<channel|>",
|
||||
"<|turn>user\nSet<turn|>\n<|turn>model\n",
|
||||
},
|
||||
{
|
||||
// System message that is only whitespace (trims to empty)
|
||||
@@ -1154,7 +1168,7 @@ Hi<turn|>
|
||||
{Role: "user", Content: "Hi"},
|
||||
},
|
||||
expected: "<bos><|turn>system\n<turn|>\n" +
|
||||
"<|turn>user\nHi<turn|>\n<|turn>model\n<|channel>thought\n<channel|>",
|
||||
"<|turn>user\nHi<turn|>\n<|turn>model\n",
|
||||
},
|
||||
{
|
||||
// Empty assistant content (empty string, not nil)
|
||||
@@ -1167,7 +1181,7 @@ Hi<turn|>
|
||||
expected: "<bos><|turn>user\nHi<turn|>\n" +
|
||||
"<|turn>model\n<turn|>\n" +
|
||||
"<|turn>user\nMore<turn|>\n" +
|
||||
"<|turn>model\n<|channel>thought\n<channel|>",
|
||||
"<|turn>model\n",
|
||||
},
|
||||
{
|
||||
// Map argument with string keys (keys NOT escaped with <|"|>)
|
||||
@@ -1193,7 +1207,7 @@ Hi<turn|>
|
||||
messages: []api.Message{{Role: "user", Content: "Search"}},
|
||||
tools: searchTool(),
|
||||
expected: "<bos><|turn>system\n" + searchDeclRef + "<turn|>\n" +
|
||||
"<|turn>user\nSearch<turn|>\n<|turn>model\n<|channel>thought\n<channel|>",
|
||||
"<|turn>user\nSearch<turn|>\n<|turn>model\n",
|
||||
},
|
||||
|
||||
// === Round 3 coverage gaps ===
|
||||
@@ -1221,7 +1235,7 @@ Hi<turn|>
|
||||
{Role: "user", Content: "Hi"},
|
||||
},
|
||||
expected: "<bos><|turn>system\n<turn|>\n" +
|
||||
"<|turn>user\nHi<turn|>\n<|turn>model\n<|channel>thought\n<channel|>",
|
||||
"<|turn>user\nHi<turn|>\n<|turn>model\n",
|
||||
},
|
||||
{
|
||||
// Nested OBJECT property with required field
|
||||
@@ -1229,7 +1243,7 @@ Hi<turn|>
|
||||
messages: []api.Message{{Role: "user", Content: "Create"}},
|
||||
tools: nestedRequiredTool(),
|
||||
expected: "<bos><|turn>system\n" + nestedRequiredDeclRef + "<turn|>\n" +
|
||||
"<|turn>user\nCreate<turn|>\n<|turn>model\n<|channel>thought\n<channel|>",
|
||||
"<|turn>user\nCreate<turn|>\n<|turn>model\n",
|
||||
},
|
||||
{
|
||||
// Non-integer float in tool call argument
|
||||
@@ -1256,7 +1270,7 @@ Hi<turn|>
|
||||
},
|
||||
expected: "<bos><|turn>user\nHi<turn|>\n" +
|
||||
"<|turn>model\nResult<turn|>\n" +
|
||||
"<|turn>model\n<|channel>thought\n<channel|>",
|
||||
"<|turn>model\n",
|
||||
},
|
||||
{
|
||||
// Tool content with newlines and leading/trailing whitespace trimmed
|
||||
@@ -1280,7 +1294,7 @@ Hi<turn|>
|
||||
messages: []api.Message{{Role: "user", Content: "Raw"}},
|
||||
tools: rawTool(),
|
||||
expected: "<bos><|turn>system\n" + rawDeclRef + "<turn|>\n" +
|
||||
"<|turn>user\nRaw<turn|>\n<|turn>model\n<|channel>thought\n<channel|>",
|
||||
"<|turn>user\nRaw<turn|>\n<|turn>model\n",
|
||||
},
|
||||
{
|
||||
// Multiple required fields at top level
|
||||
@@ -1288,7 +1302,7 @@ Hi<turn|>
|
||||
messages: []api.Message{{Role: "user", Content: "Move"}},
|
||||
tools: moveTool(),
|
||||
expected: "<bos><|turn>system\n" + moveDeclRef + "<turn|>\n" +
|
||||
"<|turn>user\nMove<turn|>\n<|turn>model\n<|channel>thought\n<channel|>",
|
||||
"<|turn>user\nMove<turn|>\n<|turn>model\n",
|
||||
},
|
||||
{
|
||||
// Assistant content that is ONLY thinking (strips to empty)
|
||||
@@ -1301,7 +1315,7 @@ Hi<turn|>
|
||||
expected: "<bos><|turn>user\nHi<turn|>\n" +
|
||||
"<|turn>model\n<turn|>\n" +
|
||||
"<|turn>user\nMore<turn|>\n" +
|
||||
"<|turn>model\n<|channel>thought\n<channel|>",
|
||||
"<|turn>model\n",
|
||||
},
|
||||
|
||||
// === Round 4: final coverage gaps ===
|
||||
@@ -1334,7 +1348,7 @@ Hi<turn|>
|
||||
messages: []api.Message{{Role: "user", Content: "Tag"}},
|
||||
tools: arrayNoItemsTool(),
|
||||
expected: "<bos><|turn>system\n" + arrayNoItemsDeclRef + "<turn|>\n" +
|
||||
"<|turn>user\nTag<turn|>\n<|turn>model\n<|channel>thought\n<channel|>",
|
||||
"<|turn>user\nTag<turn|>\n<|turn>model\n",
|
||||
},
|
||||
{
|
||||
// OBJECT property without description but with nested properties
|
||||
@@ -1342,7 +1356,7 @@ Hi<turn|>
|
||||
messages: []api.Message{{Role: "user", Content: "Update"}},
|
||||
tools: objectNoDescTool(),
|
||||
expected: "<bos><|turn>system\n" + objectNoDescDeclRef + "<turn|>\n" +
|
||||
"<|turn>user\nUpdate<turn|>\n<|turn>model\n<|channel>thought\n<channel|>",
|
||||
"<|turn>user\nUpdate<turn|>\n<|turn>model\n",
|
||||
},
|
||||
|
||||
// === Round 5: coding agent patterns ===
|
||||
@@ -1372,7 +1386,7 @@ Hi<turn|>
|
||||
"<|tool_response>response:bash{value:" + q + q + "}<tool_response|>" +
|
||||
"Done.<turn|>\n" +
|
||||
"<|turn>user\nThanks<turn|>\n" +
|
||||
"<|turn>model\n<|channel>thought\n<channel|>",
|
||||
"<|turn>model\n",
|
||||
},
|
||||
{
|
||||
// Tool call with thinking that strips to real remaining content
|
||||
@@ -1392,7 +1406,7 @@ Hi<turn|>
|
||||
"<|tool_response>response:bash{value:" + q + "main.go\ngo.mod" + q + "}<tool_response|>" +
|
||||
"Let me list the files.<turn|>\n" +
|
||||
"<|turn>user\nOK<turn|>\n" +
|
||||
"<|turn>model\n<|channel>thought\n<channel|>",
|
||||
"<|turn>model\n",
|
||||
},
|
||||
{
|
||||
// Argument value containing newlines (multi-line script)
|
||||
@@ -1460,6 +1474,47 @@ Hi<turn|>
|
||||
}
|
||||
}
|
||||
|
||||
func TestGemma4RendererVariantsMatchExpectedGenerationPrompt(t *testing.T) {
|
||||
messages := []api.Message{{Role: "user", Content: "Hello"}}
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
rendererName string
|
||||
expected string
|
||||
}{
|
||||
{
|
||||
name: "legacy_alias",
|
||||
rendererName: "gemma4",
|
||||
expected: "<bos><|turn>user\nHello<turn|>\n<|turn>model\n",
|
||||
},
|
||||
{
|
||||
name: "small",
|
||||
rendererName: "gemma4-small",
|
||||
expected: "<bos><|turn>user\nHello<turn|>\n<|turn>model\n",
|
||||
},
|
||||
{
|
||||
name: "large",
|
||||
rendererName: "gemma4-large",
|
||||
expected: "<bos><|turn>user\nHello<turn|>\n<|turn>model\n<|channel>thought\n<channel|>",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
got, err := RenderWithRenderer(tt.rendererName, messages, nil, nil)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, tt.expected, got)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestGemma4LargeRendererOmitsEmptyThoughtBlockWhenThinkingEnabled(t *testing.T) {
|
||||
got, err := RenderWithRenderer("gemma4-large", []api.Message{{Role: "user", Content: "Hello"}}, nil, thinkTrue())
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, "<bos><|turn>system\n<|think|>\n<turn|>\n<|turn>user\nHello<turn|>\n<|turn>model\n", got)
|
||||
assert.NotContains(t, got, "<|channel>thought\n<channel|>")
|
||||
}
|
||||
|
||||
func TestGemma4RendererMatchesJinja2ExpandedParity(t *testing.T) {
|
||||
if os.Getenv("VERIFY_JINJA2") == "" {
|
||||
t.Skip("set VERIFY_JINJA2=1 to run expanded Jinja2 parity checks")
|
||||
@@ -1602,15 +1657,35 @@ func TestGemma4RendererMatchesJinja2ExpandedParity(t *testing.T) {
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
renderer := &Gemma4Renderer{useImgTags: RenderImgTags}
|
||||
got, err := renderer.Render(tt.messages, tt.tools, tt.think)
|
||||
assert.NoError(t, err)
|
||||
variants := []struct {
|
||||
name string
|
||||
renderer *Gemma4Renderer
|
||||
templateRel string
|
||||
}{
|
||||
{
|
||||
name: "small",
|
||||
renderer: &Gemma4Renderer{useImgTags: RenderImgTags},
|
||||
templateRel: gemma4E2BTemplate,
|
||||
},
|
||||
{
|
||||
name: "large",
|
||||
renderer: &Gemma4Renderer{useImgTags: RenderImgTags, emptyBlockOnNothink: true},
|
||||
templateRel: gemma431BTemplate,
|
||||
},
|
||||
}
|
||||
|
||||
jinja2Output := renderWithJinja2(t, tt.messages, tt.tools, tt.think)
|
||||
assert.Equal(t, jinja2Output, got,
|
||||
"renderer output doesn't match Jinja2 template output")
|
||||
for _, variant := range variants {
|
||||
t.Run(variant.name, func(t *testing.T) {
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
got, err := variant.renderer.Render(tt.messages, tt.tools, tt.think)
|
||||
assert.NoError(t, err)
|
||||
|
||||
jinja2Output := renderWithJinja2Template(t, variant.templateRel, tt.messages, tt.tools, tt.think)
|
||||
assert.Equal(t, jinja2Output, got,
|
||||
"renderer output doesn't match Jinja2 template output")
|
||||
})
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -1720,12 +1795,35 @@ func TestGemma4RendererToolResponseWithoutNameOrIDUsesUnknown(t *testing.T) {
|
||||
assert.NotContains(t, got, `response:read{value:<|"|>payload<|"|>}`)
|
||||
}
|
||||
|
||||
func TestGemma4SizeTemplateFixturesDifferAtGenerationPrompt(t *testing.T) {
|
||||
e2b, err := os.ReadFile(gemma4E2BTemplate)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to read %s: %v", gemma4E2BTemplate, err)
|
||||
}
|
||||
|
||||
thirtyOneB, err := os.ReadFile(gemma431BTemplate)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to read %s: %v", gemma431BTemplate, err)
|
||||
}
|
||||
|
||||
assert.Contains(t, string(e2b), "{{- '<|turn>model\\n' -}}")
|
||||
assert.NotContains(t, string(e2b), "{{- '<|channel>thought\\n<channel|>' -}}")
|
||||
assert.Contains(t, string(thirtyOneB), "{{- '<|turn>model\\n' -}}")
|
||||
assert.Contains(t, string(thirtyOneB), "{{- '<|channel>thought\\n<channel|>' -}}")
|
||||
}
|
||||
|
||||
// renderWithJinja2 shells out to uv + Python to render messages through the
|
||||
// Jinja2 chat template. Returns the rendered string.
|
||||
// E2B Jinja2 chat template. Returns the rendered string.
|
||||
func renderWithJinja2(t *testing.T, messages []api.Message, tools []api.Tool, think *api.ThinkValue) string {
|
||||
return renderWithJinja2Template(t, gemma4E2BTemplate, messages, tools, think)
|
||||
}
|
||||
|
||||
// renderWithJinja2Template shells out to uv + Python to render messages through
|
||||
// the named Jinja2 chat template. Returns the rendered string.
|
||||
func renderWithJinja2Template(t *testing.T, templateRelPath string, messages []api.Message, tools []api.Tool, think *api.ThinkValue) string {
|
||||
t.Helper()
|
||||
|
||||
templatePath, err := filepath.Abs("testdata/gemma4_chat_template.jinja2")
|
||||
templatePath, err := filepath.Abs(templateRelPath)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to get template path: %v", err)
|
||||
}
|
||||
@@ -1814,3 +1912,7 @@ print(tmpl.render(**kwargs), end="")
|
||||
func thinkTrue() *api.ThinkValue {
|
||||
return &api.ThinkValue{Value: true}
|
||||
}
|
||||
|
||||
func thinkFalse() *api.ThinkValue {
|
||||
return &api.ThinkValue{Value: false}
|
||||
}
|
||||
|
||||
@@ -81,8 +81,10 @@ func rendererForName(name string) Renderer {
|
||||
return renderer
|
||||
case "nemotron-3-nano":
|
||||
return &Nemotron3NanoRenderer{}
|
||||
case "gemma4":
|
||||
case "gemma4", "gemma4-small":
|
||||
return &Gemma4Renderer{useImgTags: RenderImgTags}
|
||||
case "gemma4-large":
|
||||
return &Gemma4Renderer{useImgTags: RenderImgTags, emptyBlockOnNothink: true}
|
||||
case "functiongemma":
|
||||
return &FunctionGemmaRenderer{}
|
||||
case "glm-4.7":
|
||||
|
||||
344
model/renderers/testdata/gemma4_e2b_chat_template.jinja2
vendored
Normal file
344
model/renderers/testdata/gemma4_e2b_chat_template.jinja2
vendored
Normal file
@@ -0,0 +1,344 @@
|
||||
{%- macro format_parameters(properties, required) -%}
|
||||
{%- set standard_keys = ['description', 'type', 'properties', 'required', 'nullable'] -%}
|
||||
{%- set ns = namespace(found_first=false) -%}
|
||||
{%- for key, value in properties | dictsort -%}
|
||||
{%- set add_comma = false -%}
|
||||
{%- if key not in standard_keys -%}
|
||||
{%- if ns.found_first %},{% endif -%}
|
||||
{%- set ns.found_first = true -%}
|
||||
{{ key }}:{
|
||||
{%- if value['description'] -%}
|
||||
description:<|"|>{{ value['description'] }}<|"|>
|
||||
{%- set add_comma = true -%}
|
||||
{%- endif -%}
|
||||
{%- if value['type'] | upper == 'STRING' -%}
|
||||
{%- if value['enum'] -%}
|
||||
{%- if add_comma %},{%- else -%} {%- set add_comma = true -%} {% endif -%}
|
||||
enum:{{ format_argument(value['enum']) }}
|
||||
{%- endif -%}
|
||||
{%- elif value['type'] | upper == 'ARRAY' -%}
|
||||
{%- if value['items'] is mapping and value['items'] -%}
|
||||
{%- if add_comma %},{%- else -%} {%- set add_comma = true -%} {% endif -%}
|
||||
items:{
|
||||
{%- set ns_items = namespace(found_first=false) -%}
|
||||
{%- for item_key, item_value in value['items'] | dictsort -%}
|
||||
{%- if item_value is not none -%}
|
||||
{%- if ns_items.found_first %},{% endif -%}
|
||||
{%- set ns_items.found_first = true -%}
|
||||
{%- if item_key == 'properties' -%}
|
||||
properties:{
|
||||
{%- if item_value is mapping -%}
|
||||
{{- format_parameters(item_value, value['items']['required'] | default([])) -}}
|
||||
{%- endif -%}
|
||||
}
|
||||
{%- elif item_key == 'required' -%}
|
||||
required:[
|
||||
{%- for req_item in item_value -%}
|
||||
<|"|>{{- req_item -}}<|"|>
|
||||
{%- if not loop.last %},{% endif -%}
|
||||
{%- endfor -%}
|
||||
]
|
||||
{%- elif item_key == 'type' -%}
|
||||
{%- if item_value is string -%}
|
||||
type:{{ format_argument(item_value | upper) }}
|
||||
{%- else -%}
|
||||
type:{{ format_argument(item_value | map('upper') | list) }}
|
||||
{%- endif -%}
|
||||
{%- else -%}
|
||||
{{ item_key }}:{{ format_argument(item_value) }}
|
||||
{%- endif -%}
|
||||
{%- endif -%}
|
||||
{%- endfor -%}
|
||||
}
|
||||
{%- endif -%}
|
||||
{%- endif -%}
|
||||
{%- if value['nullable'] %}
|
||||
{%- if add_comma %},{%- else -%} {%- set add_comma = true -%} {% endif -%}
|
||||
nullable:true
|
||||
{%- endif -%}
|
||||
{%- if value['type'] | upper == 'OBJECT' -%}
|
||||
{%- if value['properties'] is defined and value['properties'] is mapping -%}
|
||||
{%- if add_comma %},{%- else -%} {%- set add_comma = true -%} {% endif -%}
|
||||
properties:{
|
||||
{{- format_parameters(value['properties'], value['required'] | default([])) -}}
|
||||
}
|
||||
{%- elif value is mapping -%}
|
||||
{%- if add_comma %},{%- else -%} {%- set add_comma = true -%} {% endif -%}
|
||||
properties:{
|
||||
{{- format_parameters(value, value['required'] | default([])) -}}
|
||||
}
|
||||
{%- endif -%}
|
||||
{%- if value['required'] -%}
|
||||
{%- if add_comma %},{%- else -%} {%- set add_comma = true -%} {% endif -%}
|
||||
required:[
|
||||
{%- for item in value['required'] | default([]) -%}
|
||||
<|"|>{{- item -}}<|"|>
|
||||
{%- if not loop.last %},{% endif -%}
|
||||
{%- endfor -%}
|
||||
]
|
||||
{%- endif -%}
|
||||
{%- endif -%}
|
||||
{%- if add_comma %},{%- else -%} {%- set add_comma = true -%} {% endif -%}
|
||||
type:<|"|>{{ value['type'] | upper }}<|"|>}
|
||||
{%- endif -%}
|
||||
{%- endfor -%}
|
||||
{%- endmacro -%}
|
||||
{%- macro format_function_declaration(tool_data) -%}
|
||||
declaration:{{- tool_data['function']['name'] -}}{description:<|"|>{{- tool_data['function']['description'] -}}<|"|>
|
||||
{%- set params = tool_data['function']['parameters'] -%}
|
||||
{%- if params -%}
|
||||
,parameters:{
|
||||
{%- if params['properties'] -%}
|
||||
properties:{ {{- format_parameters(params['properties'], params['required']) -}} },
|
||||
{%- endif -%}
|
||||
{%- if params['required'] -%}
|
||||
required:[
|
||||
{%- for item in params['required'] -%}
|
||||
<|"|>{{- item -}}<|"|>
|
||||
{{- ',' if not loop.last -}}
|
||||
{%- endfor -%}
|
||||
],
|
||||
{%- endif -%}
|
||||
{%- if params['type'] -%}
|
||||
type:<|"|>{{- params['type'] | upper -}}<|"|>}
|
||||
{%- endif -%}
|
||||
{%- endif -%}
|
||||
{%- if 'response' in tool_data['function'] -%}
|
||||
{%- set response_declaration = tool_data['function']['response'] -%}
|
||||
,response:{
|
||||
{%- if response_declaration['description'] -%}
|
||||
description:<|"|>{{- response_declaration['description'] -}}<|"|>,
|
||||
{%- endif -%}
|
||||
{%- if response_declaration['type'] | upper == 'OBJECT' -%}
|
||||
type:<|"|>{{- response_declaration['type'] | upper -}}<|"|>}
|
||||
{%- endif -%}
|
||||
{%- endif -%}
|
||||
}
|
||||
{%- endmacro -%}
|
||||
{%- macro format_argument(argument, escape_keys=True) -%}
|
||||
{%- if argument is string -%}
|
||||
{{- '<|"|>' + argument + '<|"|>' -}}
|
||||
{%- elif argument is boolean -%}
|
||||
{{- 'true' if argument else 'false' -}}
|
||||
{%- elif argument is mapping -%}
|
||||
{{- '{' -}}
|
||||
{%- set ns = namespace(found_first=false) -%}
|
||||
{%- for key, value in argument | dictsort -%}
|
||||
{%- if ns.found_first %},{% endif -%}
|
||||
{%- set ns.found_first = true -%}
|
||||
{%- if escape_keys -%}
|
||||
{{- '<|"|>' + key + '<|"|>' -}}
|
||||
{%- else -%}
|
||||
{{- key -}}
|
||||
{%- endif -%}
|
||||
:{{- format_argument(value, escape_keys=escape_keys) -}}
|
||||
{%- endfor -%}
|
||||
{{- '}' -}}
|
||||
{%- elif argument is sequence -%}
|
||||
{{- '[' -}}
|
||||
{%- for item in argument -%}
|
||||
{{- format_argument(item, escape_keys=escape_keys) -}}
|
||||
{%- if not loop.last %},{% endif -%}
|
||||
{%- endfor -%}
|
||||
{{- ']' -}}
|
||||
{%- else -%}
|
||||
{{- argument -}}
|
||||
{%- endif -%}
|
||||
{%- endmacro -%}
|
||||
{%- macro strip_thinking(text) -%}
|
||||
{%- set ns = namespace(result='') -%}
|
||||
{%- for part in text.split('<channel|>') -%}
|
||||
{%- if '<|channel>' in part -%}
|
||||
{%- set ns.result = ns.result + part.split('<|channel>')[0] -%}
|
||||
{%- else -%}
|
||||
{%- set ns.result = ns.result + part -%}
|
||||
{%- endif -%}
|
||||
{%- endfor -%}
|
||||
{{- ns.result | trim -}}
|
||||
{%- endmacro -%}
|
||||
|
||||
{%- macro format_tool_response_block(tool_name, response) -%}
|
||||
{{- '<|tool_response>' -}}
|
||||
{%- if response is mapping -%}
|
||||
{{- 'response:' + tool_name + '{' -}}
|
||||
{%- for key, value in response | dictsort -%}
|
||||
{{- key -}}:{{- format_argument(value, escape_keys=False) -}}
|
||||
{%- if not loop.last %},{% endif -%}
|
||||
{%- endfor -%}
|
||||
{{- '}' -}}
|
||||
{%- else -%}
|
||||
{{- 'response:' + tool_name + '{value:' + format_argument(response, escape_keys=False) + '}' -}}
|
||||
{%- endif -%}
|
||||
{{- '<tool_response|>' -}}
|
||||
{%- endmacro -%}
|
||||
|
||||
{%- set ns = namespace(prev_message_type=None) -%}
|
||||
{%- set loop_messages = messages -%}
|
||||
{{- bos_token -}}
|
||||
{#- Handle System/Tool Definitions Block -#}
|
||||
{%- if (enable_thinking is defined and enable_thinking) or tools or messages[0]['role'] in ['system', 'developer'] -%}
|
||||
{{- '<|turn>system\n' -}}
|
||||
|
||||
{#- Inject Thinking token at the very top of the FIRST system turn -#}
|
||||
{%- if enable_thinking is defined and enable_thinking -%}
|
||||
{{- '<|think|>\n' -}}
|
||||
{%- set ns.prev_message_type = 'think' -%}
|
||||
{%- endif -%}
|
||||
|
||||
{%- if messages[0]['role'] in ['system', 'developer'] -%}
|
||||
{{- messages[0]['content'] | trim -}}
|
||||
{%- set loop_messages = messages[1:] -%}
|
||||
{%- endif -%}
|
||||
|
||||
{%- if tools -%}
|
||||
{%- for tool in tools %}
|
||||
{{- '<|tool>' -}}
|
||||
{{- format_function_declaration(tool) | trim -}}
|
||||
{{- '<tool|>' -}}
|
||||
{%- endfor %}
|
||||
{%- set ns.prev_message_type = 'tool' -%}
|
||||
{%- endif -%}
|
||||
|
||||
{{- '<turn|>\n' -}}
|
||||
{%- endif %}
|
||||
|
||||
{#- Pre-scan: find last user message index for reasoning guard -#}
|
||||
{%- set ns_turn = namespace(last_user_idx=-1) -%}
|
||||
{%- for i in range(loop_messages | length) -%}
|
||||
{%- if loop_messages[i]['role'] == 'user' -%}
|
||||
{%- set ns_turn.last_user_idx = i -%}
|
||||
{%- endif -%}
|
||||
{%- endfor -%}
|
||||
|
||||
{#- Loop through messages -#}
|
||||
{%- for message in loop_messages -%}
|
||||
{%- if message['role'] != 'tool' -%}
|
||||
{%- set ns.prev_message_type = None -%}
|
||||
{%- set role = 'model' if message['role'] == 'assistant' else message['role'] -%}
|
||||
{#- Detect continuation: suppress duplicate <|turn>model when previous non-tool message was also assistant -#}
|
||||
{%- set prev_nt = namespace(role=None, found=false) -%}
|
||||
{%- if loop.index0 > 0 -%}
|
||||
{%- for j in range(loop.index0 - 1, -1, -1) -%}
|
||||
{%- if not prev_nt.found -%}
|
||||
{%- if loop_messages[j]['role'] != 'tool' -%}
|
||||
{%- set prev_nt.role = loop_messages[j]['role'] -%}
|
||||
{%- set prev_nt.found = true -%}
|
||||
{%- endif -%}
|
||||
{%- endif -%}
|
||||
{%- endfor -%}
|
||||
{%- endif -%}
|
||||
{%- set continue_same_model_turn = (role == 'model' and prev_nt.role == 'assistant') -%}
|
||||
{%- if not continue_same_model_turn -%}
|
||||
{{- '<|turn>' + role + '\n' }}
|
||||
{%- endif -%}
|
||||
|
||||
{#- Render reasoning/reasoning_content as thinking channel -#}
|
||||
{%- set thinking_text = message.get('reasoning') or message.get('reasoning_content') -%}
|
||||
{%- if thinking_text and loop.index0 > ns_turn.last_user_idx and message.get('tool_calls') -%}
|
||||
{{- '<|channel>thought\n' + thinking_text + '\n<channel|>' -}}
|
||||
{%- endif -%}
|
||||
|
||||
{%- if message['tool_calls'] -%}
|
||||
{%- for tool_call in message['tool_calls'] -%}
|
||||
{%- set function = tool_call['function'] -%}
|
||||
{{- '<|tool_call>call:' + function['name'] + '{' -}}
|
||||
{%- if function['arguments'] is mapping -%}
|
||||
{%- set ns_args = namespace(found_first=false) -%}
|
||||
{%- for key, value in function['arguments'] | dictsort -%}
|
||||
{%- if ns_args.found_first %},{% endif -%}
|
||||
{%- set ns_args.found_first = true -%}
|
||||
{{- key -}}:{{- format_argument(value, escape_keys=False) -}}
|
||||
{%- endfor -%}
|
||||
{%- elif function['arguments'] is string -%}
|
||||
{{- function['arguments'] -}}
|
||||
{%- endif -%}
|
||||
{{- '}<tool_call|>' -}}
|
||||
{%- endfor -%}
|
||||
{%- set ns.prev_message_type = 'tool_call' -%}
|
||||
{%- endif -%}
|
||||
|
||||
{%- set ns_tr_out = namespace(flag=false) -%}
|
||||
{%- if message.get('tool_responses') -%}
|
||||
{#- Legacy: tool_responses embedded on the assistant message (Google/Gemma native) -#}
|
||||
{%- for tool_response in message['tool_responses'] -%}
|
||||
{{- format_tool_response_block(tool_response['name'] | default('unknown'), tool_response['response']) -}}
|
||||
{%- set ns_tr_out.flag = true -%}
|
||||
{%- set ns.prev_message_type = 'tool_response' -%}
|
||||
{%- endfor -%}
|
||||
{%- elif message.get('tool_calls') -%}
|
||||
{#- OpenAI Chat Completions: forward-scan consecutive role:tool messages -#}
|
||||
{%- set ns_tool_scan = namespace(stopped=false) -%}
|
||||
{%- for k in range(loop.index0 + 1, loop_messages | length) -%}
|
||||
{%- if ns_tool_scan.stopped -%}
|
||||
{%- elif loop_messages[k]['role'] != 'tool' -%}
|
||||
{%- set ns_tool_scan.stopped = true -%}
|
||||
{%- else -%}
|
||||
{%- set follow = loop_messages[k] -%}
|
||||
{#- Resolve tool_call_id to function name -#}
|
||||
{%- set ns_tname = namespace(name=follow.get('name') | default('unknown')) -%}
|
||||
{%- for tc in message['tool_calls'] -%}
|
||||
{%- if tc.get('id') == follow.get('tool_call_id') -%}
|
||||
{%- set ns_tname.name = tc['function']['name'] -%}
|
||||
{%- endif -%}
|
||||
{%- endfor -%}
|
||||
{#- Handle content as string or content-parts array -#}
|
||||
{%- set tool_body = follow.get('content') -%}
|
||||
{%- if tool_body is string -%}
|
||||
{{- format_tool_response_block(ns_tname.name, tool_body) -}}
|
||||
{%- elif tool_body is sequence and tool_body is not string -%}
|
||||
{%- set ns_txt = namespace(s='') -%}
|
||||
{%- for part in tool_body -%}
|
||||
{%- if part.get('type') == 'text' -%}
|
||||
{%- set ns_txt.s = ns_txt.s + (part.get('text') | default('')) -%}
|
||||
{%- endif -%}
|
||||
{%- endfor -%}
|
||||
{{- format_tool_response_block(ns_tname.name, ns_txt.s) -}}
|
||||
{%- else -%}
|
||||
{{- format_tool_response_block(ns_tname.name, tool_body) -}}
|
||||
{%- endif -%}
|
||||
{%- set ns_tr_out.flag = true -%}
|
||||
{%- set ns.prev_message_type = 'tool_response' -%}
|
||||
{%- endif -%}
|
||||
{%- endfor -%}
|
||||
{%- endif -%}
|
||||
|
||||
{%- if message['content'] is string -%}
|
||||
{%- if role == 'model' -%}
|
||||
{{- strip_thinking(message['content']) -}}
|
||||
{%- else -%}
|
||||
{{- message['content'] | trim -}}
|
||||
{%- endif -%}
|
||||
{%- elif message['content'] is sequence -%}
|
||||
{%- for item in message['content'] -%}
|
||||
{%- if item['type'] == 'text' -%}
|
||||
{%- if role == 'model' -%}
|
||||
{{- strip_thinking(item['text']) -}}
|
||||
{%- else -%}
|
||||
{{- item['text'] | trim -}}
|
||||
{%- endif -%}
|
||||
{%- elif item['type'] == 'image' -%}
|
||||
{{- '<|image|>' -}}
|
||||
{%- set ns.prev_message_type = 'image' -%}
|
||||
{%- elif item['type'] == 'audio' -%}
|
||||
{{- '<|audio|>' -}}
|
||||
{%- set ns.prev_message_type = 'audio' -%}
|
||||
{%- elif item['type'] == 'video' -%}
|
||||
{{- '<|video|>' -}}
|
||||
{%- set ns.prev_message_type = 'video' -%}
|
||||
{%- endif -%}
|
||||
{%- endfor -%}
|
||||
{%- endif -%}
|
||||
|
||||
{%- if ns.prev_message_type == 'tool_call' and not ns_tr_out.flag -%}
|
||||
{{- '<|tool_response>' -}}
|
||||
{%- elif not (ns_tr_out.flag and not message.get('content')) -%}
|
||||
{{- '<turn|>\n' -}}
|
||||
{%- endif -%}
|
||||
{%- endif -%}
|
||||
{%- endfor -%}
|
||||
|
||||
{%- if add_generation_prompt -%}
|
||||
{%- if ns.prev_message_type != 'tool_response' and ns.prev_message_type != 'tool_call' -%}
|
||||
{{- '<|turn>model\n' -}}
|
||||
{%- endif -%}
|
||||
{%- endif -%}
|
||||
@@ -523,7 +523,7 @@ func createModel(r api.CreateRequest, name model.Name, baseLayers []*layerGGML,
|
||||
arch := layer.GGML.KV().Architecture()
|
||||
switch arch {
|
||||
case "gemma4":
|
||||
config.Renderer = cmp.Or(config.Renderer, "gemma4")
|
||||
config.Renderer = cmp.Or(config.Renderer, gemma4RendererLegacy)
|
||||
config.Parser = cmp.Or(config.Parser, "gemma4")
|
||||
if _, ok := r.Parameters["stop"]; !ok {
|
||||
if r.Parameters == nil {
|
||||
|
||||
78
server/gemma4_test.go
Normal file
78
server/gemma4_test.go
Normal file
@@ -0,0 +1,78 @@
|
||||
package server
|
||||
|
||||
import "testing"
|
||||
|
||||
func TestResolveGemma4Renderer(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
model *Model
|
||||
want string
|
||||
}{
|
||||
{
|
||||
name: "nil model falls back to legacy alias",
|
||||
model: nil,
|
||||
want: gemma4RendererLegacy,
|
||||
},
|
||||
{
|
||||
name: "explicit small passes through",
|
||||
model: &Model{
|
||||
Config: testConfigWithRenderer(gemma4RendererSmall),
|
||||
},
|
||||
want: gemma4RendererSmall,
|
||||
},
|
||||
{
|
||||
name: "explicit large passes through",
|
||||
model: &Model{
|
||||
Config: testConfigWithRenderer(gemma4RendererLarge),
|
||||
},
|
||||
want: gemma4RendererLarge,
|
||||
},
|
||||
{
|
||||
name: "legacy e4b tag resolves small",
|
||||
model: &Model{
|
||||
Name: "gemma4:e4b",
|
||||
ShortName: "gemma4:e4b",
|
||||
Config: testConfigWithRenderer(gemma4RendererLegacy),
|
||||
},
|
||||
want: gemma4RendererSmall,
|
||||
},
|
||||
{
|
||||
name: "legacy 31b tag resolves large",
|
||||
model: &Model{
|
||||
Name: "gemma4:31b-cloud",
|
||||
ShortName: "gemma4:31b-cloud",
|
||||
Config: testConfigWithRenderer(gemma4RendererLegacy),
|
||||
},
|
||||
want: gemma4RendererLarge,
|
||||
},
|
||||
{
|
||||
name: "legacy model type resolves small",
|
||||
model: &Model{
|
||||
Config: testConfigWithRendererAndType(gemma4RendererLegacy, "4.3B"),
|
||||
},
|
||||
want: gemma4RendererSmall,
|
||||
},
|
||||
{
|
||||
name: "legacy model type resolves large",
|
||||
model: &Model{
|
||||
Config: testConfigWithRendererAndType(gemma4RendererLegacy, "25.2B"),
|
||||
},
|
||||
want: gemma4RendererLarge,
|
||||
},
|
||||
{
|
||||
name: "legacy unknown defaults small",
|
||||
model: &Model{
|
||||
Config: testConfigWithRenderer(gemma4RendererLegacy),
|
||||
},
|
||||
want: gemma4RendererSmall,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
if got := resolveGemma4Renderer(tt.model); got != tt.want {
|
||||
t.Fatalf("resolveGemma4Renderer() = %q, want %q", got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -19,6 +19,7 @@ import (
|
||||
"slices"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/ollama/ollama/api"
|
||||
"github.com/ollama/ollama/envconfig"
|
||||
@@ -33,6 +34,10 @@ import (
|
||||
"github.com/ollama/ollama/x/imagegen/transfer"
|
||||
)
|
||||
|
||||
// Blobs newer than this may belong to another process that has not written its
|
||||
// manifest yet. They become eligible for the normal mark-and-sweep pass later.
|
||||
const layerPruneGracePeriod = time.Hour
|
||||
|
||||
var (
|
||||
errCapabilities = errors.New("does not support")
|
||||
errCapabilityCompletion = errors.New("completion")
|
||||
@@ -156,7 +161,7 @@ func (m *Model) Capabilities() []model.Capability {
|
||||
|
||||
// Temporary workaround — suppress vision/audio for gemma4 MLX models
|
||||
// until multimodal runtime pipeline lands. Remove when imageproc.go is wired up.
|
||||
if m.Config.ModelFormat == "safetensors" && m.Config.Renderer == "gemma4" {
|
||||
if m.Config.ModelFormat == "safetensors" && isGemma4Renderer(m.Config.Renderer) {
|
||||
capabilities = slices.DeleteFunc(capabilities, func(c model.Capability) bool {
|
||||
return c == model.CapabilityVision || c == "audio"
|
||||
})
|
||||
@@ -478,10 +483,23 @@ func PruneLayers() error {
|
||||
}
|
||||
|
||||
for _, blob := range blobs {
|
||||
if blob.IsDir() {
|
||||
continue
|
||||
}
|
||||
|
||||
info, err := blob.Info()
|
||||
if err != nil {
|
||||
slog.Error("couldn't stat blob", "blob", blob.Name(), "error", err)
|
||||
continue
|
||||
}
|
||||
if time.Since(info.ModTime()) < layerPruneGracePeriod {
|
||||
continue
|
||||
}
|
||||
|
||||
name := blob.Name()
|
||||
name = strings.ReplaceAll(name, "-", ":")
|
||||
|
||||
_, err := manifest.BlobsPath(name)
|
||||
_, err = manifest.BlobsPath(name)
|
||||
if err != nil {
|
||||
if errors.Is(err, manifest.ErrInvalidDigestFormat) {
|
||||
// remove invalid blobs (e.g. partial downloads)
|
||||
|
||||
@@ -5,14 +5,58 @@ import (
|
||||
"fmt"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"os"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/ollama/ollama/fs/ggml"
|
||||
"github.com/ollama/ollama/manifest"
|
||||
"github.com/ollama/ollama/template"
|
||||
"github.com/ollama/ollama/types/model"
|
||||
)
|
||||
|
||||
func TestPruneLayersSkipsRecentOrphans(t *testing.T) {
|
||||
t.Setenv("OLLAMA_MODELS", t.TempDir())
|
||||
|
||||
recentDigest := "sha256:0000000000000000000000000000000000000000000000000000000000000001"
|
||||
oldDigest := "sha256:0000000000000000000000000000000000000000000000000000000000000002"
|
||||
|
||||
for _, digest := range []string{recentDigest, oldDigest} {
|
||||
p, err := manifest.BlobsPath(digest)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if err := os.WriteFile(p, nil, 0o644); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
}
|
||||
|
||||
oldPath, err := manifest.BlobsPath(oldDigest)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
oldTime := time.Now().Add(-layerPruneGracePeriod - time.Hour)
|
||||
if err := os.Chtimes(oldPath, oldTime, oldTime); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
if err := PruneLayers(); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
recentPath, err := manifest.BlobsPath(recentDigest)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if _, err := os.Stat(recentPath); err != nil {
|
||||
t.Fatalf("recent orphan was pruned: %v", err)
|
||||
}
|
||||
if _, err := os.Stat(oldPath); !os.IsNotExist(err) {
|
||||
t.Fatalf("old orphan still exists: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestModelCapabilities(t *testing.T) {
|
||||
// Create completion model (llama architecture without vision)
|
||||
completionModelPath, _ := createBinFile(t, ggml.KV{
|
||||
@@ -118,6 +162,39 @@ func TestModelCapabilities(t *testing.T) {
|
||||
},
|
||||
expectedCaps: []model.Capability{model.CapabilityEmbedding},
|
||||
},
|
||||
{
|
||||
name: "gemma4 small safetensors suppresses vision and audio",
|
||||
model: Model{
|
||||
Config: model.ConfigV2{
|
||||
ModelFormat: "safetensors",
|
||||
Renderer: gemma4RendererSmall,
|
||||
Capabilities: []string{"vision", "audio"},
|
||||
},
|
||||
Template: chatTemplate,
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "gemma4 large safetensors suppresses vision and audio",
|
||||
model: Model{
|
||||
Config: model.ConfigV2{
|
||||
ModelFormat: "safetensors",
|
||||
Renderer: gemma4RendererLarge,
|
||||
Capabilities: []string{"vision", "audio"},
|
||||
},
|
||||
Template: chatTemplate,
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "legacy gemma4 safetensors suppresses vision and audio",
|
||||
model: Model{
|
||||
Config: model.ConfigV2{
|
||||
ModelFormat: "safetensors",
|
||||
Renderer: gemma4RendererLegacy,
|
||||
Capabilities: []string{"vision", "audio"},
|
||||
},
|
||||
Template: chatTemplate,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
// compare two slices of model.Capability regardless of order
|
||||
|
||||
@@ -115,7 +115,8 @@ func chatPrompt(ctx context.Context, m *Model, tokenize tokenizeFunc, opts *api.
|
||||
|
||||
func renderPrompt(m *Model, msgs []api.Message, tools []api.Tool, think *api.ThinkValue) (string, error) {
|
||||
if m.Config.Renderer != "" {
|
||||
rendered, err := renderers.RenderWithRenderer(m.Config.Renderer, msgs, tools, think)
|
||||
rendererName := resolveRendererName(m)
|
||||
rendered, err := renderers.RenderWithRenderer(rendererName, msgs, tools, think)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
@@ -13,6 +13,14 @@ import (
|
||||
"github.com/ollama/ollama/types/model"
|
||||
)
|
||||
|
||||
func testConfigWithRenderer(renderer string) model.ConfigV2 {
|
||||
return model.ConfigV2{Renderer: renderer}
|
||||
}
|
||||
|
||||
func testConfigWithRendererAndType(renderer, modelType string) model.ConfigV2 {
|
||||
return model.ConfigV2{Renderer: renderer, ModelType: modelType}
|
||||
}
|
||||
|
||||
func TestChatPrompt(t *testing.T) {
|
||||
type expect struct {
|
||||
prompt string
|
||||
@@ -397,3 +405,43 @@ func TestChatPromptGLMOcrRendererAddsImageTags(t *testing.T) {
|
||||
t.Fatalf("prompt missing glm-ocr image tags, got: %q", prompt)
|
||||
}
|
||||
}
|
||||
|
||||
func TestRenderPromptResolvesDynamicGemma4Renderer(t *testing.T) {
|
||||
msgs := []api.Message{{Role: "user", Content: "Hello"}}
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
model Model
|
||||
want string
|
||||
}{
|
||||
{
|
||||
name: "small from name",
|
||||
model: Model{
|
||||
Name: "gemma4:e4b",
|
||||
ShortName: "gemma4:e4b",
|
||||
Config: testConfigWithRenderer(gemma4RendererLegacy),
|
||||
},
|
||||
want: "<bos><|turn>user\nHello<turn|>\n<|turn>model\n",
|
||||
},
|
||||
{
|
||||
name: "large from model type",
|
||||
model: Model{
|
||||
Config: testConfigWithRendererAndType(gemma4RendererLegacy, "25.2B"),
|
||||
},
|
||||
want: "<bos><|turn>user\nHello<turn|>\n<|turn>model\n<|channel>thought\n<channel|>",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
got, err := renderPrompt(&tt.model, msgs, nil, nil)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
if diff := cmp.Diff(got, tt.want); diff != "" {
|
||||
t.Fatalf("rendered prompt mismatch (-got +want):\n%s", diff)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
110
server/renderer_resolution.go
Normal file
110
server/renderer_resolution.go
Normal file
@@ -0,0 +1,110 @@
|
||||
package server
|
||||
|
||||
import (
|
||||
"strconv"
|
||||
"strings"
|
||||
|
||||
"github.com/ollama/ollama/format"
|
||||
)
|
||||
|
||||
const (
|
||||
gemma4RendererLegacy = "gemma4"
|
||||
gemma4RendererSmall = "gemma4-small"
|
||||
gemma4RendererLarge = "gemma4-large"
|
||||
|
||||
// Gemma 4 small templates cover the e2b/e4b family, while 26b/31b use the
|
||||
// large template. Default to the small prompt unless the model is clearly in
|
||||
// the large range.
|
||||
gemma4LargeMinParameterCount = 16_000_000_000
|
||||
)
|
||||
|
||||
func resolveRendererName(m *Model) string {
|
||||
if m == nil || m.Config.Renderer == "" {
|
||||
return ""
|
||||
}
|
||||
|
||||
switch m.Config.Renderer {
|
||||
case gemma4RendererLegacy:
|
||||
return resolveGemma4Renderer(m)
|
||||
default:
|
||||
return m.Config.Renderer
|
||||
}
|
||||
}
|
||||
|
||||
func resolveGemma4Renderer(m *Model) string {
|
||||
if m == nil || m.Config.Renderer != gemma4RendererLegacy {
|
||||
if m == nil {
|
||||
return gemma4RendererLegacy
|
||||
}
|
||||
return m.Config.Renderer
|
||||
}
|
||||
|
||||
if renderer, ok := gemma4RendererFromName(m.ShortName); ok {
|
||||
return renderer
|
||||
}
|
||||
|
||||
if renderer, ok := gemma4RendererFromName(m.Name); ok {
|
||||
return renderer
|
||||
}
|
||||
|
||||
if parameterCount, ok := parseHumanParameterCount(m.Config.ModelType); ok {
|
||||
return gemma4RendererForParameterCount(parameterCount)
|
||||
}
|
||||
|
||||
return gemma4RendererSmall
|
||||
}
|
||||
|
||||
func gemma4RendererForParameterCount(parameterCount uint64) string {
|
||||
if parameterCount >= gemma4LargeMinParameterCount {
|
||||
return gemma4RendererLarge
|
||||
}
|
||||
|
||||
return gemma4RendererSmall
|
||||
}
|
||||
|
||||
func gemma4RendererFromName(name string) (string, bool) {
|
||||
lower := strings.ToLower(name)
|
||||
switch {
|
||||
case strings.Contains(lower, "e2b"), strings.Contains(lower, "e4b"):
|
||||
return gemma4RendererSmall, true
|
||||
case strings.Contains(lower, "26b"), strings.Contains(lower, "31b"):
|
||||
return gemma4RendererLarge, true
|
||||
default:
|
||||
return "", false
|
||||
}
|
||||
}
|
||||
|
||||
func parseHumanParameterCount(s string) (uint64, bool) {
|
||||
if s == "" {
|
||||
return 0, false
|
||||
}
|
||||
|
||||
unit := strings.ToUpper(s[len(s)-1:])
|
||||
var multiplier float64
|
||||
switch unit {
|
||||
case "B":
|
||||
multiplier = float64(format.Billion)
|
||||
case "M":
|
||||
multiplier = float64(format.Million)
|
||||
case "K":
|
||||
multiplier = float64(format.Thousand)
|
||||
default:
|
||||
return 0, false
|
||||
}
|
||||
|
||||
value, err := strconv.ParseFloat(s[:len(s)-1], 64)
|
||||
if err != nil {
|
||||
return 0, false
|
||||
}
|
||||
|
||||
return uint64(value * multiplier), true
|
||||
}
|
||||
|
||||
func isGemma4Renderer(renderer string) bool {
|
||||
switch renderer {
|
||||
case gemma4RendererLegacy, gemma4RendererSmall, gemma4RendererLarge:
|
||||
return true
|
||||
default:
|
||||
return false
|
||||
}
|
||||
}
|
||||
@@ -928,6 +928,59 @@ func TestCreateDetectTemplate(t *testing.T) {
|
||||
})
|
||||
}
|
||||
|
||||
func TestCreateGemma4KeepsDynamicRendererAlias(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
|
||||
p := t.TempDir()
|
||||
t.Setenv("OLLAMA_MODELS", p)
|
||||
var s Server
|
||||
|
||||
_, digest := createBinFile(t, ggml.KV{
|
||||
"general.architecture": "gemma4",
|
||||
"general.parameter_count": uint64(25_200_000_000),
|
||||
}, nil)
|
||||
|
||||
w := createRequest(t, s.CreateHandler, api.CreateRequest{
|
||||
Name: "test",
|
||||
Files: map[string]string{"test.gguf": digest},
|
||||
Stream: &stream,
|
||||
})
|
||||
if w.Code != http.StatusOK {
|
||||
t.Fatalf("expected status code 200, actual %d", w.Code)
|
||||
}
|
||||
|
||||
mf, err := manifest.ParseNamedManifest(model.ParseName("test"))
|
||||
if err != nil {
|
||||
t.Fatalf("parse manifest: %v", err)
|
||||
}
|
||||
if mf.Config.Digest == "" {
|
||||
t.Fatalf("unexpected empty config digest for manifest")
|
||||
}
|
||||
|
||||
configPath, err := manifest.BlobsPath(mf.Config.Digest)
|
||||
if err != nil {
|
||||
t.Fatalf("config blob path: %v", err)
|
||||
}
|
||||
|
||||
cfgFile, err := os.Open(configPath)
|
||||
if err != nil {
|
||||
t.Fatalf("open config blob: %v", err)
|
||||
}
|
||||
defer cfgFile.Close()
|
||||
|
||||
var cfg model.ConfigV2
|
||||
if err := json.NewDecoder(cfgFile).Decode(&cfg); err != nil {
|
||||
t.Fatalf("decode config: %v", err)
|
||||
}
|
||||
|
||||
if cfg.Renderer != gemma4RendererLegacy {
|
||||
t.Fatalf("expected renderer %q, got %q", gemma4RendererLegacy, cfg.Renderer)
|
||||
}
|
||||
if cfg.Parser != "gemma4" {
|
||||
t.Fatalf("expected parser %q, got %q", "gemma4", cfg.Parser)
|
||||
}
|
||||
}
|
||||
|
||||
func TestDetectModelTypeFromFiles(t *testing.T) {
|
||||
t.Run("gguf file", func(t *testing.T) {
|
||||
_, digest := createBinFile(t, nil, nil)
|
||||
|
||||
@@ -191,6 +191,10 @@ func inferSafetensorsCapabilities(modelDir string) []string {
|
||||
capabilities = append(capabilities, "vision")
|
||||
}
|
||||
|
||||
if supportsAudio(modelDir) {
|
||||
capabilities = append(capabilities, "audio")
|
||||
}
|
||||
|
||||
if supportsThinking(modelDir) {
|
||||
capabilities = append(capabilities, "thinking")
|
||||
}
|
||||
@@ -496,32 +500,38 @@ func supportsThinking(modelDir string) bool {
|
||||
return false
|
||||
}
|
||||
|
||||
// supportsVision checks if the model supports image input based on its architecture.
|
||||
// Qwen3.5 multimodal checkpoints are published as ConditionalGeneration architectures.
|
||||
// supportsVision checks if the model has a vision encoder by looking for
|
||||
// vision_config in config.json.
|
||||
func supportsVision(modelDir string) bool {
|
||||
configPath := filepath.Join(modelDir, "config.json")
|
||||
data, err := os.ReadFile(configPath)
|
||||
data, err := os.ReadFile(filepath.Join(modelDir, "config.json"))
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
|
||||
var cfg struct {
|
||||
Architectures []string `json:"architectures"`
|
||||
ModelType string `json:"model_type"`
|
||||
VisionConfig *map[string]any `json:"vision_config"`
|
||||
}
|
||||
if err := json.Unmarshal(data, &cfg); err != nil {
|
||||
return false
|
||||
}
|
||||
|
||||
for _, arch := range cfg.Architectures {
|
||||
archLower := strings.ToLower(arch)
|
||||
if strings.Contains(archLower, "qwen3") && strings.Contains(archLower, "conditionalgeneration") {
|
||||
return true
|
||||
}
|
||||
return cfg.VisionConfig != nil
|
||||
}
|
||||
|
||||
func supportsAudio(modelDir string) bool {
|
||||
data, err := os.ReadFile(filepath.Join(modelDir, "config.json"))
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
|
||||
typeLower := strings.ToLower(cfg.ModelType)
|
||||
return strings.Contains(typeLower, "qwen3") && strings.Contains(typeLower, "conditionalgeneration")
|
||||
var cfg struct {
|
||||
AudioConfig *map[string]any `json:"audio_config"`
|
||||
}
|
||||
if err := json.Unmarshal(data, &cfg); err != nil {
|
||||
return false
|
||||
}
|
||||
|
||||
return cfg.AudioConfig != nil
|
||||
}
|
||||
|
||||
// getParserName returns the parser name for a model based on its architecture.
|
||||
@@ -550,6 +560,9 @@ func getParserName(modelDir string) string {
|
||||
if strings.Contains(archLower, "deepseek") {
|
||||
return "deepseek3"
|
||||
}
|
||||
if strings.Contains(archLower, "gemma4") {
|
||||
return "gemma4"
|
||||
}
|
||||
if strings.Contains(archLower, "qwen3") {
|
||||
return "qwen3"
|
||||
}
|
||||
@@ -564,6 +577,9 @@ func getParserName(modelDir string) string {
|
||||
if strings.Contains(typeLower, "deepseek") {
|
||||
return "deepseek3"
|
||||
}
|
||||
if strings.Contains(typeLower, "gemma4") {
|
||||
return "gemma4"
|
||||
}
|
||||
if strings.Contains(typeLower, "qwen3") {
|
||||
return "qwen3"
|
||||
}
|
||||
@@ -592,6 +608,9 @@ func getRendererName(modelDir string) string {
|
||||
// Check architectures for known renderers
|
||||
for _, arch := range cfg.Architectures {
|
||||
archLower := strings.ToLower(arch)
|
||||
if strings.Contains(archLower, "gemma4") {
|
||||
return "gemma4"
|
||||
}
|
||||
if strings.Contains(archLower, "glm4") || strings.Contains(archLower, "glm-4") {
|
||||
return "glm-4.7"
|
||||
}
|
||||
@@ -606,6 +625,9 @@ func getRendererName(modelDir string) string {
|
||||
// Also check model_type
|
||||
if cfg.ModelType != "" {
|
||||
typeLower := strings.ToLower(cfg.ModelType)
|
||||
if strings.Contains(typeLower, "gemma4") {
|
||||
return "gemma4"
|
||||
}
|
||||
if strings.Contains(typeLower, "glm4") || strings.Contains(typeLower, "glm-4") {
|
||||
return "glm-4.7"
|
||||
}
|
||||
|
||||
@@ -311,10 +311,30 @@ func TestInferSafetensorsCapabilities(t *testing.T) {
|
||||
name: "qwen3.5 multimodal model",
|
||||
configJSON: `{
|
||||
"architectures": ["Qwen3_5ForConditionalGeneration"],
|
||||
"model_type": "qwen3"
|
||||
"model_type": "qwen3",
|
||||
"vision_config": {"hidden_size": 1024}
|
||||
}`,
|
||||
want: []string{"completion", "vision", "thinking"},
|
||||
},
|
||||
{
|
||||
name: "model with audio config",
|
||||
configJSON: `{
|
||||
"architectures": ["Gemma4ForConditionalGeneration"],
|
||||
"model_type": "gemma4",
|
||||
"vision_config": {"hidden_size": 1024},
|
||||
"audio_config": {"num_mel_bins": 128}
|
||||
}`,
|
||||
want: []string{"completion", "vision", "audio"},
|
||||
},
|
||||
{
|
||||
name: "model with audio but no vision",
|
||||
configJSON: `{
|
||||
"architectures": ["SomeAudioModel"],
|
||||
"model_type": "other",
|
||||
"audio_config": {"num_mel_bins": 128}
|
||||
}`,
|
||||
want: []string{"completion", "audio"},
|
||||
},
|
||||
{
|
||||
name: "non-qwen conditional generation model",
|
||||
configJSON: `{
|
||||
@@ -339,6 +359,74 @@ func TestInferSafetensorsCapabilities(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestParsePerExpertInputs(t *testing.T) {
|
||||
makeInput := func(name, quantize string) create.PackedTensorInput {
|
||||
return create.PackedTensorInput{Name: name, Quantize: quantize}
|
||||
}
|
||||
|
||||
t.Run("uniform quant across projections", func(t *testing.T) {
|
||||
inputs := []create.PackedTensorInput{
|
||||
makeInput("layer.moe.experts.0.gate_proj.weight", "int4"),
|
||||
makeInput("layer.moe.experts.1.gate_proj.weight", "int4"),
|
||||
makeInput("layer.moe.experts.0.down_proj.weight", "int4"),
|
||||
makeInput("layer.moe.experts.1.down_proj.weight", "int4"),
|
||||
}
|
||||
groups, projQ := parsePerExpertInputs("layer.moe.experts", inputs)
|
||||
if groups == nil {
|
||||
t.Fatal("expected non-nil groups")
|
||||
}
|
||||
if len(groups) != 2 {
|
||||
t.Fatalf("expected 2 projection groups, got %d", len(groups))
|
||||
}
|
||||
if projQ["gate_proj.weight"] != "int4" {
|
||||
t.Errorf("gate_proj quant = %q, want int4", projQ["gate_proj.weight"])
|
||||
}
|
||||
if projQ["down_proj.weight"] != "int4" {
|
||||
t.Errorf("down_proj quant = %q, want int4", projQ["down_proj.weight"])
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("mixed quant across projections", func(t *testing.T) {
|
||||
inputs := []create.PackedTensorInput{
|
||||
makeInput("layer.moe.experts.0.gate_proj.weight", "int4"),
|
||||
makeInput("layer.moe.experts.1.gate_proj.weight", "int4"),
|
||||
makeInput("layer.moe.experts.0.down_proj.weight", "int8"),
|
||||
makeInput("layer.moe.experts.1.down_proj.weight", "int8"),
|
||||
}
|
||||
groups, projQ := parsePerExpertInputs("layer.moe.experts", inputs)
|
||||
if groups == nil {
|
||||
t.Fatal("expected non-nil groups for mixed cross-projection quant")
|
||||
}
|
||||
if projQ["gate_proj.weight"] != "int4" {
|
||||
t.Errorf("gate_proj quant = %q, want int4", projQ["gate_proj.weight"])
|
||||
}
|
||||
if projQ["down_proj.weight"] != "int8" {
|
||||
t.Errorf("down_proj quant = %q, want int8", projQ["down_proj.weight"])
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("mixed quant within same projection rejected", func(t *testing.T) {
|
||||
inputs := []create.PackedTensorInput{
|
||||
makeInput("layer.moe.experts.0.down_proj.weight", "int4"),
|
||||
makeInput("layer.moe.experts.1.down_proj.weight", "int8"),
|
||||
}
|
||||
groups, _ := parsePerExpertInputs("layer.moe.experts", inputs)
|
||||
if groups != nil {
|
||||
t.Fatal("expected nil for mixed quant within same projection")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("non-experts group rejected", func(t *testing.T) {
|
||||
inputs := []create.PackedTensorInput{
|
||||
makeInput("layer.mlp.gate_proj.weight", "int4"),
|
||||
}
|
||||
groups, _ := parsePerExpertInputs("layer.mlp", inputs)
|
||||
if groups != nil {
|
||||
t.Fatal("expected nil for non-experts group")
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestQuantizeSupported(t *testing.T) {
|
||||
// This just verifies the function exists and returns a boolean
|
||||
// The actual value depends on build tags (mlx vs non-mlx)
|
||||
|
||||
@@ -97,6 +97,20 @@ func loadAndQuantizeArray(r io.Reader, name, quantize string, arrays map[string]
|
||||
groupSize, bits, mode := model.QuantizationParams(quantize)
|
||||
qweight, scales, qbiases := mlx.Quantize(arr, groupSize, bits, mode)
|
||||
|
||||
// Validate quantization produced non-empty output. MLX quantize may return
|
||||
// empty arrays for unsupported mode/bits combinations without raising an error.
|
||||
mlx.Eval(qweight, scales)
|
||||
if len(qweight.Dims()) == 0 || qweight.Dims()[0] == 0 {
|
||||
st.Free()
|
||||
return tmpPath, nil, nil, fmt.Errorf("mlx.Quantize produced empty weight for %s (quantize=%s, groupSize=%d, bits=%d, mode=%s)",
|
||||
name, quantize, groupSize, bits, mode)
|
||||
}
|
||||
if len(scales.Dims()) == 0 || scales.Dims()[0] == 0 {
|
||||
st.Free()
|
||||
return tmpPath, nil, nil, fmt.Errorf("mlx.Quantize produced empty scales for %s (quantize=%s, groupSize=%d, bits=%d, mode=%s)",
|
||||
name, quantize, groupSize, bits, mode)
|
||||
}
|
||||
|
||||
qweight = mlx.Contiguous(qweight, false)
|
||||
scales = mlx.Contiguous(scales, false)
|
||||
arrays[name] = qweight
|
||||
@@ -174,8 +188,8 @@ func quantizeTensor(r io.Reader, tensorName, dtype string, shape []int32, quanti
|
||||
// Returns the blob bytes.
|
||||
func quantizePackedGroup(groupName string, inputs []create.PackedTensorInput) ([]byte, error) {
|
||||
// Check if inputs are per-expert tensors that should be stacked into 3D
|
||||
if projGroups, quantize := parsePerExpertInputs(groupName, inputs); projGroups != nil {
|
||||
return stackAndQuantizeExpertGroup(groupName, projGroups, quantize)
|
||||
if projGroups, projQuantize := parsePerExpertInputs(groupName, inputs); projGroups != nil {
|
||||
return stackAndQuantizeExpertGroup(groupName, projGroups, projQuantize)
|
||||
}
|
||||
|
||||
allArrays := make(map[string]*mlx.Array)
|
||||
@@ -224,6 +238,17 @@ func quantizePackedGroup(groupName string, inputs []create.PackedTensorInput) ([
|
||||
mlx.Pin(finalArrays...)
|
||||
pinned = append(pinned, finalArrays...)
|
||||
|
||||
// Record per-tensor quant type so the model can resolve params at load time.
|
||||
if input.Quantize != "" {
|
||||
if groupSize, _, _ := model.QuantizationParams(input.Quantize); groupSize > 0 {
|
||||
if metadata == nil {
|
||||
metadata = make(map[string]string)
|
||||
}
|
||||
metadata[input.Name+".quant_type"] = input.Quantize
|
||||
metadata[input.Name+".group_size"] = strconv.Itoa(groupSize)
|
||||
}
|
||||
}
|
||||
|
||||
if st != nil {
|
||||
st.Free()
|
||||
}
|
||||
@@ -279,57 +304,60 @@ type expertTensorInfo struct {
|
||||
}
|
||||
|
||||
// parsePerExpertInputs groups per-expert 2D tensor inputs by projection type
|
||||
// and returns the uniform quantization type shared by all inputs.
|
||||
// Returns nil if the inputs are not per-expert tensors (e.g., already stacked 3D)
|
||||
// or if the inputs have mixed quantization types.
|
||||
// and returns per-projection quantization types. Different projections may use
|
||||
// different quant types (e.g., gate_up=int4, down=int8) but all experts within
|
||||
// a projection must share the same type.
|
||||
// Returns nil if the inputs are not per-expert tensors (e.g., already stacked 3D).
|
||||
// Only handles ".experts" groups; ".shared_experts" groups are left unpacked.
|
||||
func parsePerExpertInputs(groupName string, inputs []create.PackedTensorInput) (map[string][]expertTensorInfo, string) {
|
||||
func parsePerExpertInputs(groupName string, inputs []create.PackedTensorInput) (map[string][]expertTensorInfo, map[string]string) {
|
||||
if !strings.HasSuffix(groupName, ".experts") {
|
||||
return nil, ""
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
quantize := inputs[0].Quantize
|
||||
groups := make(map[string][]expertTensorInfo)
|
||||
projQuantize := make(map[string]string) // projection -> quant type
|
||||
for _, input := range inputs {
|
||||
if input.Quantize != quantize {
|
||||
return nil, "" // mixed quantization types
|
||||
}
|
||||
suffix := strings.TrimPrefix(input.Name, groupName)
|
||||
m := perExpertSuffix.FindStringSubmatch(suffix)
|
||||
if m == nil {
|
||||
return nil, "" // not a per-expert pattern
|
||||
return nil, nil // not a per-expert pattern
|
||||
}
|
||||
index, err := strconv.Atoi(m[1])
|
||||
if err != nil {
|
||||
return nil, ""
|
||||
return nil, nil
|
||||
}
|
||||
groups[m[2]] = append(groups[m[2]], expertTensorInfo{
|
||||
proj := m[2]
|
||||
if existing, ok := projQuantize[proj]; ok {
|
||||
if input.Quantize != existing {
|
||||
return nil, nil // mixed quant within same projection
|
||||
}
|
||||
} else {
|
||||
projQuantize[proj] = input.Quantize
|
||||
}
|
||||
groups[proj] = append(groups[proj], expertTensorInfo{
|
||||
index: index,
|
||||
proj: m[2],
|
||||
proj: proj,
|
||||
input: input,
|
||||
})
|
||||
}
|
||||
if len(groups) == 0 {
|
||||
return nil, ""
|
||||
return nil, nil
|
||||
}
|
||||
return groups, quantize
|
||||
return groups, projQuantize
|
||||
}
|
||||
|
||||
// stackAndQuantizeExpertGroup decodes per-expert tensors, stacks them into 3D
|
||||
// switch_mlp tensors, quantizes, and returns the combined safetensors blob.
|
||||
func stackAndQuantizeExpertGroup(groupName string, projGroups map[string][]expertTensorInfo, quantize string) ([]byte, error) {
|
||||
// projQuantize maps projection name to its quantization type (may differ per projection).
|
||||
func stackAndQuantizeExpertGroup(groupName string, projGroups map[string][]expertTensorInfo, projQuantize map[string]string) ([]byte, error) {
|
||||
groupBase := strings.TrimSuffix(groupName, ".experts")
|
||||
|
||||
allArrays := make(map[string]*mlx.Array)
|
||||
var pinned []*mlx.Array
|
||||
|
||||
var metadata map[string]string
|
||||
if groupSize, _, _ := model.QuantizationParams(quantize); groupSize > 0 && quantize != "" {
|
||||
metadata = map[string]string{
|
||||
"quant_type": quantize,
|
||||
"group_size": strconv.Itoa(groupSize),
|
||||
}
|
||||
}
|
||||
// Build metadata: if all projections use the same quant type, set global metadata.
|
||||
// Otherwise record per-tensor quant info.
|
||||
metadata := make(map[string]string)
|
||||
|
||||
// Sort projection names for deterministic output
|
||||
projNames := make([]string, 0, len(projGroups))
|
||||
@@ -339,7 +367,11 @@ func stackAndQuantizeExpertGroup(groupName string, projGroups map[string][]exper
|
||||
sort.Strings(projNames)
|
||||
|
||||
cleanup := func() {
|
||||
mlx.Unpin(pinned...)
|
||||
for _, p := range pinned {
|
||||
if p != nil {
|
||||
mlx.Unpin(p)
|
||||
}
|
||||
}
|
||||
mlx.Sweep()
|
||||
}
|
||||
|
||||
@@ -382,11 +414,27 @@ func stackAndQuantizeExpertGroup(groupName string, projGroups map[string][]exper
|
||||
mlx.Pin(stacked)
|
||||
pinned = append(pinned, stacked)
|
||||
|
||||
// Free individual decoded arrays
|
||||
// Free individual decoded arrays (remove from pinned to avoid double-unpin in cleanup)
|
||||
for i, p := range pinned {
|
||||
for _, d := range decoded {
|
||||
if p == d {
|
||||
pinned[i] = nil
|
||||
}
|
||||
}
|
||||
}
|
||||
mlx.Unpin(decoded...)
|
||||
mlx.Sweep()
|
||||
|
||||
stackedName := groupBase + ".switch_mlp." + proj
|
||||
quantize := projQuantize[proj]
|
||||
|
||||
// Record per-tensor quant metadata so the model can resolve params at load time.
|
||||
if quantize != "" {
|
||||
if groupSize, _, _ := model.QuantizationParams(quantize); groupSize > 0 {
|
||||
metadata[stackedName+".quant_type"] = quantize
|
||||
metadata[stackedName+".group_size"] = strconv.Itoa(groupSize)
|
||||
}
|
||||
}
|
||||
|
||||
// Quantize the stacked tensor
|
||||
if quantize != "" {
|
||||
@@ -394,6 +442,14 @@ func stackAndQuantizeExpertGroup(groupName string, projGroups map[string][]exper
|
||||
|
||||
qweight, scales, qbiases := mlx.Quantize(stacked, groupSize, bits, mode)
|
||||
|
||||
// Validate quantization produced non-empty output.
|
||||
mlx.Eval(qweight, scales)
|
||||
if len(qweight.Dims()) == 0 || qweight.Dims()[0] == 0 {
|
||||
cleanup()
|
||||
return nil, fmt.Errorf("mlx.Quantize produced empty weight for %s (quantize=%s, groupSize=%d, bits=%d, mode=%s)",
|
||||
stackedName, quantize, groupSize, bits, mode)
|
||||
}
|
||||
|
||||
qweight = mlx.Contiguous(qweight, false)
|
||||
scales = mlx.Contiguous(scales, false)
|
||||
allArrays[stackedName] = qweight
|
||||
@@ -409,12 +465,19 @@ func stackAndQuantizeExpertGroup(groupName string, projGroups map[string][]exper
|
||||
mlx.Pin(toEval...)
|
||||
pinned = append(pinned, toEval...)
|
||||
|
||||
// Free stacked source array
|
||||
// Free stacked source array (remove from pinned to avoid double-unpin in cleanup)
|
||||
for i, p := range pinned {
|
||||
if p == stacked {
|
||||
pinned[i] = nil
|
||||
}
|
||||
}
|
||||
mlx.Unpin(stacked)
|
||||
mlx.Sweep()
|
||||
} else {
|
||||
stacked = mlx.Contiguous(stacked, false)
|
||||
mlx.Eval(stacked)
|
||||
mlx.Pin(stacked)
|
||||
pinned = append(pinned, stacked)
|
||||
allArrays[stackedName] = stacked
|
||||
}
|
||||
}
|
||||
@@ -529,7 +592,7 @@ func decodeSourceFP8Tensor(weight, scaleInv *mlx.Array) (*mlx.Array, error) {
|
||||
padBottom := blockRows*scaleShape[0] - rows
|
||||
padSide := blockCols*scaleShape[1] - cols
|
||||
if padBottom > 0 || padSide > 0 {
|
||||
decoded = mlx.Pad(decoded, []int32{0, int32(padBottom), 0, int32(padSide)})
|
||||
decoded = mlx.PadConstant(decoded, []int{0, 1}, []int{0, 0}, []int{padBottom, padSide})
|
||||
}
|
||||
|
||||
decoded = mlx.Reshape(decoded, int32(scaleShape[0]), int32(blockRows), int32(scaleShape[1]), int32(blockCols))
|
||||
|
||||
@@ -246,6 +246,11 @@ func ShouldQuantize(name, component string) bool {
|
||||
return false
|
||||
}
|
||||
|
||||
// Skip audio encoder tensors (highly sensitive to quantization)
|
||||
if strings.Contains(name, "audio_tower") || strings.Contains(name, "embed_audio") {
|
||||
return false
|
||||
}
|
||||
|
||||
// Skip embeddings
|
||||
if strings.Contains(name, "embed") {
|
||||
return false
|
||||
@@ -291,6 +296,22 @@ func normalizeQuantType(quantize string) string {
|
||||
}
|
||||
}
|
||||
|
||||
// isAligned checks if a tensor's last dimension is divisible by the
|
||||
// group size required for the given quantization type.
|
||||
func isAligned(shape []int32, quantType string) bool {
|
||||
if len(shape) == 0 {
|
||||
return false
|
||||
}
|
||||
groupSize := int32(32)
|
||||
switch normalizeQuantType(quantType) {
|
||||
case "nvfp4":
|
||||
groupSize = 16
|
||||
case "int4", "int8":
|
||||
groupSize = 64
|
||||
}
|
||||
return shape[len(shape)-1]%groupSize == 0
|
||||
}
|
||||
|
||||
func isStackedExpertWeight(name string) bool {
|
||||
// Combined/stacked expert tensors may be emitted either as "...proj.weight" (per-expert)
|
||||
// or "...proj" (pre-stacked packed tensor).
|
||||
@@ -300,16 +321,16 @@ func isStackedExpertWeight(name string) bool {
|
||||
|
||||
return strings.Contains(name, ".mlp.switch_mlp.") ||
|
||||
strings.Contains(name, ".mlp.experts.") ||
|
||||
strings.Contains(name, ".mlp.shared_experts.")
|
||||
strings.Contains(name, ".mlp.shared_experts.") ||
|
||||
strings.Contains(name, ".moe.experts.")
|
||||
}
|
||||
|
||||
// GetTensorQuantization returns the appropriate quantization type for a tensor.
|
||||
// Returns "" if the tensor should not be quantized.
|
||||
// This implements mixed-precision quantization:
|
||||
// - Attention MLA weights (q_a, q_b, kv_a, kv_b): unquantized (most sensitive)
|
||||
// - Output projection, gate/up weights: int4 (less sensitive)
|
||||
// - Down projection weights: int8 (more sensitive, would be Q6 in GGML but no MLX kernel)
|
||||
// - v_proj, k_proj, down_proj: promoted to INT8 when base is INT4
|
||||
// - Norms, embeddings, biases, routing gates: no quantization
|
||||
// - All other eligible weights: use requested quantization type
|
||||
func GetTensorQuantization(name string, shape []int32, quantize string) string {
|
||||
stackedExpert := isStackedExpertWeight(name)
|
||||
|
||||
@@ -336,60 +357,35 @@ func GetTensorQuantization(name string, shape []int32, quantize string) string {
|
||||
// Normalize quantization type to canonical form
|
||||
quantNorm := normalizeQuantType(quantize)
|
||||
|
||||
// MLX quantization requires last dimension to be divisible by group size
|
||||
// nvfp4: 16, mxfp4/mxfp8: 32, int4/int8: 64
|
||||
groupSize := int32(32)
|
||||
switch quantNorm {
|
||||
case "nvfp4":
|
||||
groupSize = 16
|
||||
case "int4", "int8":
|
||||
groupSize = 64
|
||||
}
|
||||
if shape[len(shape)-1]%groupSize != 0 {
|
||||
return ""
|
||||
}
|
||||
|
||||
// Skip routing gate weights (should stay high precision)
|
||||
// In safetensors these are: mlp.gate.weight (not mlp.gate_proj.weight)
|
||||
if strings.Contains(name, "mlp.gate.weight") && !strings.Contains(name, "_proj") {
|
||||
return ""
|
||||
}
|
||||
|
||||
// MLX quantization requires last dimension to be divisible by group size.
|
||||
if !isAligned(shape, quantNorm) {
|
||||
return ""
|
||||
}
|
||||
|
||||
// For non-affine modes, use the same quantization for all eligible tensors.
|
||||
if quantNorm == "nvfp4" || quantNorm == "mxfp4" || quantNorm == "mxfp8" {
|
||||
return quantNorm
|
||||
}
|
||||
|
||||
// Attention MLA weights - keep unquantized (bf16)
|
||||
// These are highly sensitive: errors accumulate in the KV cache over time
|
||||
// q_a_proj, q_b_proj, kv_a_proj_with_mqa, kv_b_proj
|
||||
if strings.Contains(name, "q_a_proj") ||
|
||||
strings.Contains(name, "q_b_proj") ||
|
||||
strings.Contains(name, "kv_a_proj") ||
|
||||
strings.Contains(name, "kv_b_proj") {
|
||||
return "" // No quantization - keep bf16
|
||||
// Value projection weights directly determine attention output quality.
|
||||
// Down projection weights feed directly into the residual stream where
|
||||
// errors accumulate across layers. Both benefit from higher precision.
|
||||
// Promote to INT8 when base is INT4 (same affine mode, compatible with
|
||||
// GatherQMM for MoE expert tensors).
|
||||
if quantNorm == "int4" {
|
||||
if strings.Contains(name, ".v_proj") || strings.Contains(name, ".k_proj") || strings.Contains(name, "down_proj") {
|
||||
if isAligned(shape, "int8") {
|
||||
return "int8"
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Down projection weights - use INT8 (would be Q6_K in GGML, but MLX has no Q6 kernel)
|
||||
// mlp.down_proj, mlp.experts.X.down_proj, mlp.shared_experts.down_proj
|
||||
if strings.Contains(name, "down_proj") {
|
||||
return "int8"
|
||||
}
|
||||
|
||||
// Output projection, gate/up weights - use requested quantization (INT4)
|
||||
// o_proj, gate_proj, up_proj
|
||||
if strings.Contains(name, "o_proj") ||
|
||||
strings.Contains(name, "gate_proj") ||
|
||||
strings.Contains(name, "up_proj") {
|
||||
return quantNorm
|
||||
}
|
||||
|
||||
// LM head - use requested quantization
|
||||
if strings.Contains(name, "lm_head") {
|
||||
return quantNorm
|
||||
}
|
||||
|
||||
// Default to requested quantization for other weights
|
||||
return quantNorm
|
||||
}
|
||||
|
||||
@@ -411,6 +407,7 @@ func ExpertGroupPrefix(tensorName string) string {
|
||||
".mlp.experts.",
|
||||
".mlp.shared_experts.",
|
||||
".mlp.switch_mlp.",
|
||||
".moe.experts.",
|
||||
} {
|
||||
idx := strings.Index(tensorName, marker)
|
||||
if idx == -1 {
|
||||
@@ -637,6 +634,8 @@ var tensorImportTransformRegistry = map[string]tensorImportTransformFactory{
|
||||
"Qwen3_5MoeForConditionalGeneration": newQwen35ImportTransform,
|
||||
"Qwen3NextMoeForCausalLM": newQwen35ImportTransform,
|
||||
"Qwen3NextMoeForConditionalGeneration": newQwen35ImportTransform,
|
||||
"Gemma4ForCausalLM": newGemma4ImportTransform,
|
||||
"Gemma4ForConditionalGeneration": newGemma4ImportTransform,
|
||||
}
|
||||
|
||||
func newTensorImportTransform(modelDir string, cfg sourceModelConfig) (tensorImportTransform, error) {
|
||||
|
||||
@@ -1169,6 +1169,11 @@ func TestShouldQuantize(t *testing.T) {
|
||||
{"ln prefix", "ln_1.weight", "", false},
|
||||
{"layernorm in name", "input_layernorm.weight", "", false},
|
||||
|
||||
// Audio encoder tensors should not be quantized
|
||||
{"audio tower weight", "model.audio_tower.layers.0.weight", "", false},
|
||||
{"audio tower norm", "model.audio_tower.norm.weight", "", false},
|
||||
{"embed audio weight", "embed_audio.weight", "", false},
|
||||
|
||||
// Biases should not be quantized
|
||||
{"bias tensor", "attention.bias", "", false},
|
||||
{"proj bias", "o_proj.bias", "", false},
|
||||
@@ -1262,6 +1267,11 @@ func TestExpertGroupPrefix(t *testing.T) {
|
||||
{"model.layers.1.mlp.experts.63.gate_proj.weight", "model.layers.1.mlp.experts"},
|
||||
{"model.layers.0.mlp.experts.0.up_proj.weight", "model.layers.0.mlp.experts"},
|
||||
|
||||
// MoE expert tensors (Gemma-style .moe.experts.)
|
||||
{"model.layers.0.moe.experts.0.gate_proj.weight", "model.layers.0.moe.experts"},
|
||||
{"model.layers.1.moe.experts.42.down_proj.weight", "model.layers.1.moe.experts"},
|
||||
{"language_model.model.layers.2.moe.experts.127.up_proj.weight", "language_model.model.layers.2.moe.experts"},
|
||||
|
||||
// Expert tensors with language_model prefix should also match
|
||||
{"language_model.model.layers.0.mlp.experts.0.gate_proj.weight", "language_model.model.layers.0.mlp.experts"},
|
||||
{"language_model.model.layers.1.mlp.experts.255.down_proj.weight", "language_model.model.layers.1.mlp.experts"},
|
||||
@@ -1369,6 +1379,94 @@ func TestGetTensorQuantization_StackedExpert3D(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestIsAligned(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
shape []int32
|
||||
quantType string
|
||||
want bool
|
||||
}{
|
||||
// int4/int8: group_size=64
|
||||
{"int4 aligned", []int32{1024, 4096}, "int4", true},
|
||||
{"int4 unaligned", []int32{1024, 48}, "int4", false},
|
||||
{"int8 aligned", []int32{1024, 128}, "int8", true},
|
||||
{"int8 unaligned", []int32{1024, 32}, "int8", false},
|
||||
|
||||
// nvfp4: group_size=16
|
||||
{"nvfp4 aligned", []int32{1024, 48}, "nvfp4", true},
|
||||
{"nvfp4 unaligned", []int32{1024, 24}, "nvfp4", false},
|
||||
{"nvfp4 aligned 16", []int32{1024, 16}, "nvfp4", true},
|
||||
|
||||
// mxfp4/mxfp8: group_size=32
|
||||
{"mxfp4 aligned", []int32{1024, 64}, "mxfp4", true},
|
||||
{"mxfp4 unaligned", []int32{1024, 48}, "mxfp4", false},
|
||||
{"mxfp8 aligned", []int32{1024, 32}, "mxfp8", true},
|
||||
{"mxfp8 unaligned", []int32{1024, 24}, "mxfp8", false},
|
||||
|
||||
// Edge cases
|
||||
{"empty shape", []int32{}, "int4", false},
|
||||
{"1D tensor", []int32{4096}, "int4", true},
|
||||
{"3D stacked expert", []int32{128, 4096, 2816}, "int4", true},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
got := isAligned(tt.shape, tt.quantType)
|
||||
if got != tt.want {
|
||||
t.Errorf("isAligned(%v, %q) = %v, want %v", tt.shape, tt.quantType, got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetTensorQuantization_MixedPrecisionPromotion(t *testing.T) {
|
||||
aligned := []int32{4096, 4096} // divisible by 64
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
tensor string
|
||||
shape []int32
|
||||
quantize string
|
||||
want string
|
||||
}{
|
||||
// int4 → int8 promotion for sensitive tensors
|
||||
{"v_proj int4 promoted", "model.layers.0.self_attn.v_proj.weight", aligned, "int4", "int8"},
|
||||
{"k_proj int4 promoted", "model.layers.0.self_attn.k_proj.weight", aligned, "int4", "int8"},
|
||||
{"down_proj int4 promoted", "model.layers.0.mlp.down_proj.weight", aligned, "int4", "int8"},
|
||||
|
||||
// Non-sensitive int4 tensors stay int4
|
||||
{"q_proj int4 stays", "model.layers.0.self_attn.q_proj.weight", aligned, "int4", "int4"},
|
||||
{"o_proj int4 stays", "model.layers.0.self_attn.o_proj.weight", aligned, "int4", "int4"},
|
||||
{"gate_proj int4 stays", "model.layers.0.mlp.gate_proj.weight", aligned, "int4", "int4"},
|
||||
{"up_proj int4 stays", "model.layers.0.mlp.up_proj.weight", aligned, "int4", "int4"},
|
||||
|
||||
// nvfp4/mxfp4/mxfp8: no promotion (uniform quantization)
|
||||
{"v_proj nvfp4 uniform", "model.layers.0.self_attn.v_proj.weight", aligned, "nvfp4", "nvfp4"},
|
||||
{"down_proj mxfp4 uniform", "model.layers.0.mlp.down_proj.weight", aligned, "mxfp4", "mxfp4"},
|
||||
{"v_proj mxfp8 uniform", "model.layers.0.self_attn.v_proj.weight", aligned, "mxfp8", "mxfp8"},
|
||||
|
||||
// int8: already 8-bit, no promotion
|
||||
{"v_proj int8 stays", "model.layers.0.self_attn.v_proj.weight", aligned, "int8", "int8"},
|
||||
|
||||
// Expert tensors: down_proj also promoted for int4
|
||||
{"expert down_proj int4", "model.layers.0.mlp.experts.down_proj.weight", []int32{128, 4096, 2816}, "int4", "int8"},
|
||||
{"moe expert down_proj int4", "model.layers.0.moe.experts.down_proj.weight", []int32{128, 4096, 2816}, "int4", "int8"},
|
||||
|
||||
// Unaligned: falls back to bf16 (empty string)
|
||||
{"v_proj int4 unaligned", "model.layers.0.self_attn.v_proj.weight", []int32{1024, 48}, "int4", ""},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
got := GetTensorQuantization(tt.tensor, tt.shape, tt.quantize)
|
||||
if got != tt.want {
|
||||
t.Errorf("GetTensorQuantization(%q, %v, %q) = %q, want %q",
|
||||
tt.tensor, tt.shape, tt.quantize, got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestCreateSafetensorsModel_Qwen35NVFP4PacksSwitchMLPExperts(t *testing.T) {
|
||||
dir := t.TempDir()
|
||||
|
||||
|
||||
277
x/create/gemma4.go
Normal file
277
x/create/gemma4.go
Normal file
@@ -0,0 +1,277 @@
|
||||
package create
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"regexp"
|
||||
"strconv"
|
||||
"strings"
|
||||
|
||||
"github.com/ollama/ollama/x/safetensors"
|
||||
)
|
||||
|
||||
type gemma4ImportTransform struct {
|
||||
numLayers int
|
||||
numExperts int
|
||||
}
|
||||
|
||||
// gemma4Config is a minimal subset of the Gemma 4 config.json used for quant decisions.
|
||||
type gemma4Config struct {
|
||||
NumHiddenLayers int `json:"num_hidden_layers"`
|
||||
NumExperts int `json:"num_experts"`
|
||||
TextConfig struct {
|
||||
NumHiddenLayers int `json:"num_hidden_layers"`
|
||||
NumExperts int `json:"num_experts"`
|
||||
} `json:"text_config"`
|
||||
}
|
||||
|
||||
func newGemma4ImportTransform(modelDir string, _ sourceModelConfig) (tensorImportTransform, error) {
|
||||
data, err := os.ReadFile(filepath.Join(modelDir, "config.json"))
|
||||
if err != nil {
|
||||
return gemma4ImportTransform{}, nil //nolint:nilerr // fallback to no heuristic
|
||||
}
|
||||
var cfg gemma4Config
|
||||
if err := json.Unmarshal(data, &cfg); err != nil {
|
||||
return gemma4ImportTransform{}, nil //nolint:nilerr // fallback to no heuristic
|
||||
}
|
||||
|
||||
numLayers := cfg.NumHiddenLayers
|
||||
if numLayers == 0 {
|
||||
numLayers = cfg.TextConfig.NumHiddenLayers
|
||||
}
|
||||
numExperts := cfg.NumExperts
|
||||
if numExperts == 0 {
|
||||
numExperts = cfg.TextConfig.NumExperts
|
||||
}
|
||||
|
||||
return gemma4ImportTransform{numLayers: numLayers, numExperts: numExperts}, nil
|
||||
}
|
||||
|
||||
func (t gemma4ImportTransform) skipTensor(name string) bool {
|
||||
return false
|
||||
}
|
||||
|
||||
// layerIndexRe extracts the layer index from tensor names like
|
||||
// "model.language_model.layers.5.self_attn.v_proj.weight" or
|
||||
// "model.language_model.layers.5.moe.experts.42.down_proj.weight"
|
||||
var layerIndexRe = regexp.MustCompile(`\.layers\.(\d+)\.`)
|
||||
|
||||
// useMoreBits returns true for layers where quantization-sensitive tensors
|
||||
// should use higher precision: the first and last 1/8 of layers (which handle
|
||||
// input grounding and final output refinement), plus every 3rd layer in between
|
||||
// to limit error accumulation through the residual stream.
|
||||
func useMoreBits(layerIdx, numLayers int) bool {
|
||||
return layerIdx < numLayers/8 ||
|
||||
layerIdx >= 7*numLayers/8 ||
|
||||
(layerIdx-numLayers/8)%3 == 2
|
||||
}
|
||||
|
||||
func (t gemma4ImportTransform) quantizationType(name string, shape []int32, quantize string) string {
|
||||
quantNorm := normalizeQuantType(quantize)
|
||||
|
||||
// Embedding: quantize to 8-bit variant for bandwidth efficiency.
|
||||
// The embedding serves double duty: lookup (via QuantizedEmbedding) and
|
||||
// lm_head projection (via AsLinear). Using 8-bit matches GGUF Q6_K quality
|
||||
// (strictly higher at 8 bpw vs 6.5 bpw) while saving ~2.8 GB on 31B vs bf16.
|
||||
if isEmbedTokensWeight(name) {
|
||||
switch quantNorm {
|
||||
case "int4", "int8":
|
||||
if isAligned(shape, "int8") {
|
||||
return "int8"
|
||||
}
|
||||
case "mxfp4", "nvfp4", "mxfp8":
|
||||
if isAligned(shape, "mxfp8") {
|
||||
return "mxfp8"
|
||||
}
|
||||
}
|
||||
if isAligned(shape, quantNorm) {
|
||||
return quantNorm
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
// MoE router logits choose the top-k expert set. Quantization noise here
|
||||
// can flip expert selection, after which downstream activations diverge
|
||||
// sharply. The tensor is small, so leave it in source precision.
|
||||
if isGemma4RouterProjection(name) {
|
||||
return ""
|
||||
}
|
||||
|
||||
// Mixed-precision quantization: sensitive tensors get higher precision.
|
||||
//
|
||||
// Value projections (v_proj) directly determine attention output quality.
|
||||
// Down projections (down_proj) are the final MLP output and errors there
|
||||
// propagate directly to the residual stream. Both benefit from higher
|
||||
// precision at early layers, late layers, and periodically in between
|
||||
// (the "useMoreBits" heuristic).
|
||||
//
|
||||
// For int4: promote → int8 (same affine family, GatherQMM compatible).
|
||||
// For mxfp4/nvfp4: promote → mxfp8. MLX quantized_matmul handles mixed
|
||||
// nvfp4+mxfp8 modes within the same model — each tensor carries its own
|
||||
// quant metadata and the kernel dispatches per-tensor.
|
||||
if t.numLayers > 0 {
|
||||
layerIdx := -1
|
||||
if m := layerIndexRe.FindStringSubmatch(name); m != nil {
|
||||
if idx, err := strconv.Atoi(m[1]); err == nil {
|
||||
layerIdx = idx
|
||||
}
|
||||
}
|
||||
|
||||
// Determine promotion target for sensitive tensors.
|
||||
// "int8" = int4 base → int8 (affine family)
|
||||
// "mxfp8" = mxfp4/nvfp4 base → mxfp8
|
||||
// "" = no promotion (int8/mxfp8, already 8-bit)
|
||||
promote := ""
|
||||
switch quantNorm {
|
||||
case "int4":
|
||||
promote = "int8"
|
||||
case "mxfp4", "nvfp4":
|
||||
promote = "mxfp8"
|
||||
}
|
||||
|
||||
// Only apply to language model tensors — audio/vision tower tensors
|
||||
// should pass through to GetTensorQuantization which skips them.
|
||||
isModelTensor := !strings.Contains(name, "audio_tower") &&
|
||||
!strings.Contains(name, "vision_tower")
|
||||
isSensitive := isModelTensor &&
|
||||
(strings.Contains(name, ".v_proj") || strings.Contains(name, "down_proj"))
|
||||
isSensitiveK := isModelTensor && strings.Contains(name, "k_proj")
|
||||
|
||||
if promote != "" && (isSensitive || isSensitiveK) {
|
||||
shouldPromote := false
|
||||
|
||||
// 8-expert models: v_proj and k_proj share very few KV heads,
|
||||
// so quantization errors are amplified. Always promote.
|
||||
if t.numExperts == 8 && (strings.Contains(name, ".v_proj") || isSensitiveK) {
|
||||
shouldPromote = true
|
||||
}
|
||||
|
||||
// Layer-position heuristic for v_proj and down_proj.
|
||||
if isSensitive && layerIdx >= 0 && useMoreBits(layerIdx, t.numLayers) {
|
||||
shouldPromote = true
|
||||
}
|
||||
|
||||
if shouldPromote && isAligned(shape, promote) {
|
||||
return promote
|
||||
}
|
||||
|
||||
// Sensitive tensor at a non-promoted layer: use base quant type.
|
||||
// Return directly to bypass GetTensorQuantization's uniform
|
||||
// promotion — the layer-position heuristic is authoritative here.
|
||||
if !isAligned(shape, quantNorm) {
|
||||
return ""
|
||||
}
|
||||
return quantNorm
|
||||
}
|
||||
}
|
||||
|
||||
return GetTensorQuantization(name, shape, quantize)
|
||||
}
|
||||
|
||||
// isEmbedTokensWeight returns true for the main token embedding weight.
|
||||
func isEmbedTokensWeight(name string) bool {
|
||||
return strings.HasSuffix(name, "embed_tokens.weight") &&
|
||||
!strings.Contains(name, "per_layer")
|
||||
}
|
||||
|
||||
func isGemma4RouterProjection(name string) bool {
|
||||
return strings.HasSuffix(name, ".router.proj.weight") &&
|
||||
!strings.Contains(name, "audio_tower") &&
|
||||
!strings.Contains(name, "vision_tower")
|
||||
}
|
||||
|
||||
func (t gemma4ImportTransform) transformTensor(td *safetensors.TensorData) ([]*safetensors.TensorData, error) {
|
||||
if td == nil {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
// Split pre-stacked MoE expert tensors [N, out, in] into per-expert
|
||||
// [out, in] tensors so they go through the standard expert packing and
|
||||
// quantization flow (ExpertGroupPrefix matching, per-expert quantize).
|
||||
if isGemma4StackedMoETensor(td.Name, td.Shape) {
|
||||
return splitStackedMoETensor(td)
|
||||
}
|
||||
|
||||
return []*safetensors.TensorData{td}, nil
|
||||
}
|
||||
|
||||
// isGemma4StackedMoETensor checks if this is a pre-stacked MoE expert weight.
|
||||
// Gemma 4 HF weights come in two layouts depending on the model version:
|
||||
// - Older: model.language_model.layers.N.moe.{gate,up,down}_proj [experts, dim1, dim2]
|
||||
// - Newer: model.language_model.layers.N.experts.{gate_up,down}_proj [experts, dim1, dim2]
|
||||
//
|
||||
// The newer layout has gate+up already fused. We keep it fused (no splitting)
|
||||
// so the tensors flow through the standard expert packing and quantization path.
|
||||
func isGemma4StackedMoETensor(name string, shape []int32) bool {
|
||||
if len(shape) != 3 {
|
||||
return false
|
||||
}
|
||||
if strings.Contains(name, ".moe.") || strings.Contains(name, ".experts.") {
|
||||
return strings.HasSuffix(name, "_proj") || strings.HasSuffix(name, "_proj.weight")
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// splitStackedMoETensor splits a [N, out, in] stacked expert tensor into
|
||||
// N individual [out, in] tensors named with the per-expert convention that
|
||||
// ExpertGroupPrefix expects: prefix.moe.experts.{E}.{proj}.weight
|
||||
func splitStackedMoETensor(td *safetensors.TensorData) ([]*safetensors.TensorData, error) {
|
||||
raw, err := io.ReadAll(td.Reader())
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to read tensor %s: %w", td.Name, err)
|
||||
}
|
||||
|
||||
numExperts := int(td.Shape[0])
|
||||
rows := int(td.Shape[1]) // out_features in HF layout
|
||||
cols := int(td.Shape[2]) // in_features in HF layout
|
||||
|
||||
elemSize, err := DTypeSize(td.Dtype)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to get dtype size for %s: %w", td.Dtype, err)
|
||||
}
|
||||
|
||||
perExpertBytes := rows * cols * elemSize
|
||||
if len(raw) != numExperts*perExpertBytes {
|
||||
return nil, fmt.Errorf("tensor %s: raw byte length %d does not match shape %v and dtype %s",
|
||||
td.Name, len(raw), td.Shape, td.Dtype)
|
||||
}
|
||||
|
||||
// Determine the per-expert name pattern.
|
||||
// Two source layouts:
|
||||
// Old: model.language_model.layers.N.moe.gate_proj
|
||||
// -> model.language_model.layers.N.moe.experts.E.gate_proj.weight
|
||||
// New: model.language_model.layers.N.experts.gate_up_proj
|
||||
// -> model.language_model.layers.N.moe.experts.E.gate_up_proj.weight
|
||||
baseName := td.Name
|
||||
baseName = strings.TrimSuffix(baseName, ".weight")
|
||||
lastDot := strings.LastIndex(baseName, ".")
|
||||
if lastDot < 0 {
|
||||
return nil, fmt.Errorf("tensor %s: unexpected name format", td.Name)
|
||||
}
|
||||
parentPrefix := baseName[:lastDot] // "...layers.N.moe" or "...layers.N.experts"
|
||||
projName := baseName[lastDot+1:] // "gate_proj" or "gate_up_proj"
|
||||
|
||||
// Normalize: if parent already ends with ".experts", use the grandparent + ".moe"
|
||||
// so we get a consistent "layers.N.moe.experts.E" pattern.
|
||||
var moePrefix string
|
||||
if cut, ok := strings.CutSuffix(parentPrefix, ".experts"); ok {
|
||||
moePrefix = cut + ".moe"
|
||||
} else {
|
||||
moePrefix = parentPrefix
|
||||
}
|
||||
|
||||
transposedShape := []int32{td.Shape[1], td.Shape[2]}
|
||||
|
||||
results := make([]*safetensors.TensorData, numExperts)
|
||||
for e := range numExperts {
|
||||
expertName := fmt.Sprintf("%s.experts.%d.%s.weight", moePrefix, e, projName)
|
||||
start := e * perExpertBytes
|
||||
end := start + perExpertBytes
|
||||
results[e] = safetensors.NewTensorDataFromBytes(expertName, td.Dtype, transposedShape, raw[start:end])
|
||||
}
|
||||
|
||||
return results, nil
|
||||
}
|
||||
196
x/create/gemma4_test.go
Normal file
196
x/create/gemma4_test.go
Normal file
@@ -0,0 +1,196 @@
|
||||
package create
|
||||
|
||||
import (
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestGemma4QuantizationType(t *testing.T) {
|
||||
// 26B MoE: 30 layers, 128 experts
|
||||
transform26B := gemma4ImportTransform{numLayers: 30, numExperts: 128}
|
||||
// 8-expert model (hypothetical)
|
||||
transform8E := gemma4ImportTransform{numLayers: 30, numExperts: 8}
|
||||
|
||||
aligned := []int32{2816, 2816} // divisible by 64 (int4/int8 group size) and 16 (nvfp4)
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
transform gemma4ImportTransform
|
||||
tensor string
|
||||
shape []int32
|
||||
quantize string
|
||||
want string
|
||||
}{
|
||||
// === embed_tokens: quantize to 8-bit variant (serves as both embed and lm_head) ===
|
||||
{"embed_tokens int4", transform26B, "model.embed_tokens.weight", aligned, "int4", "int8"},
|
||||
{"embed_tokens nvfp4", transform26B, "model.embed_tokens.weight", aligned, "nvfp4", "mxfp8"},
|
||||
{"embed_tokens mxfp4", transform26B, "model.embed_tokens.weight", aligned, "mxfp4", "mxfp8"},
|
||||
{"embed_tokens int8", transform26B, "model.embed_tokens.weight", aligned, "int8", "int8"},
|
||||
{"embed_tokens mxfp8", transform26B, "model.embed_tokens.weight", aligned, "mxfp8", "mxfp8"},
|
||||
|
||||
// === v_proj: layer-position heuristic for int4/nvfp4 ===
|
||||
// Layer 0 is in first 1/8 (30/8=3) → promoted
|
||||
{"v_proj int4 promoted layer", transform26B, "model.layers.0.self_attn.v_proj.weight", aligned, "int4", "int8"},
|
||||
// Layer 4 is NOT in useMoreBits → base quant
|
||||
{"v_proj int4 non-promoted layer", transform26B, "model.layers.4.self_attn.v_proj.weight", aligned, "int4", "int4"},
|
||||
// Layer 29 is in last 1/8 → promoted
|
||||
{"v_proj int4 last layer promoted", transform26B, "model.layers.29.self_attn.v_proj.weight", aligned, "int4", "int8"},
|
||||
// nvfp4: promote to mxfp8 (cross-family, validated by MLX quantized_matmul)
|
||||
{"v_proj nvfp4 promoted layer", transform26B, "model.layers.0.self_attn.v_proj.weight", aligned, "nvfp4", "mxfp8"},
|
||||
{"v_proj nvfp4 non-promoted layer", transform26B, "model.layers.4.self_attn.v_proj.weight", aligned, "nvfp4", "nvfp4"},
|
||||
// mxfp4: promoted to mxfp8 at promoted layers (same mxfp family)
|
||||
{"v_proj mxfp4 promoted layer", transform26B, "model.layers.0.self_attn.v_proj.weight", aligned, "mxfp4", "mxfp8"},
|
||||
{"v_proj mxfp4 non-promoted layer", transform26B, "model.layers.4.self_attn.v_proj.weight", aligned, "mxfp4", "mxfp4"},
|
||||
// int8/mxfp8: no promotion (already 8-bit)
|
||||
{"v_proj int8 base", transform26B, "model.layers.0.self_attn.v_proj.weight", aligned, "int8", "int8"},
|
||||
{"v_proj mxfp8 base", transform26B, "model.layers.0.self_attn.v_proj.weight", aligned, "mxfp8", "mxfp8"},
|
||||
|
||||
// === down_proj (dense MLP): same heuristic as v_proj ===
|
||||
{"dense down_proj int4 promoted", transform26B, "model.layers.0.mlp.down_proj.weight", aligned, "int4", "int8"},
|
||||
{"dense down_proj int4 non-promoted", transform26B, "model.layers.4.mlp.down_proj.weight", aligned, "int4", "int4"},
|
||||
{"dense down_proj nvfp4 promoted", transform26B, "model.layers.0.mlp.down_proj.weight", aligned, "nvfp4", "mxfp8"},
|
||||
{"dense down_proj nvfp4 non-promoted", transform26B, "model.layers.4.mlp.down_proj.weight", aligned, "nvfp4", "nvfp4"},
|
||||
{"dense down_proj mxfp4 promoted", transform26B, "model.layers.0.mlp.down_proj.weight", aligned, "mxfp4", "mxfp8"},
|
||||
{"dense down_proj mxfp4 non-promoted", transform26B, "model.layers.4.mlp.down_proj.weight", aligned, "mxfp4", "mxfp4"},
|
||||
|
||||
// === Expert down_proj: int4→int8, nvfp4→nvfp8 at promoted layers ===
|
||||
{"expert down_proj int4 promoted", transform26B, "model.layers.0.moe.experts.42.down_proj.weight", aligned, "int4", "int8"},
|
||||
{"expert down_proj int4 non-promoted", transform26B, "model.layers.4.moe.experts.42.down_proj.weight", aligned, "int4", "int4"},
|
||||
// nvfp4 experts: promote to mxfp8 (all experts at a layer get same treatment,
|
||||
// so GatherQMM sees uniform quant per projection per layer)
|
||||
{"expert down_proj nvfp4 promoted layer", transform26B, "model.layers.0.moe.experts.42.down_proj.weight", aligned, "nvfp4", "mxfp8"},
|
||||
{"expert down_proj nvfp4 non-promoted layer", transform26B, "model.layers.4.moe.experts.42.down_proj.weight", aligned, "nvfp4", "nvfp4"},
|
||||
// mxfp4 experts: promote to mxfp8 (same mxfp family, GatherQMM compatible)
|
||||
{"expert down_proj mxfp4 promoted layer", transform26B, "model.layers.0.moe.experts.42.down_proj.weight", aligned, "mxfp4", "mxfp8"},
|
||||
{"expert down_proj mxfp4 non-promoted layer", transform26B, "model.layers.4.moe.experts.42.down_proj.weight", aligned, "mxfp4", "mxfp4"},
|
||||
|
||||
// === Expert gate_up_proj: always base quant (not a sensitive tensor) ===
|
||||
{"expert gate_up int4", transform26B, "model.layers.0.moe.experts.42.gate_up_proj.weight", aligned, "int4", "int4"},
|
||||
{"expert gate_up nvfp4", transform26B, "model.layers.0.moe.experts.42.gate_up_proj.weight", aligned, "nvfp4", "nvfp4"},
|
||||
{"expert gate_up mxfp4", transform26B, "model.layers.0.moe.experts.42.gate_up_proj.weight", aligned, "mxfp4", "mxfp4"},
|
||||
|
||||
// === Router projection: expert selection is sensitive; keep source precision ===
|
||||
{"router proj int4", transform26B, "model.layers.0.router.proj.weight", aligned, "int4", ""},
|
||||
{"router proj nvfp4", transform26B, "model.layers.0.router.proj.weight", aligned, "nvfp4", ""},
|
||||
{"router proj mxfp4", transform26B, "model.layers.0.router.proj.weight", aligned, "mxfp4", ""},
|
||||
|
||||
// === k_proj: promoted only for 8-expert models ===
|
||||
{"k_proj 128 experts int4", transform26B, "model.layers.0.self_attn.k_proj.weight", aligned, "int4", "int4"},
|
||||
{"k_proj 8 experts int4", transform8E, "model.layers.0.self_attn.k_proj.weight", aligned, "int4", "int8"},
|
||||
{"k_proj 8 experts nvfp4", transform8E, "model.layers.0.self_attn.k_proj.weight", aligned, "nvfp4", "mxfp8"},
|
||||
{"k_proj 8 experts mxfp4", transform8E, "model.layers.0.self_attn.k_proj.weight", aligned, "mxfp4", "mxfp8"},
|
||||
|
||||
// === q_proj, o_proj, gate_proj, up_proj: always base quant ===
|
||||
{"q_proj int4", transform26B, "model.layers.0.self_attn.q_proj.weight", aligned, "int4", "int4"},
|
||||
{"o_proj int4", transform26B, "model.layers.0.self_attn.o_proj.weight", aligned, "int4", "int4"},
|
||||
{"gate_proj int4", transform26B, "model.layers.0.mlp.gate_proj.weight", aligned, "int4", "int4"},
|
||||
{"up_proj int4", transform26B, "model.layers.0.mlp.up_proj.weight", aligned, "int4", "int4"},
|
||||
|
||||
// === Non-quantizable tensors: always bf16 ===
|
||||
{"embed_tokens per_layer skip", transform26B, "model.embed_tokens_per_layer.weight", aligned, "int4", ""},
|
||||
{"norm", transform26B, "model.layers.0.input_layernorm.weight", []int32{2816}, "int4", ""},
|
||||
{"router scale", transform26B, "model.layers.0.router.scale", []int32{2816}, "int4", ""},
|
||||
|
||||
// === Audio/vision tower tensors: must pass through unquantized for all quant types ===
|
||||
// These contain .v_proj and down_proj but should NOT be intercepted by
|
||||
// the sensitive-tensor promotion logic.
|
||||
{"audio norm int4", transform26B, "model.audio_tower.subsample_conv_projection.layer0.norm.weight", []int32{128}, "int4", ""},
|
||||
{"audio norm nvfp4", transform26B, "model.audio_tower.subsample_conv_projection.layer0.norm.weight", []int32{128}, "nvfp4", ""},
|
||||
{"audio norm int8", transform26B, "model.audio_tower.subsample_conv_projection.layer0.norm.weight", []int32{128}, "int8", ""},
|
||||
{"audio norm mxfp8", transform26B, "model.audio_tower.subsample_conv_projection.layer0.norm.weight", []int32{128}, "mxfp8", ""},
|
||||
{"audio conv int4", transform26B, "model.audio_tower.subsample_conv_projection.layer0.conv.weight", []int32{128, 1, 3, 3}, "int4", ""},
|
||||
{"audio conv nvfp4", transform26B, "model.audio_tower.subsample_conv_projection.layer0.conv.weight", []int32{128, 1, 3, 3}, "nvfp4", ""},
|
||||
{"audio linear int4", transform26B, "model.audio_tower.subsample_conv_projection.input_proj_linear.weight", aligned, "int4", ""},
|
||||
{"audio linear nvfp4", transform26B, "model.audio_tower.subsample_conv_projection.input_proj_linear.weight", aligned, "nvfp4", ""},
|
||||
// Audio tower v_proj — must NOT be promoted despite containing .v_proj
|
||||
{"audio v_proj int4", transform26B, "model.audio_tower.layers.0.self_attn.v_proj.linear.weight", aligned, "int4", ""},
|
||||
{"audio v_proj nvfp4", transform26B, "model.audio_tower.layers.0.self_attn.v_proj.linear.weight", aligned, "nvfp4", ""},
|
||||
// Vision tower v_proj — vision tower IS quantized (unlike audio tower),
|
||||
// but not intercepted by gemma4's layer-position heuristic.
|
||||
// Falls through to GetTensorQuantization which applies uniform promotion.
|
||||
{"vision v_proj int4", transform26B, "model.vision_tower.encoder.layers.0.self_attn.v_proj.linear.weight", aligned, "int4", "int8"},
|
||||
{"vision v_proj nvfp4", transform26B, "model.vision_tower.encoder.layers.0.self_attn.v_proj.linear.weight", aligned, "nvfp4", "nvfp4"},
|
||||
// Audio tower down_proj
|
||||
{"audio down_proj int4", transform26B, "model.audio_tower.layers.0.mlp.down_proj.linear.weight", aligned, "int4", ""},
|
||||
{"audio down_proj nvfp4", transform26B, "model.audio_tower.layers.0.mlp.down_proj.linear.weight", aligned, "nvfp4", ""},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
got := tt.transform.quantizationType(tt.tensor, tt.shape, tt.quantize)
|
||||
if got != tt.want {
|
||||
t.Errorf("quantizationType(%q, %v, %q) = %q, want %q",
|
||||
tt.tensor, tt.shape, tt.quantize, got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestUseMoreBits(t *testing.T) {
|
||||
// 30 layers: first 1/8 = layers 0-2, last 1/8 = layers 27-29
|
||||
// In between: every 3rd from offset (i - n/8) % 3 == 2
|
||||
n := 30
|
||||
promoted := map[int]bool{}
|
||||
for i := range n {
|
||||
if useMoreBits(i, n) {
|
||||
promoted[i] = true
|
||||
}
|
||||
}
|
||||
|
||||
// First 1/8 (30/8 = 3): layers 0, 1, 2
|
||||
for _, i := range []int{0, 1, 2} {
|
||||
if !promoted[i] {
|
||||
t.Errorf("layer %d should be promoted (first 1/8)", i)
|
||||
}
|
||||
}
|
||||
|
||||
// Last 1/8: layers 26, 27, 28, 29 (>= 7*30/8 = 26)
|
||||
for _, i := range []int{26, 27, 28, 29} {
|
||||
if !promoted[i] {
|
||||
t.Errorf("layer %d should be promoted (last 1/8)", i)
|
||||
}
|
||||
}
|
||||
|
||||
// Some middle layers should NOT be promoted
|
||||
for _, i := range []int{3, 4, 6, 7} {
|
||||
if promoted[i] {
|
||||
t.Errorf("layer %d should NOT be promoted", i)
|
||||
}
|
||||
}
|
||||
|
||||
// Layer 5 should be promoted: (5 - 3) % 3 == 2
|
||||
if !promoted[5] {
|
||||
t.Errorf("layer 5 should be promoted (periodic)")
|
||||
}
|
||||
}
|
||||
|
||||
func TestIsGemma4StackedMoETensor(t *testing.T) {
|
||||
tests := []struct {
|
||||
label string
|
||||
tensorName string
|
||||
shape []int32
|
||||
want bool
|
||||
}{
|
||||
// New-style: .experts.gate_up_proj
|
||||
{"experts gate_up_proj 3D", "model.layers.0.experts.gate_up_proj", []int32{128, 1408, 2816}, true},
|
||||
{"experts down_proj 3D", "model.layers.0.experts.down_proj", []int32{128, 2816, 704}, true},
|
||||
// Old-style: .moe.gate_proj
|
||||
{"moe gate_proj 3D", "model.layers.0.moe.gate_proj", []int32{128, 2112, 2816}, true},
|
||||
{"moe down_proj 3D", "model.layers.0.moe.down_proj.weight", []int32{128, 2816, 2112}, true},
|
||||
// Not stacked: 2D
|
||||
{"2D weight", "model.layers.0.experts.gate_up_proj", []int32{1408, 2816}, false},
|
||||
// Not expert
|
||||
{"non-expert 3D", "model.layers.0.mlp.gate_proj", []int32{3, 2816, 2816}, false},
|
||||
// Not a projection
|
||||
{"expert non-proj", "model.layers.0.experts.scale", []int32{128, 1, 1}, false},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.label, func(t *testing.T) {
|
||||
got := isGemma4StackedMoETensor(tt.tensorName, tt.shape)
|
||||
if got != tt.want {
|
||||
t.Errorf("isGemma4StackedMoETensor(%q, %v) = %v, want %v",
|
||||
tt.tensorName, tt.shape, got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -115,36 +115,7 @@ func (s *Server) Load(ctx context.Context, _ ml.SystemInfo, gpus []ml.DeviceInfo
|
||||
// Spawn subprocess: ollama runner --imagegen-engine --model <path> --port <port>
|
||||
cmd := exec.Command(exe, "runner", "--imagegen-engine", "--model", s.modelName, "--port", strconv.Itoa(port))
|
||||
cmd.Env = os.Environ()
|
||||
|
||||
// On Linux, set LD_LIBRARY_PATH to include MLX library directories
|
||||
if runtime.GOOS == "linux" {
|
||||
// Build library paths: start with LibOllamaPath, then add any mlx_* subdirectories
|
||||
libraryPaths := []string{ml.LibOllamaPath}
|
||||
if mlxDirs, err := filepath.Glob(filepath.Join(ml.LibOllamaPath, "mlx_*")); err == nil {
|
||||
libraryPaths = append(libraryPaths, mlxDirs...)
|
||||
}
|
||||
|
||||
// Append existing LD_LIBRARY_PATH if set
|
||||
if existingPath, ok := os.LookupEnv("LD_LIBRARY_PATH"); ok {
|
||||
libraryPaths = append(libraryPaths, filepath.SplitList(existingPath)...)
|
||||
}
|
||||
|
||||
pathEnvVal := strings.Join(libraryPaths, string(filepath.ListSeparator))
|
||||
|
||||
// Update or add LD_LIBRARY_PATH in cmd.Env
|
||||
found := false
|
||||
for i := range cmd.Env {
|
||||
if strings.HasPrefix(cmd.Env[i], "LD_LIBRARY_PATH=") {
|
||||
cmd.Env[i] = "LD_LIBRARY_PATH=" + pathEnvVal
|
||||
found = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if !found {
|
||||
cmd.Env = append(cmd.Env, "LD_LIBRARY_PATH="+pathEnvVal)
|
||||
}
|
||||
slog.Debug("mlx subprocess library path", "LD_LIBRARY_PATH", pathEnvVal)
|
||||
}
|
||||
configureMLXSubprocessEnv(cmd, ml.LibraryPaths(gpus))
|
||||
|
||||
s.cmd = cmd
|
||||
|
||||
@@ -200,6 +171,53 @@ func (s *Server) Ping(ctx context.Context) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func mlxLibraryPathEnv() string {
|
||||
switch runtime.GOOS {
|
||||
case "windows":
|
||||
return "PATH"
|
||||
case "darwin":
|
||||
return "DYLD_LIBRARY_PATH"
|
||||
default:
|
||||
return "LD_LIBRARY_PATH"
|
||||
}
|
||||
}
|
||||
|
||||
func configureMLXSubprocessEnv(cmd *exec.Cmd, libraryPaths []string) {
|
||||
if len(libraryPaths) == 0 {
|
||||
return
|
||||
}
|
||||
|
||||
// Search order for the imagegen runner is:
|
||||
// 1. bundled lib/ollama root
|
||||
// 2. backend-specific library dirs selected during GPU discovery
|
||||
// 3. any existing caller-provided library path values
|
||||
pathEnv := mlxLibraryPathEnv()
|
||||
pathEnvPaths := append([]string{}, libraryPaths...)
|
||||
if existingPath, ok := os.LookupEnv(pathEnv); ok {
|
||||
pathEnvPaths = append(pathEnvPaths, filepath.SplitList(existingPath)...)
|
||||
}
|
||||
setSubprocessEnv(cmd, pathEnv, strings.Join(pathEnvPaths, string(filepath.ListSeparator)))
|
||||
slog.Debug("mlx subprocess library path", pathEnv, strings.Join(pathEnvPaths, string(filepath.ListSeparator)))
|
||||
|
||||
ollamaLibraryPaths := append([]string{}, libraryPaths...)
|
||||
if existingPath, ok := os.LookupEnv("OLLAMA_LIBRARY_PATH"); ok {
|
||||
ollamaLibraryPaths = append(ollamaLibraryPaths, filepath.SplitList(existingPath)...)
|
||||
}
|
||||
setSubprocessEnv(cmd, "OLLAMA_LIBRARY_PATH", strings.Join(ollamaLibraryPaths, string(filepath.ListSeparator)))
|
||||
slog.Debug("mlx subprocess library path", "OLLAMA_LIBRARY_PATH", strings.Join(ollamaLibraryPaths, string(filepath.ListSeparator)))
|
||||
}
|
||||
|
||||
func setSubprocessEnv(cmd *exec.Cmd, key, value string) {
|
||||
for i := range cmd.Env {
|
||||
name, _, ok := strings.Cut(cmd.Env[i], "=")
|
||||
if ok && strings.EqualFold(name, key) {
|
||||
cmd.Env[i] = key + "=" + value
|
||||
return
|
||||
}
|
||||
}
|
||||
cmd.Env = append(cmd.Env, key+"="+value)
|
||||
}
|
||||
|
||||
// getLastErr returns the last stderr line.
|
||||
func (s *Server) getLastErr() string {
|
||||
s.lastErrLock.Lock()
|
||||
|
||||
24
x/mlxrunner/cache/cache.go
vendored
24
x/mlxrunner/cache/cache.go
vendored
@@ -254,8 +254,23 @@ func (c *RotatingKVCache) concat(keys, values *mlx.Array) (newK *mlx.Array, newV
|
||||
mlx.Pin(c.keys, c.values)
|
||||
} else {
|
||||
if c.idx < c.keys.Dim(2) {
|
||||
c.keys.Set(c.keys.Slice(mlx.Slice(), mlx.Slice(), mlx.Slice(0, c.idx), mlx.Slice()))
|
||||
c.values.Set(c.values.Slice(mlx.Slice(), mlx.Slice(), mlx.Slice(0, c.idx), mlx.Slice()))
|
||||
if c.offset <= c.maxSize {
|
||||
// Not yet wrapped: slots [c.idx, Dim) are grow padding
|
||||
// or stale post-rewind data, not live window content.
|
||||
c.keys.Set(c.keys.Slice(mlx.Slice(), mlx.Slice(), mlx.Slice(0, c.idx), mlx.Slice()))
|
||||
c.values.Set(c.values.Slice(mlx.Slice(), mlx.Slice(), mlx.Slice(0, c.idx), mlx.Slice()))
|
||||
} else {
|
||||
// Wrapped: logical order is slots[idx..Dim) then slots[0..idx).
|
||||
// Linearize so the trim + concat below operate on contiguous
|
||||
// positions and preserve the last (maxSize - 1) old tokens.
|
||||
tailK := c.keys.Slice(mlx.Slice(), mlx.Slice(), mlx.Slice(c.idx, c.keys.Dim(2)), mlx.Slice())
|
||||
tailV := c.values.Slice(mlx.Slice(), mlx.Slice(), mlx.Slice(c.idx, c.values.Dim(2)), mlx.Slice())
|
||||
headK := c.keys.Slice(mlx.Slice(), mlx.Slice(), mlx.Slice(0, c.idx), mlx.Slice())
|
||||
headV := c.values.Slice(mlx.Slice(), mlx.Slice(), mlx.Slice(0, c.idx), mlx.Slice())
|
||||
c.keys.Set(tailK.Concatenate(2, headK))
|
||||
c.values.Set(tailV.Concatenate(2, headV))
|
||||
c.idx = c.keys.Dim(2)
|
||||
}
|
||||
}
|
||||
|
||||
// Trim to max_size to maintain sliding window
|
||||
@@ -322,9 +337,10 @@ func (c *RotatingKVCache) State() []*mlx.Array {
|
||||
if c.keys == nil || c.values == nil {
|
||||
return nil
|
||||
}
|
||||
liveLen := min(c.offset, c.keys.Dim(2))
|
||||
return []*mlx.Array{
|
||||
c.keys.Slice(mlx.Slice(), mlx.Slice(), mlx.Slice(0, c.offset), mlx.Slice()),
|
||||
c.values.Slice(mlx.Slice(), mlx.Slice(), mlx.Slice(0, c.offset), mlx.Slice()),
|
||||
c.keys.Slice(mlx.Slice(), mlx.Slice(), mlx.Slice(0, liveLen), mlx.Slice()),
|
||||
c.values.Slice(mlx.Slice(), mlx.Slice(), mlx.Slice(0, liveLen), mlx.Slice()),
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
338
x/mlxrunner/cache/rotating_multiturn_test.go
vendored
Normal file
338
x/mlxrunner/cache/rotating_multiturn_test.go
vendored
Normal file
@@ -0,0 +1,338 @@
|
||||
package cache
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/ollama/ollama/x/mlxrunner/mlx"
|
||||
)
|
||||
|
||||
// singleTokenKV and multiTokenKV fabricate [B=1, H=1, L, D=2] key/value
|
||||
// tensors whose channel value is the token id, so stateIDs can recover
|
||||
// which ids survived in the cache.
|
||||
func singleTokenKV(id float32) (*mlx.Array, *mlx.Array) {
|
||||
k := mlx.FromValues([]float32{id, id}, 1, 1, 1, 2)
|
||||
v := mlx.FromValues([]float32{id, id}, 1, 1, 1, 2)
|
||||
return k, v
|
||||
}
|
||||
|
||||
func multiTokenKV(ids []float32) (*mlx.Array, *mlx.Array) {
|
||||
data := make([]float32, 0, 2*len(ids))
|
||||
for _, id := range ids {
|
||||
data = append(data, id, id)
|
||||
}
|
||||
k := mlx.FromValues(data, 1, 1, len(ids), 2)
|
||||
v := mlx.FromValues(data, 1, 1, len(ids), 2)
|
||||
return k, v
|
||||
}
|
||||
|
||||
// stateIDs returns the ids currently in the cache in slot order (logical
|
||||
// after a concat, physical/rotated after a single-token update).
|
||||
func stateIDs(t *testing.T, c *RotatingKVCache) []float32 {
|
||||
t.Helper()
|
||||
state := c.State()
|
||||
if state == nil {
|
||||
return nil
|
||||
}
|
||||
mlx.Eval(state[0])
|
||||
flat := state[0].Floats()
|
||||
n := state[0].Dim(2)
|
||||
out := make([]float32, n)
|
||||
for i := range n {
|
||||
out[i] = flat[i*2]
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
func equalSlice(a, b []float32) bool {
|
||||
if len(a) != len(b) {
|
||||
return false
|
||||
}
|
||||
for i := range a {
|
||||
if a[i] != b[i] {
|
||||
return false
|
||||
}
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
func feedMulti(c *RotatingKVCache, startID float32, n int) float32 {
|
||||
ids := make([]float32, n)
|
||||
for i := range ids {
|
||||
ids[i] = startID + float32(i)
|
||||
}
|
||||
k, v := multiTokenKV(ids)
|
||||
c.Update(k, v)
|
||||
return startID + float32(n)
|
||||
}
|
||||
|
||||
func feedSingle(c *RotatingKVCache, id float32) {
|
||||
k, v := singleTokenKV(id)
|
||||
c.Update(k, v)
|
||||
}
|
||||
|
||||
// TestRotatingKVCacheConcatMidRotationPreservesContext: after the buffer
|
||||
// has wrapped, a multi-token concat must keep the (maxSize-1) most recent
|
||||
// pre-existing tokens in logical order so the first Q of the new batch
|
||||
// has a full sliding window.
|
||||
func TestRotatingKVCacheConcatMidRotationPreservesContext(t *testing.T) {
|
||||
skipIfNoMLX(t)
|
||||
|
||||
const window = 4
|
||||
c := NewRotatingKVCache(window)
|
||||
|
||||
nextID := feedMulti(c, 1, 3)
|
||||
for range 6 {
|
||||
feedSingle(c, nextID)
|
||||
nextID++
|
||||
}
|
||||
if c.Offset() != 9 {
|
||||
t.Fatalf("setup: offset=%d want 9", c.Offset())
|
||||
}
|
||||
if c.idx >= c.maxSize {
|
||||
t.Fatalf("setup: expected mid-rotation idx (<%d), got %d", c.maxSize, c.idx)
|
||||
}
|
||||
|
||||
feedMulti(c, 10, 2)
|
||||
got := stateIDs(t, c)
|
||||
want := []float32{7, 8, 9, 10, 11}
|
||||
if !equalSlice(got, want) {
|
||||
t.Fatalf("post-concat window=%v want %v", got, want)
|
||||
}
|
||||
if c.Offset() != 11 {
|
||||
t.Fatalf("offset=%d want 11", c.Offset())
|
||||
}
|
||||
}
|
||||
|
||||
// TestRotatingKVCacheConcatAlignedInvariant: with an aligned buffer
|
||||
// (c.idx == Dim), an L>1 concat keeps the last (maxSize-1) pre-existing
|
||||
// tokens plus the full new batch. This is the chunked-prefill contract
|
||||
// x/mlxrunner/pipeline.go relies on.
|
||||
func TestRotatingKVCacheConcatAlignedInvariant(t *testing.T) {
|
||||
skipIfNoMLX(t)
|
||||
|
||||
const window = 4
|
||||
c := NewRotatingKVCache(window)
|
||||
|
||||
// Chunk 1 fills past maxSize, leaving Dim == maxSize aligned.
|
||||
feedMulti(c, 1, 6)
|
||||
// Chunk 2: the buffer is intentionally oversized to (maxSize-1) + L
|
||||
// so the first new Q has its full window in scope for this forward.
|
||||
feedMulti(c, 7, 3)
|
||||
got := stateIDs(t, c)
|
||||
want := []float32{4, 5, 6, 7, 8, 9}
|
||||
if !equalSlice(got, want) {
|
||||
t.Fatalf("post-chunk-2 buffer=%v want %v", got, want)
|
||||
}
|
||||
|
||||
// The next decode trims oversize back to maxSize; order may be
|
||||
// physical (rotated), so check as a set.
|
||||
feedSingle(c, 10)
|
||||
got = stateIDs(t, c)
|
||||
if len(got) != window {
|
||||
t.Fatalf("post-decode Dim=%d want %d", len(got), window)
|
||||
}
|
||||
seen := map[float32]bool{}
|
||||
for _, v := range got {
|
||||
seen[v] = true
|
||||
}
|
||||
for _, w := range []float32{7, 8, 9, 10} {
|
||||
if !seen[w] {
|
||||
t.Fatalf("post-decode window missing %v (got %v)", w, got)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// TestRotatingKVCacheConcatAfterDecodeGrowsBuffer: update() grows the
|
||||
// underlying buffer by `step` slots via mlx.Zeros before writing, so
|
||||
// after one decode on a short prefill c.idx < Dim even though the cache
|
||||
// has not wrapped. Those trailing slots are zero padding and must not
|
||||
// be pulled back into the live window on the next concat.
|
||||
func TestRotatingKVCacheConcatAfterDecodeGrowsBuffer(t *testing.T) {
|
||||
skipIfNoMLX(t)
|
||||
|
||||
const window = 512
|
||||
c := NewRotatingKVCache(window)
|
||||
|
||||
feedMulti(c, 1, 3)
|
||||
feedSingle(c, 4)
|
||||
feedMulti(c, 5, 3)
|
||||
|
||||
got := stateIDs(t, c)
|
||||
want := []float32{1, 2, 3, 4, 5, 6, 7}
|
||||
if !equalSlice(got, want) {
|
||||
t.Fatalf("growing-buffer concat=%v want %v", got, want)
|
||||
}
|
||||
}
|
||||
|
||||
// TestRotatingKVCacheConcatAfterLiveRewind: x/mlxrunner/cache.go calls
|
||||
// Restore(nil, target) between conversation turns to rewind the cache to
|
||||
// the matched prefix. Restore moves c.offset/c.idx without trimming the
|
||||
// underlying buffer, so slots [c.idx, Dim) still hold stale pre-rewind
|
||||
// tokens. A subsequent concat must drop those, not treat them as wrapped
|
||||
// window content.
|
||||
func TestRotatingKVCacheConcatAfterLiveRewind(t *testing.T) {
|
||||
skipIfNoMLX(t)
|
||||
|
||||
const window = 8
|
||||
c := NewRotatingKVCache(window)
|
||||
|
||||
// Grow the buffer to exactly maxSize without wrapping.
|
||||
feedMulti(c, 1, 2)
|
||||
for id := float32(3); id <= 8; id++ {
|
||||
feedSingle(c, id)
|
||||
}
|
||||
if c.Offset() != window {
|
||||
t.Fatalf("setup: offset=%d want %d", c.Offset(), window)
|
||||
}
|
||||
|
||||
if !c.Restore(nil, 2) {
|
||||
t.Fatalf("live rewind to 2 failed")
|
||||
}
|
||||
if c.Offset() != 2 {
|
||||
t.Fatalf("post-rewind offset=%d want 2", c.Offset())
|
||||
}
|
||||
|
||||
feedMulti(c, 9, 3)
|
||||
got := stateIDs(t, c)
|
||||
want := []float32{1, 2, 9, 10, 11}
|
||||
if !equalSlice(got, want) {
|
||||
t.Fatalf("post-rewind concat=%v want %v", got, want)
|
||||
}
|
||||
if c.Offset() != 5 {
|
||||
t.Fatalf("offset=%d want 5", c.Offset())
|
||||
}
|
||||
}
|
||||
|
||||
// TestRotatingKVCacheConcatGrowingBuffer: when oldLen < maxSize the trim
|
||||
// formula drops to non-positive and all pre-existing tokens are kept.
|
||||
func TestRotatingKVCacheConcatGrowingBuffer(t *testing.T) {
|
||||
skipIfNoMLX(t)
|
||||
|
||||
const window = 4
|
||||
c := NewRotatingKVCache(window)
|
||||
|
||||
feedMulti(c, 1, 2)
|
||||
feedMulti(c, 3, 2)
|
||||
got := stateIDs(t, c)
|
||||
want := []float32{1, 2, 3, 4}
|
||||
if !equalSlice(got, want) {
|
||||
t.Fatalf("growing buffer=%v want %v", got, want)
|
||||
}
|
||||
}
|
||||
|
||||
// TestRotatingKVCacheRunnerChunkedPrefill mirrors the
|
||||
// x/mlxrunner/pipeline.go prefill loop: a long prompt fed through
|
||||
// repeated L>1 Update() calls on a single cache. Scaled-down proxy for
|
||||
// the Gemma 4 26B case (sliding_window=1024, prefillChunkSize=2048).
|
||||
func TestRotatingKVCacheRunnerChunkedPrefill(t *testing.T) {
|
||||
skipIfNoMLX(t)
|
||||
|
||||
const window = 4
|
||||
c := NewRotatingKVCache(window)
|
||||
|
||||
feedMulti(c, 1, 8)
|
||||
if c.Offset() != 8 {
|
||||
t.Fatalf("chunk 1: offset=%d want 8", c.Offset())
|
||||
}
|
||||
|
||||
feedMulti(c, 9, 8)
|
||||
got := stateIDs(t, c)
|
||||
want := []float32{6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16}
|
||||
if !equalSlice(got, want) {
|
||||
t.Fatalf("chunk 2: buffer=%v want %v", got, want)
|
||||
}
|
||||
|
||||
feedMulti(c, 17, 4)
|
||||
got = stateIDs(t, c)
|
||||
want = []float32{14, 15, 16, 17, 18, 19, 20}
|
||||
if !equalSlice(got, want) {
|
||||
t.Fatalf("chunk 3: buffer=%v want %v", got, want)
|
||||
}
|
||||
|
||||
// Decode trims oversize back to maxSize; order may be physical.
|
||||
feedSingle(c, 21)
|
||||
got = stateIDs(t, c)
|
||||
if len(got) != window {
|
||||
t.Fatalf("post-decode Dim=%d want %d", len(got), window)
|
||||
}
|
||||
seen := map[float32]bool{}
|
||||
for _, v := range got {
|
||||
seen[v] = true
|
||||
}
|
||||
for _, w := range []float32{18, 19, 20, 21} {
|
||||
if !seen[w] {
|
||||
t.Fatalf("post-decode window missing %v (got %v)", w, got)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// TestRotatingKVCacheMultiTurnChatSimulation walks a prefill → decode →
|
||||
// prefill sequence and checks that each new prefill retains the last
|
||||
// (maxSize-1) pre-existing tokens in logical order.
|
||||
func TestRotatingKVCacheMultiTurnChatSimulation(t *testing.T) {
|
||||
skipIfNoMLX(t)
|
||||
|
||||
const window = 4
|
||||
c := NewRotatingKVCache(window)
|
||||
|
||||
nextID := feedMulti(c, 1, 2)
|
||||
for range 5 {
|
||||
feedSingle(c, nextID)
|
||||
nextID++
|
||||
}
|
||||
if c.Offset() != 7 {
|
||||
t.Fatalf("turn 1: offset=%d want 7", c.Offset())
|
||||
}
|
||||
|
||||
feedMulti(c, nextID, 3)
|
||||
nextID += 3
|
||||
got := stateIDs(t, c)
|
||||
want := []float32{5, 6, 7, 8, 9, 10}
|
||||
if !equalSlice(got, want) {
|
||||
t.Fatalf("turn 2 prefill buffer=%v want %v", got, want)
|
||||
}
|
||||
|
||||
for range 4 {
|
||||
feedSingle(c, nextID)
|
||||
nextID++
|
||||
}
|
||||
if c.Offset() != 14 {
|
||||
t.Fatalf("turn 2 decode: offset=%d want 14", c.Offset())
|
||||
}
|
||||
|
||||
feedMulti(c, nextID, 2)
|
||||
got = stateIDs(t, c)
|
||||
want = []float32{12, 13, 14, 15, 16}
|
||||
if !equalSlice(got, want) {
|
||||
t.Fatalf("turn 3 prefill buffer=%v want %v", got, want)
|
||||
}
|
||||
}
|
||||
|
||||
// TestRotatingKVCacheOffsetTracking: Offset() is the monotonic logical
|
||||
// token count through any mix of Update() calls — Gemma 4 uses
|
||||
// donorEntry.Offset - L for the consumer's RoPE offset.
|
||||
func TestRotatingKVCacheOffsetTracking(t *testing.T) {
|
||||
skipIfNoMLX(t)
|
||||
|
||||
c := NewRotatingKVCache(4)
|
||||
nextID := feedMulti(c, 1, 3)
|
||||
if c.Offset() != 3 {
|
||||
t.Fatalf("after prefill 3: offset=%d want 3", c.Offset())
|
||||
}
|
||||
for i := range 5 {
|
||||
feedSingle(c, nextID)
|
||||
nextID++
|
||||
if c.Offset() != 3+i+1 {
|
||||
t.Fatalf("after decode %d: offset=%d want %d", i, c.Offset(), 3+i+1)
|
||||
}
|
||||
}
|
||||
nextID = feedMulti(c, nextID, 2)
|
||||
if c.Offset() != 10 {
|
||||
t.Fatalf("after turn-2 prefill: offset=%d want 10", c.Offset())
|
||||
}
|
||||
// L > maxSize concat.
|
||||
feedMulti(c, nextID, 7)
|
||||
if c.Offset() != 17 {
|
||||
t.Fatalf("after large prefill: offset=%d want 17", c.Offset())
|
||||
}
|
||||
}
|
||||
@@ -2,6 +2,7 @@ package mlxrunner
|
||||
|
||||
import (
|
||||
_ "github.com/ollama/ollama/x/models/gemma3"
|
||||
_ "github.com/ollama/ollama/x/models/gemma4"
|
||||
_ "github.com/ollama/ollama/x/models/glm4_moe_lite"
|
||||
_ "github.com/ollama/ollama/x/models/llama"
|
||||
_ "github.com/ollama/ollama/x/models/qwen3"
|
||||
|
||||
@@ -1,21 +1,64 @@
|
||||
package mlx
|
||||
|
||||
// #include "generated.h"
|
||||
import "C"
|
||||
import "math"
|
||||
|
||||
func GELUApprox(t *Array) *Array {
|
||||
return t.Multiply(
|
||||
FromValue[float32](0.5),
|
||||
).Multiply(
|
||||
t.Add(
|
||||
t.Power(FromValue[float32](3.0)).Multiply(FromValue[float32](0.044715)),
|
||||
).Multiply(
|
||||
FromValue(float32(math.Sqrt(2 / math.Pi))),
|
||||
).Tanh().Add(FromValue[float32](1.0)),
|
||||
).AsType(t.DType())
|
||||
}
|
||||
var geluCoeff = float32(math.Sqrt(2 / math.Pi))
|
||||
|
||||
func SILU(t *Array) *Array {
|
||||
return t.Multiply(t.Sigmoid()).AsType(t.DType())
|
||||
}
|
||||
// GELUApprox returns 0.5 * x * (1 + tanh(sqrt(2/pi) * (x + 0.044715 * x^3)))
|
||||
// as a fused kernel.
|
||||
var GELUApprox = Compile1(
|
||||
"GELUApprox",
|
||||
func(x *Array) *Array {
|
||||
// Dtype-matched scalars avoid implicit upcasts on bf16 inputs.
|
||||
dt := x.DType()
|
||||
half := FromValue[float32](0.5).AsType(dt)
|
||||
coeff := FromValue(geluCoeff).AsType(dt)
|
||||
c := FromValue[float32](0.044715).AsType(dt)
|
||||
one := FromValue[float32](1.0).AsType(dt)
|
||||
|
||||
// x^3 via x*x*x (avoids general Power which is slower).
|
||||
x3 := x.Multiply(x).Multiply(x)
|
||||
inner := x.Add(c.Multiply(x3))
|
||||
tanh := coeff.Multiply(inner).Tanh()
|
||||
return half.Multiply(x).Multiply(one.Add(tanh))
|
||||
},
|
||||
Shapeless(),
|
||||
)
|
||||
|
||||
// SiLU returns a * sigmoid(a) as a fused kernel.
|
||||
var SiLU = Compile1(
|
||||
"SiLU",
|
||||
func(a *Array) *Array {
|
||||
return a.Multiply(a.Sigmoid())
|
||||
},
|
||||
Shapeless(),
|
||||
)
|
||||
|
||||
// SwiGLU returns silu(gate) * up as a fused kernel.
|
||||
var SwiGLU = Compile2(
|
||||
"SwiGLU",
|
||||
func(gate, up *Array) *Array {
|
||||
return SiLU(gate).Multiply(up)
|
||||
},
|
||||
Shapeless(),
|
||||
)
|
||||
|
||||
// GeGLU returns gelu_approx(gate) * up as a fused kernel. Matches mlx_lm's
|
||||
// geglu, used by Gemma-family MLP and MoE paths.
|
||||
var GeGLU = Compile2(
|
||||
"GeGLU",
|
||||
func(gate, up *Array) *Array {
|
||||
return GELUApprox(gate).Multiply(up)
|
||||
},
|
||||
Shapeless(),
|
||||
)
|
||||
|
||||
// LogitSoftcap returns tanh(x / cap) * cap as a fused kernel. Matches
|
||||
// mlx_lm's logit_softcap. cap must have the same dtype as x.
|
||||
var LogitSoftcap = Compile2(
|
||||
"LogitSoftcap",
|
||||
func(x, cap *Array) *Array {
|
||||
return x.Divide(cap).Tanh().Multiply(cap)
|
||||
},
|
||||
Shapeless(),
|
||||
)
|
||||
|
||||
@@ -27,7 +27,11 @@ var arrays []*Array
|
||||
|
||||
func New(name string) *Array {
|
||||
t := &Array{name: name}
|
||||
arrays = append(arrays, t)
|
||||
if tracing {
|
||||
traceScratch = append(traceScratch, t)
|
||||
} else {
|
||||
arrays = append(arrays, t)
|
||||
}
|
||||
return t
|
||||
}
|
||||
|
||||
|
||||
192
x/mlxrunner/mlx/compile.go
Normal file
192
x/mlxrunner/mlx/compile.go
Normal file
@@ -0,0 +1,192 @@
|
||||
package mlx
|
||||
|
||||
// #include <stdlib.h>
|
||||
// #include "generated.h"
|
||||
//
|
||||
// extern int closureCallback(mlx_vector_array* res, mlx_vector_array input, void* payload);
|
||||
// extern void closureDestructor(void* payload);
|
||||
import "C"
|
||||
|
||||
import (
|
||||
"log/slog"
|
||||
"runtime/cgo"
|
||||
"sync"
|
||||
"unsafe"
|
||||
)
|
||||
|
||||
// CompileFunc is the signature of a function that can be compiled.
|
||||
type CompileFunc func(inputs ...*Array) []*Array
|
||||
|
||||
// CompileOption configures Compile behavior.
|
||||
type CompileOption func(*compileConfig)
|
||||
|
||||
type compileConfig struct {
|
||||
shapeless bool
|
||||
}
|
||||
|
||||
// Shapeless traces the function once against symbolic shapes so the compiled
|
||||
// graph accepts any input shape afterwards. Without this option, MLX re-traces
|
||||
// on each new (shape, dtype) combination and caches each specialization.
|
||||
func Shapeless() CompileOption {
|
||||
return func(c *compileConfig) { c.shapeless = true }
|
||||
}
|
||||
|
||||
// Compile returns a compiled version of fn. When called during another
|
||||
// compile's trace, fn is inlined directly so outer compiles can fuse through
|
||||
// inner ones.
|
||||
//
|
||||
// Compiled functions must not have side effects outside of the function. Do
|
||||
// not access data other than the arguments passed in (either Go data or MLX
|
||||
// arrays) unless it is a constant.
|
||||
func Compile(name string, fn CompileFunc, opts ...CompileOption) CompileFunc {
|
||||
var cfg compileConfig
|
||||
for _, o := range opts {
|
||||
o(&cfg)
|
||||
}
|
||||
|
||||
var closure C.mlx_closure
|
||||
var once sync.Once
|
||||
|
||||
return func(inputs ...*Array) []*Array {
|
||||
if tracing {
|
||||
return fn(inputs...)
|
||||
}
|
||||
|
||||
once.Do(func() {
|
||||
payload := (*cgo.Handle)(C.malloc(C.size_t(unsafe.Sizeof(cgo.Handle(0)))))
|
||||
*payload = cgo.NewHandle(fn)
|
||||
src := C.mlx_closure_new_func_payload(
|
||||
(*[0]byte)(C.closureCallback),
|
||||
unsafe.Pointer(payload),
|
||||
(*[0]byte)(C.closureDestructor),
|
||||
)
|
||||
defer C.mlx_closure_free(src)
|
||||
|
||||
closure = C.mlx_closure_new()
|
||||
mlxCheck(name+": compile failed", func() C.int {
|
||||
return C.mlx_compile(&closure, src, C.bool(cfg.shapeless))
|
||||
})
|
||||
})
|
||||
|
||||
inVec := C.mlx_vector_array_new()
|
||||
defer C.mlx_vector_array_free(inVec)
|
||||
for _, in := range inputs {
|
||||
C.mlx_vector_array_append_value(inVec, in.ctx)
|
||||
}
|
||||
|
||||
outVec := C.mlx_vector_array_new()
|
||||
defer C.mlx_vector_array_free(outVec)
|
||||
mlxCheck(name+": closure apply failed", func() C.int {
|
||||
return C.mlx_closure_apply(&outVec, closure, inVec)
|
||||
})
|
||||
|
||||
n := int(C.mlx_vector_array_size(outVec))
|
||||
outputs := make([]*Array, n)
|
||||
for i := range n {
|
||||
outputs[i] = New(name)
|
||||
C.mlx_vector_array_get(&outputs[i].ctx, outVec, C.size_t(i))
|
||||
}
|
||||
return outputs
|
||||
}
|
||||
}
|
||||
|
||||
// Compile1 compiles a unary function. See Compile.
|
||||
func Compile1(name string, fn func(*Array) *Array, opts ...CompileOption) func(*Array) *Array {
|
||||
cf := Compile(name, func(in ...*Array) []*Array {
|
||||
return []*Array{fn(in[0])}
|
||||
}, opts...)
|
||||
return func(a *Array) *Array {
|
||||
return cf(a)[0]
|
||||
}
|
||||
}
|
||||
|
||||
// Compile2 compiles a binary function. See Compile.
|
||||
func Compile2(name string, fn func(*Array, *Array) *Array, opts ...CompileOption) func(*Array, *Array) *Array {
|
||||
cf := Compile(name, func(in ...*Array) []*Array {
|
||||
return []*Array{fn(in[0], in[1])}
|
||||
}, opts...)
|
||||
return func(a, b *Array) *Array {
|
||||
return cf(a, b)[0]
|
||||
}
|
||||
}
|
||||
|
||||
// Compile3 compiles a ternary function. See Compile.
|
||||
func Compile3(name string, fn func(*Array, *Array, *Array) *Array, opts ...CompileOption) func(*Array, *Array, *Array) *Array {
|
||||
cf := Compile(name, func(in ...*Array) []*Array {
|
||||
return []*Array{fn(in[0], in[1], in[2])}
|
||||
}, opts...)
|
||||
return func(a, b, c *Array) *Array {
|
||||
return cf(a, b, c)[0]
|
||||
}
|
||||
}
|
||||
|
||||
// tracing is true while a compile callback is running. Since MLX is
|
||||
// single-threaded at this level a plain Go bool suffices.
|
||||
var tracing bool
|
||||
|
||||
// traceScratch collects arrays created during a compile trace so they can be
|
||||
// freed as a group when the callback returns.
|
||||
var traceScratch []*Array
|
||||
|
||||
//export closureCallback
|
||||
func closureCallback(res *C.mlx_vector_array, input C.mlx_vector_array, payload unsafe.Pointer) (rc C.int) {
|
||||
defer func() {
|
||||
if r := recover(); r != nil {
|
||||
slog.Error("mlx closure callback panicked", "panic", r)
|
||||
rc = 1
|
||||
}
|
||||
}()
|
||||
|
||||
handle := *(*cgo.Handle)(payload)
|
||||
fn := handle.Value().(CompileFunc)
|
||||
|
||||
// When tracing, we track all of the intermediates that are created and free them separately at the end of
|
||||
// the process. This will give the effect of a single op - inputs are owned by the original caller (via
|
||||
// the MLX layer) and outputs are transferred back to MLX to create a new Go side tensor.
|
||||
if tracing {
|
||||
panic("mlx: nested compile trace")
|
||||
}
|
||||
tracing = true
|
||||
traceScratch = nil
|
||||
defer func() {
|
||||
for _, a := range traceScratch {
|
||||
if a.pinned > 0 {
|
||||
panic("mlx: traced array was pinned during compilation")
|
||||
}
|
||||
if a.Valid() {
|
||||
C.mlx_array_free(a.ctx)
|
||||
a.ctx.ctx = nil
|
||||
}
|
||||
}
|
||||
tracing = false
|
||||
traceScratch = nil
|
||||
}()
|
||||
|
||||
n := int(C.mlx_vector_array_size(input))
|
||||
inputs := make([]*Array, n)
|
||||
for i := range n {
|
||||
a := New("")
|
||||
C.mlx_vector_array_get(&a.ctx, input, C.size_t(i))
|
||||
inputs[i] = a
|
||||
}
|
||||
|
||||
outputs := fn(inputs...)
|
||||
|
||||
var arrPtr *C.mlx_array
|
||||
if len(outputs) > 0 {
|
||||
handles := make([]C.mlx_array, len(outputs))
|
||||
for i, out := range outputs {
|
||||
handles[i] = out.ctx
|
||||
}
|
||||
arrPtr = &handles[0]
|
||||
}
|
||||
C.mlx_vector_array_set_data(res, arrPtr, C.size_t(len(outputs)))
|
||||
return 0
|
||||
}
|
||||
|
||||
//export closureDestructor
|
||||
func closureDestructor(payload unsafe.Pointer) {
|
||||
handle := *(*cgo.Handle)(payload)
|
||||
handle.Delete()
|
||||
C.free(payload)
|
||||
}
|
||||
147
x/mlxrunner/mlx/compile_test.go
Normal file
147
x/mlxrunner/mlx/compile_test.go
Normal file
@@ -0,0 +1,147 @@
|
||||
package mlx
|
||||
|
||||
import (
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestCompileFusion(t *testing.T) {
|
||||
skipIfNoMLX(t)
|
||||
|
||||
// Compile fuses the ops inside a function body into a single kernel,
|
||||
// eliminating intermediate buffers. Use a diamond-shaped graph where
|
||||
// two branches must be materialized simultaneously without fusion,
|
||||
// then compare peak memory against the compiled version which fuses
|
||||
// everything into one kernel with no intermediates.
|
||||
const n = 1024 * 1024 // 4MB per float32 array
|
||||
data := make([]float32, n)
|
||||
for i := range data {
|
||||
data[i] = float32(i + 1)
|
||||
}
|
||||
|
||||
// Diamond: both a*b and a+b must be live for the final multiply.
|
||||
// Without fusion: peak includes both intermediates (~8MB extra).
|
||||
// With fusion: single kernel, no intermediates.
|
||||
body := func(a, b *Array) *Array {
|
||||
return a.Multiply(b).Multiply(a.Add(b))
|
||||
}
|
||||
|
||||
a := FromValues(data, n)
|
||||
b := FromValues(data, n)
|
||||
Pin(a, b)
|
||||
defer Unpin(a, b)
|
||||
|
||||
// Compiled: ops fused into a single kernel.
|
||||
EnableCompile()
|
||||
fn := Compile2("diamond", body, Shapeless())
|
||||
warm := fn(a, b)
|
||||
Eval(warm)
|
||||
Sweep()
|
||||
ClearCache()
|
||||
ResetPeakMemory()
|
||||
y := fn(a, b)
|
||||
Eval(y)
|
||||
compiledPeak := PeakMemory()
|
||||
Sweep()
|
||||
|
||||
// Uncompiled: ops evaluated individually, intermediates materialized.
|
||||
ClearCache()
|
||||
ResetPeakMemory()
|
||||
z := body(a, b)
|
||||
Eval(z)
|
||||
uncompiledPeak := PeakMemory()
|
||||
Sweep()
|
||||
|
||||
if compiledPeak == 0 && uncompiledPeak == 0 {
|
||||
t.Skip("peak memory tracking not available")
|
||||
}
|
||||
|
||||
t.Logf("peak memory: compiled=%d uncompiled=%d", compiledPeak, uncompiledPeak)
|
||||
|
||||
if compiledPeak >= uncompiledPeak {
|
||||
t.Fatalf("compilation did not reduce peak memory: compiled=%d uncompiled=%d", compiledPeak, uncompiledPeak)
|
||||
}
|
||||
}
|
||||
|
||||
func TestCompileNested(t *testing.T) {
|
||||
skipIfNoMLX(t)
|
||||
|
||||
// A compiled function that calls another compiled function should
|
||||
// produce correct results. The inner function inlines via isTracing()
|
||||
// during the outer's trace.
|
||||
inner := Compile1("silu", func(a *Array) *Array {
|
||||
return a.Multiply(a.Sigmoid())
|
||||
}, Shapeless())
|
||||
|
||||
outer := Compile2("swiglu", func(gate, up *Array) *Array {
|
||||
return inner(gate).Multiply(up)
|
||||
}, Shapeless())
|
||||
|
||||
gate := FromValues([]float32{0, 1, 2}, 3)
|
||||
up := FromValues([]float32{1, 1, 1}, 3)
|
||||
Pin(gate, up)
|
||||
defer Unpin(gate, up)
|
||||
|
||||
y := outer(gate, up)
|
||||
Eval(y)
|
||||
|
||||
// silu(x) = x * sigmoid(x); for x=0 → 0, x=1 → ~0.7311, x=2 → ~1.7616
|
||||
got := y.Floats()
|
||||
want := []float32{0, 0.7310586, 1.7615942}
|
||||
for i, v := range got {
|
||||
if v-want[i] > 1e-4 || want[i]-v > 1e-4 {
|
||||
t.Fatalf("got[%d]=%v want %v", i, v, want[i])
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestCompileCallbackPanicRecovers(t *testing.T) {
|
||||
skipIfNoMLX(t)
|
||||
|
||||
boom := Compile1("boom", func(a *Array) *Array {
|
||||
panic("intentional test panic")
|
||||
})
|
||||
|
||||
x := FromValues([]float32{1}, 1)
|
||||
Pin(x)
|
||||
defer Unpin(x)
|
||||
|
||||
defer func() {
|
||||
r := recover()
|
||||
if r == nil {
|
||||
t.Fatal("expected panic from Call, got none")
|
||||
}
|
||||
if _, ok := r.(string); !ok {
|
||||
t.Fatalf("expected string panic, got %T: %v", r, r)
|
||||
}
|
||||
}()
|
||||
boom(x)
|
||||
}
|
||||
|
||||
func TestCompileNoTrackingGrowth(t *testing.T) {
|
||||
skipIfNoMLX(t)
|
||||
|
||||
// Repeated invocations of a compiled kernel should not grow the
|
||||
// tracked-arrays list — the callback's traceScratch collects
|
||||
// intermediates during tracing and frees them when the callback returns.
|
||||
fn := Compile2("mul_add", func(a, b *Array) *Array {
|
||||
return a.Multiply(b).Add(b)
|
||||
})
|
||||
|
||||
a := FromValues([]float32{1, 2}, 2)
|
||||
b := FromValues([]float32{3, 4}, 2)
|
||||
Pin(a, b)
|
||||
defer Unpin(a, b)
|
||||
|
||||
Sweep()
|
||||
before := len(arrays)
|
||||
|
||||
for range 100 {
|
||||
_ = fn(a, b)
|
||||
Sweep()
|
||||
}
|
||||
|
||||
after := len(arrays)
|
||||
if after > before+2 {
|
||||
t.Fatalf("tracked arrays grew from %d to %d across 100 calls (includes initial trace)", before, after)
|
||||
}
|
||||
}
|
||||
@@ -9,8 +9,8 @@ package mlx
|
||||
// #include "generated.h"
|
||||
// #include <string.h>
|
||||
//
|
||||
// static char _mlx_last_error_msg[1024] = {0};
|
||||
// static int _mlx_last_error_flag = 0;
|
||||
// static __thread char _mlx_last_error_msg[1024] = {0};
|
||||
// static __thread int _mlx_last_error_flag = 0;
|
||||
//
|
||||
// static void _mlx_capture_error_handler(const char* msg, void* data) {
|
||||
// (void)data;
|
||||
@@ -30,15 +30,13 @@ package mlx
|
||||
// _mlx_last_error_msg[0] = '\0';
|
||||
// }
|
||||
//
|
||||
// static int mlx_had_last_error(void) {
|
||||
// return _mlx_last_error_flag;
|
||||
// }
|
||||
//
|
||||
// static const char* mlx_get_last_error(void) {
|
||||
// return _mlx_last_error_flag ? _mlx_last_error_msg : NULL;
|
||||
// return _mlx_last_error_flag ? _mlx_last_error_msg : "";
|
||||
// }
|
||||
import "C"
|
||||
|
||||
import "runtime"
|
||||
|
||||
func init() {
|
||||
// Replace the default exit(-1) error handler with one that captures
|
||||
// the error message so we can surface it in Go.
|
||||
@@ -53,6 +51,24 @@ func Version() string {
|
||||
return C.GoString(C.mlx_string_data(str))
|
||||
}
|
||||
|
||||
// mlxCheck locks the goroutine to its OS thread, clears the captured error
|
||||
// state, calls fn, and panics with the captured message if fn returns non-zero.
|
||||
// The thread lock ensures the thread-local error state is read from the same
|
||||
// thread that executed the call.
|
||||
func mlxCheck(fallback string, fn func() C.int) {
|
||||
runtime.LockOSThread()
|
||||
defer runtime.UnlockOSThread()
|
||||
|
||||
C.mlx_clear_last_error()
|
||||
if fn() != 0 {
|
||||
msg := C.GoString(C.mlx_get_last_error())
|
||||
if msg == "" {
|
||||
msg = fallback
|
||||
}
|
||||
panic("mlx: " + msg)
|
||||
}
|
||||
}
|
||||
|
||||
func doEval(outputs []*Array, async bool) {
|
||||
if len(outputs) == 0 {
|
||||
return
|
||||
@@ -67,20 +83,12 @@ func doEval(outputs []*Array, async bool) {
|
||||
}
|
||||
}
|
||||
|
||||
C.mlx_clear_last_error()
|
||||
var rc C.int
|
||||
if async {
|
||||
rc = C.mlx_async_eval(vector)
|
||||
} else {
|
||||
rc = C.mlx_eval(vector)
|
||||
}
|
||||
if rc != 0 {
|
||||
msg := "mlx eval failed"
|
||||
if C.mlx_had_last_error() != 0 {
|
||||
msg = C.GoString(C.mlx_get_last_error())
|
||||
mlxCheck("eval failed", func() C.int {
|
||||
if async {
|
||||
return C.mlx_async_eval(vector)
|
||||
}
|
||||
panic("mlx: " + msg)
|
||||
}
|
||||
return C.mlx_eval(vector)
|
||||
})
|
||||
}
|
||||
|
||||
func AsyncEval(outputs ...*Array) {
|
||||
@@ -90,3 +98,10 @@ func AsyncEval(outputs ...*Array) {
|
||||
func Eval(outputs ...*Array) {
|
||||
doEval(outputs, false)
|
||||
}
|
||||
|
||||
// MetalIsAvailable returns true if a Metal GPU is available.
|
||||
func MetalIsAvailable() bool {
|
||||
var available C._Bool
|
||||
C.mlx_metal_is_available(&available)
|
||||
return bool(available)
|
||||
}
|
||||
|
||||
@@ -149,45 +149,132 @@ func Contiguous(a *Array, allowColMajor bool) *Array {
|
||||
return out
|
||||
}
|
||||
|
||||
func Pad(a *Array, paddings []int32) *Array {
|
||||
numAxes := len(paddings) / 2
|
||||
axes := make([]C.int, numAxes)
|
||||
lowPad := make([]C.int, numAxes)
|
||||
highPad := make([]C.int, numAxes)
|
||||
for i := range numAxes {
|
||||
axes[i] = C.int(i)
|
||||
lowPad[i] = C.int(paddings[i*2])
|
||||
highPad[i] = C.int(paddings[i*2+1])
|
||||
// Conv2d performs 2D convolution: x [N,H,W,C_in], weight [C_out,kH,kW,C_in].
|
||||
// MLX uses NHWC layout.
|
||||
func Conv2d(x, weight *Array, strideH, strideW, padH, padW, dilationH, dilationW, groups int32) *Array {
|
||||
out := New("CONV2D")
|
||||
C.mlx_conv2d(
|
||||
&out.ctx,
|
||||
x.ctx,
|
||||
weight.ctx,
|
||||
C.int(strideH), C.int(strideW),
|
||||
C.int(padH), C.int(padW),
|
||||
C.int(dilationH), C.int(dilationW),
|
||||
C.int(groups),
|
||||
DefaultStream().ctx,
|
||||
)
|
||||
return out
|
||||
}
|
||||
|
||||
// Pad pads array a along the given axes with specified low/high pad sizes.
|
||||
// mode should be "constant", "edge", or "reflect".
|
||||
func Pad(a *Array, axes []int, lowPad, highPad []int, padValue *Array, mode string) *Array {
|
||||
cAxes := make([]C.int, len(axes))
|
||||
cLow := make([]C.int, len(lowPad))
|
||||
cHigh := make([]C.int, len(highPad))
|
||||
for i := range axes {
|
||||
cAxes[i] = C.int(axes[i])
|
||||
cLow[i] = C.int(lowPad[i])
|
||||
cHigh[i] = C.int(highPad[i])
|
||||
}
|
||||
|
||||
padValue := C.mlx_array_new_float(C.float(0))
|
||||
defer C.mlx_array_free(padValue)
|
||||
|
||||
cMode := C.CString("constant")
|
||||
cMode := C.CString(mode)
|
||||
defer C.free(unsafe.Pointer(cMode))
|
||||
|
||||
out := New("PAD")
|
||||
C.mlx_pad(
|
||||
&out.ctx,
|
||||
a.ctx,
|
||||
unsafe.SliceData(axes),
|
||||
C.size_t(len(axes)),
|
||||
unsafe.SliceData(lowPad),
|
||||
C.size_t(len(lowPad)),
|
||||
unsafe.SliceData(highPad),
|
||||
C.size_t(len(highPad)),
|
||||
padValue,
|
||||
unsafe.SliceData(cAxes), C.size_t(len(cAxes)),
|
||||
unsafe.SliceData(cLow), C.size_t(len(cLow)),
|
||||
unsafe.SliceData(cHigh), C.size_t(len(cHigh)),
|
||||
padValue.ctx,
|
||||
cMode,
|
||||
DefaultStream().ctx,
|
||||
)
|
||||
return out
|
||||
}
|
||||
|
||||
// PadConstant pads with zeros along the given axes.
|
||||
func PadConstant(a *Array, axes []int, lowPad, highPad []int) *Array {
|
||||
zero := NewScalarArray(float32(0))
|
||||
return Pad(a, axes, lowPad, highPad, zero, "constant")
|
||||
}
|
||||
|
||||
func DepthwiseConv1d(x, weight *Array, bias *Array) *Array {
|
||||
groups := int32(x.Dim(x.NumDims() - 1))
|
||||
return Conv1d(x, weight, bias, 1, 0, 1, groups)
|
||||
}
|
||||
|
||||
// Maximum returns element-wise maximum of two arrays.
|
||||
func Maximum(a, b *Array) *Array {
|
||||
out := New("MAXIMUM")
|
||||
C.mlx_maximum(&out.ctx, a.ctx, b.ctx, DefaultStream().ctx)
|
||||
return out
|
||||
}
|
||||
|
||||
// Minimum returns element-wise minimum of two arrays.
|
||||
func Minimum(a, b *Array) *Array {
|
||||
out := New("MINIMUM")
|
||||
C.mlx_minimum(&out.ctx, a.ctx, b.ctx, DefaultStream().ctx)
|
||||
return out
|
||||
}
|
||||
|
||||
// Softplus computes log(1 + exp(x)) using logaddexp for numerical stability.
|
||||
func Softplus(a *Array) *Array {
|
||||
return Logaddexp(a, Zeros(a.DType(), a.Dims()...))
|
||||
}
|
||||
|
||||
// ReLU computes max(0, x).
|
||||
func ReLU(a *Array) *Array {
|
||||
return Maximum(a, NewScalarArray(float32(0)))
|
||||
}
|
||||
|
||||
// GLU applies Gated Linear Unit: splits x along last dim into two halves,
|
||||
// returns first * sigmoid(second).
|
||||
func GLU(a *Array) *Array {
|
||||
lastDim := a.NumDims() - 1
|
||||
halfSize := a.Dim(lastDim) / 2
|
||||
first := SliceStartStop(a,
|
||||
make([]int32, lastDim+1), // all zeros for start
|
||||
appendDims(a, lastDim, int32(halfSize)),
|
||||
)
|
||||
second := SliceStartStop(a,
|
||||
appendDimsStart(a, lastDim, int32(halfSize)),
|
||||
appendDims(a, lastDim, int32(a.Dim(lastDim))),
|
||||
)
|
||||
return first.Multiply(second.Sigmoid())
|
||||
}
|
||||
|
||||
// helper: builds stop array for SliceStartStop where the target axis = val
|
||||
func appendDims(a *Array, targetAxis int, val int32) []int32 {
|
||||
n := a.NumDims()
|
||||
out := make([]int32, n)
|
||||
for i := range n {
|
||||
if i == targetAxis {
|
||||
out[i] = val
|
||||
} else {
|
||||
out[i] = int32(a.Dim(i))
|
||||
}
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
// helper: builds start array for SliceStartStop where the target axis = val
|
||||
func appendDimsStart(a *Array, targetAxis int, val int32) []int32 {
|
||||
n := a.NumDims()
|
||||
out := make([]int32, n)
|
||||
for i := range n {
|
||||
if i == targetAxis {
|
||||
out[i] = val
|
||||
}
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
// Clamp clamps array values to [min, max].
|
||||
func Clamp(a *Array, minVal, maxVal float32) *Array {
|
||||
return Minimum(Maximum(a, NewScalarArray(minVal)), NewScalarArray(maxVal))
|
||||
}
|
||||
|
||||
// Convenience wrappers (function-style for the model code)
|
||||
|
||||
func Stack(arrays []*Array, axis int) *Array {
|
||||
@@ -317,26 +404,38 @@ func GatherMM(a, b *Array, lhsIndices, rhsIndices *Array, sortedIndices bool) *A
|
||||
return a.GatherMM(b, lhsIndices, rhsIndices, sortedIndices)
|
||||
}
|
||||
|
||||
func SiLU(a *Array) *Array {
|
||||
sig := a.Sigmoid()
|
||||
return a.Multiply(sig)
|
||||
func RoPEWithBase(x *Array, dims int, traditional bool, base, scale float32, offset int) *Array {
|
||||
return RoPEWithFreqs(x, dims, traditional, base, scale, offset, nil)
|
||||
}
|
||||
|
||||
func RoPEWithBase(x *Array, dims int, traditional bool, base, scale float32, offset int) *Array {
|
||||
freqs := New("")
|
||||
// RoPEWithFreqs applies RoPE with optional custom frequencies.
|
||||
// When freqs is non-nil, it is used instead of computing from base.
|
||||
// Note: MLX takes reciprocal(freqs) internally to get inv_freq, so pass
|
||||
// the actual frequencies (base^(2i/dim)), not the inverse frequencies.
|
||||
func RoPEWithFreqs(x *Array, dims int, traditional bool, base, scale float32, offset int, freqs *Array) *Array {
|
||||
var freqsCtx C.mlx_array
|
||||
var optBase C.mlx_optional_float
|
||||
if freqs != nil {
|
||||
freqsCtx = freqs.ctx
|
||||
optBase = C.mlx_optional_float{has_value: C.bool(false)}
|
||||
} else {
|
||||
empty := New("")
|
||||
freqsCtx = empty.ctx
|
||||
optBase = C.mlx_optional_float{
|
||||
value: C.float(base),
|
||||
has_value: C.bool(func() bool { return base != 0 }()),
|
||||
}
|
||||
}
|
||||
out := New("FAST_ROPE")
|
||||
C.mlx_fast_rope(
|
||||
&out.ctx,
|
||||
x.ctx,
|
||||
C.int(dims),
|
||||
C.bool(traditional),
|
||||
C.mlx_optional_float{
|
||||
value: C.float(base),
|
||||
has_value: C.bool(func() bool { return base != 0 }()),
|
||||
},
|
||||
optBase,
|
||||
C.float(scale),
|
||||
C.int(offset),
|
||||
freqs.ctx,
|
||||
freqsCtx,
|
||||
DefaultStream().ctx,
|
||||
)
|
||||
return out
|
||||
@@ -358,6 +457,24 @@ func Log(a *Array) *Array {
|
||||
return out
|
||||
}
|
||||
|
||||
func Sin(a *Array) *Array {
|
||||
out := New("SIN")
|
||||
C.mlx_sin(&out.ctx, a.ctx, DefaultStream().ctx)
|
||||
return out
|
||||
}
|
||||
|
||||
func Cos(a *Array) *Array {
|
||||
out := New("COS")
|
||||
C.mlx_cos(&out.ctx, a.ctx, DefaultStream().ctx)
|
||||
return out
|
||||
}
|
||||
|
||||
func Clip(a, aMin, aMax *Array) *Array {
|
||||
out := New("CLIP")
|
||||
C.mlx_clip(&out.ctx, a.ctx, aMin.ctx, aMax.ctx, DefaultStream().ctx)
|
||||
return out
|
||||
}
|
||||
|
||||
func Logaddexp(a, b *Array) *Array {
|
||||
out := New("LOGADDEXP")
|
||||
C.mlx_logaddexp(&out.ctx, a.ctx, b.ctx, DefaultStream().ctx)
|
||||
@@ -385,6 +502,20 @@ func ScaledDotProductAttentionCausal(q, k, v *Array, scale float32, causalMask b
|
||||
return out
|
||||
}
|
||||
|
||||
// ScaledDotProductAttentionMasked runs the fast SDPA kernel with an explicit
|
||||
// additive mask. The mask is broadcast to [B, H, Q, K] and added to scores
|
||||
// before softmax. Pass mode="array" so MLX actually consults mask_arr; the
|
||||
// empty string is "no mask" and silently ignores the array argument.
|
||||
func ScaledDotProductAttentionMasked(q, k, v *Array, scale float32, mask *Array) *Array {
|
||||
sinks := New("")
|
||||
cMode := C.CString("array")
|
||||
defer C.free(unsafe.Pointer(cMode))
|
||||
|
||||
out := New("FAST_SDPA")
|
||||
C.mlx_fast_scaled_dot_product_attention(&out.ctx, q.ctx, k.ctx, v.ctx, C.float(scale), cMode, mask.ctx, sinks.ctx, DefaultStream().ctx)
|
||||
return out
|
||||
}
|
||||
|
||||
func LayerNormFn(x, weight, bias *Array, eps float32) *Array {
|
||||
out := New("FAST_LAYERNORM")
|
||||
var w, b C.mlx_array
|
||||
|
||||
@@ -131,6 +131,12 @@ func readBlobTensorQuantInfo(path string) (map[string]*TensorQuantInfo, string,
|
||||
globalQuantType, globalGroupSize := parseGlobalQuantMetadata(header)
|
||||
globalQuantType = strings.ToUpper(globalQuantType)
|
||||
|
||||
// Parse full metadata for per-tensor quant info
|
||||
var metaMap map[string]string
|
||||
if metaRaw, ok := header["__metadata__"]; ok {
|
||||
json.Unmarshal(metaRaw, &metaMap)
|
||||
}
|
||||
|
||||
mainNames := mainTensorNames(header)
|
||||
infos := make(map[string]*TensorQuantInfo)
|
||||
for _, name := range mainNames {
|
||||
@@ -141,6 +147,18 @@ func readBlobTensorQuantInfo(path string) (map[string]*TensorQuantInfo, string,
|
||||
quantType := globalQuantType
|
||||
groupSize := globalGroupSize
|
||||
|
||||
// Check per-tensor metadata (e.g. from packed expert blobs with mixed precision)
|
||||
if metaMap != nil {
|
||||
if qt, ok := metaMap[name+".quant_type"]; ok && qt != "" {
|
||||
quantType = strings.ToUpper(qt)
|
||||
}
|
||||
if gs, ok := metaMap[name+".group_size"]; ok && gs != "" {
|
||||
if v, err := strconv.Atoi(gs); err == nil {
|
||||
groupSize = v
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
inferredType, inferredGroup := inferQuantTypeFromShapes(header, name, quantType)
|
||||
if quantType == "" {
|
||||
quantType = inferredType
|
||||
|
||||
@@ -23,15 +23,6 @@ func (r *Runner) TextGenerationPipeline(request Request) error {
|
||||
return errors.New("model not loaded")
|
||||
}
|
||||
|
||||
enableCompile := true
|
||||
if modelCompile, ok := r.Model.(interface{ EnableCompile() bool }); ok {
|
||||
enableCompile = modelCompile.EnableCompile()
|
||||
}
|
||||
if enableCompile {
|
||||
mlx.EnableCompile()
|
||||
} else {
|
||||
mlx.DisableCompile()
|
||||
}
|
||||
mlx.ResetPeakMemory()
|
||||
ctx := request.Ctx
|
||||
var (
|
||||
|
||||
@@ -79,6 +79,8 @@ func (r *Runner) Load(modelName string) error {
|
||||
r.Model = m
|
||||
r.Tokenizer = m.Tokenizer()
|
||||
r.contextLength = m.MaxContextLength()
|
||||
|
||||
mlx.EnableCompile()
|
||||
return nil
|
||||
}
|
||||
|
||||
|
||||
1508
x/models/gemma4/gemma4.go
Normal file
1508
x/models/gemma4/gemma4.go
Normal file
File diff suppressed because it is too large
Load Diff
228
x/models/gemma4/gemma4_moe_test.go
Normal file
228
x/models/gemma4/gemma4_moe_test.go
Normal file
@@ -0,0 +1,228 @@
|
||||
package gemma4
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/ollama/ollama/x/mlxrunner/mlx"
|
||||
)
|
||||
|
||||
// onesLike creates a tensor of the given shape filled with a small constant.
|
||||
func onesLike(shape ...int) *mlx.Array {
|
||||
return mlx.AddScalar(mlx.Zeros(mlx.DTypeBFloat16, shape...), 0.01)
|
||||
}
|
||||
|
||||
func TestMoEForward(t *testing.T) {
|
||||
skipIfNoMLX(t)
|
||||
|
||||
// Small config matching 26b architecture pattern.
|
||||
cfg := &TextConfig{
|
||||
HiddenSize: 16, // tiny for testing
|
||||
NumAttentionHeads: 2,
|
||||
NumKeyValueHeads: 1,
|
||||
NumGlobalKeyValueHeads: 1,
|
||||
HeadDim: 8,
|
||||
GlobalHeadDim: 8,
|
||||
NumExperts: 4,
|
||||
TopKExperts: 2,
|
||||
ExpertIntermediateSize: 8,
|
||||
EnableMoeBlock: true,
|
||||
AttentionKEqV: false,
|
||||
RMSNormEps: 1e-6,
|
||||
SlidingScale: 1.0,
|
||||
FullScale: 1.0,
|
||||
}
|
||||
|
||||
B, L := int32(1), int32(3)
|
||||
x := onesLike(int(B), int(L), int(cfg.HiddenSize))
|
||||
|
||||
// Test Router.Forward.
|
||||
router := &Router{
|
||||
Proj: linearFromWeight(onesLike(int(cfg.NumExperts), int(cfg.HiddenSize))),
|
||||
Scale: onesLike(int(cfg.HiddenSize)),
|
||||
}
|
||||
|
||||
t.Run("Router", func(t *testing.T) {
|
||||
scores, inds := router.Forward(x, cfg)
|
||||
mlx.Eval(scores, inds)
|
||||
|
||||
sDims := scores.Dims()
|
||||
iDims := inds.Dims()
|
||||
t.Logf("scores shape: %v, inds shape: %v", sDims, iDims)
|
||||
|
||||
if len(sDims) != 2 || sDims[0] != int(B*L) || sDims[1] != int(cfg.TopKExperts) {
|
||||
t.Errorf("scores shape = %v, want [%d, %d]", sDims, B*L, cfg.TopKExperts)
|
||||
}
|
||||
if len(iDims) != 2 || iDims[0] != int(B*L) || iDims[1] != int(cfg.TopKExperts) {
|
||||
t.Errorf("inds shape = %v, want [%d, %d]", iDims, B*L, cfg.TopKExperts)
|
||||
}
|
||||
})
|
||||
|
||||
// Test MoEBlock.Forward.
|
||||
moe := &MoEBlock{
|
||||
GateWeight: onesLike(int(cfg.NumExperts), int(cfg.HiddenSize), int(cfg.ExpertIntermediateSize)),
|
||||
UpWeight: onesLike(int(cfg.NumExperts), int(cfg.HiddenSize), int(cfg.ExpertIntermediateSize)),
|
||||
DownWeight: onesLike(int(cfg.NumExperts), int(cfg.ExpertIntermediateSize), int(cfg.HiddenSize)),
|
||||
PerExpertScale: onesLike(int(cfg.NumExperts)),
|
||||
}
|
||||
|
||||
t.Run("MoEBlock", func(t *testing.T) {
|
||||
scores, inds := router.Forward(x, cfg)
|
||||
mlx.Eval(scores, inds)
|
||||
|
||||
out := moe.Forward(x, scores, inds, cfg)
|
||||
mlx.Eval(out)
|
||||
|
||||
outDims := out.Dims()
|
||||
t.Logf("MoE output shape: %v", outDims)
|
||||
|
||||
if len(outDims) != 3 || outDims[0] != int(B) || outDims[1] != int(L) || outDims[2] != int(cfg.HiddenSize) {
|
||||
t.Errorf("output shape = %v, want [%d, %d, %d]", outDims, B, L, cfg.HiddenSize)
|
||||
}
|
||||
})
|
||||
|
||||
// Test with larger batch to exercise the sorted GatherMM path (B*L >= 64).
|
||||
t.Run("MoEBlock_sorted", func(t *testing.T) {
|
||||
bigB, bigL := int32(1), int32(128)
|
||||
bigX := onesLike(int(bigB), int(bigL), int(cfg.HiddenSize))
|
||||
|
||||
scores, inds := router.Forward(bigX, cfg)
|
||||
mlx.Eval(scores, inds)
|
||||
|
||||
out := moe.Forward(bigX, scores, inds, cfg)
|
||||
mlx.Eval(out)
|
||||
|
||||
outDims := out.Dims()
|
||||
t.Logf("MoE sorted output shape: %v", outDims)
|
||||
|
||||
if len(outDims) != 3 || outDims[0] != int(bigB) || outDims[1] != int(bigL) || outDims[2] != int(cfg.HiddenSize) {
|
||||
t.Errorf("output shape = %v, want [%d, %d, %d]", outDims, bigB, bigL, cfg.HiddenSize)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
// TestRouterForwardMatchesLegacy verifies the optimized Router.Forward —
|
||||
// which takes the top-k of the raw logits and softmaxes only the selected
|
||||
// values — produces the same indices and (within tolerance) the same
|
||||
// normalized scores as the legacy path that softmaxes over every expert
|
||||
// first, gathers the top-k probabilities, then renormalizes.
|
||||
func TestRouterForwardMatchesLegacy(t *testing.T) {
|
||||
skipIfNoMLX(t)
|
||||
|
||||
cfg := &TextConfig{
|
||||
HiddenSize: 8,
|
||||
NumExperts: 4,
|
||||
TopKExperts: 2,
|
||||
RMSNormEps: 1e-6,
|
||||
RouterScale: 0.5,
|
||||
}
|
||||
|
||||
// Distinct per-expert weight rows so top-k has a well-defined ordering
|
||||
// (tied scores would let argpartition pick either tied expert and make
|
||||
// the index comparison below flaky).
|
||||
projWeight := mlx.FromValues([]float32{
|
||||
0.10, 0.11, 0.12, 0.13, 0.14, 0.15, 0.16, 0.17, // expert 0
|
||||
0.30, 0.29, 0.28, 0.27, 0.26, 0.25, 0.24, 0.23, // expert 1
|
||||
-0.05, -0.06, -0.07, -0.08, -0.09, -0.10, -0.11, -0.12, // expert 2
|
||||
0.50, 0.48, 0.46, 0.44, 0.42, 0.40, 0.38, 0.36, // expert 3
|
||||
}, int(cfg.NumExperts), int(cfg.HiddenSize))
|
||||
|
||||
scale := mlx.FromValues([]float32{
|
||||
1.0, 0.9, 1.1, 1.0, 1.2, 0.8, 1.0, 1.05,
|
||||
}, int(cfg.HiddenSize))
|
||||
|
||||
r := &Router{
|
||||
Proj: linearFromWeight(projWeight),
|
||||
Scale: scale,
|
||||
}
|
||||
|
||||
// Varied x so different positions potentially hit different top-k.
|
||||
x := mlx.FromValues([]float32{
|
||||
0.2, -0.1, 0.3, 0.0, 0.4, -0.2, 0.1, 0.05,
|
||||
-0.3, 0.2, -0.1, 0.4, -0.05, 0.3, 0.0, 0.2,
|
||||
0.5, 0.4, -0.2, 0.1, -0.3, 0.0, 0.3, -0.1,
|
||||
}, 1, 3, int(cfg.HiddenSize))
|
||||
|
||||
gotScores, gotInds := r.Forward(x, cfg)
|
||||
wantScores, wantInds := legacyRouterForward(r, x, cfg)
|
||||
mlx.Eval(gotScores, gotInds, wantScores, wantInds)
|
||||
|
||||
if got, want := gotInds.Ints(), wantInds.Ints(); !intSlicesEqual(got, want) {
|
||||
t.Fatalf("indices mismatch:\n got %v\n want %v", got, want)
|
||||
}
|
||||
if got, want := gotScores.Floats(), wantScores.Floats(); !floatSlicesClose(got, want, 1e-5) {
|
||||
t.Fatalf("scores mismatch:\n got %v\n want %v", got, want)
|
||||
}
|
||||
}
|
||||
|
||||
// legacyRouterForward implements the pre-optimization router: full softmax
|
||||
// over every expert, gather the top-k probabilities, then renormalize them
|
||||
// to sum to 1. Algebraically identical to the fused form in Router.Forward.
|
||||
func legacyRouterForward(r *Router, x *mlx.Array, cfg *TextConfig) (*mlx.Array, *mlx.Array) {
|
||||
dims := x.Dims()
|
||||
BL := int32(dims[0]) * int32(dims[1])
|
||||
|
||||
xFlat := mlx.Reshape(x, BL, cfg.HiddenSize)
|
||||
normed := mlx.RMSNormFn(xFlat, nil, cfg.RMSNormEps)
|
||||
normed = mlx.MulScalar(normed, cfg.RouterScale)
|
||||
normed = mlx.Mul(normed, r.Scale)
|
||||
|
||||
expertScores := r.Proj.Forward(normed)
|
||||
probs := mlx.SoftmaxAxis(expertScores, -1, true)
|
||||
|
||||
neg := mlx.Neg(expertScores)
|
||||
inds := mlx.Argpartition(neg, int(cfg.TopKExperts)-1, -1)
|
||||
inds = mlx.SliceStartStop(inds,
|
||||
[]int32{0, 0},
|
||||
[]int32{BL, cfg.TopKExperts},
|
||||
)
|
||||
|
||||
scores := mlx.TakeAlongAxis(probs, inds, -1)
|
||||
sumScores := mlx.Sum(scores, -1, true)
|
||||
scores = mlx.Div(scores, sumScores)
|
||||
return scores, inds
|
||||
}
|
||||
|
||||
func intSlicesEqual(a, b []int) bool {
|
||||
if len(a) != len(b) {
|
||||
return false
|
||||
}
|
||||
for i := range a {
|
||||
if a[i] != b[i] {
|
||||
return false
|
||||
}
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
func floatSlicesClose(a, b []float32, tol float32) bool {
|
||||
if len(a) != len(b) {
|
||||
return false
|
||||
}
|
||||
for i := range a {
|
||||
d := a[i] - b[i]
|
||||
if d < 0 {
|
||||
d = -d
|
||||
}
|
||||
if d > tol {
|
||||
return false
|
||||
}
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
// linearFromWeight creates a simple nn.LinearLayer from a weight tensor (no bias).
|
||||
func linearFromWeight(w *mlx.Array) *simpleLinear {
|
||||
return &simpleLinear{weight: w}
|
||||
}
|
||||
|
||||
type simpleLinear struct {
|
||||
weight *mlx.Array
|
||||
}
|
||||
|
||||
func (l *simpleLinear) Forward(x *mlx.Array) *mlx.Array {
|
||||
return x.Matmul(mlx.Transpose(l.weight, 1, 0))
|
||||
}
|
||||
|
||||
func (l *simpleLinear) OutputDim() int32 {
|
||||
return int32(l.weight.Dims()[0])
|
||||
}
|
||||
503
x/models/gemma4/gemma4_test.go
Normal file
503
x/models/gemma4/gemma4_test.go
Normal file
@@ -0,0 +1,503 @@
|
||||
package gemma4
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/ollama/ollama/x/mlxrunner/mlx"
|
||||
)
|
||||
|
||||
func TestParseTextConfigE2B(t *testing.T) {
|
||||
skipIfNoMLX(t)
|
||||
data := []byte(`{
|
||||
"architectures": ["Gemma4ForConditionalGeneration"],
|
||||
"text_config": {
|
||||
"hidden_size": 1536,
|
||||
"num_hidden_layers": 35,
|
||||
"intermediate_size": 6144,
|
||||
"num_attention_heads": 8,
|
||||
"num_key_value_heads": 1,
|
||||
"head_dim": 256,
|
||||
"global_head_dim": 512,
|
||||
"vocab_size": 262144,
|
||||
"rms_norm_eps": 1e-6,
|
||||
"max_position_embeddings": 131072,
|
||||
"sliding_window": 512,
|
||||
"sliding_window_pattern": 5,
|
||||
"final_logit_softcapping": 30.0,
|
||||
"use_double_wide_mlp": true,
|
||||
"num_kv_shared_layers": 20,
|
||||
"hidden_size_per_layer_input": 256,
|
||||
"vocab_size_per_layer_input": 262144,
|
||||
"attention_k_eq_v": false,
|
||||
"tie_word_embeddings": true,
|
||||
"layer_types": [
|
||||
"sliding_attention","sliding_attention","sliding_attention","sliding_attention","full_attention",
|
||||
"sliding_attention","sliding_attention","sliding_attention","sliding_attention","full_attention",
|
||||
"sliding_attention","sliding_attention","sliding_attention","sliding_attention","full_attention",
|
||||
"sliding_attention","sliding_attention","sliding_attention","sliding_attention","full_attention",
|
||||
"sliding_attention","sliding_attention","sliding_attention","sliding_attention","full_attention",
|
||||
"sliding_attention","sliding_attention","sliding_attention","sliding_attention","full_attention",
|
||||
"sliding_attention","sliding_attention","sliding_attention","sliding_attention","full_attention"
|
||||
],
|
||||
"rope_parameters": {
|
||||
"full_attention": {
|
||||
"partial_rotary_factor": 0.25,
|
||||
"rope_theta": 1000000.0,
|
||||
"rope_type": "proportional"
|
||||
},
|
||||
"sliding_attention": {
|
||||
"rope_theta": 10000.0,
|
||||
"rope_type": "default"
|
||||
}
|
||||
}
|
||||
}
|
||||
}`)
|
||||
|
||||
cfg, err := parseTextConfig(data)
|
||||
if err != nil {
|
||||
t.Fatalf("parseTextConfig failed: %v", err)
|
||||
}
|
||||
|
||||
// Basic fields.
|
||||
if cfg.HiddenSize != 1536 {
|
||||
t.Errorf("HiddenSize = %d, want 1536", cfg.HiddenSize)
|
||||
}
|
||||
if cfg.NumHiddenLayers != 35 {
|
||||
t.Errorf("NumHiddenLayers = %d, want 35", cfg.NumHiddenLayers)
|
||||
}
|
||||
if cfg.GlobalHeadDim != 512 {
|
||||
t.Errorf("GlobalHeadDim = %d, want 512", cfg.GlobalHeadDim)
|
||||
}
|
||||
if cfg.FinalLogitSoftcapping != 30.0 {
|
||||
t.Errorf("FinalLogitSoftcapping = %f, want 30.0", cfg.FinalLogitSoftcapping)
|
||||
}
|
||||
if cfg.NumKVSharedLayers != 20 {
|
||||
t.Errorf("NumKVSharedLayers = %d, want 20", cfg.NumKVSharedLayers)
|
||||
}
|
||||
if cfg.HiddenSizePerLayer != 256 {
|
||||
t.Errorf("HiddenSizePerLayer = %d, want 256", cfg.HiddenSizePerLayer)
|
||||
}
|
||||
|
||||
// RoPE settings.
|
||||
if cfg.SlidingRopeDims != 256 {
|
||||
t.Errorf("SlidingRopeDims = %d, want 256", cfg.SlidingRopeDims)
|
||||
}
|
||||
if cfg.FullRopeDims != 512 {
|
||||
t.Errorf("FullRopeDims = %d, want 512 (GlobalHeadDim, partial rotation handled via custom freqs)", cfg.FullRopeDims)
|
||||
}
|
||||
if cfg.SlidingRopeBase != 10000 {
|
||||
t.Errorf("SlidingRopeBase = %f, want 10000", cfg.SlidingRopeBase)
|
||||
}
|
||||
if cfg.FullRopeBase != 1000000 {
|
||||
t.Errorf("FullRopeBase = %f, want 1000000", cfg.FullRopeBase)
|
||||
}
|
||||
|
||||
// Attention scale.
|
||||
if cfg.SlidingScale == 0 || cfg.FullScale == 0 {
|
||||
t.Error("attention scales should be non-zero")
|
||||
}
|
||||
|
||||
// KV sharing map.
|
||||
// First shared layer is 35 - 20 = 15.
|
||||
if donor, ok := cfg.KVShareMap[15]; !ok || donor != 13 {
|
||||
t.Errorf("KVShareMap[15] = %d, ok=%v; want 13, true", donor, ok)
|
||||
}
|
||||
if donor, ok := cfg.KVShareMap[19]; !ok || donor != 14 {
|
||||
t.Errorf("KVShareMap[19] = %d, ok=%v; want 14, true (full attn donor)", donor, ok)
|
||||
}
|
||||
if donor, ok := cfg.KVShareMap[34]; !ok || donor != 14 {
|
||||
t.Errorf("KVShareMap[34] = %d, ok=%v; want 14, true (full attn donor)", donor, ok)
|
||||
}
|
||||
// Layer 14 should not be shared.
|
||||
if _, ok := cfg.KVShareMap[14]; ok {
|
||||
t.Error("layer 14 should not be in KVShareMap (non-shared)")
|
||||
}
|
||||
|
||||
// Donors.
|
||||
if !cfg.KVDonors[13] {
|
||||
t.Error("layer 13 should be a KV donor")
|
||||
}
|
||||
if !cfg.KVDonors[14] {
|
||||
t.Error("layer 14 should be a KV donor")
|
||||
}
|
||||
}
|
||||
|
||||
func TestParseTextConfig26B(t *testing.T) {
|
||||
skipIfNoMLX(t)
|
||||
data := []byte(`{
|
||||
"architectures": ["Gemma4ForConditionalGeneration"],
|
||||
"text_config": {
|
||||
"hidden_size": 2816,
|
||||
"num_hidden_layers": 30,
|
||||
"intermediate_size": 2112,
|
||||
"num_attention_heads": 16,
|
||||
"num_key_value_heads": 8,
|
||||
"num_global_key_value_heads": 2,
|
||||
"head_dim": 256,
|
||||
"global_head_dim": 512,
|
||||
"vocab_size": 262144,
|
||||
"rms_norm_eps": 1e-6,
|
||||
"max_position_embeddings": 131072,
|
||||
"sliding_window": 1024,
|
||||
"final_logit_softcapping": 30.0,
|
||||
"use_double_wide_mlp": false,
|
||||
"num_kv_shared_layers": 0,
|
||||
"hidden_size_per_layer_input": null,
|
||||
"attention_k_eq_v": true,
|
||||
"enable_moe_block": true,
|
||||
"num_experts": 128,
|
||||
"top_k_experts": 8,
|
||||
"moe_intermediate_size": 704,
|
||||
"tie_word_embeddings": true,
|
||||
"layer_types": [
|
||||
"sliding_attention","sliding_attention","sliding_attention","sliding_attention","sliding_attention","full_attention",
|
||||
"sliding_attention","sliding_attention","sliding_attention","sliding_attention","sliding_attention","full_attention",
|
||||
"sliding_attention","sliding_attention","sliding_attention","sliding_attention","sliding_attention","full_attention",
|
||||
"sliding_attention","sliding_attention","sliding_attention","sliding_attention","sliding_attention","full_attention",
|
||||
"sliding_attention","sliding_attention","sliding_attention","sliding_attention","sliding_attention","full_attention"
|
||||
],
|
||||
"rope_parameters": {
|
||||
"full_attention": {
|
||||
"partial_rotary_factor": 0.25,
|
||||
"rope_theta": 1000000.0,
|
||||
"rope_type": "proportional"
|
||||
},
|
||||
"sliding_attention": {
|
||||
"rope_theta": 10000.0,
|
||||
"rope_type": "default"
|
||||
}
|
||||
}
|
||||
}
|
||||
}`)
|
||||
|
||||
cfg, err := parseTextConfig(data)
|
||||
if err != nil {
|
||||
t.Fatalf("parseTextConfig failed: %v", err)
|
||||
}
|
||||
|
||||
if cfg.HiddenSize != 2816 {
|
||||
t.Errorf("HiddenSize = %d, want 2816", cfg.HiddenSize)
|
||||
}
|
||||
if !cfg.AttentionKEqV {
|
||||
t.Error("AttentionKEqV should be true")
|
||||
}
|
||||
if cfg.NumGlobalKeyValueHeads != 2 {
|
||||
t.Errorf("NumGlobalKeyValueHeads = %d, want 2", cfg.NumGlobalKeyValueHeads)
|
||||
}
|
||||
if !cfg.EnableMoeBlock {
|
||||
t.Error("EnableMoeBlock should be true")
|
||||
}
|
||||
if cfg.NumExperts != 128 {
|
||||
t.Errorf("NumExperts = %d, want 128", cfg.NumExperts)
|
||||
}
|
||||
if cfg.TopKExperts != 8 {
|
||||
t.Errorf("TopKExperts = %d, want 8", cfg.TopKExperts)
|
||||
}
|
||||
if cfg.ExpertIntermediateSize != 704 {
|
||||
t.Errorf("ExpertIntermediateSize = %d, want 704", cfg.ExpertIntermediateSize)
|
||||
}
|
||||
if cfg.HiddenSizePerLayer != 0 {
|
||||
t.Errorf("HiddenSizePerLayer = %d, want 0 (no PLE)", cfg.HiddenSizePerLayer)
|
||||
}
|
||||
}
|
||||
|
||||
func TestParseTextConfig31B(t *testing.T) {
|
||||
skipIfNoMLX(t)
|
||||
data := []byte(`{
|
||||
"architectures": ["Gemma4ForConditionalGeneration"],
|
||||
"text_config": {
|
||||
"hidden_size": 5376,
|
||||
"num_hidden_layers": 60,
|
||||
"intermediate_size": 21504,
|
||||
"num_attention_heads": 32,
|
||||
"num_key_value_heads": 16,
|
||||
"num_global_key_value_heads": 4,
|
||||
"head_dim": 256,
|
||||
"global_head_dim": 512,
|
||||
"vocab_size": 262144,
|
||||
"rms_norm_eps": 1e-6,
|
||||
"max_position_embeddings": 131072,
|
||||
"sliding_window": 1024,
|
||||
"final_logit_softcapping": 30.0,
|
||||
"use_double_wide_mlp": false,
|
||||
"num_kv_shared_layers": 0,
|
||||
"hidden_size_per_layer_input": null,
|
||||
"attention_k_eq_v": true,
|
||||
"tie_word_embeddings": true,
|
||||
"layer_types": [
|
||||
"sliding_attention","sliding_attention","sliding_attention","sliding_attention","sliding_attention","full_attention",
|
||||
"sliding_attention","sliding_attention","sliding_attention","sliding_attention","sliding_attention","full_attention",
|
||||
"sliding_attention","sliding_attention","sliding_attention","sliding_attention","sliding_attention","full_attention",
|
||||
"sliding_attention","sliding_attention","sliding_attention","sliding_attention","sliding_attention","full_attention",
|
||||
"sliding_attention","sliding_attention","sliding_attention","sliding_attention","sliding_attention","full_attention",
|
||||
"sliding_attention","sliding_attention","sliding_attention","sliding_attention","sliding_attention","full_attention",
|
||||
"sliding_attention","sliding_attention","sliding_attention","sliding_attention","sliding_attention","full_attention",
|
||||
"sliding_attention","sliding_attention","sliding_attention","sliding_attention","sliding_attention","full_attention",
|
||||
"sliding_attention","sliding_attention","sliding_attention","sliding_attention","sliding_attention","full_attention",
|
||||
"sliding_attention","sliding_attention","sliding_attention","sliding_attention","sliding_attention","full_attention"
|
||||
],
|
||||
"rope_parameters": {
|
||||
"full_attention": {
|
||||
"partial_rotary_factor": 0.25,
|
||||
"rope_theta": 1000000.0,
|
||||
"rope_type": "proportional"
|
||||
},
|
||||
"sliding_attention": {
|
||||
"rope_theta": 10000.0,
|
||||
"rope_type": "default"
|
||||
}
|
||||
}
|
||||
}
|
||||
}`)
|
||||
|
||||
cfg, err := parseTextConfig(data)
|
||||
if err != nil {
|
||||
t.Fatalf("parseTextConfig failed: %v", err)
|
||||
}
|
||||
|
||||
if cfg.HiddenSize != 5376 {
|
||||
t.Errorf("HiddenSize = %d, want 5376", cfg.HiddenSize)
|
||||
}
|
||||
if cfg.NumHiddenLayers != 60 {
|
||||
t.Errorf("NumHiddenLayers = %d, want 60", cfg.NumHiddenLayers)
|
||||
}
|
||||
if !cfg.AttentionKEqV {
|
||||
t.Error("AttentionKEqV should be true")
|
||||
}
|
||||
if cfg.NumGlobalKeyValueHeads != 4 {
|
||||
t.Errorf("NumGlobalKeyValueHeads = %d, want 4", cfg.NumGlobalKeyValueHeads)
|
||||
}
|
||||
if cfg.NumKeyValueHeads != 16 {
|
||||
t.Errorf("NumKeyValueHeads = %d, want 16", cfg.NumKeyValueHeads)
|
||||
}
|
||||
if cfg.NumKVSharedLayers != 0 {
|
||||
t.Errorf("NumKVSharedLayers = %d, want 0", cfg.NumKVSharedLayers)
|
||||
}
|
||||
if cfg.HiddenSizePerLayer != 0 {
|
||||
t.Errorf("HiddenSizePerLayer = %d, want 0 (no PLE)", cfg.HiddenSizePerLayer)
|
||||
}
|
||||
if cfg.SlidingWindow != 1024 {
|
||||
t.Errorf("SlidingWindow = %d, want 1024", cfg.SlidingWindow)
|
||||
}
|
||||
|
||||
// KV sharing should be empty (no shared layers).
|
||||
if len(cfg.KVShareMap) != 0 {
|
||||
t.Errorf("KVShareMap should be empty, got %d entries", len(cfg.KVShareMap))
|
||||
}
|
||||
|
||||
// Layer types: pattern is 5 sliding + 1 full, repeating 10 times.
|
||||
if !isLayerSliding(0, &cfg) {
|
||||
t.Error("layer 0 should be sliding")
|
||||
}
|
||||
if isLayerSliding(5, &cfg) {
|
||||
t.Error("layer 5 should be full attention")
|
||||
}
|
||||
if !isLayerSliding(6, &cfg) {
|
||||
t.Error("layer 6 should be sliding")
|
||||
}
|
||||
if isLayerSliding(59, &cfg) {
|
||||
t.Error("layer 59 should be full attention")
|
||||
}
|
||||
}
|
||||
|
||||
func TestParseTextConfigE4B(t *testing.T) {
|
||||
skipIfNoMLX(t)
|
||||
data := []byte(`{
|
||||
"architectures": ["Gemma4ForConditionalGeneration"],
|
||||
"text_config": {
|
||||
"hidden_size": 2560,
|
||||
"num_hidden_layers": 42,
|
||||
"intermediate_size": 10240,
|
||||
"num_attention_heads": 8,
|
||||
"num_key_value_heads": 2,
|
||||
"head_dim": 256,
|
||||
"global_head_dim": 512,
|
||||
"vocab_size": 262144,
|
||||
"rms_norm_eps": 1e-6,
|
||||
"max_position_embeddings": 131072,
|
||||
"sliding_window": 512,
|
||||
"final_logit_softcapping": 30.0,
|
||||
"use_double_wide_mlp": false,
|
||||
"num_kv_shared_layers": 18,
|
||||
"hidden_size_per_layer_input": 256,
|
||||
"vocab_size_per_layer_input": 262144,
|
||||
"attention_k_eq_v": false,
|
||||
"enable_moe_block": false,
|
||||
"tie_word_embeddings": true,
|
||||
"layer_types": [
|
||||
"sliding_attention","sliding_attention","sliding_attention","sliding_attention","sliding_attention","full_attention",
|
||||
"sliding_attention","sliding_attention","sliding_attention","sliding_attention","sliding_attention","full_attention",
|
||||
"sliding_attention","sliding_attention","sliding_attention","sliding_attention","sliding_attention","full_attention",
|
||||
"sliding_attention","sliding_attention","sliding_attention","sliding_attention","sliding_attention","full_attention",
|
||||
"sliding_attention","sliding_attention","sliding_attention","sliding_attention","sliding_attention","full_attention",
|
||||
"sliding_attention","sliding_attention","sliding_attention","sliding_attention","sliding_attention","full_attention",
|
||||
"sliding_attention","sliding_attention","sliding_attention","sliding_attention","sliding_attention","full_attention"
|
||||
],
|
||||
"rope_parameters": {
|
||||
"full_attention": {
|
||||
"partial_rotary_factor": 0.25,
|
||||
"rope_theta": 1000000.0,
|
||||
"rope_type": "proportional"
|
||||
},
|
||||
"sliding_attention": {
|
||||
"rope_theta": 10000.0,
|
||||
"rope_type": "default"
|
||||
}
|
||||
}
|
||||
}
|
||||
}`)
|
||||
|
||||
cfg, err := parseTextConfig(data)
|
||||
if err != nil {
|
||||
t.Fatalf("parseTextConfig failed: %v", err)
|
||||
}
|
||||
|
||||
if cfg.HiddenSize != 2560 {
|
||||
t.Errorf("HiddenSize = %d, want 2560", cfg.HiddenSize)
|
||||
}
|
||||
if cfg.NumHiddenLayers != 42 {
|
||||
t.Errorf("NumHiddenLayers = %d, want 42", cfg.NumHiddenLayers)
|
||||
}
|
||||
if cfg.IntermediateSize != 10240 {
|
||||
t.Errorf("IntermediateSize = %d, want 10240", cfg.IntermediateSize)
|
||||
}
|
||||
if cfg.NumKeyValueHeads != 2 {
|
||||
t.Errorf("NumKeyValueHeads = %d, want 2", cfg.NumKeyValueHeads)
|
||||
}
|
||||
if cfg.UseDoubleWideMLP {
|
||||
t.Error("UseDoubleWideMLP should be false")
|
||||
}
|
||||
if cfg.NumKVSharedLayers != 18 {
|
||||
t.Errorf("NumKVSharedLayers = %d, want 18", cfg.NumKVSharedLayers)
|
||||
}
|
||||
if cfg.HiddenSizePerLayer != 256 {
|
||||
t.Errorf("HiddenSizePerLayer = %d, want 256 (has PLE)", cfg.HiddenSizePerLayer)
|
||||
}
|
||||
if cfg.AttentionKEqV {
|
||||
t.Error("AttentionKEqV should be false")
|
||||
}
|
||||
if cfg.EnableMoeBlock {
|
||||
t.Error("EnableMoeBlock should be false")
|
||||
}
|
||||
if cfg.SlidingWindow != 512 {
|
||||
t.Errorf("SlidingWindow = %d, want 512", cfg.SlidingWindow)
|
||||
}
|
||||
|
||||
// Layer types: pattern is 5 sliding + 1 full, repeating 7 times = 42 layers.
|
||||
if !isLayerSliding(0, &cfg) {
|
||||
t.Error("layer 0 should be sliding")
|
||||
}
|
||||
if isLayerSliding(5, &cfg) {
|
||||
t.Error("layer 5 should be full attention")
|
||||
}
|
||||
if !isLayerSliding(6, &cfg) {
|
||||
t.Error("layer 6 should be sliding")
|
||||
}
|
||||
if isLayerSliding(41, &cfg) {
|
||||
t.Error("layer 41 should be full attention")
|
||||
}
|
||||
|
||||
// KV sharing: first shared = 42 - 18 = 24.
|
||||
// Layer 24 is sliding, its donor should be the last non-shared sliding layer.
|
||||
// Non-shared layers: 0-23. Last sliding in 0-23 is layer 22 (23=full).
|
||||
if donor, ok := cfg.KVShareMap[24]; !ok {
|
||||
t.Error("layer 24 should be in KVShareMap")
|
||||
} else {
|
||||
t.Logf("layer 24 donor = %d", donor)
|
||||
}
|
||||
// Layer 29 is full_attention (5th full), donor should be the last non-shared full layer.
|
||||
// Non-shared full layers: 5, 11, 17, 23.
|
||||
if donor, ok := cfg.KVShareMap[29]; !ok || donor != 23 {
|
||||
t.Errorf("KVShareMap[29] = %d, ok=%v; want 23, true (full attn donor)", donor, ok)
|
||||
}
|
||||
// Layer 23 should NOT be shared (it's the last non-shared layer).
|
||||
if _, ok := cfg.KVShareMap[23]; ok {
|
||||
t.Error("layer 23 should not be in KVShareMap (non-shared)")
|
||||
}
|
||||
}
|
||||
|
||||
func TestLayerTypeDetection(t *testing.T) {
|
||||
cfg := &TextConfig{
|
||||
LayerTypes: []string{
|
||||
"sliding_attention", "sliding_attention", "sliding_attention", "sliding_attention", "full_attention",
|
||||
},
|
||||
}
|
||||
|
||||
if !isLayerSliding(0, cfg) {
|
||||
t.Error("layer 0 should be sliding")
|
||||
}
|
||||
if !isLayerSliding(3, cfg) {
|
||||
t.Error("layer 3 should be sliding")
|
||||
}
|
||||
if isLayerSliding(4, cfg) {
|
||||
t.Error("layer 4 should be full attention")
|
||||
}
|
||||
}
|
||||
|
||||
func TestNewCachesOmitsSharedKVLayers(t *testing.T) {
|
||||
m := &Model{
|
||||
Layers: []*DecoderLayer{
|
||||
{IsSliding: true, KVShareDonor: -1},
|
||||
{IsSliding: false, KVShareDonor: -1},
|
||||
{IsSliding: true, KVShareDonor: 0},
|
||||
{IsSliding: false, KVShareDonor: 1},
|
||||
},
|
||||
TextConfig: &TextConfig{SlidingWindow: 512},
|
||||
}
|
||||
|
||||
caches := m.NewCaches()
|
||||
if got, want := len(caches), 2; got != want {
|
||||
t.Fatalf("len(NewCaches()) = %d, want %d", got, want)
|
||||
}
|
||||
}
|
||||
|
||||
func TestNewCachesIncludesAllNonSharedLayers(t *testing.T) {
|
||||
m := &Model{
|
||||
Layers: []*DecoderLayer{
|
||||
{IsSliding: true, KVShareDonor: -1},
|
||||
{IsSliding: false, KVShareDonor: -1},
|
||||
{IsSliding: true, KVShareDonor: -1},
|
||||
},
|
||||
TextConfig: &TextConfig{SlidingWindow: 512},
|
||||
}
|
||||
|
||||
caches := m.NewCaches()
|
||||
if got, want := len(caches), len(m.Layers); got != want {
|
||||
t.Fatalf("len(NewCaches()) = %d, want %d", got, want)
|
||||
}
|
||||
}
|
||||
|
||||
func TestResolveWeightPrefix(t *testing.T) {
|
||||
if err := mlx.CheckInit(); err != nil {
|
||||
t.Skipf("MLX not available: %v", err)
|
||||
}
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
key string
|
||||
wantPfx string
|
||||
}{
|
||||
{"bare", "embed_tokens.weight", ""},
|
||||
{"language_model", "model.language_model.embed_tokens.weight", "model.language_model."},
|
||||
{"with_model", "model.embed_tokens.weight", "model."},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
dummy := mlx.FromValue(float32(1.0))
|
||||
mlx.Eval(dummy)
|
||||
tensors := map[string]*mlx.Array{tt.key: dummy}
|
||||
got := resolveWeightPrefix(tensors)
|
||||
if got != tt.wantPfx {
|
||||
t.Errorf("resolveWeightPrefix(%q) = %q, want %q", tt.key, got, tt.wantPfx)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func skipIfNoMLX(t *testing.T) {
|
||||
t.Helper()
|
||||
if err := mlx.CheckInit(); err != nil {
|
||||
t.Skipf("MLX not available: %v", err)
|
||||
}
|
||||
}
|
||||
@@ -148,9 +148,7 @@ type DenseMLP struct {
|
||||
|
||||
// Forward applies the SwiGLU MLP
|
||||
func (m *DenseMLP) Forward(x *mlx.Array) *mlx.Array {
|
||||
gate := mlx.SiLU(m.GateProj.Forward(x))
|
||||
up := m.UpProj.Forward(x)
|
||||
return m.DownProj.Forward(mlx.Mul(gate, up))
|
||||
return m.DownProj.Forward(mlx.SwiGLU(m.GateProj.Forward(x), m.UpProj.Forward(x)))
|
||||
}
|
||||
|
||||
// MoEGate implements the expert gating mechanism
|
||||
@@ -242,7 +240,7 @@ func (s *SwitchMLP) Forward(x *mlx.Array, indices *mlx.Array, cfg *Config) *mlx.
|
||||
up = mlx.GatherQMM(xFlat, s.UpWeightQ, s.UpScales, s.UpBiases,
|
||||
nil, idxFlat, true, s.UpGroupSize, s.UpBits, cfg.QuantMode, doSort)
|
||||
|
||||
hidden = mlx.Mul(mlx.SiLU(gate), up)
|
||||
hidden = mlx.SwiGLU(gate, up)
|
||||
|
||||
down = mlx.GatherQMM(hidden, s.DownWeightQ, s.DownScales, s.DownBiases,
|
||||
nil, idxFlat, true, s.DownGroupSize, s.DownBits, cfg.QuantMode, doSort)
|
||||
@@ -250,7 +248,7 @@ func (s *SwitchMLP) Forward(x *mlx.Array, indices *mlx.Array, cfg *Config) *mlx.
|
||||
gate = mlx.GatherMM(xFlat, mlx.Transpose(s.GateWeight, 0, 2, 1), nil, idxFlat, doSort)
|
||||
up = mlx.GatherMM(xFlat, mlx.Transpose(s.UpWeight, 0, 2, 1), nil, idxFlat, doSort)
|
||||
|
||||
hidden = mlx.Mul(mlx.SiLU(gate), up)
|
||||
hidden = mlx.SwiGLU(gate, up)
|
||||
|
||||
down = mlx.GatherMM(hidden, mlx.Transpose(s.DownWeight, 0, 2, 1), nil, idxFlat, doSort)
|
||||
}
|
||||
@@ -273,9 +271,7 @@ type SharedExperts struct {
|
||||
|
||||
// Forward applies the shared expert MLP
|
||||
func (s *SharedExperts) Forward(x *mlx.Array) *mlx.Array {
|
||||
gate := mlx.SiLU(s.GateProj.Forward(x))
|
||||
up := s.UpProj.Forward(x)
|
||||
return s.DownProj.Forward(mlx.Mul(gate, up))
|
||||
return s.DownProj.Forward(mlx.SwiGLU(s.GateProj.Forward(x), s.UpProj.Forward(x)))
|
||||
}
|
||||
|
||||
// MoE implements the full Mixture of Experts layer
|
||||
|
||||
@@ -314,5 +314,5 @@ func (a *Attention) Forward(x *mlx.Array, c cache.Cache, B, L int32, cfg *Config
|
||||
}
|
||||
|
||||
func (m *MLP) Forward(x *mlx.Array) *mlx.Array {
|
||||
return m.DownProj.Forward(mlx.Mul(mlx.SiLU(m.GateProj.Forward(x)), m.UpProj.Forward(x)))
|
||||
return m.DownProj.Forward(mlx.SwiGLU(m.GateProj.Forward(x), m.UpProj.Forward(x)))
|
||||
}
|
||||
|
||||
@@ -333,5 +333,5 @@ func (a *Attention) Forward(x *mlx.Array, c cache.Cache, B, L int32, cfg *Config
|
||||
}
|
||||
|
||||
func (m *MLP) Forward(x *mlx.Array) *mlx.Array {
|
||||
return m.DownProj.Forward(mlx.Mul(mlx.SiLU(m.GateProj.Forward(x)), m.UpProj.Forward(x)))
|
||||
return m.DownProj.Forward(mlx.SwiGLU(m.GateProj.Forward(x), m.UpProj.Forward(x)))
|
||||
}
|
||||
|
||||
@@ -1253,7 +1253,7 @@ func (g *GatedDeltaNet) Forward(x *mlx.Array, c cache.Cache, B, L int32, cfg *Co
|
||||
}
|
||||
|
||||
func (m *DenseMLP) Forward(x *mlx.Array, _ *Config) *mlx.Array {
|
||||
return m.DownProj.Forward(mlx.Mul(mlx.SiLU(m.GateProj.Forward(x)), m.UpProj.Forward(x)))
|
||||
return m.DownProj.Forward(mlx.SwiGLU(m.GateProj.Forward(x), m.UpProj.Forward(x)))
|
||||
}
|
||||
|
||||
func (s *SwitchMLP) Forward(x *mlx.Array, indices *mlx.Array, cfg *Config) *mlx.Array {
|
||||
@@ -1283,13 +1283,13 @@ func (s *SwitchMLP) Forward(x *mlx.Array, indices *mlx.Array, cfg *Config) *mlx.
|
||||
nil, idxFlat, true, s.GateGroupSize, s.GateBits, cfg.QuantMode, doSort)
|
||||
up = mlx.GatherQMM(xFlat, s.UpWeightQ, s.UpScales, s.UpBiases,
|
||||
nil, idxFlat, true, s.UpGroupSize, s.UpBits, cfg.QuantMode, doSort)
|
||||
hidden = mlx.Mul(mlx.SiLU(gate), up)
|
||||
hidden = mlx.SwiGLU(gate, up)
|
||||
down = mlx.GatherQMM(hidden, s.DownWeightQ, s.DownScales, s.DownBiases,
|
||||
nil, idxFlat, true, s.DownGroupSize, s.DownBits, cfg.QuantMode, doSort)
|
||||
} else {
|
||||
gate = mlx.GatherMM(xFlat, s.GateWeight, nil, idxFlat, doSort)
|
||||
up = mlx.GatherMM(xFlat, s.UpWeight, nil, idxFlat, doSort)
|
||||
hidden = mlx.Mul(mlx.SiLU(gate), up)
|
||||
hidden = mlx.SwiGLU(gate, up)
|
||||
down = mlx.GatherMM(hidden, s.DownWeight, nil, idxFlat, doSort)
|
||||
}
|
||||
|
||||
|
||||
Reference in New Issue
Block a user