Merge pull request #42 from ollama/jmorganca/gemma4-ggml-improvements

gemma4: fix MoE fused gate_up split and multiline tool-call arg parsing
This commit is contained in:
Daniel Hiltgen
2026-04-02 07:16:06 -07:00
committed by GitHub
12 changed files with 334 additions and 831 deletions

View File

@@ -568,395 +568,6 @@ func hasListedModelName(models []api.ListModelResponse, name string) bool {
return false
}
// getMaxAudioSeconds extracts the max audio duration from model info metadata.
// Returns 0 if the model doesn't report audio limits.
func getMaxAudioSeconds(info *api.ShowResponse) int {
if info == nil || info.ModelInfo == nil {
return 0
}
// Look for {arch}.max_audio_seconds in ModelInfo.
for k, v := range info.ModelInfo {
if strings.HasSuffix(k, ".max_audio_seconds") {
switch val := v.(type) {
case float64:
return int(val)
case int:
return val
}
}
}
return 0
}
// ANSI escape helpers for transcription display.
const (
)
// TranscribeHandler implements `ollama transcribe MODEL`.
//
// Two modes:
// - Interactive (tty on stdin): spacebar start/stop with >>> prompt,
// slash commands (/set, /show, /load, /bye, /?), word-wrapped output.
// - Non-interactive (pipe/redirect): reads audio from stdin or records
// until Ctrl+C, transcribes, writes word-wrapped text to stdout.
func TranscribeHandler(cmd *cobra.Command, args []string) error {
client, err := api.ClientFromEnvironment()
if err != nil {
return err
}
modelName := args[0]
interactive := term.IsTerminal(int(os.Stdin.Fd()))
// Pull model if needed and get model info.
showReq := &api.ShowRequest{Name: modelName}
info, err := client.Show(cmd.Context(), showReq)
if err != nil {
var se api.StatusError
if errors.As(err, &se) && se.StatusCode == http.StatusNotFound {
if err := PullHandler(cmd, []string{modelName}); err != nil {
return err
}
info, err = client.Show(cmd.Context(), showReq)
if err != nil {
return err
}
} else {
return err
}
}
language, _ := cmd.Flags().GetString("language")
opts := runOptions{
Model: modelName,
WordWrap: true,
Options: map[string]any{"temperature": 0},
Language: language,
}
transcribeAndDisplay := func(wav []byte) {
state := &displayResponseState{}
_, err := transcribeAudio(cmd, opts, wav, func(tok string) {
displayResponse(tok, opts.WordWrap, state)
})
if err != nil {
fmt.Fprintln(os.Stderr, "Transcription error:", err)
}
fmt.Println()
}
// --- Non-interactive mode ---
if !interactive {
audioData, err := io.ReadAll(os.Stdin)
if err != nil {
return fmt.Errorf("read stdin: %w", err)
}
if len(audioData) > 44 {
// Pipe with data (at least WAV header size): transcribe and output.
transcribeAndDisplay(audioData)
return nil
}
// Empty stdin (< /dev/null or echo "" |): record until Ctrl+C.
recorder, err := NewAudioRecorder()
if err != nil {
return fmt.Errorf("audio input unavailable: %w", err)
}
if maxSec := getMaxAudioSeconds(info); maxSec > 0 {
recorder.MaxChunkSeconds = maxSec - 2
}
if err := recorder.Start(); err != nil {
return fmt.Errorf("start recording: %w", err)
}
fmt.Fprintln(os.Stderr, "Recording... Press Ctrl+C to stop.")
sigCh := make(chan os.Signal, 1)
signal.Notify(sigCh, os.Interrupt)
<-sigCh
signal.Stop(sigCh)
recorder.Stop()
fmt.Fprintln(os.Stderr)
if wav := recorder.FlushWAV(); wav != nil {
transcribeAndDisplay(wav)
}
return nil
}
// --- Interactive mode ---
recorder, err := NewAudioRecorder()
if err != nil {
return fmt.Errorf("audio input unavailable: %w", err)
}
if maxSec := getMaxAudioSeconds(info); maxSec > 0 {
recorder.MaxChunkSeconds = maxSec - 2
}
scanner, err := readline.New(readline.Prompt{
Prompt: ">>> ",
Placeholder: "Press Space to record (/? for help)",
})
if err != nil {
return err
}
fmt.Print(readline.StartBracketedPaste)
defer fmt.Printf(readline.EndBracketedPaste)
usage := func() {
fmt.Fprintln(os.Stderr, "Available Commands:")
fmt.Fprintln(os.Stderr, " /set Set session variables")
fmt.Fprintln(os.Stderr, " /show Show model information")
fmt.Fprintln(os.Stderr, " /load <model> Load a different model")
fmt.Fprintln(os.Stderr, " /bye Exit")
fmt.Fprintln(os.Stderr, " /?, /help Help for a command")
fmt.Fprintln(os.Stderr, "")
fmt.Fprintln(os.Stderr, "Press Space to start/stop recording.")
fmt.Fprintln(os.Stderr, "")
}
usageSet := func() {
fmt.Fprintln(os.Stderr, "Available Commands:")
fmt.Fprintln(os.Stderr, " /set parameter ... Set a parameter")
fmt.Fprintln(os.Stderr, " /set wordwrap Enable wordwrap")
fmt.Fprintln(os.Stderr, " /set nowordwrap Disable wordwrap")
fmt.Fprintln(os.Stderr, " /set verbose Show LLM stats")
fmt.Fprintln(os.Stderr, " /set quiet Disable LLM stats")
fmt.Fprintln(os.Stderr, "")
}
// doTranscribeRecording is like doAudioRecording but polls TakeChunk()
// during recording to stream transcription of long recordings.
doTranscribeRecording := func() ([]byte, error) {
fmt.Print(">>> \033[90m◉ Press Space to record...\033[0m")
for {
r, err := scanner.ReadRaw()
if err != nil {
return nil, io.EOF
}
if r == 3 { // Ctrl+C
fmt.Print("\r\033[K")
fmt.Println("Use Ctrl + d or /bye to exit.")
return nil, nil
}
if r == 4 { // Ctrl+D
fmt.Println()
return nil, io.EOF
}
if r == ' ' {
fmt.Print("\r\033[K") // clear the prompt line
break
}
if r == '/' || (r >= 32 && r < 127) {
fmt.Print("\r\033[K")
return nil, errFallbackToText{prefill: string(r)}
}
}
if err := recorder.Start(); err != nil {
fmt.Println()
return nil, fmt.Errorf("start recording: %w", err)
}
// Poll for chunks in a background goroutine while recording.
chunkDone := make(chan struct{})
go func() {
ticker := time.NewTicker(500 * time.Millisecond)
defer ticker.Stop()
for {
select {
case <-chunkDone:
return
case <-ticker.C:
if wav := recorder.TakeChunk(); wav != nil {
transcribeAndDisplay(wav)
}
}
}
}()
// Wait for Space to stop, Ctrl+C to discard, Ctrl+D to exit.
for {
r, err := scanner.ReadRaw()
if err != nil {
close(chunkDone)
recorder.Stop()
return nil, io.EOF
}
if r == 4 { // Ctrl+D
close(chunkDone)
recorder.Stop()
fmt.Println()
return nil, io.EOF
}
if r == 3 { // Ctrl+C
close(chunkDone)
recorder.Stop()
return nil, nil
}
if r == ' ' { // Space: stop recording
close(chunkDone)
recorder.Stop()
return recorder.FlushWAV(), nil
}
// Ignore other keys while recording.
}
}
for {
wav, err := doTranscribeRecording()
if err != nil {
var fallback errFallbackToText
if errors.As(err, &fallback) {
// User typed text instead of pressing Space.
line := fallback.prefill
if line == "/" {
// Need the rest of the command — read via readline.
scanner.Prefill = "/"
fullLine, err := scanner.Readline()
if errors.Is(err, io.EOF) {
fmt.Println()
return nil
}
if err != nil {
return err
}
line = fullLine
}
line = strings.TrimSpace(line)
switch {
case line == "/?" || line == "/help":
usage()
case strings.HasPrefix(line, "/? "):
arg := strings.TrimSpace(line[3:])
switch arg {
case "set":
usageSet()
default:
usage()
}
case strings.HasPrefix(line, "/set"):
args := strings.Fields(line)
if len(args) == 1 {
usageSet()
continue
}
switch args[1] {
case "wordwrap":
opts.WordWrap = true
fmt.Println("Set 'wordwrap' mode.")
case "nowordwrap":
opts.WordWrap = false
fmt.Println("Set 'nowordwrap' mode.")
case "verbose":
cmd.Flags().Set("verbose", "true")
fmt.Println("Set 'verbose' mode.")
case "quiet":
cmd.Flags().Set("verbose", "false")
fmt.Println("Set 'quiet' mode.")
case "parameter":
if len(args) < 4 {
fmt.Println("Usage: /set parameter <name> <value>")
continue
}
opts.Options[args[2]] = args[3]
fmt.Printf("Set parameter '%s' to '%s'\n", args[2], args[3])
default:
fmt.Printf("Unknown option: %s\n", args[1])
usageSet()
}
case strings.HasPrefix(line, "/show"):
args := strings.Fields(line)
if len(args) == 1 {
args = append(args, "info")
}
showReq := &api.ShowRequest{Name: opts.Model}
resp, err := client.Show(cmd.Context(), showReq)
if err != nil {
fmt.Println("Error:", err)
continue
}
switch args[1] {
case "info":
if err := showInfo(resp, false, os.Stdout); err != nil {
fmt.Println("Error:", err)
}
case "license":
fmt.Println(resp.License)
case "parameters":
fmt.Println(resp.Parameters)
case "system":
fmt.Println(resp.System)
default:
fmt.Printf("Unknown show command: %s\n", args[1])
}
case strings.HasPrefix(line, "/load"):
args := strings.Fields(line)
if len(args) != 2 {
fmt.Println("Usage: /load <modelname>")
continue
}
newModel := args[1]
fmt.Printf("Loading model '%s'\n", newModel)
showReq := &api.ShowRequest{Name: newModel}
newInfo, err := client.Show(cmd.Context(), showReq)
if err != nil {
var se api.StatusError
if errors.As(err, &se) && se.StatusCode == http.StatusNotFound {
fmt.Printf("error: model '%s' not found\n", newModel)
} else {
fmt.Println("Error:", err)
}
continue
}
// Verify audio capability.
hasAudio := false
for _, cap := range newInfo.Capabilities {
if cap == "audio" {
hasAudio = true
break
}
}
if !hasAudio {
fmt.Printf("error: model '%s' does not support audio input\n", newModel)
continue
}
opts.Model = newModel
if maxSec := getMaxAudioSeconds(newInfo); maxSec > 0 {
recorder.MaxChunkSeconds = maxSec - 2
}
case line == "/exit" || line == "/bye":
fmt.Println()
return nil
case line != "":
fmt.Printf("Unknown command: %s (type /? for help)\n", line)
}
continue
}
if errors.Is(err, io.EOF) {
fmt.Println()
return nil
}
fmt.Fprintf(os.Stderr, "Recording error: %v\n", err)
continue
}
if wav == nil {
// Ctrl+C during recording — discard and retry.
continue
}
// Transcribe the recording.
transcribeAndDisplay(wav)
}
}
func RunHandler(cmd *cobra.Command, args []string) error {
interactive := true
@@ -1084,7 +695,8 @@ func RunHandler(cmd *cobra.Command, args []string) error {
return err
}
opts.MultiModal = slices.Contains(info.Capabilities, model.CapabilityVision)
audioCapable := slices.Contains(info.Capabilities, model.CapabilityAudio)
opts.MultiModal = slices.Contains(info.Capabilities, model.CapabilityVision) || audioCapable
// TODO: remove the projector info and vision info checks below,
// these are left in for backwards compatibility with older servers
@@ -1101,19 +713,6 @@ func RunHandler(cmd *cobra.Command, args []string) error {
opts.ParentModel = info.Details.ParentModel
opts.AudioCapable = slices.Contains(info.Capabilities, model.CapabilityAudio)
audioin, _ := cmd.Flags().GetBool("audioin")
if audioin {
if !opts.AudioCapable {
fmt.Fprintf(os.Stderr, "Warning: audio input disabled — %s does not support audio\n", opts.Model)
} else {
opts.AudioInput = true
opts.MultiModal = true // audio uses the multimodal pipeline
opts.MaxAudioSeconds = getMaxAudioSeconds(info)
}
}
// Check if this is an embedding model
isEmbeddingModel := slices.Contains(info.Capabilities, model.CapabilityEmbedding)
@@ -1837,12 +1436,8 @@ type runOptions struct {
System string
Images []api.ImageData
Options map[string]any
MultiModal bool
AudioInput bool
AudioCapable bool // model supports audio input
MaxAudioSeconds int // from model metadata; 0 = use default
Language string // language hint for transcription
KeepAlive *api.Duration
MultiModal bool
KeepAlive *api.Duration
Think *api.ThinkValue
HideThinking bool
ShowConnect bool
@@ -2568,7 +2163,6 @@ func NewCLI() *cobra.Command {
runCmd.Flags().Bool("experimental", false, "Enable experimental agent loop with tools")
runCmd.Flags().Bool("experimental-yolo", false, "Skip all tool approval prompts (use with caution)")
runCmd.Flags().Bool("experimental-websearch", false, "Enable web search tool in experimental mode")
runCmd.Flags().Bool("audioin", false, "Enable audio input via microphone (press Space to record)")
// Image generation flags (width, height, steps, seed, etc.)
imagegen.RegisterFlags(runCmd)
@@ -2576,16 +2170,6 @@ func NewCLI() *cobra.Command {
runCmd.Flags().Bool("imagegen", false, "Use the imagegen runner for LLM inference")
runCmd.Flags().MarkHidden("imagegen")
transcribeCmd := &cobra.Command{
Use: "transcribe MODEL",
Short: "Transcribe audio to text using microphone",
Long: "Record audio via microphone and transcribe to text.\nPress Space to start/stop recording. Ctrl+D to exit.",
Args: cobra.ExactArgs(1),
PreRunE: checkServerHeartbeat,
RunE: TranscribeHandler,
}
transcribeCmd.Flags().String("language", "", "Language hint (e.g. en, es, fr)")
stopCmd := &cobra.Command{
Use: "stop MODEL",
Short: "Stop a running model",
@@ -2706,7 +2290,6 @@ func NewCLI() *cobra.Command {
createCmd,
showCmd,
runCmd,
transcribeCmd,
stopCmd,
pullCmd,
pushCmd,
@@ -2750,7 +2333,6 @@ func NewCLI() *cobra.Command {
createCmd,
showCmd,
runCmd,
transcribeCmd,
stopCmd,
pullCmd,
pushCmd,

View File

@@ -12,7 +12,6 @@ import (
"regexp"
"slices"
"strings"
"time"
"github.com/spf13/cobra"
@@ -24,143 +23,6 @@ import (
"github.com/ollama/ollama/types/model"
)
// errFallbackToText is returned when the user types a non-space key in audio mode,
// indicating we should fall through to the normal text input.
type errFallbackToText struct {
prefill string
}
func (e errFallbackToText) Error() string { return "fallback to text" }
// doAudioRecording handles the spacebar-driven recording flow.
// Returns WAV bytes on success, nil to retry, or an error.
func doAudioRecording(scanner *readline.Instance, recorder *AudioRecorder) ([]byte, error) {
fmt.Print(">>> \033[90m◉ Press Space to record...\033[0m")
// Wait for spacebar to start.
for {
r, err := scanner.ReadRaw()
if err != nil {
return nil, io.EOF
}
if r == 3 { // Ctrl+C
fmt.Print("\r\033[K")
fmt.Println("Use Ctrl + d or /bye to exit.")
return nil, nil
}
if r == 4 { // Ctrl+D
fmt.Println()
return nil, io.EOF
}
if r == ' ' {
break
}
// User typed a regular character — fall back to text input with this char.
if r == '/' || (r >= 32 && r < 127) {
fmt.Print("\r\033[K") // clear the "Press Space" line
return nil, errFallbackToText{prefill: string(r)}
}
}
// Start recording.
if err := recorder.Start(); err != nil {
fmt.Println()
return nil, fmt.Errorf("start recording: %w", err)
}
// Show recording indicator with elapsed time.
done := make(chan struct{})
go func() {
ticker := time.NewTicker(100 * time.Millisecond)
defer ticker.Stop()
for {
select {
case <-done:
return
case <-ticker.C:
d := recorder.Duration()
fmt.Printf("\r>>> \033[91m◈ Recording... %.1fs\033[0m ", d.Seconds())
}
}
}()
// Wait for spacebar to stop.
for {
r, err := scanner.ReadRaw()
if err != nil {
close(done)
recorder.Stop()
return nil, io.EOF
}
if r == ' ' || r == 3 { // Space or Ctrl+C
break
}
}
close(done)
dur, _ := recorder.Stop()
fmt.Printf("\r>>> \033[90m◇ Recorded %.1fs\033[0m \n", dur.Seconds())
// Encode to WAV.
wav, err := recorder.WAV()
if err != nil {
return nil, err
}
return wav, nil
}
// tokenCallback is called for each streamed token. Return non-nil error to abort.
type tokenCallback func(token string)
// streamChat sends a chat request and streams tokens to the callback.
// Returns the full accumulated text.
func streamChat(cmd *cobra.Command, model string, messages []api.Message, onToken tokenCallback) (string, error) {
client, err := api.ClientFromEnvironment()
if err != nil {
return "", err
}
noThink := &api.ThinkValue{Value: false}
stream := true
req := &api.ChatRequest{
Model: model,
Messages: messages,
Stream: &stream,
Think: noThink,
Options: map[string]any{"temperature": 0},
}
var result strings.Builder
fn := func(response api.ChatResponse) error {
tok := response.Message.Content
result.WriteString(tok)
if onToken != nil {
onToken(tok)
}
return nil
}
if err := client.Chat(cmd.Context(), req, fn); err != nil {
return "", err
}
return strings.TrimSpace(result.String()), nil
}
// transcribeAudio sends audio to the model for transcription.
// onToken is called for each streamed token (may be nil for silent operation).
func transcribeAudio(cmd *cobra.Command, opts runOptions, audioData []byte, onToken tokenCallback) (string, error) {
systemPrompt := "Transcribe the following audio exactly as spoken. Output only the transcription text, nothing else."
if opts.Language != "" {
systemPrompt += " The audio is in " + opts.Language + "."
}
return streamChat(cmd, opts.Model, []api.Message{
{Role: "system", Content: systemPrompt},
{Role: "user", Content: "Transcribe this audio.", Images: []api.ImageData{audioData}},
}, onToken)
}
type MultilineState int
const (
@@ -177,11 +39,6 @@ func generateInteractive(cmd *cobra.Command, opts runOptions) error {
fmt.Fprintln(os.Stderr, " /load <model> Load a session or model")
fmt.Fprintln(os.Stderr, " /save <model> Save your current session")
fmt.Fprintln(os.Stderr, " /clear Clear session context")
if opts.AudioCapable {
fmt.Fprintln(os.Stderr, " /audio Toggle voice input mode")
} else {
fmt.Fprintln(os.Stderr, " /audio (not supported by current model)")
}
fmt.Fprintln(os.Stderr, " /bye Exit")
fmt.Fprintln(os.Stderr, " /?, /help Help for a command")
fmt.Fprintln(os.Stderr, " /? shortcuts Help for keyboard shortcuts")
@@ -190,7 +47,7 @@ func generateInteractive(cmd *cobra.Command, opts runOptions) error {
fmt.Fprintln(os.Stderr, "Use \"\"\" to begin a multi-line message.")
if opts.MultiModal {
fmt.Fprintf(os.Stderr, "Use %s to include .jpg, .png, or .webp images.\n", filepath.FromSlash("/path/to/file"))
fmt.Fprintf(os.Stderr, "Use %s to include .jpg, .png, .webp images, or .wav audio files.\n", filepath.FromSlash("/path/to/file"))
}
fmt.Fprintln(os.Stderr, "")
@@ -279,66 +136,7 @@ func generateInteractive(cmd *cobra.Command, opts runOptions) error {
var multiline MultilineState
var thinkExplicitlySet bool = opts.Think != nil
audioMode := opts.AudioInput
var recorder *AudioRecorder
if audioMode {
var err error
recorder, err = NewAudioRecorder()
if err != nil {
fmt.Fprintf(os.Stderr, "Audio input unavailable: %v\n", err)
audioMode = false
} else {
if opts.MaxAudioSeconds > 0 {
recorder.MaxChunkSeconds = opts.MaxAudioSeconds - 2 // 2s headroom
}
fmt.Fprintln(os.Stderr, "Voice input enabled. Press Space to record, Space again to send.")
}
}
for {
// Audio recording mode: wait for spacebar instead of text input.
if audioMode && recorder != nil {
audioData, err := doAudioRecording(scanner, recorder)
if err != nil {
if err == io.EOF {
fmt.Println()
return nil
}
// User typed a regular key — fall through to normal readline.
if fb, ok := err.(errFallbackToText); ok {
scanner.Prefill = fb.prefill
goto textInput
}
fmt.Fprintf(os.Stderr, "Audio error: %v\n", err)
continue
}
if audioData == nil {
continue
}
// Send audio as the user's input — the model hears and responds.
newMessage := api.Message{
Role: "user",
Images: []api.ImageData{audioData},
}
opts.Messages = append(opts.Messages, newMessage)
assistant, err := chat(cmd, opts)
if err != nil {
if strings.Contains(err.Error(), "does not support thinking") ||
strings.Contains(err.Error(), "invalid think value") {
fmt.Printf("error: %v\n", err)
continue
}
return err
}
if assistant != nil {
opts.Messages = append(opts.Messages, *assistant)
}
continue
}
textInput:
line, err := scanner.Readline()
switch {
case errors.Is(err, io.EOF):
@@ -676,29 +474,6 @@ func generateInteractive(cmd *cobra.Command, opts runOptions) error {
} else {
usage()
}
case line == "/audio":
if !opts.AudioCapable {
fmt.Fprintf(os.Stderr, "Audio input not supported by %s\n", opts.Model)
continue
}
if audioMode {
audioMode = false
fmt.Fprintln(os.Stderr, "Voice input disabled.")
} else {
audioMode = true
if recorder == nil {
var recErr error
recorder, recErr = NewAudioRecorder()
if recErr != nil {
fmt.Fprintf(os.Stderr, "Audio input unavailable: %v\n", recErr)
audioMode = false
continue
}
}
opts.MultiModal = true
fmt.Fprintln(os.Stderr, "Voice input enabled. Press Space to record, Space again to send.")
}
continue
case strings.HasPrefix(line, "/exit"), strings.HasPrefix(line, "/bye"):
return nil
case strings.HasPrefix(line, "/"):

View File

@@ -84,3 +84,33 @@ func TestExtractFileDataRemovesQuotedFilepath(t *testing.T) {
assert.Len(t, imgs, 1)
assert.Equal(t, cleaned, "before after")
}
func TestExtractFileDataWAV(t *testing.T) {
dir := t.TempDir()
fp := filepath.Join(dir, "sample.wav")
data := make([]byte, 600)
copy(data[:44], []byte{
'R', 'I', 'F', 'F',
0x58, 0x02, 0x00, 0x00, // file size - 8
'W', 'A', 'V', 'E',
'f', 'm', 't', ' ',
0x10, 0x00, 0x00, 0x00, // fmt chunk size
0x01, 0x00, // PCM
0x01, 0x00, // mono
0x80, 0x3e, 0x00, 0x00, // 16000 Hz
0x00, 0x7d, 0x00, 0x00, // byte rate
0x02, 0x00, // block align
0x10, 0x00, // 16-bit
'd', 'a', 't', 'a',
0x34, 0x02, 0x00, 0x00, // data size
})
if err := os.WriteFile(fp, data, 0o600); err != nil {
t.Fatalf("failed to write test audio: %v", err)
}
input := "before " + fp + " after"
cleaned, imgs, err := extractFileData(input)
assert.NoError(t, err)
assert.Len(t, imgs, 1)
assert.Equal(t, "before after", cleaned)
}

View File

@@ -9,8 +9,6 @@ import (
"strings"
"github.com/ollama/ollama/fs/ggml"
"github.com/pdevine/tensor"
"github.com/pdevine/tensor/native"
)
type gemma4Model struct {
@@ -237,15 +235,6 @@ func (p *gemma4Model) Tensors(ts []Tensor) []*ggml.Tensor {
continue
}
// Skip vision clamp scalars — packed into v.clamp_data below.
// Audio clamp scalars are kept as individual tensors (matching published GGUF).
isVisionClamp := (strings.Contains(name, "input_min") || strings.Contains(name, "input_max") ||
strings.Contains(name, "output_min") || strings.Contains(name, "output_max")) &&
(strings.Contains(name, "vision_tower") || strings.Contains(name, "embed_vision"))
if isVisionClamp {
continue
}
// Vision tensor renaming: match published mmproj GGUF names
if strings.HasPrefix(name, "v.blk.") {
name = strings.Replace(name, ".attn_norm.", ".ln1.", 1)
@@ -282,26 +271,6 @@ func (p *gemma4Model) Tensors(ts []Tensor) []*ggml.Tensor {
shape = []uint64{shape[0], shape[2]}
}
// Fused MoE gate_up_proj: split [experts, 2*intermediate, hidden] into separate gate and up.
// No transpose needed — the split shape [experts, intermediate, hidden] already matches
// the GGUF layout after the framework's dimension reversal (ne[0]=hidden matches input).
if strings.Contains(name, "ffn_gate_exps.weight") && len(shape) == 3 {
halfDim := int(shape[1]) / 2
newShape := slices.Clone(shape)
newShape[1] = newShape[1] / 2
for i, ggufName := range []string{"ffn_gate_exps.weight", "ffn_up_exps.weight"} {
tt := t.Clone()
tt.SetRepacker(p.sliceExperts(tensor.S(i*halfDim, (i+1)*halfDim)))
out = append(out, &ggml.Tensor{
Name: strings.ReplaceAll(name, "ffn_gate_exps.weight", ggufName),
Kind: tt.Kind(),
Shape: slices.Clone(newShape),
WriterTo: tt,
})
}
continue
}
// MoE expert weights: no transpose needed. Safetensors stores [experts, out, in]
// which the framework reverses to GGUF ne=[in, out, experts], matching ggml_mul_mat_id.
// (transposeExperts was incorrectly swapping dims — removed)
@@ -454,30 +423,6 @@ func (*gemma4Model) reshapePatchEmbed(_ string, data []float32, shape []uint64)
return result, nil
}
// sliceExperts returns a repacker that slices dim 1 of a 3D expert tensor.
// Used for splitting fused gate_up_proj into separate gate and up tensors.
func (*gemma4Model) sliceExperts(dim1Slice tensor.Slice) Repacker {
return func(_ string, data []float32, shape []uint64) ([]float32, error) {
dims := make([]int, len(shape))
for i, d := range shape {
dims[i] = int(d)
}
var t tensor.Tensor = tensor.New(tensor.WithShape(dims...), tensor.WithBacking(data))
t, err := t.Slice(nil, dim1Slice)
if err != nil {
return nil, err
}
t = tensor.Materialize(t)
if err := t.Reshape(t.Shape().TotalSize()); err != nil {
return nil, err
}
return native.VectorF32(t.(*tensor.Dense))
}
}
// softplusRepacker applies softplus (ln(1 + exp(x))) to tensor data.
// Used for per_dim_scale tensors which the published GGUF stores pre-activated.
func softplusRepacker(_ string, data []float32, shape []uint64) ([]float32, error) {
@@ -501,34 +446,35 @@ func (p *gemma4Model) Replacements() []string {
".linear.bias", ".bias",
// Audio SSCP (Sub-Sample Convolution Projection)
"model.audio_tower.subsample_conv_projection.layer0.conv", "a.conv1d.0",
"model.audio_tower.subsample_conv_projection.layer0.norm", "a.conv1d.0.norm",
"model.audio_tower.subsample_conv_projection.layer1.conv", "a.conv1d.1",
"model.audio_tower.subsample_conv_projection.layer1.norm", "a.conv1d.1.norm",
"model.audio_tower.subsample_conv_projection.conv_0.conv", "a.conv1d.0",
"model.audio_tower.subsample_conv_projection.conv_0.norm", "a.conv1d.0.norm",
"model.audio_tower.subsample_conv_projection.conv_1.conv", "a.conv1d.1",
"model.audio_tower.subsample_conv_projection.conv_1.norm", "a.conv1d.1.norm",
"model.audio_tower.subsample_conv_projection.input_proj_linear", "a.pre_encode.out",
// Audio conformer blocks
"model.audio_tower.layers", "a.blk",
"model.audio_tower.conformer", "a.blk",
// Audio conformer attention
"self_attn.relative_k_proj", "linear_pos",
"self_attn.per_dim_scale", "per_dim_scale",
"self_attn.q_proj", "attn_q",
"self_attn.k_proj", "attn_k",
"self_attn.v_proj", "attn_v",
"norm_post_attn", "ln2",
"norm_pre_attn", "ln1",
"self_attn.post", "attn_out",
"attention.attn.relative_position_embedding.pos_proj", "linear_pos",
"attention.attn.per_dim_key_scale", "per_dim_k_scale",
"attention.attn.per_dim_scale", "per_dim_scale",
"attention.attn.q_proj", "attn_q",
"attention.attn.k_proj", "attn_k",
"attention.attn.v_proj", "attn_v",
"attention.pre_attn_norm", "ln1",
"attention.post_norm", "ln2",
"attention.post", "attn_out",
// Audio conformer feedforward
"feed_forward1.pre_layer_norm", "ffn_norm",
"feed_forward1.post_layer_norm", "ffn_post_norm",
"feed_forward1.ffw_layer_1", "ffn_up",
"feed_forward1.ffw_layer_2", "ffn_down",
"feed_forward2.pre_layer_norm", "ffn_norm_1",
"feed_forward2.post_layer_norm", "ffn_post_norm_1",
"feed_forward2.ffw_layer_1", "ffn_up_1",
"feed_forward2.ffw_layer_2", "ffn_down_1",
"ffw_layer_start.pre_layer_norm", "ffn_norm",
"ffw_layer_start.post_layer_norm", "ffn_post_norm",
"ffw_layer_start.ffw_layer_1", "ffn_up",
"ffw_layer_start.ffw_layer_2", "ffn_down",
"ffw_layer_end.pre_layer_norm", "ffn_norm_1",
"ffw_layer_end.post_layer_norm", "ffn_post_norm_1",
"ffw_layer_end.ffw_layer_1", "ffn_up_1",
"ffw_layer_end.ffw_layer_2", "ffn_down_1",
// Audio conformer lightweight conv1d
"lconv1d.depthwise_conv1d", "conv_dw",
@@ -548,6 +494,8 @@ func (p *gemma4Model) Replacements() []string {
"model.vision_tower.encoder.layers", "v.blk",
"model.vision_tower.patch_embedder.input_proj", "v.patch_embd",
"model.vision_tower.patch_embedder.position_embedding_table", "v.position_embd.weight",
"model.vision_tower.std_bias", "v.std_bias",
"model.vision_tower.std_scale", "v.std_scale",
// Vision multimodal projector
"model.embed_vision.embedding_projection", "mm.input_projection",
@@ -588,9 +536,19 @@ func (p *gemma4Model) Replacements() []string {
// MoE
"router.proj", "ffn_gate_inp",
"router.scale", "ffn_gate_inp.scale",
"router.per_expert_scale", "ffn_gate_inp.per_expert_scale",
"experts.gate_up_proj", "ffn_gate_exps.weight",
"router.per_expert_scale.weight", "ffn_down_exps.scale",
"router.per_expert_scale", "ffn_down_exps.scale",
"experts.gate_up_proj.weight", "ffn_gate_up_exps.weight",
"experts.gate_up_proj", "ffn_gate_up_exps.weight",
"experts.down_proj.weight", "ffn_down_exps.weight",
"experts.down_proj", "ffn_down_exps.weight",
"moe.gate_proj", "ffn_gate_exps.weight",
"moe.up_proj", "ffn_up_exps.weight",
"moe.gate_up_proj.weight", "ffn_gate_up_exps.weight",
"moe.gate_up_proj", "ffn_gate_up_exps.weight",
"moe.down_proj", "ffn_down_exps.weight",
"moe.per_expert_scale.weight", "ffn_down_exps.scale",
"moe.per_expert_scale", "ffn_down_exps.scale",
// Layer scalar
"layer_scalar", "layer_output_scale.weight",

View File

@@ -194,6 +194,16 @@ func TestGemma4AudioReplacements(t *testing.T) {
"model.vision_tower.encoder.layers.0.self_attn.q_proj.linear.weight",
"v.blk.0.attn_q.weight",
},
{
"vision std bias",
"model.vision_tower.std_bias",
"v.std_bias",
},
{
"vision std scale",
"model.vision_tower.std_scale",
"v.std_scale",
},
{
"vision patch embd",
"model.vision_tower.patch_embedder.input_proj.weight",
@@ -216,6 +226,31 @@ func TestGemma4AudioReplacements(t *testing.T) {
"model.language_model.embed_tokens.weight",
"token_embd.weight",
},
{
"text moe gate up fused",
"model.language_model.layers.0.experts.gate_up_proj",
"blk.0.ffn_gate_up_exps.weight",
},
{
"text moe down",
"model.language_model.layers.0.experts.down_proj",
"blk.0.ffn_down_exps.weight",
},
{
"text moe down with weight suffix",
"model.language_model.layers.0.experts.down_proj.weight",
"blk.0.ffn_down_exps.weight",
},
{
"text moe per expert scale",
"model.language_model.layers.0.router.per_expert_scale",
"blk.0.ffn_down_exps.scale",
},
{
"text moe per expert scale with weight suffix",
"model.language_model.layers.0.router.per_expert_scale.weight",
"blk.0.ffn_down_exps.scale",
},
}
for _, tt := range tests {

View File

@@ -47,6 +47,12 @@ type Validator interface {
Validate() error
}
// PostLoader is an optional interface that models can implement to run
// initialization steps after backend weights have been loaded.
type PostLoader interface {
PostLoad() error
}
// MultimodalProcessor must be implemented by multimodal models.
type MultimodalProcessor interface {
// EncodeMultimodal processes a single input (such as an image) and

View File

@@ -26,7 +26,7 @@ type Model struct {
*TextModel
*AudioModel `gguf:"a"`
*MultiModalProjector `gguf:"mm"`
*MultiModalProjector `gguf:"mm"`
*AudioMultimodalProjector `gguf:"mm.a"`
ImageProcessor
@@ -97,22 +97,25 @@ func New(c fs.Config) (model.Model, error) {
slog.Info("gemma4: token IDs", "image", imageTokenID, "image_end", imageEndTokenID, "audio", audioTokenID, "audio_end", audioEndTokenID)
m := Model{
Tokenizer: t,
TextModel: newTextModel(c),
VisionModel: newVisionModel(c),
AudioModel: newAudioModel(c),
MultiModalProjector: &MultiModalProjector{},
Tokenizer: t,
TextModel: newTextModel(c),
VisionModel: newVisionModel(c),
AudioModel: newAudioModel(c),
MultiModalProjector: &MultiModalProjector{},
AudioMultimodalProjector: &AudioMultimodalProjector{},
ImageProcessor: newImageProcessor(c),
imageTokenID: imageTokenID,
imageEndTokenID: imageEndTokenID,
audioTokenID: audioTokenID,
audioEndTokenID: audioEndTokenID,
audioOpts: newAudioModelOptions(c),
ImageProcessor: newImageProcessor(c),
imageTokenID: imageTokenID,
imageEndTokenID: imageEndTokenID,
audioTokenID: audioTokenID,
audioEndTokenID: audioEndTokenID,
audioOpts: newAudioModelOptions(c),
}
slidingWindowLen := int32(c.Uint("attention.sliding_window"))
m.Cache = kvcache.NewWrapperCache(kvcache.NewSWACache(slidingWindowLen, m.Shift), kvcache.NewCausalCache(m.Shift))
m.Cache = kvcache.NewWrapperCache(
kvcache.NewSWAMemCache(slidingWindowLen, 4096, m.Shift),
kvcache.NewCausalCache(m.Shift),
)
return &m, nil
}
@@ -127,9 +130,6 @@ func (m *Model) EncodeMultimodal(ctx ml.Context, multimodalData []byte) ([]input
return nil, model.ErrNoVisionModel
}
// Initialize clamp values from model tensors (lazy, once, after model is fully loaded)
m.VisionModel.InitClamp(m.MultiModalProjector)
t0 := time.Now()
img, _, err := image.Decode(bytes.NewReader(multimodalData))
if err != nil {
@@ -152,12 +152,17 @@ func (m *Model) EncodeMultimodal(ctx ml.Context, multimodalData []byte) ([]input
slog.Info("vision: patches", "patchesX", numPatchesX, "patchesY", numPatchesY, "total", numPatchesX*numPatchesY, "patchSize", m.ImageProcessor.patchSize)
visionOutputs := m.VisionModel.Forward(ctx, pixelValues, numPatchesX, numPatchesY)
visionOutputs = visionPoolAndProject(ctx, visionOutputs, numPatchesX, numPatchesY, m.VisionModel.VisionModelOptions, m.MultiModalProjector)
visionOutputs = visionPoolAndProject(ctx, visionOutputs, numPatchesX, numPatchesY, m.VisionModel.VisionModelOptions, m.MultiModalProjector, m.VisionModel.StdBias, m.VisionModel.StdScale)
slog.Info("vision: encoded", "elapsed", time.Since(t0), "shape", visionOutputs.Shape())
return []input.Multimodal{{Tensor: visionOutputs}}, nil
}
func (m *Model) PostLoad() error {
m.VisionModel.InitClamp(m.MultiModalProjector)
return nil
}
func (m *Model) encodeAudioMultimodal(ctx ml.Context, data []byte) ([]input.Multimodal, error) {
if m.AudioModel == nil || m.audioOpts == nil {
return nil, model.ErrNoVisionModel

View File

@@ -318,9 +318,8 @@ func (mlp *TextMLP) Forward(ctx ml.Context, hiddenState ml.Tensor) ml.Tensor {
// TextRouter implements the Gemma 4 MoE router.
type TextRouter struct {
Proj *nn.Linear `gguf:"ffn_gate_inp"`
Scale ml.Tensor `gguf:"ffn_gate_inp.scale"`
PerExpertScale ml.Tensor `gguf:"ffn_gate_inp.per_expert_scale"`
Proj *nn.Linear `gguf:"ffn_gate_inp"`
Scale ml.Tensor `gguf:"ffn_gate_inp.scale"`
}
func (r *TextRouter) Forward(ctx ml.Context, hiddenState ml.Tensor, opts *TextOptions) (routingWeights, selectedExperts ml.Tensor) {
@@ -341,63 +340,46 @@ func (r *TextRouter) Forward(ctx ml.Context, hiddenState ml.Tensor, opts *TextOp
// TextMoEBlock implements the Gemma 4 sparse MoE.
type TextMoEBlock struct {
Gate *nn.LinearBatch `gguf:"ffn_gate_exps"`
Up *nn.LinearBatch `gguf:"ffn_up_exps"`
Down *nn.LinearBatch `gguf:"ffn_down_exps"`
GateUp *nn.LinearBatch `gguf:"ffn_gate_up_exps"`
Gate *nn.LinearBatch `gguf:"ffn_gate_exps"`
Up *nn.LinearBatch `gguf:"ffn_up_exps"`
Down *nn.LinearBatch `gguf:"ffn_down_exps"`
DownScale ml.Tensor `gguf:"ffn_down_exps.scale,alt:ffn_gate_inp.per_expert_scale"`
}
func (moe *TextMoEBlock) Forward(ctx ml.Context, hiddenState, routingWeights, selectedExperts ml.Tensor, perExpertScale ml.Tensor, opts *TextOptions) ml.Tensor {
batchSize := hiddenState.Dim(1)
rw3d := routingWeights.Reshape(ctx, 1, opts.numExperts, batchSize)
// Gather per-expert scales for the selected experts (before we consume routingWeights).
// Multiply scale [numExperts] into the full softmax [numExperts, batchSize],
// then use Rows to select — gives us [numExpertsUsed, batchSize] of (softmax*scale).
// After we independently renorm the unscaled weights, we divide to recover just the scales.
// But that's fragile. Instead: just Mul scale into softmax, Rows both, renorm unscaled,
// then Mul the ratio.
//
// Simpler: Mul per_expert_scale onto the softmax ONCE to create a scaled copy.
// Rows both. Renorm the unscaled. Divide scaled/unscaled_pre_renorm to get scale[k].
// Mul scale[k] * renormed.
//
// Actually simplest: just select, renorm, then Mul the scale factors.
// To get scale factors for selected experts, Rows on a [1, numExperts, batchSize]
// tensor where every column is the same scale. Build it via Mul: softmax * scale
// gives us the scaled version; softmax gives unscaled. Their ratio = scale[k].
// After Rows, ratio = (softmax[k]*scale[k]) / softmax[k] = scale[k].
func (moe *TextMoEBlock) Forward(ctx ml.Context, hiddenState, routingWeights, selectedExperts ml.Tensor, opts *TextOptions) ml.Tensor {
// Select routing weights for chosen experts and renormalize
routingWeights = rw3d.Rows(ctx, selectedExperts)
routingWeights = routingWeights.Reshape(ctx, opts.numExpertsUsed, batchSize)
routingWeights = routingWeights.Reshape(ctx, 1, opts.numExperts, hiddenState.Dim(1)).Rows(ctx, selectedExperts)
routingWeights = routingWeights.Reshape(ctx, opts.numExpertsUsed, hiddenState.Dim(1))
routingWeights = routingWeights.Div(ctx, routingWeights.SumRows(ctx))
// Apply per-expert scale after renormalization
if perExpertScale != nil {
// Build [numExperts, batchSize] with scale broadcast, select with Rows
scaledSoftmax := rw3d.Reshape(ctx, opts.numExperts, batchSize).Mul(ctx, perExpertScale)
scaledSoftmax = scaledSoftmax.Reshape(ctx, 1, opts.numExperts, batchSize).Rows(ctx, selectedExperts)
scaledSoftmax = scaledSoftmax.Reshape(ctx, opts.numExpertsUsed, batchSize)
// Recover unscaled selected weights (before renorm) for the ratio
unscaled := rw3d.Rows(ctx, selectedExperts)
unscaled = unscaled.Reshape(ctx, opts.numExpertsUsed, batchSize)
// scale[k] = scaledSoftmax[k] / unscaled[k]
expertScales := scaledSoftmax.Div(ctx, unscaled)
routingWeights = routingWeights.Mul(ctx, expertScales)
}
routingWeights = routingWeights.Reshape(ctx, 1, opts.numExpertsUsed, batchSize)
routingWeights = routingWeights.Reshape(ctx, 1, opts.numExpertsUsed, hiddenState.Dim(1))
hiddenState = hiddenState.Reshape(ctx, hiddenState.Dim(0), 1, hiddenState.Dim(1))
// Expert computation using LinearBatch (MulmatID selecting experts by index)
gateOut := moe.Gate.Forward(ctx, hiddenState, selectedExperts)
upOut := moe.Up.Forward(ctx, hiddenState, selectedExperts)
var gateOut, upOut ml.Tensor
if moe.GateUp != nil && moe.GateUp.Weight != nil {
gateUp := moe.GateUp.Forward(ctx, hiddenState, selectedExperts)
nFF := gateUp.Dim(0) / 2
gateOut = gateUp.Slice(ctx, 0, 0, nFF, 1)
upOut = gateUp.Slice(ctx, 0, nFF, gateUp.Dim(0), 1)
} else {
gateOut = moe.Gate.Forward(ctx, hiddenState, selectedExperts)
upOut = moe.Up.Forward(ctx, hiddenState, selectedExperts)
}
hiddenState = gateOut.GELU(ctx, upOut)
experts := moe.Down.Forward(ctx, hiddenState, selectedExperts)
// Apply per-expert down projection scale when present.
if moe.DownScale != nil {
expertScales := moe.DownScale.Reshape(ctx, opts.numExperts, 1)
expertScales = expertScales.Repeat(ctx, 1, hiddenState.Dim(2))
expertScales = expertScales.Reshape(ctx, 1, opts.numExperts, hiddenState.Dim(2)).Rows(ctx, selectedExperts)
expertScales = expertScales.Reshape(ctx, opts.numExpertsUsed, hiddenState.Dim(2))
expertScales = expertScales.Reshape(ctx, 1, opts.numExpertsUsed, hiddenState.Dim(2))
experts = experts.Mul(ctx, expertScales)
}
// Apply routing weights
experts = experts.Mul(ctx, routingWeights)
@@ -450,7 +432,9 @@ func (l *TextLayer) Forward(ctx ml.Context, layer int, hiddenState, positions, p
residual = hiddenState
// MLP (+ optional MoE in parallel)
if l.Router != nil && l.MoE != nil && l.MoE.Gate != nil && l.MoE.Gate.Weight != nil {
hasSplitExperts := l.MoE != nil && l.MoE.Gate != nil && l.MoE.Up != nil && l.MoE.Gate.Weight != nil && l.MoE.Up.Weight != nil
hasFusedExperts := l.MoE != nil && l.MoE.GateUp != nil && l.MoE.GateUp.Weight != nil
if l.Router != nil && l.MoE != nil && l.MoE.Down != nil && l.MoE.Down.Weight != nil && (hasSplitExperts || hasFusedExperts) {
// MoE layers: run MLP and MoE in parallel, sum results
mlpState := l.MLPNorm.Forward(ctx, hiddenState, opts.eps)
mlpState = l.MLP.Forward(ctx, mlpState)
@@ -458,7 +442,7 @@ func (l *TextLayer) Forward(ctx ml.Context, layer int, hiddenState, positions, p
routingWeights, selectedExperts := l.Router.Forward(ctx, hiddenState, opts)
moeState := l.MoENorm.Forward(ctx, hiddenState, opts.eps)
moeState = l.MoE.Forward(ctx, moeState, routingWeights, selectedExperts, l.Router.PerExpertScale, opts)
moeState = l.MoE.Forward(ctx, moeState, routingWeights, selectedExperts, opts)
moeState = l.PostMoENorm.Forward(ctx, moeState, opts.eps)
// Combine MLP + MoE, apply outer post-FFN norm, then add residual

View File

@@ -1,6 +1,8 @@
package gemma4
import (
"math"
"github.com/ollama/ollama/fs"
"github.com/ollama/ollama/ml"
"github.com/ollama/ollama/ml/nn"
@@ -11,33 +13,116 @@ const batchSize = 1
// ClippableLinear is a linear layer with optional input/output clamping.
// Required by Gemma4 vision encoder for numerical stability with F16 weights.
// Clamp values are populated by VisionModel.InitClamp from the packed v.clamp_data tensor.
type ClippableLinear struct {
Weight ml.Tensor `gguf:"weight"`
InputMin ml.Tensor `gguf:"input_min"`
InputMax ml.Tensor `gguf:"input_max"`
OutputMin ml.Tensor `gguf:"output_min"`
OutputMax ml.Tensor `gguf:"output_max"`
inMin, inMax, outMin, outMax float32
hasClamp bool
clampsLoaded bool
}
func scalarValue(t ml.Tensor) (float32, bool) {
if t == nil {
return 0, false
}
data := t.BackendGet()
if len(data) == 0 {
return 0, false
}
return data[0], true
}
func (l *ClippableLinear) loadClampFromScalars() {
if l.clampsLoaded {
return
}
l.clampsLoaded = true
const (
defaultMin = -math.MaxFloat32
defaultMax = math.MaxFloat32
)
inMin, hasInMin := scalarValue(l.InputMin)
inMax, hasInMax := scalarValue(l.InputMax)
outMin, hasOutMin := scalarValue(l.OutputMin)
outMax, hasOutMax := scalarValue(l.OutputMax)
if !(hasInMin || hasInMax || hasOutMin || hasOutMax) {
return
}
l.hasClamp = true
l.inMin = defaultMin
l.inMax = defaultMax
l.outMin = defaultMin
l.outMax = defaultMax
if hasInMin {
l.inMin = inMin
}
if hasInMax {
l.inMax = inMax
}
if hasOutMin {
l.outMin = outMin
}
if hasOutMax {
l.outMax = outMax
}
}
func (l *ClippableLinear) Forward(ctx ml.Context, x ml.Tensor) ml.Tensor {
if l.inMax != 0 {
if l.hasClamp {
x = x.Clamp(ctx, l.inMin, l.inMax)
}
out := l.Weight.Mulmat(ctx, x)
if l.outMax != 0 {
if l.hasClamp {
out = out.Clamp(ctx, l.outMin, l.outMax)
}
return out
}
// InitClamp distributes packed clamp values from v.clamp_data to ClippableLinear structs.
// If scalar clamp tensors (input_min/max, output_min/max) are present, they are used too.
// Layout: numLayers × 7 linears (q,k,v,out,gate,up,down) × 4 floats (inMin,inMax,outMin,outMax)
// then 4 floats for the projector.
func (m *VisionModel) InitClamp(proj *MultiModalProjector) {
if m.clampInitDone || m.ClampData == nil {
if m.clampInitDone {
return
}
m.clampInitDone = true
linears := func(l *VisionEncoderLayer) []*ClippableLinear {
return []*ClippableLinear{
l.SelfAttention.Query, l.SelfAttention.Key, l.SelfAttention.Value,
l.SelfAttention.Output, l.MLP.Gate, l.MLP.Up, l.MLP.Down,
}
}
for i := range m.Layers {
for _, cl := range linears(&m.Layers[i]) {
if cl != nil {
cl.loadClampFromScalars()
}
}
}
if proj != nil && proj.Projection != nil {
proj.Projection.loadClampFromScalars()
}
// Load packed clamp data when present (legacy Ollama format).
if m.ClampData == nil {
return
}
// Read all clamp values from packed F32 tensor
data := m.ClampData.BackendGet()
if len(data) == 0 {
@@ -45,32 +130,31 @@ func (m *VisionModel) InitClamp(proj *MultiModalProjector) {
}
// Distribute to layer linears: 7 per layer × 4 values each
linears := func(l *VisionEncoderLayer) []*ClippableLinear {
return []*ClippableLinear{
l.SelfAttention.Query, l.SelfAttention.Key, l.SelfAttention.Value,
l.SelfAttention.Output, l.MLP.Gate, l.MLP.Up, l.MLP.Down,
}
}
for i := range m.Layers {
for li, cl := range linears(&m.Layers[i]) {
if cl == nil {
continue
}
idx := (i*7 + li) * 4
if idx+3 < len(data) {
cl.inMin = data[idx]
cl.inMax = data[idx+1]
cl.outMin = data[idx+2]
cl.outMax = data[idx+3]
cl.hasClamp = true
}
}
}
// Projector clamp values (last 4 floats)
if proj != nil {
if proj != nil && proj.Projection != nil {
projIdx := len(m.Layers) * 7 * 4
if projIdx+3 < len(data) {
proj.Projection.inMin = data[projIdx]
proj.Projection.inMax = data[projIdx+1]
proj.Projection.outMin = data[projIdx+2]
proj.Projection.outMax = data[projIdx+3]
proj.Projection.hasClamp = true
}
}
}
@@ -199,6 +283,8 @@ type VisionModel struct {
PatchEmbedding *nn.Conv2D `gguf:"patch_embd"`
PositionEmbedding ml.Tensor `gguf:"position_embd.weight"`
ClampData ml.Tensor `gguf:"clamp_data"`
StdBias ml.Tensor `gguf:"std_bias"`
StdScale ml.Tensor `gguf:"std_scale"`
Layers []VisionEncoderLayer `gguf:"blk"`
@@ -274,7 +360,7 @@ func visionTokenCount(imageWidth, imageHeight, patchSize, nMerge int) int {
return mergedX * mergedY
}
func visionPoolAndProject(ctx ml.Context, hiddenState ml.Tensor, numPatchesX, numPatchesY int, opts *VisionModelOptions, proj *MultiModalProjector) ml.Tensor {
func visionPoolAndProject(ctx ml.Context, hiddenState ml.Tensor, numPatchesX, numPatchesY int, opts *VisionModelOptions, proj *MultiModalProjector, stdBias, stdScale ml.Tensor) ml.Tensor {
hiddenSize := opts.hiddenSize
// Reshape from [hiddenSize, numPatches] to spatial layout for pooling
@@ -290,12 +376,14 @@ func visionPoolAndProject(ctx ml.Context, hiddenState ml.Tensor, numPatchesX, nu
hiddenState = hiddenState.Reshape(ctx, mergedX*mergedY, hiddenSize)
hiddenState = hiddenState.Permute(ctx, 1, 0, 2, 3).Contiguous(ctx)
// Ensure F32 for the projection Mulmat. The Metal mul_mm kernel for F16×F32
// casts F32 activations to F16 in shared memory, so values must stay within
// F16 range (≤65504). The sqrt(hiddenSize) scaling from the HF reference is
// omitted because it's normalized out by the unweighted RMSNorm that follows
// the projection: RMSNorm(√d·x) = x/rms(x) = RMSNorm(x).
hiddenState = hiddenState.Cast(ctx, ml.DTypeF32)
hiddenState = hiddenState.Scale(ctx, math.Sqrt(float64(hiddenSize)))
// Optional vision standardization before projection.
if stdBias != nil && stdScale != nil {
hiddenState = hiddenState.Sub(ctx, stdBias)
hiddenState = hiddenState.Mul(ctx, stdScale)
}
// Project to text embedding dimension
hiddenState = proj.Forward(ctx, hiddenState, opts.eps)

View File

@@ -344,49 +344,62 @@ func parseGemma4ToolCall(content string) (api.ToolCall, error) {
}
// gemma4ArgsToJSON converts Gemma 4's custom argument format to valid JSON.
// The format uses <|"|> for string delimiters and bare identifier keys.
// Example: {location:<|"|>Paris<|"|>,count:42} → {"location":"Paris","count":42}
func gemma4ArgsToJSON(s string) string {
// Step 1: Replace <|"|> with "
s = strings.ReplaceAll(s, `<|"|>`, `"`)
// Step 2: Quote bare keys (identifiers followed by : that aren't inside strings)
var buf strings.Builder
buf.Grow(len(s) + 32)
inString := false
hex := "0123456789abcdef"
i := 0
for i < len(s) {
ch := s[i]
if ch == '"' && !inString {
inString = true
buf.WriteByte(ch)
if ch == '"' {
inString = !inString
buf.WriteByte('"')
i++
// Write until closing quote
for i < len(s) {
buf.WriteByte(s[i])
if s[i] == '"' {
inString = false
i++
break
}
i++
}
continue
}
if inString {
switch ch {
case '\\':
buf.WriteString(`\\`)
case '\n':
buf.WriteString(`\n`)
case '\r':
buf.WriteString(`\r`)
case '\t':
buf.WriteString(`\t`)
case '\b':
buf.WriteString(`\b`)
case '\f':
buf.WriteString(`\f`)
default:
if ch < 0x20 {
buf.WriteString(`\u00`)
buf.WriteByte(hex[ch>>4])
buf.WriteByte(hex[ch&0x0f])
} else {
buf.WriteByte(ch)
}
}
i++
continue
}
if !inString && isIdentStart(ch) {
// Read the full identifier
j := i + 1
for j < len(s) && isIdentPart(s[j]) {
j++
}
word := s[i:j]
if j < len(s) && s[j] == ':' {
// It's an object key — quote it
buf.WriteByte('"')
buf.WriteString(word)
buf.WriteByte('"')
} else {
// It's a bare value (true, false, null, etc.)
buf.WriteString(word)
}
i = j

View File

@@ -133,6 +133,21 @@ func TestGemma4Parser(t *testing.T) {
},
},
},
{
name: "tool_call_with_multiline_string_arg",
input: `<|tool_call>call:bash{command:<|"|>date
<|"|>}<tool_call|>`,
expectedToolCalls: []api.ToolCall{
{
Function: api.ToolCallFunction{
Name: "bash",
Arguments: testArgs(map[string]any{
"command": "date\n",
}),
},
},
},
},
{
name: "multiple_tool_calls",
input: `<|tool_call>call:get_weather{location:<|"|>Paris<|"|>}<tool_call|><|tool_call>call:get_weather{location:<|"|>London<|"|>}<tool_call|>`,
@@ -410,6 +425,12 @@ func TestGemma4ArgsToJSON(t *testing.T) {
input: `{value:null}`,
expected: `{"value":null}`,
},
{
name: "multiline_string_value",
input: `{command:<|"|>date
<|"|>}`,
expected: `{"command":"date\n"}`,
},
}
for _, tt := range tests {

View File

@@ -1258,6 +1258,12 @@ func (s *Server) loadModel() {
panic(fmt.Errorf("failed to load model: %v", err))
}
if postLoader, ok := s.model.(model.PostLoader); ok {
if err := postLoader.PostLoad(); err != nil {
panic(fmt.Errorf("failed to finalize model initialization: %v", err))
}
}
s.status = llm.ServerStatusReady
s.ready.Done()
}