mirror of
https://github.com/ollama/ollama.git
synced 2026-04-18 09:03:35 -04:00
Merge pull request #42 from ollama/jmorganca/gemma4-ggml-improvements
gemma4: fix MoE fused gate_up split and multiline tool-call arg parsing
This commit is contained in:
426
cmd/cmd.go
426
cmd/cmd.go
@@ -568,395 +568,6 @@ func hasListedModelName(models []api.ListModelResponse, name string) bool {
|
||||
return false
|
||||
}
|
||||
|
||||
// getMaxAudioSeconds extracts the max audio duration from model info metadata.
|
||||
// Returns 0 if the model doesn't report audio limits.
|
||||
func getMaxAudioSeconds(info *api.ShowResponse) int {
|
||||
if info == nil || info.ModelInfo == nil {
|
||||
return 0
|
||||
}
|
||||
// Look for {arch}.max_audio_seconds in ModelInfo.
|
||||
for k, v := range info.ModelInfo {
|
||||
if strings.HasSuffix(k, ".max_audio_seconds") {
|
||||
switch val := v.(type) {
|
||||
case float64:
|
||||
return int(val)
|
||||
case int:
|
||||
return val
|
||||
}
|
||||
}
|
||||
}
|
||||
return 0
|
||||
}
|
||||
|
||||
// ANSI escape helpers for transcription display.
|
||||
const (
|
||||
)
|
||||
|
||||
// TranscribeHandler implements `ollama transcribe MODEL`.
|
||||
//
|
||||
// Two modes:
|
||||
// - Interactive (tty on stdin): spacebar start/stop with >>> prompt,
|
||||
// slash commands (/set, /show, /load, /bye, /?), word-wrapped output.
|
||||
// - Non-interactive (pipe/redirect): reads audio from stdin or records
|
||||
// until Ctrl+C, transcribes, writes word-wrapped text to stdout.
|
||||
func TranscribeHandler(cmd *cobra.Command, args []string) error {
|
||||
client, err := api.ClientFromEnvironment()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
modelName := args[0]
|
||||
interactive := term.IsTerminal(int(os.Stdin.Fd()))
|
||||
|
||||
// Pull model if needed and get model info.
|
||||
showReq := &api.ShowRequest{Name: modelName}
|
||||
info, err := client.Show(cmd.Context(), showReq)
|
||||
if err != nil {
|
||||
var se api.StatusError
|
||||
if errors.As(err, &se) && se.StatusCode == http.StatusNotFound {
|
||||
if err := PullHandler(cmd, []string{modelName}); err != nil {
|
||||
return err
|
||||
}
|
||||
info, err = client.Show(cmd.Context(), showReq)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
} else {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
language, _ := cmd.Flags().GetString("language")
|
||||
|
||||
opts := runOptions{
|
||||
Model: modelName,
|
||||
WordWrap: true,
|
||||
Options: map[string]any{"temperature": 0},
|
||||
Language: language,
|
||||
}
|
||||
|
||||
transcribeAndDisplay := func(wav []byte) {
|
||||
state := &displayResponseState{}
|
||||
_, err := transcribeAudio(cmd, opts, wav, func(tok string) {
|
||||
displayResponse(tok, opts.WordWrap, state)
|
||||
})
|
||||
if err != nil {
|
||||
fmt.Fprintln(os.Stderr, "Transcription error:", err)
|
||||
}
|
||||
fmt.Println()
|
||||
}
|
||||
|
||||
// --- Non-interactive mode ---
|
||||
if !interactive {
|
||||
audioData, err := io.ReadAll(os.Stdin)
|
||||
if err != nil {
|
||||
return fmt.Errorf("read stdin: %w", err)
|
||||
}
|
||||
|
||||
if len(audioData) > 44 {
|
||||
// Pipe with data (at least WAV header size): transcribe and output.
|
||||
transcribeAndDisplay(audioData)
|
||||
return nil
|
||||
}
|
||||
|
||||
// Empty stdin (< /dev/null or echo "" |): record until Ctrl+C.
|
||||
recorder, err := NewAudioRecorder()
|
||||
if err != nil {
|
||||
return fmt.Errorf("audio input unavailable: %w", err)
|
||||
}
|
||||
if maxSec := getMaxAudioSeconds(info); maxSec > 0 {
|
||||
recorder.MaxChunkSeconds = maxSec - 2
|
||||
}
|
||||
|
||||
if err := recorder.Start(); err != nil {
|
||||
return fmt.Errorf("start recording: %w", err)
|
||||
}
|
||||
fmt.Fprintln(os.Stderr, "Recording... Press Ctrl+C to stop.")
|
||||
|
||||
sigCh := make(chan os.Signal, 1)
|
||||
signal.Notify(sigCh, os.Interrupt)
|
||||
<-sigCh
|
||||
signal.Stop(sigCh)
|
||||
|
||||
recorder.Stop()
|
||||
fmt.Fprintln(os.Stderr)
|
||||
|
||||
if wav := recorder.FlushWAV(); wav != nil {
|
||||
transcribeAndDisplay(wav)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// --- Interactive mode ---
|
||||
recorder, err := NewAudioRecorder()
|
||||
if err != nil {
|
||||
return fmt.Errorf("audio input unavailable: %w", err)
|
||||
}
|
||||
if maxSec := getMaxAudioSeconds(info); maxSec > 0 {
|
||||
recorder.MaxChunkSeconds = maxSec - 2
|
||||
}
|
||||
|
||||
scanner, err := readline.New(readline.Prompt{
|
||||
Prompt: ">>> ",
|
||||
Placeholder: "Press Space to record (/? for help)",
|
||||
})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
fmt.Print(readline.StartBracketedPaste)
|
||||
defer fmt.Printf(readline.EndBracketedPaste)
|
||||
|
||||
usage := func() {
|
||||
fmt.Fprintln(os.Stderr, "Available Commands:")
|
||||
fmt.Fprintln(os.Stderr, " /set Set session variables")
|
||||
fmt.Fprintln(os.Stderr, " /show Show model information")
|
||||
fmt.Fprintln(os.Stderr, " /load <model> Load a different model")
|
||||
fmt.Fprintln(os.Stderr, " /bye Exit")
|
||||
fmt.Fprintln(os.Stderr, " /?, /help Help for a command")
|
||||
fmt.Fprintln(os.Stderr, "")
|
||||
fmt.Fprintln(os.Stderr, "Press Space to start/stop recording.")
|
||||
fmt.Fprintln(os.Stderr, "")
|
||||
}
|
||||
|
||||
usageSet := func() {
|
||||
fmt.Fprintln(os.Stderr, "Available Commands:")
|
||||
fmt.Fprintln(os.Stderr, " /set parameter ... Set a parameter")
|
||||
fmt.Fprintln(os.Stderr, " /set wordwrap Enable wordwrap")
|
||||
fmt.Fprintln(os.Stderr, " /set nowordwrap Disable wordwrap")
|
||||
fmt.Fprintln(os.Stderr, " /set verbose Show LLM stats")
|
||||
fmt.Fprintln(os.Stderr, " /set quiet Disable LLM stats")
|
||||
fmt.Fprintln(os.Stderr, "")
|
||||
}
|
||||
|
||||
// doTranscribeRecording is like doAudioRecording but polls TakeChunk()
|
||||
// during recording to stream transcription of long recordings.
|
||||
doTranscribeRecording := func() ([]byte, error) {
|
||||
fmt.Print(">>> \033[90m◉ Press Space to record...\033[0m")
|
||||
|
||||
for {
|
||||
r, err := scanner.ReadRaw()
|
||||
if err != nil {
|
||||
return nil, io.EOF
|
||||
}
|
||||
if r == 3 { // Ctrl+C
|
||||
fmt.Print("\r\033[K")
|
||||
fmt.Println("Use Ctrl + d or /bye to exit.")
|
||||
return nil, nil
|
||||
}
|
||||
if r == 4 { // Ctrl+D
|
||||
fmt.Println()
|
||||
return nil, io.EOF
|
||||
}
|
||||
if r == ' ' {
|
||||
fmt.Print("\r\033[K") // clear the prompt line
|
||||
break
|
||||
}
|
||||
if r == '/' || (r >= 32 && r < 127) {
|
||||
fmt.Print("\r\033[K")
|
||||
return nil, errFallbackToText{prefill: string(r)}
|
||||
}
|
||||
}
|
||||
|
||||
if err := recorder.Start(); err != nil {
|
||||
fmt.Println()
|
||||
return nil, fmt.Errorf("start recording: %w", err)
|
||||
}
|
||||
|
||||
// Poll for chunks in a background goroutine while recording.
|
||||
chunkDone := make(chan struct{})
|
||||
go func() {
|
||||
ticker := time.NewTicker(500 * time.Millisecond)
|
||||
defer ticker.Stop()
|
||||
for {
|
||||
select {
|
||||
case <-chunkDone:
|
||||
return
|
||||
case <-ticker.C:
|
||||
if wav := recorder.TakeChunk(); wav != nil {
|
||||
transcribeAndDisplay(wav)
|
||||
}
|
||||
}
|
||||
}
|
||||
}()
|
||||
|
||||
// Wait for Space to stop, Ctrl+C to discard, Ctrl+D to exit.
|
||||
for {
|
||||
r, err := scanner.ReadRaw()
|
||||
if err != nil {
|
||||
close(chunkDone)
|
||||
recorder.Stop()
|
||||
return nil, io.EOF
|
||||
}
|
||||
if r == 4 { // Ctrl+D
|
||||
close(chunkDone)
|
||||
recorder.Stop()
|
||||
fmt.Println()
|
||||
return nil, io.EOF
|
||||
}
|
||||
if r == 3 { // Ctrl+C
|
||||
close(chunkDone)
|
||||
recorder.Stop()
|
||||
return nil, nil
|
||||
}
|
||||
if r == ' ' { // Space: stop recording
|
||||
close(chunkDone)
|
||||
recorder.Stop()
|
||||
return recorder.FlushWAV(), nil
|
||||
}
|
||||
// Ignore other keys while recording.
|
||||
}
|
||||
}
|
||||
|
||||
for {
|
||||
wav, err := doTranscribeRecording()
|
||||
if err != nil {
|
||||
var fallback errFallbackToText
|
||||
if errors.As(err, &fallback) {
|
||||
// User typed text instead of pressing Space.
|
||||
line := fallback.prefill
|
||||
if line == "/" {
|
||||
// Need the rest of the command — read via readline.
|
||||
scanner.Prefill = "/"
|
||||
fullLine, err := scanner.Readline()
|
||||
if errors.Is(err, io.EOF) {
|
||||
fmt.Println()
|
||||
return nil
|
||||
}
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
line = fullLine
|
||||
}
|
||||
line = strings.TrimSpace(line)
|
||||
|
||||
switch {
|
||||
case line == "/?" || line == "/help":
|
||||
usage()
|
||||
case strings.HasPrefix(line, "/? "):
|
||||
arg := strings.TrimSpace(line[3:])
|
||||
switch arg {
|
||||
case "set":
|
||||
usageSet()
|
||||
default:
|
||||
usage()
|
||||
}
|
||||
case strings.HasPrefix(line, "/set"):
|
||||
args := strings.Fields(line)
|
||||
if len(args) == 1 {
|
||||
usageSet()
|
||||
continue
|
||||
}
|
||||
switch args[1] {
|
||||
case "wordwrap":
|
||||
opts.WordWrap = true
|
||||
fmt.Println("Set 'wordwrap' mode.")
|
||||
case "nowordwrap":
|
||||
opts.WordWrap = false
|
||||
fmt.Println("Set 'nowordwrap' mode.")
|
||||
case "verbose":
|
||||
cmd.Flags().Set("verbose", "true")
|
||||
fmt.Println("Set 'verbose' mode.")
|
||||
case "quiet":
|
||||
cmd.Flags().Set("verbose", "false")
|
||||
fmt.Println("Set 'quiet' mode.")
|
||||
case "parameter":
|
||||
if len(args) < 4 {
|
||||
fmt.Println("Usage: /set parameter <name> <value>")
|
||||
continue
|
||||
}
|
||||
opts.Options[args[2]] = args[3]
|
||||
fmt.Printf("Set parameter '%s' to '%s'\n", args[2], args[3])
|
||||
default:
|
||||
fmt.Printf("Unknown option: %s\n", args[1])
|
||||
usageSet()
|
||||
}
|
||||
case strings.HasPrefix(line, "/show"):
|
||||
args := strings.Fields(line)
|
||||
if len(args) == 1 {
|
||||
args = append(args, "info")
|
||||
}
|
||||
showReq := &api.ShowRequest{Name: opts.Model}
|
||||
resp, err := client.Show(cmd.Context(), showReq)
|
||||
if err != nil {
|
||||
fmt.Println("Error:", err)
|
||||
continue
|
||||
}
|
||||
switch args[1] {
|
||||
case "info":
|
||||
if err := showInfo(resp, false, os.Stdout); err != nil {
|
||||
fmt.Println("Error:", err)
|
||||
}
|
||||
case "license":
|
||||
fmt.Println(resp.License)
|
||||
case "parameters":
|
||||
fmt.Println(resp.Parameters)
|
||||
case "system":
|
||||
fmt.Println(resp.System)
|
||||
default:
|
||||
fmt.Printf("Unknown show command: %s\n", args[1])
|
||||
}
|
||||
case strings.HasPrefix(line, "/load"):
|
||||
args := strings.Fields(line)
|
||||
if len(args) != 2 {
|
||||
fmt.Println("Usage: /load <modelname>")
|
||||
continue
|
||||
}
|
||||
newModel := args[1]
|
||||
fmt.Printf("Loading model '%s'\n", newModel)
|
||||
showReq := &api.ShowRequest{Name: newModel}
|
||||
newInfo, err := client.Show(cmd.Context(), showReq)
|
||||
if err != nil {
|
||||
var se api.StatusError
|
||||
if errors.As(err, &se) && se.StatusCode == http.StatusNotFound {
|
||||
fmt.Printf("error: model '%s' not found\n", newModel)
|
||||
} else {
|
||||
fmt.Println("Error:", err)
|
||||
}
|
||||
continue
|
||||
}
|
||||
// Verify audio capability.
|
||||
hasAudio := false
|
||||
for _, cap := range newInfo.Capabilities {
|
||||
if cap == "audio" {
|
||||
hasAudio = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if !hasAudio {
|
||||
fmt.Printf("error: model '%s' does not support audio input\n", newModel)
|
||||
continue
|
||||
}
|
||||
opts.Model = newModel
|
||||
if maxSec := getMaxAudioSeconds(newInfo); maxSec > 0 {
|
||||
recorder.MaxChunkSeconds = maxSec - 2
|
||||
}
|
||||
case line == "/exit" || line == "/bye":
|
||||
fmt.Println()
|
||||
return nil
|
||||
case line != "":
|
||||
fmt.Printf("Unknown command: %s (type /? for help)\n", line)
|
||||
}
|
||||
continue
|
||||
}
|
||||
if errors.Is(err, io.EOF) {
|
||||
fmt.Println()
|
||||
return nil
|
||||
}
|
||||
fmt.Fprintf(os.Stderr, "Recording error: %v\n", err)
|
||||
continue
|
||||
}
|
||||
|
||||
if wav == nil {
|
||||
// Ctrl+C during recording — discard and retry.
|
||||
continue
|
||||
}
|
||||
|
||||
// Transcribe the recording.
|
||||
transcribeAndDisplay(wav)
|
||||
}
|
||||
}
|
||||
|
||||
func RunHandler(cmd *cobra.Command, args []string) error {
|
||||
interactive := true
|
||||
|
||||
@@ -1084,7 +695,8 @@ func RunHandler(cmd *cobra.Command, args []string) error {
|
||||
return err
|
||||
}
|
||||
|
||||
opts.MultiModal = slices.Contains(info.Capabilities, model.CapabilityVision)
|
||||
audioCapable := slices.Contains(info.Capabilities, model.CapabilityAudio)
|
||||
opts.MultiModal = slices.Contains(info.Capabilities, model.CapabilityVision) || audioCapable
|
||||
|
||||
// TODO: remove the projector info and vision info checks below,
|
||||
// these are left in for backwards compatibility with older servers
|
||||
@@ -1101,19 +713,6 @@ func RunHandler(cmd *cobra.Command, args []string) error {
|
||||
|
||||
opts.ParentModel = info.Details.ParentModel
|
||||
|
||||
opts.AudioCapable = slices.Contains(info.Capabilities, model.CapabilityAudio)
|
||||
|
||||
audioin, _ := cmd.Flags().GetBool("audioin")
|
||||
if audioin {
|
||||
if !opts.AudioCapable {
|
||||
fmt.Fprintf(os.Stderr, "Warning: audio input disabled — %s does not support audio\n", opts.Model)
|
||||
} else {
|
||||
opts.AudioInput = true
|
||||
opts.MultiModal = true // audio uses the multimodal pipeline
|
||||
opts.MaxAudioSeconds = getMaxAudioSeconds(info)
|
||||
}
|
||||
}
|
||||
|
||||
// Check if this is an embedding model
|
||||
isEmbeddingModel := slices.Contains(info.Capabilities, model.CapabilityEmbedding)
|
||||
|
||||
@@ -1837,12 +1436,8 @@ type runOptions struct {
|
||||
System string
|
||||
Images []api.ImageData
|
||||
Options map[string]any
|
||||
MultiModal bool
|
||||
AudioInput bool
|
||||
AudioCapable bool // model supports audio input
|
||||
MaxAudioSeconds int // from model metadata; 0 = use default
|
||||
Language string // language hint for transcription
|
||||
KeepAlive *api.Duration
|
||||
MultiModal bool
|
||||
KeepAlive *api.Duration
|
||||
Think *api.ThinkValue
|
||||
HideThinking bool
|
||||
ShowConnect bool
|
||||
@@ -2568,7 +2163,6 @@ func NewCLI() *cobra.Command {
|
||||
runCmd.Flags().Bool("experimental", false, "Enable experimental agent loop with tools")
|
||||
runCmd.Flags().Bool("experimental-yolo", false, "Skip all tool approval prompts (use with caution)")
|
||||
runCmd.Flags().Bool("experimental-websearch", false, "Enable web search tool in experimental mode")
|
||||
runCmd.Flags().Bool("audioin", false, "Enable audio input via microphone (press Space to record)")
|
||||
|
||||
// Image generation flags (width, height, steps, seed, etc.)
|
||||
imagegen.RegisterFlags(runCmd)
|
||||
@@ -2576,16 +2170,6 @@ func NewCLI() *cobra.Command {
|
||||
runCmd.Flags().Bool("imagegen", false, "Use the imagegen runner for LLM inference")
|
||||
runCmd.Flags().MarkHidden("imagegen")
|
||||
|
||||
transcribeCmd := &cobra.Command{
|
||||
Use: "transcribe MODEL",
|
||||
Short: "Transcribe audio to text using microphone",
|
||||
Long: "Record audio via microphone and transcribe to text.\nPress Space to start/stop recording. Ctrl+D to exit.",
|
||||
Args: cobra.ExactArgs(1),
|
||||
PreRunE: checkServerHeartbeat,
|
||||
RunE: TranscribeHandler,
|
||||
}
|
||||
transcribeCmd.Flags().String("language", "", "Language hint (e.g. en, es, fr)")
|
||||
|
||||
stopCmd := &cobra.Command{
|
||||
Use: "stop MODEL",
|
||||
Short: "Stop a running model",
|
||||
@@ -2706,7 +2290,6 @@ func NewCLI() *cobra.Command {
|
||||
createCmd,
|
||||
showCmd,
|
||||
runCmd,
|
||||
transcribeCmd,
|
||||
stopCmd,
|
||||
pullCmd,
|
||||
pushCmd,
|
||||
@@ -2750,7 +2333,6 @@ func NewCLI() *cobra.Command {
|
||||
createCmd,
|
||||
showCmd,
|
||||
runCmd,
|
||||
transcribeCmd,
|
||||
stopCmd,
|
||||
pullCmd,
|
||||
pushCmd,
|
||||
|
||||
@@ -12,7 +12,6 @@ import (
|
||||
"regexp"
|
||||
"slices"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/spf13/cobra"
|
||||
|
||||
@@ -24,143 +23,6 @@ import (
|
||||
"github.com/ollama/ollama/types/model"
|
||||
)
|
||||
|
||||
// errFallbackToText is returned when the user types a non-space key in audio mode,
|
||||
// indicating we should fall through to the normal text input.
|
||||
type errFallbackToText struct {
|
||||
prefill string
|
||||
}
|
||||
|
||||
func (e errFallbackToText) Error() string { return "fallback to text" }
|
||||
|
||||
// doAudioRecording handles the spacebar-driven recording flow.
|
||||
// Returns WAV bytes on success, nil to retry, or an error.
|
||||
func doAudioRecording(scanner *readline.Instance, recorder *AudioRecorder) ([]byte, error) {
|
||||
fmt.Print(">>> \033[90m◉ Press Space to record...\033[0m")
|
||||
|
||||
// Wait for spacebar to start.
|
||||
for {
|
||||
r, err := scanner.ReadRaw()
|
||||
if err != nil {
|
||||
return nil, io.EOF
|
||||
}
|
||||
if r == 3 { // Ctrl+C
|
||||
fmt.Print("\r\033[K")
|
||||
fmt.Println("Use Ctrl + d or /bye to exit.")
|
||||
return nil, nil
|
||||
}
|
||||
if r == 4 { // Ctrl+D
|
||||
fmt.Println()
|
||||
return nil, io.EOF
|
||||
}
|
||||
if r == ' ' {
|
||||
break
|
||||
}
|
||||
// User typed a regular character — fall back to text input with this char.
|
||||
if r == '/' || (r >= 32 && r < 127) {
|
||||
fmt.Print("\r\033[K") // clear the "Press Space" line
|
||||
return nil, errFallbackToText{prefill: string(r)}
|
||||
}
|
||||
}
|
||||
|
||||
// Start recording.
|
||||
if err := recorder.Start(); err != nil {
|
||||
fmt.Println()
|
||||
return nil, fmt.Errorf("start recording: %w", err)
|
||||
}
|
||||
|
||||
// Show recording indicator with elapsed time.
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
ticker := time.NewTicker(100 * time.Millisecond)
|
||||
defer ticker.Stop()
|
||||
for {
|
||||
select {
|
||||
case <-done:
|
||||
return
|
||||
case <-ticker.C:
|
||||
d := recorder.Duration()
|
||||
fmt.Printf("\r>>> \033[91m◈ Recording... %.1fs\033[0m ", d.Seconds())
|
||||
}
|
||||
}
|
||||
}()
|
||||
|
||||
// Wait for spacebar to stop.
|
||||
for {
|
||||
r, err := scanner.ReadRaw()
|
||||
if err != nil {
|
||||
close(done)
|
||||
recorder.Stop()
|
||||
return nil, io.EOF
|
||||
}
|
||||
if r == ' ' || r == 3 { // Space or Ctrl+C
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
close(done)
|
||||
dur, _ := recorder.Stop()
|
||||
fmt.Printf("\r>>> \033[90m◇ Recorded %.1fs\033[0m \n", dur.Seconds())
|
||||
|
||||
// Encode to WAV.
|
||||
wav, err := recorder.WAV()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return wav, nil
|
||||
}
|
||||
|
||||
// tokenCallback is called for each streamed token. Return non-nil error to abort.
|
||||
type tokenCallback func(token string)
|
||||
|
||||
// streamChat sends a chat request and streams tokens to the callback.
|
||||
// Returns the full accumulated text.
|
||||
func streamChat(cmd *cobra.Command, model string, messages []api.Message, onToken tokenCallback) (string, error) {
|
||||
client, err := api.ClientFromEnvironment()
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
noThink := &api.ThinkValue{Value: false}
|
||||
stream := true
|
||||
req := &api.ChatRequest{
|
||||
Model: model,
|
||||
Messages: messages,
|
||||
Stream: &stream,
|
||||
Think: noThink,
|
||||
Options: map[string]any{"temperature": 0},
|
||||
}
|
||||
|
||||
var result strings.Builder
|
||||
fn := func(response api.ChatResponse) error {
|
||||
tok := response.Message.Content
|
||||
result.WriteString(tok)
|
||||
if onToken != nil {
|
||||
onToken(tok)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
if err := client.Chat(cmd.Context(), req, fn); err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
return strings.TrimSpace(result.String()), nil
|
||||
}
|
||||
|
||||
// transcribeAudio sends audio to the model for transcription.
|
||||
// onToken is called for each streamed token (may be nil for silent operation).
|
||||
func transcribeAudio(cmd *cobra.Command, opts runOptions, audioData []byte, onToken tokenCallback) (string, error) {
|
||||
systemPrompt := "Transcribe the following audio exactly as spoken. Output only the transcription text, nothing else."
|
||||
if opts.Language != "" {
|
||||
systemPrompt += " The audio is in " + opts.Language + "."
|
||||
}
|
||||
|
||||
return streamChat(cmd, opts.Model, []api.Message{
|
||||
{Role: "system", Content: systemPrompt},
|
||||
{Role: "user", Content: "Transcribe this audio.", Images: []api.ImageData{audioData}},
|
||||
}, onToken)
|
||||
}
|
||||
|
||||
type MultilineState int
|
||||
|
||||
const (
|
||||
@@ -177,11 +39,6 @@ func generateInteractive(cmd *cobra.Command, opts runOptions) error {
|
||||
fmt.Fprintln(os.Stderr, " /load <model> Load a session or model")
|
||||
fmt.Fprintln(os.Stderr, " /save <model> Save your current session")
|
||||
fmt.Fprintln(os.Stderr, " /clear Clear session context")
|
||||
if opts.AudioCapable {
|
||||
fmt.Fprintln(os.Stderr, " /audio Toggle voice input mode")
|
||||
} else {
|
||||
fmt.Fprintln(os.Stderr, " /audio (not supported by current model)")
|
||||
}
|
||||
fmt.Fprintln(os.Stderr, " /bye Exit")
|
||||
fmt.Fprintln(os.Stderr, " /?, /help Help for a command")
|
||||
fmt.Fprintln(os.Stderr, " /? shortcuts Help for keyboard shortcuts")
|
||||
@@ -190,7 +47,7 @@ func generateInteractive(cmd *cobra.Command, opts runOptions) error {
|
||||
fmt.Fprintln(os.Stderr, "Use \"\"\" to begin a multi-line message.")
|
||||
|
||||
if opts.MultiModal {
|
||||
fmt.Fprintf(os.Stderr, "Use %s to include .jpg, .png, or .webp images.\n", filepath.FromSlash("/path/to/file"))
|
||||
fmt.Fprintf(os.Stderr, "Use %s to include .jpg, .png, .webp images, or .wav audio files.\n", filepath.FromSlash("/path/to/file"))
|
||||
}
|
||||
|
||||
fmt.Fprintln(os.Stderr, "")
|
||||
@@ -279,66 +136,7 @@ func generateInteractive(cmd *cobra.Command, opts runOptions) error {
|
||||
var multiline MultilineState
|
||||
var thinkExplicitlySet bool = opts.Think != nil
|
||||
|
||||
audioMode := opts.AudioInput
|
||||
var recorder *AudioRecorder
|
||||
if audioMode {
|
||||
var err error
|
||||
recorder, err = NewAudioRecorder()
|
||||
if err != nil {
|
||||
fmt.Fprintf(os.Stderr, "Audio input unavailable: %v\n", err)
|
||||
audioMode = false
|
||||
} else {
|
||||
if opts.MaxAudioSeconds > 0 {
|
||||
recorder.MaxChunkSeconds = opts.MaxAudioSeconds - 2 // 2s headroom
|
||||
}
|
||||
fmt.Fprintln(os.Stderr, "Voice input enabled. Press Space to record, Space again to send.")
|
||||
}
|
||||
}
|
||||
|
||||
for {
|
||||
// Audio recording mode: wait for spacebar instead of text input.
|
||||
if audioMode && recorder != nil {
|
||||
audioData, err := doAudioRecording(scanner, recorder)
|
||||
if err != nil {
|
||||
if err == io.EOF {
|
||||
fmt.Println()
|
||||
return nil
|
||||
}
|
||||
// User typed a regular key — fall through to normal readline.
|
||||
if fb, ok := err.(errFallbackToText); ok {
|
||||
scanner.Prefill = fb.prefill
|
||||
goto textInput
|
||||
}
|
||||
fmt.Fprintf(os.Stderr, "Audio error: %v\n", err)
|
||||
continue
|
||||
}
|
||||
if audioData == nil {
|
||||
continue
|
||||
}
|
||||
|
||||
// Send audio as the user's input — the model hears and responds.
|
||||
newMessage := api.Message{
|
||||
Role: "user",
|
||||
Images: []api.ImageData{audioData},
|
||||
}
|
||||
opts.Messages = append(opts.Messages, newMessage)
|
||||
|
||||
assistant, err := chat(cmd, opts)
|
||||
if err != nil {
|
||||
if strings.Contains(err.Error(), "does not support thinking") ||
|
||||
strings.Contains(err.Error(), "invalid think value") {
|
||||
fmt.Printf("error: %v\n", err)
|
||||
continue
|
||||
}
|
||||
return err
|
||||
}
|
||||
if assistant != nil {
|
||||
opts.Messages = append(opts.Messages, *assistant)
|
||||
}
|
||||
continue
|
||||
}
|
||||
|
||||
textInput:
|
||||
line, err := scanner.Readline()
|
||||
switch {
|
||||
case errors.Is(err, io.EOF):
|
||||
@@ -676,29 +474,6 @@ func generateInteractive(cmd *cobra.Command, opts runOptions) error {
|
||||
} else {
|
||||
usage()
|
||||
}
|
||||
case line == "/audio":
|
||||
if !opts.AudioCapable {
|
||||
fmt.Fprintf(os.Stderr, "Audio input not supported by %s\n", opts.Model)
|
||||
continue
|
||||
}
|
||||
if audioMode {
|
||||
audioMode = false
|
||||
fmt.Fprintln(os.Stderr, "Voice input disabled.")
|
||||
} else {
|
||||
audioMode = true
|
||||
if recorder == nil {
|
||||
var recErr error
|
||||
recorder, recErr = NewAudioRecorder()
|
||||
if recErr != nil {
|
||||
fmt.Fprintf(os.Stderr, "Audio input unavailable: %v\n", recErr)
|
||||
audioMode = false
|
||||
continue
|
||||
}
|
||||
}
|
||||
opts.MultiModal = true
|
||||
fmt.Fprintln(os.Stderr, "Voice input enabled. Press Space to record, Space again to send.")
|
||||
}
|
||||
continue
|
||||
case strings.HasPrefix(line, "/exit"), strings.HasPrefix(line, "/bye"):
|
||||
return nil
|
||||
case strings.HasPrefix(line, "/"):
|
||||
|
||||
@@ -84,3 +84,33 @@ func TestExtractFileDataRemovesQuotedFilepath(t *testing.T) {
|
||||
assert.Len(t, imgs, 1)
|
||||
assert.Equal(t, cleaned, "before after")
|
||||
}
|
||||
|
||||
func TestExtractFileDataWAV(t *testing.T) {
|
||||
dir := t.TempDir()
|
||||
fp := filepath.Join(dir, "sample.wav")
|
||||
data := make([]byte, 600)
|
||||
copy(data[:44], []byte{
|
||||
'R', 'I', 'F', 'F',
|
||||
0x58, 0x02, 0x00, 0x00, // file size - 8
|
||||
'W', 'A', 'V', 'E',
|
||||
'f', 'm', 't', ' ',
|
||||
0x10, 0x00, 0x00, 0x00, // fmt chunk size
|
||||
0x01, 0x00, // PCM
|
||||
0x01, 0x00, // mono
|
||||
0x80, 0x3e, 0x00, 0x00, // 16000 Hz
|
||||
0x00, 0x7d, 0x00, 0x00, // byte rate
|
||||
0x02, 0x00, // block align
|
||||
0x10, 0x00, // 16-bit
|
||||
'd', 'a', 't', 'a',
|
||||
0x34, 0x02, 0x00, 0x00, // data size
|
||||
})
|
||||
if err := os.WriteFile(fp, data, 0o600); err != nil {
|
||||
t.Fatalf("failed to write test audio: %v", err)
|
||||
}
|
||||
|
||||
input := "before " + fp + " after"
|
||||
cleaned, imgs, err := extractFileData(input)
|
||||
assert.NoError(t, err)
|
||||
assert.Len(t, imgs, 1)
|
||||
assert.Equal(t, "before after", cleaned)
|
||||
}
|
||||
|
||||
@@ -9,8 +9,6 @@ import (
|
||||
"strings"
|
||||
|
||||
"github.com/ollama/ollama/fs/ggml"
|
||||
"github.com/pdevine/tensor"
|
||||
"github.com/pdevine/tensor/native"
|
||||
)
|
||||
|
||||
type gemma4Model struct {
|
||||
@@ -237,15 +235,6 @@ func (p *gemma4Model) Tensors(ts []Tensor) []*ggml.Tensor {
|
||||
continue
|
||||
}
|
||||
|
||||
// Skip vision clamp scalars — packed into v.clamp_data below.
|
||||
// Audio clamp scalars are kept as individual tensors (matching published GGUF).
|
||||
isVisionClamp := (strings.Contains(name, "input_min") || strings.Contains(name, "input_max") ||
|
||||
strings.Contains(name, "output_min") || strings.Contains(name, "output_max")) &&
|
||||
(strings.Contains(name, "vision_tower") || strings.Contains(name, "embed_vision"))
|
||||
if isVisionClamp {
|
||||
continue
|
||||
}
|
||||
|
||||
// Vision tensor renaming: match published mmproj GGUF names
|
||||
if strings.HasPrefix(name, "v.blk.") {
|
||||
name = strings.Replace(name, ".attn_norm.", ".ln1.", 1)
|
||||
@@ -282,26 +271,6 @@ func (p *gemma4Model) Tensors(ts []Tensor) []*ggml.Tensor {
|
||||
shape = []uint64{shape[0], shape[2]}
|
||||
}
|
||||
|
||||
// Fused MoE gate_up_proj: split [experts, 2*intermediate, hidden] into separate gate and up.
|
||||
// No transpose needed — the split shape [experts, intermediate, hidden] already matches
|
||||
// the GGUF layout after the framework's dimension reversal (ne[0]=hidden matches input).
|
||||
if strings.Contains(name, "ffn_gate_exps.weight") && len(shape) == 3 {
|
||||
halfDim := int(shape[1]) / 2
|
||||
newShape := slices.Clone(shape)
|
||||
newShape[1] = newShape[1] / 2
|
||||
for i, ggufName := range []string{"ffn_gate_exps.weight", "ffn_up_exps.weight"} {
|
||||
tt := t.Clone()
|
||||
tt.SetRepacker(p.sliceExperts(tensor.S(i*halfDim, (i+1)*halfDim)))
|
||||
out = append(out, &ggml.Tensor{
|
||||
Name: strings.ReplaceAll(name, "ffn_gate_exps.weight", ggufName),
|
||||
Kind: tt.Kind(),
|
||||
Shape: slices.Clone(newShape),
|
||||
WriterTo: tt,
|
||||
})
|
||||
}
|
||||
continue
|
||||
}
|
||||
|
||||
// MoE expert weights: no transpose needed. Safetensors stores [experts, out, in]
|
||||
// which the framework reverses to GGUF ne=[in, out, experts], matching ggml_mul_mat_id.
|
||||
// (transposeExperts was incorrectly swapping dims — removed)
|
||||
@@ -454,30 +423,6 @@ func (*gemma4Model) reshapePatchEmbed(_ string, data []float32, shape []uint64)
|
||||
return result, nil
|
||||
}
|
||||
|
||||
// sliceExperts returns a repacker that slices dim 1 of a 3D expert tensor.
|
||||
// Used for splitting fused gate_up_proj into separate gate and up tensors.
|
||||
func (*gemma4Model) sliceExperts(dim1Slice tensor.Slice) Repacker {
|
||||
return func(_ string, data []float32, shape []uint64) ([]float32, error) {
|
||||
dims := make([]int, len(shape))
|
||||
for i, d := range shape {
|
||||
dims[i] = int(d)
|
||||
}
|
||||
|
||||
var t tensor.Tensor = tensor.New(tensor.WithShape(dims...), tensor.WithBacking(data))
|
||||
t, err := t.Slice(nil, dim1Slice)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
t = tensor.Materialize(t)
|
||||
if err := t.Reshape(t.Shape().TotalSize()); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return native.VectorF32(t.(*tensor.Dense))
|
||||
}
|
||||
}
|
||||
|
||||
// softplusRepacker applies softplus (ln(1 + exp(x))) to tensor data.
|
||||
// Used for per_dim_scale tensors which the published GGUF stores pre-activated.
|
||||
func softplusRepacker(_ string, data []float32, shape []uint64) ([]float32, error) {
|
||||
@@ -501,34 +446,35 @@ func (p *gemma4Model) Replacements() []string {
|
||||
".linear.bias", ".bias",
|
||||
|
||||
// Audio SSCP (Sub-Sample Convolution Projection)
|
||||
"model.audio_tower.subsample_conv_projection.layer0.conv", "a.conv1d.0",
|
||||
"model.audio_tower.subsample_conv_projection.layer0.norm", "a.conv1d.0.norm",
|
||||
"model.audio_tower.subsample_conv_projection.layer1.conv", "a.conv1d.1",
|
||||
"model.audio_tower.subsample_conv_projection.layer1.norm", "a.conv1d.1.norm",
|
||||
"model.audio_tower.subsample_conv_projection.conv_0.conv", "a.conv1d.0",
|
||||
"model.audio_tower.subsample_conv_projection.conv_0.norm", "a.conv1d.0.norm",
|
||||
"model.audio_tower.subsample_conv_projection.conv_1.conv", "a.conv1d.1",
|
||||
"model.audio_tower.subsample_conv_projection.conv_1.norm", "a.conv1d.1.norm",
|
||||
"model.audio_tower.subsample_conv_projection.input_proj_linear", "a.pre_encode.out",
|
||||
|
||||
// Audio conformer blocks
|
||||
"model.audio_tower.layers", "a.blk",
|
||||
"model.audio_tower.conformer", "a.blk",
|
||||
|
||||
// Audio conformer attention
|
||||
"self_attn.relative_k_proj", "linear_pos",
|
||||
"self_attn.per_dim_scale", "per_dim_scale",
|
||||
"self_attn.q_proj", "attn_q",
|
||||
"self_attn.k_proj", "attn_k",
|
||||
"self_attn.v_proj", "attn_v",
|
||||
"norm_post_attn", "ln2",
|
||||
"norm_pre_attn", "ln1",
|
||||
"self_attn.post", "attn_out",
|
||||
"attention.attn.relative_position_embedding.pos_proj", "linear_pos",
|
||||
"attention.attn.per_dim_key_scale", "per_dim_k_scale",
|
||||
"attention.attn.per_dim_scale", "per_dim_scale",
|
||||
"attention.attn.q_proj", "attn_q",
|
||||
"attention.attn.k_proj", "attn_k",
|
||||
"attention.attn.v_proj", "attn_v",
|
||||
"attention.pre_attn_norm", "ln1",
|
||||
"attention.post_norm", "ln2",
|
||||
"attention.post", "attn_out",
|
||||
|
||||
// Audio conformer feedforward
|
||||
"feed_forward1.pre_layer_norm", "ffn_norm",
|
||||
"feed_forward1.post_layer_norm", "ffn_post_norm",
|
||||
"feed_forward1.ffw_layer_1", "ffn_up",
|
||||
"feed_forward1.ffw_layer_2", "ffn_down",
|
||||
"feed_forward2.pre_layer_norm", "ffn_norm_1",
|
||||
"feed_forward2.post_layer_norm", "ffn_post_norm_1",
|
||||
"feed_forward2.ffw_layer_1", "ffn_up_1",
|
||||
"feed_forward2.ffw_layer_2", "ffn_down_1",
|
||||
"ffw_layer_start.pre_layer_norm", "ffn_norm",
|
||||
"ffw_layer_start.post_layer_norm", "ffn_post_norm",
|
||||
"ffw_layer_start.ffw_layer_1", "ffn_up",
|
||||
"ffw_layer_start.ffw_layer_2", "ffn_down",
|
||||
"ffw_layer_end.pre_layer_norm", "ffn_norm_1",
|
||||
"ffw_layer_end.post_layer_norm", "ffn_post_norm_1",
|
||||
"ffw_layer_end.ffw_layer_1", "ffn_up_1",
|
||||
"ffw_layer_end.ffw_layer_2", "ffn_down_1",
|
||||
|
||||
// Audio conformer lightweight conv1d
|
||||
"lconv1d.depthwise_conv1d", "conv_dw",
|
||||
@@ -548,6 +494,8 @@ func (p *gemma4Model) Replacements() []string {
|
||||
"model.vision_tower.encoder.layers", "v.blk",
|
||||
"model.vision_tower.patch_embedder.input_proj", "v.patch_embd",
|
||||
"model.vision_tower.patch_embedder.position_embedding_table", "v.position_embd.weight",
|
||||
"model.vision_tower.std_bias", "v.std_bias",
|
||||
"model.vision_tower.std_scale", "v.std_scale",
|
||||
|
||||
// Vision multimodal projector
|
||||
"model.embed_vision.embedding_projection", "mm.input_projection",
|
||||
@@ -588,9 +536,19 @@ func (p *gemma4Model) Replacements() []string {
|
||||
// MoE
|
||||
"router.proj", "ffn_gate_inp",
|
||||
"router.scale", "ffn_gate_inp.scale",
|
||||
"router.per_expert_scale", "ffn_gate_inp.per_expert_scale",
|
||||
"experts.gate_up_proj", "ffn_gate_exps.weight",
|
||||
"router.per_expert_scale.weight", "ffn_down_exps.scale",
|
||||
"router.per_expert_scale", "ffn_down_exps.scale",
|
||||
"experts.gate_up_proj.weight", "ffn_gate_up_exps.weight",
|
||||
"experts.gate_up_proj", "ffn_gate_up_exps.weight",
|
||||
"experts.down_proj.weight", "ffn_down_exps.weight",
|
||||
"experts.down_proj", "ffn_down_exps.weight",
|
||||
"moe.gate_proj", "ffn_gate_exps.weight",
|
||||
"moe.up_proj", "ffn_up_exps.weight",
|
||||
"moe.gate_up_proj.weight", "ffn_gate_up_exps.weight",
|
||||
"moe.gate_up_proj", "ffn_gate_up_exps.weight",
|
||||
"moe.down_proj", "ffn_down_exps.weight",
|
||||
"moe.per_expert_scale.weight", "ffn_down_exps.scale",
|
||||
"moe.per_expert_scale", "ffn_down_exps.scale",
|
||||
|
||||
// Layer scalar
|
||||
"layer_scalar", "layer_output_scale.weight",
|
||||
|
||||
@@ -194,6 +194,16 @@ func TestGemma4AudioReplacements(t *testing.T) {
|
||||
"model.vision_tower.encoder.layers.0.self_attn.q_proj.linear.weight",
|
||||
"v.blk.0.attn_q.weight",
|
||||
},
|
||||
{
|
||||
"vision std bias",
|
||||
"model.vision_tower.std_bias",
|
||||
"v.std_bias",
|
||||
},
|
||||
{
|
||||
"vision std scale",
|
||||
"model.vision_tower.std_scale",
|
||||
"v.std_scale",
|
||||
},
|
||||
{
|
||||
"vision patch embd",
|
||||
"model.vision_tower.patch_embedder.input_proj.weight",
|
||||
@@ -216,6 +226,31 @@ func TestGemma4AudioReplacements(t *testing.T) {
|
||||
"model.language_model.embed_tokens.weight",
|
||||
"token_embd.weight",
|
||||
},
|
||||
{
|
||||
"text moe gate up fused",
|
||||
"model.language_model.layers.0.experts.gate_up_proj",
|
||||
"blk.0.ffn_gate_up_exps.weight",
|
||||
},
|
||||
{
|
||||
"text moe down",
|
||||
"model.language_model.layers.0.experts.down_proj",
|
||||
"blk.0.ffn_down_exps.weight",
|
||||
},
|
||||
{
|
||||
"text moe down with weight suffix",
|
||||
"model.language_model.layers.0.experts.down_proj.weight",
|
||||
"blk.0.ffn_down_exps.weight",
|
||||
},
|
||||
{
|
||||
"text moe per expert scale",
|
||||
"model.language_model.layers.0.router.per_expert_scale",
|
||||
"blk.0.ffn_down_exps.scale",
|
||||
},
|
||||
{
|
||||
"text moe per expert scale with weight suffix",
|
||||
"model.language_model.layers.0.router.per_expert_scale.weight",
|
||||
"blk.0.ffn_down_exps.scale",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
|
||||
@@ -47,6 +47,12 @@ type Validator interface {
|
||||
Validate() error
|
||||
}
|
||||
|
||||
// PostLoader is an optional interface that models can implement to run
|
||||
// initialization steps after backend weights have been loaded.
|
||||
type PostLoader interface {
|
||||
PostLoad() error
|
||||
}
|
||||
|
||||
// MultimodalProcessor must be implemented by multimodal models.
|
||||
type MultimodalProcessor interface {
|
||||
// EncodeMultimodal processes a single input (such as an image) and
|
||||
|
||||
@@ -26,7 +26,7 @@ type Model struct {
|
||||
*TextModel
|
||||
*AudioModel `gguf:"a"`
|
||||
|
||||
*MultiModalProjector `gguf:"mm"`
|
||||
*MultiModalProjector `gguf:"mm"`
|
||||
*AudioMultimodalProjector `gguf:"mm.a"`
|
||||
|
||||
ImageProcessor
|
||||
@@ -97,22 +97,25 @@ func New(c fs.Config) (model.Model, error) {
|
||||
slog.Info("gemma4: token IDs", "image", imageTokenID, "image_end", imageEndTokenID, "audio", audioTokenID, "audio_end", audioEndTokenID)
|
||||
|
||||
m := Model{
|
||||
Tokenizer: t,
|
||||
TextModel: newTextModel(c),
|
||||
VisionModel: newVisionModel(c),
|
||||
AudioModel: newAudioModel(c),
|
||||
MultiModalProjector: &MultiModalProjector{},
|
||||
Tokenizer: t,
|
||||
TextModel: newTextModel(c),
|
||||
VisionModel: newVisionModel(c),
|
||||
AudioModel: newAudioModel(c),
|
||||
MultiModalProjector: &MultiModalProjector{},
|
||||
AudioMultimodalProjector: &AudioMultimodalProjector{},
|
||||
ImageProcessor: newImageProcessor(c),
|
||||
imageTokenID: imageTokenID,
|
||||
imageEndTokenID: imageEndTokenID,
|
||||
audioTokenID: audioTokenID,
|
||||
audioEndTokenID: audioEndTokenID,
|
||||
audioOpts: newAudioModelOptions(c),
|
||||
ImageProcessor: newImageProcessor(c),
|
||||
imageTokenID: imageTokenID,
|
||||
imageEndTokenID: imageEndTokenID,
|
||||
audioTokenID: audioTokenID,
|
||||
audioEndTokenID: audioEndTokenID,
|
||||
audioOpts: newAudioModelOptions(c),
|
||||
}
|
||||
|
||||
slidingWindowLen := int32(c.Uint("attention.sliding_window"))
|
||||
m.Cache = kvcache.NewWrapperCache(kvcache.NewSWACache(slidingWindowLen, m.Shift), kvcache.NewCausalCache(m.Shift))
|
||||
m.Cache = kvcache.NewWrapperCache(
|
||||
kvcache.NewSWAMemCache(slidingWindowLen, 4096, m.Shift),
|
||||
kvcache.NewCausalCache(m.Shift),
|
||||
)
|
||||
|
||||
return &m, nil
|
||||
}
|
||||
@@ -127,9 +130,6 @@ func (m *Model) EncodeMultimodal(ctx ml.Context, multimodalData []byte) ([]input
|
||||
return nil, model.ErrNoVisionModel
|
||||
}
|
||||
|
||||
// Initialize clamp values from model tensors (lazy, once, after model is fully loaded)
|
||||
m.VisionModel.InitClamp(m.MultiModalProjector)
|
||||
|
||||
t0 := time.Now()
|
||||
img, _, err := image.Decode(bytes.NewReader(multimodalData))
|
||||
if err != nil {
|
||||
@@ -152,12 +152,17 @@ func (m *Model) EncodeMultimodal(ctx ml.Context, multimodalData []byte) ([]input
|
||||
slog.Info("vision: patches", "patchesX", numPatchesX, "patchesY", numPatchesY, "total", numPatchesX*numPatchesY, "patchSize", m.ImageProcessor.patchSize)
|
||||
|
||||
visionOutputs := m.VisionModel.Forward(ctx, pixelValues, numPatchesX, numPatchesY)
|
||||
visionOutputs = visionPoolAndProject(ctx, visionOutputs, numPatchesX, numPatchesY, m.VisionModel.VisionModelOptions, m.MultiModalProjector)
|
||||
visionOutputs = visionPoolAndProject(ctx, visionOutputs, numPatchesX, numPatchesY, m.VisionModel.VisionModelOptions, m.MultiModalProjector, m.VisionModel.StdBias, m.VisionModel.StdScale)
|
||||
slog.Info("vision: encoded", "elapsed", time.Since(t0), "shape", visionOutputs.Shape())
|
||||
|
||||
return []input.Multimodal{{Tensor: visionOutputs}}, nil
|
||||
}
|
||||
|
||||
func (m *Model) PostLoad() error {
|
||||
m.VisionModel.InitClamp(m.MultiModalProjector)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *Model) encodeAudioMultimodal(ctx ml.Context, data []byte) ([]input.Multimodal, error) {
|
||||
if m.AudioModel == nil || m.audioOpts == nil {
|
||||
return nil, model.ErrNoVisionModel
|
||||
|
||||
@@ -318,9 +318,8 @@ func (mlp *TextMLP) Forward(ctx ml.Context, hiddenState ml.Tensor) ml.Tensor {
|
||||
|
||||
// TextRouter implements the Gemma 4 MoE router.
|
||||
type TextRouter struct {
|
||||
Proj *nn.Linear `gguf:"ffn_gate_inp"`
|
||||
Scale ml.Tensor `gguf:"ffn_gate_inp.scale"`
|
||||
PerExpertScale ml.Tensor `gguf:"ffn_gate_inp.per_expert_scale"`
|
||||
Proj *nn.Linear `gguf:"ffn_gate_inp"`
|
||||
Scale ml.Tensor `gguf:"ffn_gate_inp.scale"`
|
||||
}
|
||||
|
||||
func (r *TextRouter) Forward(ctx ml.Context, hiddenState ml.Tensor, opts *TextOptions) (routingWeights, selectedExperts ml.Tensor) {
|
||||
@@ -341,63 +340,46 @@ func (r *TextRouter) Forward(ctx ml.Context, hiddenState ml.Tensor, opts *TextOp
|
||||
|
||||
// TextMoEBlock implements the Gemma 4 sparse MoE.
|
||||
type TextMoEBlock struct {
|
||||
Gate *nn.LinearBatch `gguf:"ffn_gate_exps"`
|
||||
Up *nn.LinearBatch `gguf:"ffn_up_exps"`
|
||||
Down *nn.LinearBatch `gguf:"ffn_down_exps"`
|
||||
GateUp *nn.LinearBatch `gguf:"ffn_gate_up_exps"`
|
||||
Gate *nn.LinearBatch `gguf:"ffn_gate_exps"`
|
||||
Up *nn.LinearBatch `gguf:"ffn_up_exps"`
|
||||
Down *nn.LinearBatch `gguf:"ffn_down_exps"`
|
||||
DownScale ml.Tensor `gguf:"ffn_down_exps.scale,alt:ffn_gate_inp.per_expert_scale"`
|
||||
}
|
||||
|
||||
func (moe *TextMoEBlock) Forward(ctx ml.Context, hiddenState, routingWeights, selectedExperts ml.Tensor, perExpertScale ml.Tensor, opts *TextOptions) ml.Tensor {
|
||||
batchSize := hiddenState.Dim(1)
|
||||
rw3d := routingWeights.Reshape(ctx, 1, opts.numExperts, batchSize)
|
||||
|
||||
// Gather per-expert scales for the selected experts (before we consume routingWeights).
|
||||
// Multiply scale [numExperts] into the full softmax [numExperts, batchSize],
|
||||
// then use Rows to select — gives us [numExpertsUsed, batchSize] of (softmax*scale).
|
||||
// After we independently renorm the unscaled weights, we divide to recover just the scales.
|
||||
// But that's fragile. Instead: just Mul scale into softmax, Rows both, renorm unscaled,
|
||||
// then Mul the ratio.
|
||||
//
|
||||
// Simpler: Mul per_expert_scale onto the softmax ONCE to create a scaled copy.
|
||||
// Rows both. Renorm the unscaled. Divide scaled/unscaled_pre_renorm to get scale[k].
|
||||
// Mul scale[k] * renormed.
|
||||
//
|
||||
// Actually simplest: just select, renorm, then Mul the scale factors.
|
||||
// To get scale factors for selected experts, Rows on a [1, numExperts, batchSize]
|
||||
// tensor where every column is the same scale. Build it via Mul: softmax * scale
|
||||
// gives us the scaled version; softmax gives unscaled. Their ratio = scale[k].
|
||||
// After Rows, ratio = (softmax[k]*scale[k]) / softmax[k] = scale[k].
|
||||
|
||||
func (moe *TextMoEBlock) Forward(ctx ml.Context, hiddenState, routingWeights, selectedExperts ml.Tensor, opts *TextOptions) ml.Tensor {
|
||||
// Select routing weights for chosen experts and renormalize
|
||||
routingWeights = rw3d.Rows(ctx, selectedExperts)
|
||||
routingWeights = routingWeights.Reshape(ctx, opts.numExpertsUsed, batchSize)
|
||||
routingWeights = routingWeights.Reshape(ctx, 1, opts.numExperts, hiddenState.Dim(1)).Rows(ctx, selectedExperts)
|
||||
routingWeights = routingWeights.Reshape(ctx, opts.numExpertsUsed, hiddenState.Dim(1))
|
||||
routingWeights = routingWeights.Div(ctx, routingWeights.SumRows(ctx))
|
||||
|
||||
// Apply per-expert scale after renormalization
|
||||
if perExpertScale != nil {
|
||||
// Build [numExperts, batchSize] with scale broadcast, select with Rows
|
||||
scaledSoftmax := rw3d.Reshape(ctx, opts.numExperts, batchSize).Mul(ctx, perExpertScale)
|
||||
scaledSoftmax = scaledSoftmax.Reshape(ctx, 1, opts.numExperts, batchSize).Rows(ctx, selectedExperts)
|
||||
scaledSoftmax = scaledSoftmax.Reshape(ctx, opts.numExpertsUsed, batchSize)
|
||||
|
||||
// Recover unscaled selected weights (before renorm) for the ratio
|
||||
unscaled := rw3d.Rows(ctx, selectedExperts)
|
||||
unscaled = unscaled.Reshape(ctx, opts.numExpertsUsed, batchSize)
|
||||
|
||||
// scale[k] = scaledSoftmax[k] / unscaled[k]
|
||||
expertScales := scaledSoftmax.Div(ctx, unscaled)
|
||||
routingWeights = routingWeights.Mul(ctx, expertScales)
|
||||
}
|
||||
|
||||
routingWeights = routingWeights.Reshape(ctx, 1, opts.numExpertsUsed, batchSize)
|
||||
routingWeights = routingWeights.Reshape(ctx, 1, opts.numExpertsUsed, hiddenState.Dim(1))
|
||||
|
||||
hiddenState = hiddenState.Reshape(ctx, hiddenState.Dim(0), 1, hiddenState.Dim(1))
|
||||
|
||||
// Expert computation using LinearBatch (MulmatID selecting experts by index)
|
||||
gateOut := moe.Gate.Forward(ctx, hiddenState, selectedExperts)
|
||||
upOut := moe.Up.Forward(ctx, hiddenState, selectedExperts)
|
||||
var gateOut, upOut ml.Tensor
|
||||
if moe.GateUp != nil && moe.GateUp.Weight != nil {
|
||||
gateUp := moe.GateUp.Forward(ctx, hiddenState, selectedExperts)
|
||||
nFF := gateUp.Dim(0) / 2
|
||||
gateOut = gateUp.Slice(ctx, 0, 0, nFF, 1)
|
||||
upOut = gateUp.Slice(ctx, 0, nFF, gateUp.Dim(0), 1)
|
||||
} else {
|
||||
gateOut = moe.Gate.Forward(ctx, hiddenState, selectedExperts)
|
||||
upOut = moe.Up.Forward(ctx, hiddenState, selectedExperts)
|
||||
}
|
||||
hiddenState = gateOut.GELU(ctx, upOut)
|
||||
experts := moe.Down.Forward(ctx, hiddenState, selectedExperts)
|
||||
|
||||
// Apply per-expert down projection scale when present.
|
||||
if moe.DownScale != nil {
|
||||
expertScales := moe.DownScale.Reshape(ctx, opts.numExperts, 1)
|
||||
expertScales = expertScales.Repeat(ctx, 1, hiddenState.Dim(2))
|
||||
expertScales = expertScales.Reshape(ctx, 1, opts.numExperts, hiddenState.Dim(2)).Rows(ctx, selectedExperts)
|
||||
expertScales = expertScales.Reshape(ctx, opts.numExpertsUsed, hiddenState.Dim(2))
|
||||
expertScales = expertScales.Reshape(ctx, 1, opts.numExpertsUsed, hiddenState.Dim(2))
|
||||
experts = experts.Mul(ctx, expertScales)
|
||||
}
|
||||
|
||||
// Apply routing weights
|
||||
experts = experts.Mul(ctx, routingWeights)
|
||||
|
||||
@@ -450,7 +432,9 @@ func (l *TextLayer) Forward(ctx ml.Context, layer int, hiddenState, positions, p
|
||||
residual = hiddenState
|
||||
|
||||
// MLP (+ optional MoE in parallel)
|
||||
if l.Router != nil && l.MoE != nil && l.MoE.Gate != nil && l.MoE.Gate.Weight != nil {
|
||||
hasSplitExperts := l.MoE != nil && l.MoE.Gate != nil && l.MoE.Up != nil && l.MoE.Gate.Weight != nil && l.MoE.Up.Weight != nil
|
||||
hasFusedExperts := l.MoE != nil && l.MoE.GateUp != nil && l.MoE.GateUp.Weight != nil
|
||||
if l.Router != nil && l.MoE != nil && l.MoE.Down != nil && l.MoE.Down.Weight != nil && (hasSplitExperts || hasFusedExperts) {
|
||||
// MoE layers: run MLP and MoE in parallel, sum results
|
||||
mlpState := l.MLPNorm.Forward(ctx, hiddenState, opts.eps)
|
||||
mlpState = l.MLP.Forward(ctx, mlpState)
|
||||
@@ -458,7 +442,7 @@ func (l *TextLayer) Forward(ctx ml.Context, layer int, hiddenState, positions, p
|
||||
|
||||
routingWeights, selectedExperts := l.Router.Forward(ctx, hiddenState, opts)
|
||||
moeState := l.MoENorm.Forward(ctx, hiddenState, opts.eps)
|
||||
moeState = l.MoE.Forward(ctx, moeState, routingWeights, selectedExperts, l.Router.PerExpertScale, opts)
|
||||
moeState = l.MoE.Forward(ctx, moeState, routingWeights, selectedExperts, opts)
|
||||
moeState = l.PostMoENorm.Forward(ctx, moeState, opts.eps)
|
||||
|
||||
// Combine MLP + MoE, apply outer post-FFN norm, then add residual
|
||||
|
||||
@@ -1,6 +1,8 @@
|
||||
package gemma4
|
||||
|
||||
import (
|
||||
"math"
|
||||
|
||||
"github.com/ollama/ollama/fs"
|
||||
"github.com/ollama/ollama/ml"
|
||||
"github.com/ollama/ollama/ml/nn"
|
||||
@@ -11,33 +13,116 @@ const batchSize = 1
|
||||
|
||||
// ClippableLinear is a linear layer with optional input/output clamping.
|
||||
// Required by Gemma4 vision encoder for numerical stability with F16 weights.
|
||||
// Clamp values are populated by VisionModel.InitClamp from the packed v.clamp_data tensor.
|
||||
type ClippableLinear struct {
|
||||
Weight ml.Tensor `gguf:"weight"`
|
||||
|
||||
InputMin ml.Tensor `gguf:"input_min"`
|
||||
InputMax ml.Tensor `gguf:"input_max"`
|
||||
OutputMin ml.Tensor `gguf:"output_min"`
|
||||
OutputMax ml.Tensor `gguf:"output_max"`
|
||||
|
||||
inMin, inMax, outMin, outMax float32
|
||||
hasClamp bool
|
||||
clampsLoaded bool
|
||||
}
|
||||
|
||||
func scalarValue(t ml.Tensor) (float32, bool) {
|
||||
if t == nil {
|
||||
return 0, false
|
||||
}
|
||||
|
||||
data := t.BackendGet()
|
||||
if len(data) == 0 {
|
||||
return 0, false
|
||||
}
|
||||
|
||||
return data[0], true
|
||||
}
|
||||
|
||||
func (l *ClippableLinear) loadClampFromScalars() {
|
||||
if l.clampsLoaded {
|
||||
return
|
||||
}
|
||||
l.clampsLoaded = true
|
||||
|
||||
const (
|
||||
defaultMin = -math.MaxFloat32
|
||||
defaultMax = math.MaxFloat32
|
||||
)
|
||||
|
||||
inMin, hasInMin := scalarValue(l.InputMin)
|
||||
inMax, hasInMax := scalarValue(l.InputMax)
|
||||
outMin, hasOutMin := scalarValue(l.OutputMin)
|
||||
outMax, hasOutMax := scalarValue(l.OutputMax)
|
||||
|
||||
if !(hasInMin || hasInMax || hasOutMin || hasOutMax) {
|
||||
return
|
||||
}
|
||||
|
||||
l.hasClamp = true
|
||||
l.inMin = defaultMin
|
||||
l.inMax = defaultMax
|
||||
l.outMin = defaultMin
|
||||
l.outMax = defaultMax
|
||||
|
||||
if hasInMin {
|
||||
l.inMin = inMin
|
||||
}
|
||||
if hasInMax {
|
||||
l.inMax = inMax
|
||||
}
|
||||
if hasOutMin {
|
||||
l.outMin = outMin
|
||||
}
|
||||
if hasOutMax {
|
||||
l.outMax = outMax
|
||||
}
|
||||
}
|
||||
|
||||
func (l *ClippableLinear) Forward(ctx ml.Context, x ml.Tensor) ml.Tensor {
|
||||
if l.inMax != 0 {
|
||||
if l.hasClamp {
|
||||
x = x.Clamp(ctx, l.inMin, l.inMax)
|
||||
}
|
||||
out := l.Weight.Mulmat(ctx, x)
|
||||
if l.outMax != 0 {
|
||||
if l.hasClamp {
|
||||
out = out.Clamp(ctx, l.outMin, l.outMax)
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
// InitClamp distributes packed clamp values from v.clamp_data to ClippableLinear structs.
|
||||
// If scalar clamp tensors (input_min/max, output_min/max) are present, they are used too.
|
||||
// Layout: numLayers × 7 linears (q,k,v,out,gate,up,down) × 4 floats (inMin,inMax,outMin,outMax)
|
||||
// then 4 floats for the projector.
|
||||
func (m *VisionModel) InitClamp(proj *MultiModalProjector) {
|
||||
if m.clampInitDone || m.ClampData == nil {
|
||||
if m.clampInitDone {
|
||||
return
|
||||
}
|
||||
m.clampInitDone = true
|
||||
|
||||
linears := func(l *VisionEncoderLayer) []*ClippableLinear {
|
||||
return []*ClippableLinear{
|
||||
l.SelfAttention.Query, l.SelfAttention.Key, l.SelfAttention.Value,
|
||||
l.SelfAttention.Output, l.MLP.Gate, l.MLP.Up, l.MLP.Down,
|
||||
}
|
||||
}
|
||||
|
||||
for i := range m.Layers {
|
||||
for _, cl := range linears(&m.Layers[i]) {
|
||||
if cl != nil {
|
||||
cl.loadClampFromScalars()
|
||||
}
|
||||
}
|
||||
}
|
||||
if proj != nil && proj.Projection != nil {
|
||||
proj.Projection.loadClampFromScalars()
|
||||
}
|
||||
|
||||
// Load packed clamp data when present (legacy Ollama format).
|
||||
if m.ClampData == nil {
|
||||
return
|
||||
}
|
||||
|
||||
// Read all clamp values from packed F32 tensor
|
||||
data := m.ClampData.BackendGet()
|
||||
if len(data) == 0 {
|
||||
@@ -45,32 +130,31 @@ func (m *VisionModel) InitClamp(proj *MultiModalProjector) {
|
||||
}
|
||||
|
||||
// Distribute to layer linears: 7 per layer × 4 values each
|
||||
linears := func(l *VisionEncoderLayer) []*ClippableLinear {
|
||||
return []*ClippableLinear{
|
||||
l.SelfAttention.Query, l.SelfAttention.Key, l.SelfAttention.Value,
|
||||
l.SelfAttention.Output, l.MLP.Gate, l.MLP.Up, l.MLP.Down,
|
||||
}
|
||||
}
|
||||
for i := range m.Layers {
|
||||
for li, cl := range linears(&m.Layers[i]) {
|
||||
if cl == nil {
|
||||
continue
|
||||
}
|
||||
idx := (i*7 + li) * 4
|
||||
if idx+3 < len(data) {
|
||||
cl.inMin = data[idx]
|
||||
cl.inMax = data[idx+1]
|
||||
cl.outMin = data[idx+2]
|
||||
cl.outMax = data[idx+3]
|
||||
cl.hasClamp = true
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Projector clamp values (last 4 floats)
|
||||
if proj != nil {
|
||||
if proj != nil && proj.Projection != nil {
|
||||
projIdx := len(m.Layers) * 7 * 4
|
||||
if projIdx+3 < len(data) {
|
||||
proj.Projection.inMin = data[projIdx]
|
||||
proj.Projection.inMax = data[projIdx+1]
|
||||
proj.Projection.outMin = data[projIdx+2]
|
||||
proj.Projection.outMax = data[projIdx+3]
|
||||
proj.Projection.hasClamp = true
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -199,6 +283,8 @@ type VisionModel struct {
|
||||
PatchEmbedding *nn.Conv2D `gguf:"patch_embd"`
|
||||
PositionEmbedding ml.Tensor `gguf:"position_embd.weight"`
|
||||
ClampData ml.Tensor `gguf:"clamp_data"`
|
||||
StdBias ml.Tensor `gguf:"std_bias"`
|
||||
StdScale ml.Tensor `gguf:"std_scale"`
|
||||
|
||||
Layers []VisionEncoderLayer `gguf:"blk"`
|
||||
|
||||
@@ -274,7 +360,7 @@ func visionTokenCount(imageWidth, imageHeight, patchSize, nMerge int) int {
|
||||
return mergedX * mergedY
|
||||
}
|
||||
|
||||
func visionPoolAndProject(ctx ml.Context, hiddenState ml.Tensor, numPatchesX, numPatchesY int, opts *VisionModelOptions, proj *MultiModalProjector) ml.Tensor {
|
||||
func visionPoolAndProject(ctx ml.Context, hiddenState ml.Tensor, numPatchesX, numPatchesY int, opts *VisionModelOptions, proj *MultiModalProjector, stdBias, stdScale ml.Tensor) ml.Tensor {
|
||||
hiddenSize := opts.hiddenSize
|
||||
|
||||
// Reshape from [hiddenSize, numPatches] to spatial layout for pooling
|
||||
@@ -290,12 +376,14 @@ func visionPoolAndProject(ctx ml.Context, hiddenState ml.Tensor, numPatchesX, nu
|
||||
hiddenState = hiddenState.Reshape(ctx, mergedX*mergedY, hiddenSize)
|
||||
hiddenState = hiddenState.Permute(ctx, 1, 0, 2, 3).Contiguous(ctx)
|
||||
|
||||
// Ensure F32 for the projection Mulmat. The Metal mul_mm kernel for F16×F32
|
||||
// casts F32 activations to F16 in shared memory, so values must stay within
|
||||
// F16 range (≤65504). The sqrt(hiddenSize) scaling from the HF reference is
|
||||
// omitted because it's normalized out by the unweighted RMSNorm that follows
|
||||
// the projection: RMSNorm(√d·x) = x/rms(x) = RMSNorm(x).
|
||||
hiddenState = hiddenState.Cast(ctx, ml.DTypeF32)
|
||||
hiddenState = hiddenState.Scale(ctx, math.Sqrt(float64(hiddenSize)))
|
||||
|
||||
// Optional vision standardization before projection.
|
||||
if stdBias != nil && stdScale != nil {
|
||||
hiddenState = hiddenState.Sub(ctx, stdBias)
|
||||
hiddenState = hiddenState.Mul(ctx, stdScale)
|
||||
}
|
||||
|
||||
// Project to text embedding dimension
|
||||
hiddenState = proj.Forward(ctx, hiddenState, opts.eps)
|
||||
|
||||
@@ -344,49 +344,62 @@ func parseGemma4ToolCall(content string) (api.ToolCall, error) {
|
||||
}
|
||||
|
||||
// gemma4ArgsToJSON converts Gemma 4's custom argument format to valid JSON.
|
||||
// The format uses <|"|> for string delimiters and bare identifier keys.
|
||||
// Example: {location:<|"|>Paris<|"|>,count:42} → {"location":"Paris","count":42}
|
||||
func gemma4ArgsToJSON(s string) string {
|
||||
// Step 1: Replace <|"|> with "
|
||||
s = strings.ReplaceAll(s, `<|"|>`, `"`)
|
||||
|
||||
// Step 2: Quote bare keys (identifiers followed by : that aren't inside strings)
|
||||
var buf strings.Builder
|
||||
buf.Grow(len(s) + 32)
|
||||
inString := false
|
||||
hex := "0123456789abcdef"
|
||||
i := 0
|
||||
for i < len(s) {
|
||||
ch := s[i]
|
||||
if ch == '"' && !inString {
|
||||
inString = true
|
||||
buf.WriteByte(ch)
|
||||
|
||||
if ch == '"' {
|
||||
inString = !inString
|
||||
buf.WriteByte('"')
|
||||
i++
|
||||
// Write until closing quote
|
||||
for i < len(s) {
|
||||
buf.WriteByte(s[i])
|
||||
if s[i] == '"' {
|
||||
inString = false
|
||||
i++
|
||||
break
|
||||
}
|
||||
i++
|
||||
}
|
||||
continue
|
||||
}
|
||||
|
||||
if inString {
|
||||
switch ch {
|
||||
case '\\':
|
||||
buf.WriteString(`\\`)
|
||||
case '\n':
|
||||
buf.WriteString(`\n`)
|
||||
case '\r':
|
||||
buf.WriteString(`\r`)
|
||||
case '\t':
|
||||
buf.WriteString(`\t`)
|
||||
case '\b':
|
||||
buf.WriteString(`\b`)
|
||||
case '\f':
|
||||
buf.WriteString(`\f`)
|
||||
default:
|
||||
if ch < 0x20 {
|
||||
buf.WriteString(`\u00`)
|
||||
buf.WriteByte(hex[ch>>4])
|
||||
buf.WriteByte(hex[ch&0x0f])
|
||||
} else {
|
||||
buf.WriteByte(ch)
|
||||
}
|
||||
}
|
||||
i++
|
||||
continue
|
||||
}
|
||||
|
||||
if !inString && isIdentStart(ch) {
|
||||
// Read the full identifier
|
||||
j := i + 1
|
||||
for j < len(s) && isIdentPart(s[j]) {
|
||||
j++
|
||||
}
|
||||
word := s[i:j]
|
||||
if j < len(s) && s[j] == ':' {
|
||||
// It's an object key — quote it
|
||||
buf.WriteByte('"')
|
||||
buf.WriteString(word)
|
||||
buf.WriteByte('"')
|
||||
} else {
|
||||
// It's a bare value (true, false, null, etc.)
|
||||
buf.WriteString(word)
|
||||
}
|
||||
i = j
|
||||
|
||||
@@ -133,6 +133,21 @@ func TestGemma4Parser(t *testing.T) {
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "tool_call_with_multiline_string_arg",
|
||||
input: `<|tool_call>call:bash{command:<|"|>date
|
||||
<|"|>}<tool_call|>`,
|
||||
expectedToolCalls: []api.ToolCall{
|
||||
{
|
||||
Function: api.ToolCallFunction{
|
||||
Name: "bash",
|
||||
Arguments: testArgs(map[string]any{
|
||||
"command": "date\n",
|
||||
}),
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "multiple_tool_calls",
|
||||
input: `<|tool_call>call:get_weather{location:<|"|>Paris<|"|>}<tool_call|><|tool_call>call:get_weather{location:<|"|>London<|"|>}<tool_call|>`,
|
||||
@@ -410,6 +425,12 @@ func TestGemma4ArgsToJSON(t *testing.T) {
|
||||
input: `{value:null}`,
|
||||
expected: `{"value":null}`,
|
||||
},
|
||||
{
|
||||
name: "multiline_string_value",
|
||||
input: `{command:<|"|>date
|
||||
<|"|>}`,
|
||||
expected: `{"command":"date\n"}`,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
|
||||
@@ -1258,6 +1258,12 @@ func (s *Server) loadModel() {
|
||||
panic(fmt.Errorf("failed to load model: %v", err))
|
||||
}
|
||||
|
||||
if postLoader, ok := s.model.(model.PostLoader); ok {
|
||||
if err := postLoader.PostLoad(); err != nil {
|
||||
panic(fmt.Errorf("failed to finalize model initialization: %v", err))
|
||||
}
|
||||
}
|
||||
|
||||
s.status = llm.ServerStatusReady
|
||||
s.ready.Done()
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user