models: add nemotronh architecture support (#14356)

This commit is contained in:
Jeffrey Morgan
2026-02-22 15:09:14 -08:00
committed by GitHub
parent 06edabdde1
commit 0ade9205cc
22 changed files with 3196 additions and 4 deletions

View File

@@ -257,10 +257,11 @@ func LoadModelMetadata(fsys fs.FS) (ModelKV, *Tokenizer, error) {
if err != nil {
return nil, nil, err
}
bts = sanitizeNonFiniteJSON(bts)
var p ModelParameters
if err := json.Unmarshal(bts, &p); err != nil {
return nil, nil, err
return nil, nil, fmt.Errorf("parse config.json: %w", err)
}
if len(p.Architectures) < 1 {
@@ -319,12 +320,14 @@ func LoadModelMetadata(fsys fs.FS) (ModelKV, *Tokenizer, error) {
conv = &lfm2Model{}
case "Qwen3NextForCausalLM":
conv = &qwen3NextModel{}
case "NemotronHForCausalLM":
conv = &nemotronHModel{}
default:
return nil, nil, fmt.Errorf("unsupported architecture %q", p.Architectures[0])
}
if err := json.Unmarshal(bts, conv); err != nil {
return nil, nil, err
return nil, nil, fmt.Errorf("parse config.json for %q: %w", p.Architectures[0], err)
}
if t, ok := conv.(moreParser); ok {

View File

@@ -0,0 +1,385 @@
package convert
import (
"cmp"
"encoding/json"
"fmt"
"io/fs"
"math"
"slices"
"strings"
"github.com/ollama/ollama/fs/ggml"
)
type hybridPattern string
func (p *hybridPattern) UnmarshalJSON(data []byte) error {
if string(data) == "null" {
*p = ""
return nil
}
var single string
if err := json.Unmarshal(data, &single); err == nil {
*p = hybridPattern(strings.TrimSpace(single))
return nil
}
var parts []string
if err := json.Unmarshal(data, &parts); err == nil {
*p = hybridPattern(strings.Join(parts, ""))
return nil
}
return fmt.Errorf("hybrid_override_pattern must be a string or string array")
}
type nemotronHModel struct {
ModelParameters
MaxPositionEmbeddings uint32 `json:"max_position_embeddings"`
HiddenSize uint32 `json:"hidden_size"`
NumHiddenLayers uint32 `json:"num_hidden_layers"`
NumAttentionHeads uint32 `json:"num_attention_heads"`
NumKeyValueHeads uint32 `json:"num_key_value_heads"`
HeadDim uint32 `json:"head_dim"`
LayerNormEpsilon float32 `json:"layer_norm_epsilon"`
NormEpsilon float32 `json:"norm_eps"`
RopeTheta float32 `json:"rope_theta"`
PartialRotaryFactor float32 `json:"partial_rotary_factor"`
ConvKernel uint32 `json:"conv_kernel"`
SSMStateSize uint32 `json:"ssm_state_size"`
MambaNumHeads uint32 `json:"mamba_num_heads"`
MambaHeadDim uint32 `json:"mamba_head_dim"`
NGroups uint32 `json:"n_groups"`
IntermediateSize uint32 `json:"intermediate_size"`
HybridOverridePattern hybridPattern `json:"hybrid_override_pattern"`
// MoE
NumExperts uint32 `json:"num_experts"`
NumSharedExperts uint32 `json:"num_shared_experts"`
NRoutedExperts uint32 `json:"n_routed_experts"`
NSharedExperts uint32 `json:"n_shared_experts"`
NumExpertsPerTok uint32 `json:"num_experts_per_tok"`
MoEIntermediateSize uint32 `json:"moe_intermediate_size"`
MoESharedExpertIntermediate uint32 `json:"moe_shared_expert_intermediate_size"`
NormTopKProb bool `json:"norm_topk_prob"`
RoutedScalingFactor float32 `json:"routed_scaling_factor"`
ExpertGroupCount uint32 `json:"n_group"`
ExpertGroupUsedCount uint32 `json:"topk_group"`
}
var _ ModelConverter = (*nemotronHModel)(nil)
func (n *nemotronHModel) parseMore(_ fs.FS) error {
if n.NumHiddenLayers == 0 {
return fmt.Errorf("nemotron_h: num_hidden_layers must be set")
}
if n.HiddenSize == 0 {
return fmt.Errorf("nemotron_h: hidden_size must be set")
}
if n.NumAttentionHeads == 0 {
return fmt.Errorf("nemotron_h: num_attention_heads must be set")
}
if n.HeadDim == 0 {
if n.HiddenSize%n.NumAttentionHeads != 0 {
return fmt.Errorf("nemotron_h: hidden_size (%d) must be divisible by num_attention_heads (%d)", n.HiddenSize, n.NumAttentionHeads)
}
n.HeadDim = n.HiddenSize / n.NumAttentionHeads
}
if n.NumKeyValueHeads == 0 {
n.NumKeyValueHeads = n.NumAttentionHeads
}
if n.ConvKernel == 0 {
return fmt.Errorf("nemotron_h: conv_kernel must be set")
}
if n.SSMStateSize == 0 {
return fmt.Errorf("nemotron_h: ssm_state_size must be set")
}
if n.ssmHeadCount() == 0 {
return fmt.Errorf("nemotron_h: mamba_num_heads must be set")
}
if n.MambaHeadDim == 0 {
return fmt.Errorf("nemotron_h: mamba_head_dim must be set")
}
if n.NGroups == 0 {
n.NGroups = 1
}
if _, _, err := n.layerArrays(); err != nil {
return err
}
if n.isMoE() {
if n.routedExpertCount() == 0 {
return fmt.Errorf("nemotron_h: routed expert count must be set for MoE models")
}
if n.NumExpertsPerTok == 0 {
return fmt.Errorf("nemotron_h: num_experts_per_tok must be set for MoE models")
}
if n.NumExpertsPerTok > n.routedExpertCount() {
return fmt.Errorf("nemotron_h: num_experts_per_tok (%d) cannot exceed expert_count (%d)", n.NumExpertsPerTok, n.routedExpertCount())
}
if n.moeIntermediateSize() == 0 {
return fmt.Errorf("nemotron_h: moe_intermediate_size must be set for MoE models")
}
}
return nil
}
func (n *nemotronHModel) isMoE() bool {
return cmp.Or(n.routedExpertCount(), n.NumExpertsPerTok, n.MoEIntermediateSize) > 0
}
func (n *nemotronHModel) routedExpertCount() uint32 {
return cmp.Or(n.NRoutedExperts, n.NumExperts)
}
func (n *nemotronHModel) sharedExpertCount() uint32 {
return cmp.Or(n.NSharedExperts, n.NumSharedExperts)
}
func (n *nemotronHModel) ssmHeadCount() uint32 {
return n.MambaNumHeads
}
func (n *nemotronHModel) ssmInnerSize() uint32 {
return n.MambaHeadDim * n.ssmHeadCount()
}
func (n *nemotronHModel) epsilon() float32 {
return cmp.Or(n.NormEpsilon, n.LayerNormEpsilon, float32(1e-5))
}
func (n *nemotronHModel) moeIntermediateSize() uint32 {
return cmp.Or(n.MoEIntermediateSize, n.IntermediateSize)
}
func (n *nemotronHModel) denseIntermediateSize() uint32 {
return cmp.Or(n.IntermediateSize, n.MoEIntermediateSize)
}
func (n *nemotronHModel) layerArrays() (headCountKV []uint32, ffnLengths []uint32, err error) {
pattern := strings.TrimSpace(string(n.HybridOverridePattern))
if pattern == "" {
return nil, nil, fmt.Errorf("nemotron_h: hybrid_override_pattern must be set")
}
runes := []rune(pattern)
if len(runes) != int(n.NumHiddenLayers) {
return nil, nil, fmt.Errorf("nemotron_h: hybrid_override_pattern length (%d) must match num_hidden_layers (%d)", len(runes), n.NumHiddenLayers)
}
headCountKV = make([]uint32, n.NumHiddenLayers)
ffnLengths = make([]uint32, n.NumHiddenLayers)
attnKVHeads := cmp.Or(n.NumKeyValueHeads, n.NumAttentionHeads)
moeFFN := n.moeIntermediateSize()
denseFFN := n.denseIntermediateSize()
for i, layerType := range runes {
switch layerType {
case 'M':
// Recurrent layer: no KV heads and no FFN.
case '*', 'A':
// Attention-only layer.
headCountKV[i] = attnKVHeads
case 'E':
// MoE layer.
if moeFFN == 0 {
return nil, nil, fmt.Errorf("nemotron_h: moe layer at index %d but moe_intermediate_size is zero", i)
}
ffnLengths[i] = moeFFN
case '-':
// Dense FFN layer.
if denseFFN == 0 {
return nil, nil, fmt.Errorf("nemotron_h: dense FFN layer at index %d but intermediate_size is zero", i)
}
ffnLengths[i] = denseFFN
default:
return nil, nil, fmt.Errorf("nemotron_h: unsupported layer type %q in hybrid_override_pattern at index %d", layerType, i)
}
}
return headCountKV, ffnLengths, nil
}
func (n *nemotronHModel) KV(t *Tokenizer) KV {
kv := n.ModelParameters.KV(t)
arch := "nemotron_h"
if n.isMoE() {
arch = "nemotron_h_moe"
}
kv["general.architecture"] = arch
kv["block_count"] = n.NumHiddenLayers
kv["context_length"] = n.MaxPositionEmbeddings
kv["embedding_length"] = n.HiddenSize
kv["attention.head_count"] = n.NumAttentionHeads
kv["attention.key_length"] = n.HeadDim
kv["attention.value_length"] = n.HeadDim
kv["attention.layer_norm_epsilon"] = n.epsilon()
kv["attention.layer_norm_rms_epsilon"] = n.epsilon()
kv["rope.freq_base"] = cmp.Or(n.RopeTheta, float32(10000))
if n.PartialRotaryFactor > 0 && n.PartialRotaryFactor <= 1 {
kv["rope.dimension_count"] = uint32(float32(n.HeadDim) * n.PartialRotaryFactor)
}
if headCountKV, ffnLengths, err := n.layerArrays(); err == nil {
kv["attention.head_count_kv"] = headCountKV
kv["feed_forward_length"] = ffnLengths
}
kv["ssm.conv_kernel"] = n.ConvKernel
kv["ssm.inner_size"] = n.ssmInnerSize()
kv["ssm.state_size"] = n.SSMStateSize
kv["ssm.group_count"] = n.NGroups
kv["ssm.time_step_rank"] = n.ssmHeadCount()
if n.isMoE() {
kv["expert_count"] = n.routedExpertCount()
kv["expert_used_count"] = n.NumExpertsPerTok
kv["expert_feed_forward_length"] = n.moeIntermediateSize()
if n.sharedExpertCount() > 0 {
kv["expert_shared_count"] = n.sharedExpertCount()
}
if n.MoESharedExpertIntermediate > 0 {
kv["expert_shared_feed_forward_length"] = n.MoESharedExpertIntermediate
}
kv["expert_weights_norm"] = n.NormTopKProb
kv["expert_weights_scale"] = n.RoutedScalingFactor
if n.ExpertGroupCount > 0 {
kv["expert_group_count"] = n.ExpertGroupCount
}
if n.ExpertGroupUsedCount > 0 {
kv["expert_group_used_count"] = n.ExpertGroupUsedCount
}
}
return kv
}
func normalizeVectorShapeToColumn(shape []uint64) []uint64 {
switch len(shape) {
case 1:
return []uint64{shape[0], 1}
case 2:
if shape[0] == 1 && shape[1] > 1 {
return []uint64{shape[1], 1}
}
if shape[1] == 1 && shape[0] > 1 {
return []uint64{shape[0], 1}
}
}
return slices.Clone(shape)
}
func (n *nemotronHModel) Tensors(ts []Tensor) []*ggml.Tensor {
var out []*ggml.Tensor
remaining := ts
if n.isMoE() {
merges := make([]merge, 0, n.NumHiddenLayers*2)
for i := range n.NumHiddenLayers {
merges = append(merges, merge{
fmt.Sprintf("blk.%d.mixer.experts.*.up_proj.weight", i),
fmt.Sprintf("blk.%d.ffn_up_exps.weight", i),
}, merge{
fmt.Sprintf("blk.%d.mixer.experts.*.down_proj.weight", i),
fmt.Sprintf("blk.%d.ffn_down_exps.weight", i),
})
}
merged, rest := mergeTensors(ts, merges...)
out = append(out, merged...)
remaining = rest
}
nGroups := uint64(cmp.Or(n.NGroups, uint32(1)))
for _, t := range remaining {
name := t.Name()
shape := slices.Clone(t.Shape())
switch {
case strings.HasSuffix(name, ".ssm_a"):
shape = normalizeVectorShapeToColumn(shape)
t.SetRepacker(func(_ string, data []float32, _ []uint64) ([]float32, error) {
out := make([]float32, len(data))
for i, v := range data {
out[i] = -float32(math.Exp(float64(v)))
}
return out, nil
})
case strings.HasSuffix(name, ".ssm_d"):
shape = normalizeVectorShapeToColumn(shape)
case strings.HasSuffix(name, ".ssm_norm.weight"):
switch len(shape) {
case 1:
if nGroups > 0 && shape[0]%nGroups == 0 {
shape = []uint64{nGroups, shape[0] / nGroups}
}
case 2:
if shape[0] == 1 && nGroups > 0 && shape[1]%nGroups == 0 {
shape = []uint64{nGroups, shape[1] / nGroups}
}
}
case strings.HasSuffix(name, ".ssm_conv1d.weight"):
if len(shape) == 3 {
if shape[0] == 1 {
shape = []uint64{shape[1], shape[2]}
} else if shape[1] == 1 {
shape = []uint64{shape[0], shape[2]}
}
}
}
out = append(out, &ggml.Tensor{
Name: name,
Kind: t.Kind(),
Shape: shape,
WriterTo: t,
})
}
return out
}
func (n *nemotronHModel) Replacements() []string {
return []string{
// Embedding and output
"lm_head", "output",
"backbone.embeddings", "token_embd",
"backbone.norm_f", "output_norm",
"backbone.layers", "blk",
// Recurrent (Mamba2) tensors
"mixer.in_proj", "ssm_in",
"mixer.out_proj", "ssm_out",
"mixer.dt_bias", "ssm_dt.bias",
"mixer.A_log", "ssm_a",
"mixer.D", "ssm_d",
"mixer.conv1d", "ssm_conv1d",
"mixer.norm.weight", "ssm_norm.weight",
// Attention tensors
"mixer.q_proj", "attn_q",
"mixer.k_proj", "attn_k",
"mixer.v_proj", "attn_v",
"mixer.o_proj", "attn_output",
// FFN / MoE tensors
"mixer.gate.e_score_correction_bias", "exp_probs_b.bias",
"mixer.gate", "ffn_gate_inp",
"mixer.fc1_latent_proj", "ffn_latent_in",
"mixer.fc2_latent_proj", "ffn_latent_out",
"mixer.shared_experts.up_proj", "ffn_up_shexp",
"mixer.shared_experts.down_proj", "ffn_down_shexp",
"mixer.up_proj", "ffn_up",
"mixer.down_proj", "ffn_down",
// Per-layer pre-norm
".norm.weight", ".attn_norm.weight",
}
}

View File

@@ -0,0 +1,230 @@
package convert
import (
"bytes"
"encoding/binary"
"encoding/json"
"io"
"os"
"path/filepath"
"slices"
"strings"
"testing"
)
func TestHybridPatternUnmarshal(t *testing.T) {
t.Run("string", func(t *testing.T) {
var p hybridPattern
if err := json.Unmarshal([]byte(`"MEM*"`), &p); err != nil {
t.Fatal(err)
}
if got, want := string(p), "MEM*"; got != want {
t.Fatalf("unexpected pattern: got %q want %q", got, want)
}
})
t.Run("array", func(t *testing.T) {
var p hybridPattern
if err := json.Unmarshal([]byte(`["M","E","M","*"]`), &p); err != nil {
t.Fatal(err)
}
if got, want := string(p), "MEM*"; got != want {
t.Fatalf("unexpected pattern: got %q want %q", got, want)
}
})
}
func TestNemotronHLayerArrays(t *testing.T) {
m := &nemotronHModel{
NumHiddenLayers: 5,
NumAttentionHeads: 32,
NumKeyValueHeads: 8,
HybridOverridePattern: "MEM*E",
NRoutedExperts: 128,
NumExpertsPerTok: 6,
MoEIntermediateSize: 1856,
}
headsKV, ffn, err := m.layerArrays()
if err != nil {
t.Fatal(err)
}
if got, want := headsKV, []uint32{0, 0, 0, 8, 0}; !slices.Equal(got, want) {
t.Fatalf("unexpected head_count_kv: got %v want %v", got, want)
}
if got, want := ffn, []uint32{0, 1856, 0, 0, 1856}; !slices.Equal(got, want) {
t.Fatalf("unexpected feed_forward_length: got %v want %v", got, want)
}
}
func TestNemotronHKV(t *testing.T) {
m := &nemotronHModel{
MaxPositionEmbeddings: 1048576,
HiddenSize: 2688,
NumHiddenLayers: 5,
NumAttentionHeads: 32,
NumKeyValueHeads: 2,
HeadDim: 128,
LayerNormEpsilon: 1e-5,
RopeTheta: 10000,
PartialRotaryFactor: 0.5,
ConvKernel: 4,
SSMStateSize: 128,
MambaNumHeads: 64,
MambaHeadDim: 64,
NGroups: 8,
HybridOverridePattern: "MEM*E",
NRoutedExperts: 128,
NSharedExperts: 1,
NumExpertsPerTok: 6,
MoEIntermediateSize: 1856,
MoESharedExpertIntermediate: 3712,
NormTopKProb: true,
RoutedScalingFactor: 2.5,
}
if err := m.parseMore(nil); err != nil {
t.Fatal(err)
}
kv := m.KV(&Tokenizer{Vocabulary: &Vocabulary{}})
if got, want := kv["general.architecture"], "nemotron_h_moe"; got != want {
t.Fatalf("unexpected architecture: got %v want %v", got, want)
}
headCountKV, ok := kv["attention.head_count_kv"].([]uint32)
if !ok {
t.Fatalf("attention.head_count_kv has unexpected type: %T", kv["attention.head_count_kv"])
}
if got, want := headCountKV, []uint32{0, 0, 0, 2, 0}; !slices.Equal(got, want) {
t.Fatalf("unexpected attention.head_count_kv: got %v want %v", got, want)
}
ffnLength, ok := kv["feed_forward_length"].([]uint32)
if !ok {
t.Fatalf("feed_forward_length has unexpected type: %T", kv["feed_forward_length"])
}
if got, want := ffnLength, []uint32{0, 1856, 0, 0, 1856}; !slices.Equal(got, want) {
t.Fatalf("unexpected feed_forward_length: got %v want %v", got, want)
}
}
func TestNemotronHTensorsTransforms(t *testing.T) {
m := &nemotronHModel{NGroups: 8}
in := []Tensor{
&fakeTensor{
name: "blk.0.ssm_a",
shape: []uint64{4},
data: []float32{0, 1, 2, 3},
},
&fakeTensor{
name: "blk.0.ssm_d",
shape: []uint64{4},
data: []float32{0, 1, 2, 3},
},
&fakeTensor{
name: "blk.0.ssm_norm.weight",
shape: []uint64{16},
data: make([]float32, 16),
},
&fakeTensor{
name: "blk.0.ssm_conv1d.weight",
shape: []uint64{10, 1, 4},
data: make([]float32, 40),
},
}
out := m.Tensors(in)
if len(out) != len(in) {
t.Fatalf("unexpected output tensor count: got %d want %d", len(out), len(in))
}
got := map[string]struct {
shape []uint64
writer io.WriterTo
}{}
for _, t := range out {
got[t.Name] = struct {
shape []uint64
writer io.WriterTo
}{shape: t.Shape, writer: t.WriterTo}
}
if shape := got["blk.0.ssm_a"].shape; !slices.Equal(shape, []uint64{4, 1}) {
t.Fatalf("unexpected ssm_a shape: %v", shape)
}
if shape := got["blk.0.ssm_d"].shape; !slices.Equal(shape, []uint64{4, 1}) {
t.Fatalf("unexpected ssm_d shape: %v", shape)
}
if shape := got["blk.0.ssm_norm.weight"].shape; !slices.Equal(shape, []uint64{8, 2}) {
t.Fatalf("unexpected ssm_norm shape: %v", shape)
}
if shape := got["blk.0.ssm_conv1d.weight"].shape; !slices.Equal(shape, []uint64{10, 4}) {
t.Fatalf("unexpected ssm_conv1d shape: %v", shape)
}
var b bytes.Buffer
if _, err := got["blk.0.ssm_a"].writer.WriteTo(&b); err != nil {
t.Fatal(err)
}
values := make([]float32, 4)
if err := binary.Read(&b, binary.LittleEndian, &values); err != nil {
t.Fatal(err)
}
// 0 -> -exp(0) == -1
if values[0] != -1 {
t.Fatalf("unexpected transformed ssm_a[0]: got %v want -1", values[0])
}
}
func TestNemotronHLoadModelMetadata(t *testing.T) {
tempDir := t.TempDir()
config := `{
"architectures": ["NemotronHForCausalLM"],
"model_type": "nemotron_h",
"num_hidden_layers": 4,
"hidden_size": 512,
"max_position_embeddings": 32768,
"num_attention_heads": 8,
"num_key_value_heads": 2,
"head_dim": 64,
"layer_norm_epsilon": 1e-5,
"conv_kernel": 4,
"ssm_state_size": 128,
"mamba_num_heads": 16,
"mamba_head_dim": 32,
"n_groups": 8,
"hybrid_override_pattern": "ME*M",
"n_routed_experts": 16,
"num_experts_per_tok": 4,
"moe_intermediate_size": 256
}`
if err := os.WriteFile(filepath.Join(tempDir, "config.json"), []byte(config), 0o644); err != nil {
t.Fatal(err)
}
if err := os.WriteFile(filepath.Join(tempDir, "tokenizer.json"), []byte(`{}`), 0o644); err != nil {
t.Fatal(err)
}
kv, _, err := LoadModelMetadata(os.DirFS(tempDir))
if err != nil {
t.Fatal(err)
}
if _, ok := kv.(*nemotronHModel); !ok {
t.Fatalf("unexpected converter type: %T", kv)
}
}
func TestNemotronHReplacementsLatentProjections(t *testing.T) {
m := &nemotronHModel{}
r := strings.NewReplacer(m.Replacements()...)
if got, want := r.Replace("backbone.layers.1.mixer.fc1_latent_proj.weight"), "blk.1.ffn_latent_in.weight"; got != want {
t.Fatalf("unexpected fc1 replacement: got %q want %q", got, want)
}
if got, want := r.Replace("backbone.layers.1.mixer.fc2_latent_proj.weight"), "blk.1.ffn_latent_out.weight"; got != want {
t.Fatalf("unexpected fc2 replacement: got %q want %q", got, want)
}
}

97
convert/json_compat.go Normal file
View File

@@ -0,0 +1,97 @@
package convert
// sanitizeNonFiniteJSON rewrites non-standard JSON numeric tokens that some
// HF configs emit (Infinity, -Infinity, NaN) into standard JSON numbers.
//
// This is intentionally conservative:
// - only runs outside quoted strings
// - only rewrites full tokens
//
// We map these values to 0 because encoding/json rejects non-finite values,
// and these fields are typically model-side metadata not consumed by the
// converter.
func sanitizeNonFiniteJSON(in []byte) []byte {
if len(in) == 0 {
return in
}
out := make([]byte, 0, len(in))
inString := false
escape := false
for i := 0; i < len(in); {
c := in[i]
if inString {
out = append(out, c)
if escape {
escape = false
} else if c == '\\' {
escape = true
} else if c == '"' {
inString = false
}
i++
continue
}
if c == '"' {
inString = true
out = append(out, c)
i++
continue
}
if hasToken(in, i, "-Infinity") {
out = append(out, '0')
i += len("-Infinity")
continue
}
if hasToken(in, i, "Infinity") {
out = append(out, '0')
i += len("Infinity")
continue
}
if hasToken(in, i, "NaN") {
out = append(out, '0')
i += len("NaN")
continue
}
out = append(out, c)
i++
}
return out
}
func hasToken(in []byte, at int, tok string) bool {
end := at + len(tok)
if at < 0 || end > len(in) {
return false
}
if string(in[at:end]) != tok {
return false
}
if at > 0 && !isJSONValuePrefixBoundary(in[at-1]) {
return false
}
if end < len(in) && !isJSONValueSuffixBoundary(in[end]) {
return false
}
return true
}
func isJSONWhitespace(b byte) bool {
return b == ' ' || b == '\t' || b == '\n' || b == '\r'
}
func isJSONValuePrefixBoundary(b byte) bool {
return isJSONWhitespace(b) || b == ':' || b == ',' || b == '['
}
func isJSONValueSuffixBoundary(b byte) bool {
return isJSONWhitespace(b) || b == ',' || b == ']' || b == '}'
}

View File

@@ -0,0 +1,46 @@
package convert
import "testing"
func TestSanitizeNonFiniteJSON(t *testing.T) {
tests := []struct {
name string
in string
want string
}{
{
name: "infinity token",
in: `{"a":[0,Infinity,1]}`,
want: `{"a":[0,0,1]}`,
},
{
name: "negative infinity token",
in: `{"a":-Infinity}`,
want: `{"a":0}`,
},
{
name: "nan token",
in: `{"a":NaN}`,
want: `{"a":0}`,
},
{
name: "tokens inside strings untouched",
in: `{"a":"Infinity -Infinity NaN","b":Infinity}`,
want: `{"a":"Infinity -Infinity NaN","b":0}`,
},
{
name: "identifier-like token untouched",
in: `{"a":InfinityValue}`,
want: `{"a":InfinityValue}`,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got := string(sanitizeNonFiniteJSON([]byte(tt.in)))
if got != tt.want {
t.Fatalf("sanitizeNonFiniteJSON() = %q, want %q", got, tt.want)
}
})
}
}

View File

@@ -160,6 +160,27 @@ func (kv KV) SSMGroupCount() uint64 {
return uint64(kv.Uint("ssm.group_count"))
}
func (kv KV) FFNLength() []uint64 {
ffnLengthDefault := uint32(0)
ffnLength := kv.UintOrArrayValueAsArray("feed_forward_length", ffnLengthDefault)
if len(ffnLength) == 1 {
ffnLengthDefault = ffnLength[0]
}
nLayers := int(kv.BlockCount())
if len(ffnLength) > nLayers {
slog.Warn("got more elements of feed_forward_length than layers", "len(ffnLength)", len(ffnLength), "layers", nLayers)
}
out := make([]uint64, nLayers)
for i := range nLayers {
if i >= len(ffnLength) {
out[i] = uint64(ffnLengthDefault)
} else {
out[i] = uint64(ffnLength[i])
}
}
return out
}
// general types
func (kv KV) String(key string, defaultValue ...string) string {
@@ -264,6 +285,7 @@ func (kv KV) OllamaEngineRequired() bool {
"llama4",
"mistral3",
"mllama",
"nemotron_h", "nemotron_h_moe",
"nomic-bert",
"olmo3",
"qwen25vl",
@@ -865,6 +887,7 @@ func (f GGML) FlashAttention() bool {
"gptoss", "gpt-oss",
"lfm2",
"mistral3",
"nemotron_h", "nemotron_h_moe",
"olmo3",
"qwen3", "qwen3moe",
"qwen3next",

752
kvcache/recurrent.go Normal file
View File

@@ -0,0 +1,752 @@
package kvcache
import (
"errors"
"fmt"
"math"
"slices"
"github.com/ollama/ollama/ml"
"github.com/ollama/ollama/model/input"
)
const (
DefaultCheckpointCount = 32
DefaultCheckpointMinPos = int32(16)
DefaultCheckpointInterval = int32(1280)
)
var ErrInvalidRecurrentShape = errors.New("kvcache: invalid recurrent state shape")
// Config configures a shared hybrid recurrent cache.
type RecurrentConfig struct {
Shift func(ctx ml.Context, layer int, key, shift ml.Tensor) (ml.Tensor, error)
ConvDim int
ConvChannels int
RecurrentStateSize int
CheckpointLogPrefix string
}
var (
_ Cache = (*Recurrent)(nil)
_ CheckpointCache = (*Recurrent)(nil)
)
// Cache stores:
// - a standard causal KV cache
// - per-sequence conv state for recurrent operators
// - per-sequence recurrent state for recurrent operators
//
// Conv state shape (per layer, per sequence): [convDim, convChannels]
// Recurrent state shape (per layer, per sequence): [recurrentStateSize]
type Recurrent struct {
kv *Causal
backend ml.Backend
dtype ml.DType
maxSequences int
// Conv state dimensions
convDim int
convChannels int
// Recurrent state dimensions
recurrentStateSize int
logPrefix string
// slot mapping for recurrent state (copy-on-write)
slotForSeq map[int]int
refCount []int
freeSlots []int
seqCounts map[int]int
slotScratch [1]int32
// per-layer conv state buffers (allocated lazily)
convCtxs map[int]ml.Context
convStates map[int]ml.Tensor // [convDim*convChannels, maxSlots]
// per-layer recurrent state buffers (allocated lazily)
recurrentCtxs map[int]ml.Context
recurrentStates map[int]ml.Tensor // [recurrentStateSize, maxSlots]
// recurrent checkpoints (per slot)
checkpointCount int
checkpointMinPos int32
checkpointInterval int32
checkpointCtxSize int
checkpoints map[int]*slotCheckpointStore
pendingRestore map[int]checkpointRestore
curCheckpointPos []int32
curCheckpointSlots map[int]int
reserveCheckpoints bool
checkpointConvCtxs map[int]ml.Context
checkpointRecurCtxs map[int]ml.Context
checkpointReserved map[int]struct{}
// current forward batch (derived in StartForward)
curSeqs []int
curSlots []int
curSlotsInput ml.Tensor
curSeqTokens int
// track if EnsureWritable has been called for this forward pass
writableEnsured bool
writableError error
}
func NewRecurrentCache(config RecurrentConfig) *Recurrent {
return &Recurrent{
kv: NewCausalCache(config.Shift),
convDim: config.ConvDim,
convChannels: config.ConvChannels,
recurrentStateSize: config.RecurrentStateSize,
logPrefix: config.CheckpointLogPrefix,
slotForSeq: make(map[int]int),
seqCounts: make(map[int]int),
convCtxs: make(map[int]ml.Context),
convStates: make(map[int]ml.Tensor),
recurrentCtxs: make(map[int]ml.Context),
recurrentStates: make(map[int]ml.Tensor),
checkpointCount: DefaultCheckpointCount,
checkpointMinPos: DefaultCheckpointMinPos,
checkpointInterval: DefaultCheckpointInterval,
checkpoints: make(map[int]*slotCheckpointStore),
pendingRestore: make(map[int]checkpointRestore),
curCheckpointSlots: make(map[int]int),
checkpointConvCtxs: make(map[int]ml.Context),
checkpointRecurCtxs: make(map[int]ml.Context),
checkpointReserved: make(map[int]struct{}),
}
}
func (c *Recurrent) Init(backend ml.Backend, dtype ml.DType, maxSequences, capacity, maxBatch int) {
c.backend = backend
c.dtype = dtype
c.maxSequences = maxSequences
c.checkpoints = make(map[int]*slotCheckpointStore)
c.pendingRestore = make(map[int]checkpointRestore)
c.curCheckpointPos = c.curCheckpointPos[:0]
c.curCheckpointSlots = make(map[int]int)
c.checkpointReserved = make(map[int]struct{})
c.checkpointCtxSize = c.checkpointCount * c.maxSequences
if c.checkpointCtxSize < 8 {
c.checkpointCtxSize = 8
}
// initialize slot allocator
c.refCount = make([]int, maxSequences)
c.freeSlots = c.freeSlots[:0]
for i := maxSequences - 1; i >= 0; i-- {
c.freeSlots = append(c.freeSlots, i)
}
c.kv.Init(backend, dtype, maxSequences, capacity, maxBatch)
}
func (c *Recurrent) Close() {
for _, ctx := range c.convCtxs {
ctx.Close()
}
for _, ctx := range c.recurrentCtxs {
ctx.Close()
}
for _, ctx := range c.checkpointConvCtxs {
ctx.Close()
}
for _, ctx := range c.checkpointRecurCtxs {
ctx.Close()
}
c.kv.Close()
}
func (c *Recurrent) SetConfig(config ml.CacheConfig) {
c.kv.SetConfig(config)
}
func (c *Recurrent) SetLayer(layer int) {
c.kv.SetLayer(layer)
}
func (c *Recurrent) Get(ctx ml.Context) (ml.Tensor, ml.Tensor, ml.Tensor) {
return c.kv.Get(ctx)
}
func (c *Recurrent) Put(ctx ml.Context, key, value ml.Tensor) {
c.kv.Put(ctx, key, value)
}
func (c *Recurrent) StartForward(ctx ml.Context, batch input.Batch, reserve bool) error {
if err := c.kv.StartForward(ctx, batch, reserve); err != nil {
return err
}
nTokens := len(batch.Sequences)
if nTokens == 0 {
c.curSeqs = c.curSeqs[:0]
c.curSlots = c.curSlots[:0]
c.curSlotsInput = nil
c.curSeqTokens = 0
c.reserveCheckpoints = false
c.writableEnsured = false
c.writableError = nil
return nil
}
// Fast path for single-sequence batches (common during decode and prefill).
firstSeq := batch.Sequences[0]
singleSeq := true
for _, s := range batch.Sequences[1:] {
if s != firstSeq {
singleSeq = false
break
}
}
if singleSeq {
return c.startForwardSingleSeq(ctx, firstSeq, nTokens, batch, reserve)
}
// Derive equal-length sequence layout for recurrent layers.
seqCounts := c.seqCounts
for s := range seqCounts {
delete(seqCounts, s)
}
c.curSeqs = c.curSeqs[:0]
for _, s := range batch.Sequences {
if seqCounts[s] == 0 {
c.curSeqs = append(c.curSeqs, s)
}
seqCounts[s]++
}
nSeqs := len(c.curSeqs)
want := nTokens / nSeqs
for _, s := range c.curSeqs {
if seqCounts[s] != want {
return ErrNotSupported
}
}
c.curSeqTokens = want
if reserve {
c.curSlots = c.curSlots[:0]
for i := range nSeqs {
c.curSlots = append(c.curSlots, i)
}
c.finalizeStartForward(ctx, batch, true)
return nil
}
// Ensure slots exist for sequences in this batch.
c.curSlots = c.curSlots[:0]
var newSlots []int
for _, s := range c.curSeqs {
slot, ok := c.slotForSeq[s]
if !ok {
var err error
slot, err = c.allocSlot()
if err != nil {
return err
}
c.slotForSeq[s] = slot
c.refCount[slot] = 1
newSlots = append(newSlots, slot)
}
c.curSlots = append(c.curSlots, slot)
}
if len(newSlots) > 0 {
c.zeroSlots(ctx, newSlots)
}
c.finalizeStartForward(ctx, batch, false)
return nil
}
func (c *Recurrent) startForwardSingleSeq(ctx ml.Context, seq, seqTokens int, batch input.Batch, reserve bool) error {
c.curSeqs = append(c.curSeqs[:0], seq)
c.curSeqTokens = seqTokens
if reserve {
c.curSlots = append(c.curSlots[:0], 0)
c.finalizeStartForward(ctx, batch, true)
return nil
}
slot, ok := c.slotForSeq[seq]
if !ok {
var err error
slot, err = c.allocSlot()
if err != nil {
return err
}
c.slotForSeq[seq] = slot
c.refCount[slot] = 1
slotList := [1]int{slot}
c.zeroSlots(ctx, slotList[:])
}
c.curSlots = append(c.curSlots[:0], slot)
c.finalizeStartForward(ctx, batch, false)
return nil
}
func (c *Recurrent) finalizeStartForward(ctx ml.Context, batch input.Batch, reserve bool) {
c.setCurSlotsInput(ctx)
c.writableEnsured = false
c.writableError = nil
c.reserveCheckpoints = reserve
c.planCheckpoints(batch)
}
func (c *Recurrent) setCurSlotsInput(ctx ml.Context) {
c.curSlotsInput = c.slotsInput(ctx, c.curSlots)
}
func (c *Recurrent) slotsInput(ctx ml.Context, slots []int) ml.Tensor {
switch len(slots) {
case 0:
return nil
case 1:
c.slotScratch[0] = int32(slots[0])
return ctx.Input().FromInts(c.slotScratch[:], 1)
default:
slotIndices := make([]int32, len(slots))
for i, v := range slots {
slotIndices[i] = int32(v)
}
return ctx.Input().FromInts(slotIndices, len(slotIndices))
}
}
func (c *Recurrent) allocSlot() (int, error) {
if len(c.freeSlots) == 0 {
return 0, ErrKvCacheFull
}
slot := c.freeSlots[len(c.freeSlots)-1]
c.freeSlots = c.freeSlots[:len(c.freeSlots)-1]
return slot, nil
}
func (c *Recurrent) freeSlot(slot int) {
if slot >= 0 && slot < c.maxSequences {
c.freeSlots = append(c.freeSlots, slot)
}
}
// zeroSlots zeros recurrent state for the given slots across all cached layers.
func (c *Recurrent) zeroSlots(ctx ml.Context, slots []int) {
if len(slots) == 0 {
return
}
inputCtx := ctx.Input()
slotsTensor := c.slotsInput(ctx, slots)
if len(c.convStates) > 0 {
zeros := inputCtx.Zeros(ml.DTypeF32, c.convDim*c.convChannels, len(slots))
for _, buf := range c.convStates {
ctx.Forward(buf.SetRows(ctx, zeros, slotsTensor))
}
}
if len(c.recurrentStates) > 0 {
zeros := inputCtx.Zeros(ml.DTypeF32, c.recurrentStateSize, len(slots))
for _, buf := range c.recurrentStates {
ctx.Forward(buf.SetRows(ctx, zeros, slotsTensor))
}
}
}
// EnsureWritable ensures sequences have private slots (copy-on-write).
func (c *Recurrent) EnsureWritable(ctx ml.Context) error {
for i, seq := range c.curSeqs {
slot, ok := c.slotForSeq[seq]
if !ok {
continue
}
if slot < 0 || slot >= len(c.refCount) {
continue
}
if c.refCount[slot] <= 1 {
continue
}
newSlot, err := c.allocSlot()
if err != nil {
return err
}
c.refCount[slot]--
c.refCount[newSlot] = 1
c.slotForSeq[seq] = newSlot
c.curSlots[i] = newSlot
c.copyRecurrentState(ctx, slot, newSlot)
c.copyCheckpoints(ctx, slot, newSlot)
}
c.setCurSlotsInput(ctx)
return nil
}
func (c *Recurrent) copyRecurrentState(ctx ml.Context, srcSlot, dstSlot int) {
src := ctx.Input().FromInts([]int32{int32(srcSlot)}, 1)
dst := ctx.Input().FromInts([]int32{int32(dstSlot)}, 1)
for _, buf := range c.convStates {
rows := buf.Rows(ctx, src)
if rows.DType() != ml.DTypeF32 {
rows = rows.Cast(ctx, ml.DTypeF32)
}
ctx.Forward(buf.SetRows(ctx, rows, dst))
}
for _, buf := range c.recurrentStates {
rows := buf.Rows(ctx, src)
if rows.DType() != ml.DTypeF32 {
rows = rows.Cast(ctx, ml.DTypeF32)
}
ctx.Forward(buf.SetRows(ctx, rows, dst))
}
}
func (c *Recurrent) CopyPrefix(srcSeq, dstSeq int, prefixLen int32) {
c.kv.CopyPrefix(srcSeq, dstSeq, prefixLen)
if dstSlot, ok := c.slotForSeq[dstSeq]; ok {
if c.validSlot(dstSlot) {
c.refCount[dstSlot]--
if c.refCount[dstSlot] <= 0 {
c.refCount[dstSlot] = 0
c.freeSlot(dstSlot)
}
}
delete(c.slotForSeq, dstSeq)
}
srcSlot, ok := c.slotForSeq[srcSeq]
if !ok {
return
}
if c.validSlot(srcSlot) {
c.slotForSeq[dstSeq] = srcSlot
c.refCount[srcSlot]++
}
}
func (c *Recurrent) CanResume(seq int, pos int32) bool {
if !c.kv.CanResume(seq, pos) {
return false
}
if pos == 0 {
return true
}
return c.hasCheckpoint(seq, pos)
}
func (c *Recurrent) Remove(seq int, beginIndex, endIndex int32) error {
if beginIndex > 0 && endIndex != math.MaxInt32 {
if err := c.kv.Remove(seq, beginIndex, endIndex); err != nil {
return err
}
delete(c.pendingRestore, seq)
slot, ok := c.slotForSeq[seq]
if !ok || !c.validSlot(slot) {
return nil
}
// Detach shared recurrent state/checkpoints before mutating checkpoint positions.
if c.refCount[slot] > 1 {
newSlot, err := c.allocSlot()
if err != nil {
return err
}
ctx := c.backend.NewContext()
c.copyRecurrentState(ctx, slot, newSlot)
c.copyCheckpoints(ctx, slot, newSlot)
if len(c.convStates) > 0 || len(c.recurrentStates) > 0 {
ctx.Compute()
}
ctx.Close()
c.refCount[slot]--
c.refCount[newSlot] = 1
c.slotForSeq[seq] = newSlot
slot = newSlot
}
c.shiftCheckpoints(slot, beginIndex, endIndex)
return nil
}
if beginIndex > 0 {
restore, ok := c.pendingRestore[seq]
if !ok || restore.pos+1 != beginIndex {
return ErrNotSupported
}
if !c.restoreComplete(restore) {
return ErrNotSupported
}
if slot, ok := c.slotForSeq[seq]; ok && c.validSlot(slot) && c.refCount[slot] > 1 {
newSlot, err := c.allocSlot()
if err != nil {
return err
}
ctx := c.backend.NewContext()
c.copyRecurrentState(ctx, slot, newSlot)
c.copyCheckpoints(ctx, slot, newSlot)
if len(c.convStates) > 0 || len(c.recurrentStates) > 0 {
ctx.Compute()
}
ctx.Close()
c.refCount[slot]--
c.refCount[newSlot] = 1
c.slotForSeq[seq] = newSlot
restore.slot = newSlot
c.pendingRestore[seq] = restore
}
}
if err := c.kv.Remove(seq, beginIndex, endIndex); err != nil {
return err
}
if beginIndex > 0 {
restore := c.pendingRestore[seq]
delete(c.pendingRestore, seq)
return c.applyCheckpointRestore(restore)
}
slot, ok := c.slotForSeq[seq]
delete(c.pendingRestore, seq)
if !ok {
return nil
}
if !c.validSlot(slot) {
delete(c.slotForSeq, seq)
return nil
}
c.refCount[slot]--
if c.refCount[slot] <= 0 {
c.refCount[slot] = 0
c.clearCheckpoints(slot)
c.freeSlot(slot)
}
delete(c.slotForSeq, seq)
return nil
}
func (c *Recurrent) validSlot(slot int) bool {
return slot >= 0 && slot < len(c.refCount)
}
func (c *Recurrent) SlotsTensor() ml.Tensor {
return c.curSlotsInput
}
// contiguousSlots returns the starting slot if current slots are contiguous and ordered.
func (c *Recurrent) contiguousSlots() (int, bool) {
if len(c.curSlots) == 0 {
return 0, false
}
start := c.curSlots[0]
for i, s := range c.curSlots {
if s != start+i {
return 0, false
}
}
return start, true
}
func (c *Recurrent) SeqTokens() int {
return c.curSeqTokens
}
func (c *Recurrent) NumSeqs() int {
return len(c.curSeqs)
}
func (c *Recurrent) convBuffer(layer int) ml.Tensor {
if buf, ok := c.convStates[layer]; ok {
return buf
}
if _, ok := c.convCtxs[layer]; !ok {
c.convCtxs[layer] = c.backend.NewContextSize(1).Layer(layer)
}
buf := c.convCtxs[layer].Zeros(ml.DTypeF32, c.convDim*c.convChannels, c.maxSequences)
c.convStates[layer] = buf
return buf
}
func (c *Recurrent) recurrentBuffer(layer int) ml.Tensor {
if buf, ok := c.recurrentStates[layer]; ok {
return buf
}
if _, ok := c.recurrentCtxs[layer]; !ok {
c.recurrentCtxs[layer] = c.backend.NewContextSize(1).Layer(layer)
}
buf := c.recurrentCtxs[layer].Zeros(ml.DTypeF32, c.recurrentStateSize, c.maxSequences)
c.recurrentStates[layer] = buf
return buf
}
func (c *Recurrent) ensureWritable(ctx ml.Context) error {
c.ensureWritableOnce(ctx)
return c.writableError
}
func (c *Recurrent) currentSlotRows(ctx ml.Context, buf ml.Tensor, rowSize int) ml.Tensor {
if start, ok := c.contiguousSlots(); ok {
offset := start * buf.Stride(1)
return buf.View(ctx, offset, rowSize, buf.Stride(1), c.NumSeqs())
}
return buf.Rows(ctx, c.SlotsTensor())
}
func (c *Recurrent) writeCurrentSlotRows(ctx ml.Context, buf ml.Tensor, rowSize int, src ml.Tensor) {
if start, ok := c.contiguousSlots(); ok {
offset := start * buf.Stride(1)
view := buf.View(ctx, offset, rowSize, buf.Stride(1), c.NumSeqs())
ctx.Forward(src.Copy(ctx, view))
return
}
ctx.Forward(buf.SetRows(ctx, src, c.SlotsTensor()))
}
func (c *Recurrent) ensureWritableOnce(ctx ml.Context) {
if !c.writableEnsured {
needsWritable := false
for _, seq := range c.curSeqs {
slot, ok := c.slotForSeq[seq]
if !ok {
continue
}
if slot >= 0 && slot < len(c.refCount) && c.refCount[slot] > 1 {
needsWritable = true
break
}
}
if needsWritable {
if err := c.EnsureWritable(ctx); err != nil {
c.writableError = err
}
}
c.writableEnsured = true
}
}
// ConvState returns conv state for current batch sequences as [convDim, convChannels, nSeqs].
func (c *Recurrent) ConvState(ctx ml.Context, layer int) (ml.Tensor, error) {
if err := c.ensureWritable(ctx); err != nil {
return nil, err
}
buf := c.convBuffer(layer)
cur := c.currentSlotRows(ctx, buf, c.convDim*c.convChannels)
return cur.Reshape(ctx, c.convDim, c.convChannels, c.NumSeqs()), nil
}
// UpdateConvState writes new conv state for current batch sequences.
func (c *Recurrent) UpdateConvState(ctx ml.Context, layer int, newState ml.Tensor) {
buf := c.convBuffer(layer)
src := newState.Reshape(ctx, c.convDim*c.convChannels, c.NumSeqs())
srcF32 := src
if src.DType() != ml.DTypeF32 {
srcF32 = src.Cast(ctx, ml.DTypeF32)
}
c.writeCurrentSlotRows(ctx, buf, c.convDim*c.convChannels, srcF32)
c.captureConvCheckpoint(ctx, layer, srcF32)
}
// RecurrentState returns recurrent state for current batch sequences with shape [dims..., nSeqs].
func (c *Recurrent) RecurrentState(ctx ml.Context, layer int, dims ...int) (ml.Tensor, error) {
if err := c.ensureWritable(ctx); err != nil {
return nil, err
}
if len(dims) == 0 {
return nil, ErrInvalidRecurrentShape
}
size := 1
for _, d := range dims {
if d <= 0 {
return nil, ErrInvalidRecurrentShape
}
size *= d
}
if size != c.recurrentStateSize {
return nil, fmt.Errorf("%w: got %v (size %d), want size %d", ErrInvalidRecurrentShape, dims, size, c.recurrentStateSize)
}
buf := c.recurrentBuffer(layer)
cur := c.currentSlotRows(ctx, buf, c.recurrentStateSize)
shape := make([]int, 0, len(dims)+1)
shape = append(shape, dims...)
shape = append(shape, c.NumSeqs())
return cur.Reshape(ctx, shape...), nil
}
// RecurrentState4D returns recurrent state as [dim0, dim1, dim2, nSeqs].
func (c *Recurrent) RecurrentState4D(ctx ml.Context, layer int, dim0, dim1, dim2 int) (ml.Tensor, error) {
if err := c.ensureWritable(ctx); err != nil {
return nil, err
}
if dim0 <= 0 || dim1 <= 0 || dim2 <= 0 {
return nil, ErrInvalidRecurrentShape
}
size := dim0 * dim1 * dim2
if size != c.recurrentStateSize {
return nil, fmt.Errorf("%w: got [%d %d %d] (size %d), want size %d", ErrInvalidRecurrentShape, dim0, dim1, dim2, size, c.recurrentStateSize)
}
buf := c.recurrentBuffer(layer)
cur := c.currentSlotRows(ctx, buf, c.recurrentStateSize)
return cur.Reshape(ctx, dim0, dim1, dim2, c.NumSeqs()), nil
}
// UpdateRecurrentState writes new recurrent state for current batch sequences.
func (c *Recurrent) UpdateRecurrentState(ctx ml.Context, layer int, newState ml.Tensor) {
buf := c.recurrentBuffer(layer)
src := newState.Reshape(ctx, c.recurrentStateSize, c.NumSeqs())
srcF32 := src
if src.DType() != ml.DTypeF32 {
srcF32 = src.Cast(ctx, ml.DTypeF32)
}
c.writeCurrentSlotRows(ctx, buf, c.recurrentStateSize, srcF32)
c.captureRecurrentCheckpoint(ctx, layer, srcF32)
}
// IsSupportedForBatch returns true if the current batch layout supports recurrent layers.
func (c *Recurrent) IsSupportedForBatch() bool {
return c.curSeqTokens > 0 && len(c.curSeqs) > 0
}
// Seqs returns the ordered unique sequences for the current forward pass.
func (c *Recurrent) Seqs() []int {
return slices.Clone(c.curSeqs)
}

View File

@@ -0,0 +1,561 @@
package kvcache
import (
"log/slog"
"math"
"github.com/ollama/ollama/ml"
"github.com/ollama/ollama/model/input"
)
// TODO(jmorganca): Add byte-serialized host-RAM checkpoints to reduce GPU
// memory usage while preserving prefix reuse for recurrent state.
type checkpointEntry struct {
pos int32
conv map[int]ml.Tensor
recurrent map[int]ml.Tensor
}
type slotCheckpointStore struct {
entries []checkpointEntry
size int
next int
lastPos int32
}
type checkpointRestore struct {
slot int
idx int
pos int32
}
func newSlotCheckpointStore(n int) *slotCheckpointStore {
entries := make([]checkpointEntry, n)
for i := range entries {
entries[i].pos = -1
}
return &slotCheckpointStore{
entries: entries,
lastPos: -1,
}
}
func (s *slotCheckpointStore) reset() {
s.size = 0
s.next = 0
s.lastPos = -1
for i := range s.entries {
s.entries[i].pos = -1
}
}
func (s *slotCheckpointStore) record(pos int32) int {
if len(s.entries) == 0 {
return -1
}
idx := s.next
s.next = (s.next + 1) % len(s.entries)
if s.size < len(s.entries) {
s.size++
}
s.entries[idx].pos = pos
s.lastPos = pos
return idx
}
func (s *slotCheckpointStore) bestIndex(targetPos int32) (int, int32, bool) {
bestIdx := -1
bestPos := int32(-1)
for i := range s.entries {
pos := s.entries[i].pos
if pos < 0 || pos >= targetPos {
continue
}
if pos > bestPos {
bestPos = pos
bestIdx = i
}
}
if bestIdx < 0 {
return -1, -1, false
}
return bestIdx, bestPos, true
}
func (s *slotCheckpointStore) pruneAfter(pos int32) {
if len(s.entries) == 0 {
s.size = 0
s.next = 0
s.lastPos = -1
return
}
size := 0
next := -1
minPos := int32(math.MaxInt32)
minIdx := 0
for i := range s.entries {
if s.entries[i].pos > pos {
s.entries[i].pos = -1
}
if s.entries[i].pos >= 0 {
size++
if s.entries[i].pos < minPos {
minPos = s.entries[i].pos
minIdx = i
}
} else if next == -1 {
next = i
}
}
s.size = size
if size == 0 {
s.next = 0
s.lastPos = -1
return
}
if next != -1 {
s.next = next
} else {
// Full ring: overwrite the oldest checkpoint next.
s.next = minIdx
}
s.lastPos = pos
}
func (s *slotCheckpointStore) shiftRange(beginIndex, endIndex int32) {
if len(s.entries) == 0 {
s.size = 0
s.next = 0
s.lastPos = -1
return
}
offset := beginIndex - endIndex
size := 0
next := -1
minPos := int32(math.MaxInt32)
maxPos := int32(-1)
minIdx := 0
for i := range s.entries {
pos := s.entries[i].pos
if pos >= 0 {
if pos >= beginIndex && pos < endIndex {
s.entries[i].pos = -1
} else if pos >= endIndex {
s.entries[i].pos = pos + offset
}
}
pos = s.entries[i].pos
if pos >= 0 {
size++
if pos < minPos {
minPos = pos
minIdx = i
}
if pos > maxPos {
maxPos = pos
}
} else if next == -1 {
next = i
}
}
s.size = size
if size == 0 {
s.next = 0
s.lastPos = -1
return
}
if next != -1 {
s.next = next
} else {
// Full ring: overwrite the oldest checkpoint next.
s.next = minIdx
}
s.lastPos = maxPos
}
func (s *slotCheckpointStore) window() (size int, minPos, maxPos, lastPos int32) {
minPos = int32(math.MaxInt32)
maxPos = int32(-1)
for i := range s.entries {
pos := s.entries[i].pos
if pos < 0 {
continue
}
size++
if pos < minPos {
minPos = pos
}
if pos > maxPos {
maxPos = pos
}
}
if size == 0 {
minPos = -1
maxPos = -1
}
return size, minPos, maxPos, s.lastPos
}
func (c *Recurrent) checkpointTag() string {
if c.logPrefix == "" {
return "kvcache.recurrent"
}
return c.logPrefix
}
func (c *Recurrent) planCheckpoints(batch input.Batch) {
if c.checkpointCount == 0 || len(c.curSeqs) == 0 {
c.curCheckpointPos = c.curCheckpointPos[:0]
for k := range c.curCheckpointSlots {
delete(c.curCheckpointSlots, k)
}
return
}
if cap(c.curCheckpointPos) < len(c.curSeqs) {
c.curCheckpointPos = make([]int32, len(c.curSeqs))
} else {
c.curCheckpointPos = c.curCheckpointPos[:len(c.curSeqs)]
}
for i := range c.curCheckpointPos {
c.curCheckpointPos[i] = -1
}
for k := range c.curCheckpointSlots {
delete(c.curCheckpointSlots, k)
}
posMax := make(map[int]int32, len(c.curSeqs))
for i, seq := range batch.Sequences {
pos := batch.Positions[i]
if cur, ok := posMax[seq]; !ok || pos > cur {
posMax[seq] = pos
}
}
for i, seq := range c.curSeqs {
pos, ok := posMax[seq]
if !ok {
continue
}
if pos < c.checkpointMinPos {
continue
}
slot := c.curSlots[i]
store := c.checkpointStore(slot)
lastPos := store.lastPos
if lastPos < 0 || pos-lastPos >= c.checkpointInterval {
c.curCheckpointPos[i] = pos
}
}
}
func (c *Recurrent) checkpointStore(slot int) *slotCheckpointStore {
store, ok := c.checkpoints[slot]
if ok {
return store
}
store = newSlotCheckpointStore(c.checkpointCount)
c.checkpoints[slot] = store
return store
}
func (c *Recurrent) checkpointIndexForSlot(slot int, pos int32) int {
if c.checkpointCount == 0 {
return -1
}
if idx, ok := c.curCheckpointSlots[slot]; ok {
return idx
}
store := c.checkpointStore(slot)
idx := store.record(pos)
if idx >= 0 {
c.curCheckpointSlots[slot] = idx
}
return idx
}
func (c *Recurrent) hasCheckpoint(seq int, pos int32) bool {
if pos <= 0 {
return false
}
slot, ok := c.slotForSeq[seq]
if !ok {
return false
}
store, ok := c.checkpoints[slot]
if !ok {
return false
}
_, _, ok = store.bestIndex(pos)
return ok
}
func (c *Recurrent) PrepareRestore(seq int, targetPos int32) (int32, bool) {
if targetPos <= 0 {
return 0, false
}
slot, ok := c.slotForSeq[seq]
if !ok {
return 0, false
}
store, ok := c.checkpoints[slot]
if !ok {
slog.Debug(c.checkpointTag()+": checkpoint miss", "seq", seq, "slot", slot, "target", targetPos, "size", 0)
return 0, false
}
idx, pos, ok := store.bestIndex(targetPos)
if !ok {
size, minPos, maxPos, lastPos := store.window()
slog.Debug(c.checkpointTag()+": checkpoint miss", "seq", seq, "slot", slot, "target", targetPos, "size", size,
"min", minPos, "max", maxPos, "last", lastPos)
return 0, false
}
c.pendingRestore[seq] = checkpointRestore{
slot: slot,
idx: idx,
pos: pos,
}
return pos + 1, true
}
func (c *Recurrent) applyCheckpointRestore(restore checkpointRestore) error {
entry, ok := c.restoreEntry(restore)
if !ok {
return ErrNotSupported
}
ctx := c.backend.NewContext()
defer ctx.Close()
slotIdx := ctx.Input().FromInts([]int32{int32(restore.slot)}, 1)
for layer, src := range entry.conv {
buf := c.convBuffer(layer)
ctx.Forward(buf.SetRows(ctx, src, slotIdx))
}
for layer, src := range entry.recurrent {
buf := c.recurrentBuffer(layer)
ctx.Forward(buf.SetRows(ctx, src, slotIdx))
}
if len(entry.conv) > 0 || len(entry.recurrent) > 0 {
ctx.Compute()
}
store := c.checkpoints[restore.slot]
store.pruneAfter(restore.pos)
return nil
}
func (c *Recurrent) restoreComplete(restore checkpointRestore) bool {
_, ok := c.restoreEntry(restore)
return ok
}
func (c *Recurrent) restoreEntry(restore checkpointRestore) (*checkpointEntry, bool) {
store, ok := c.checkpoints[restore.slot]
if !ok || restore.idx < 0 || restore.idx >= len(store.entries) {
return nil, false
}
entry := &store.entries[restore.idx]
if entry.pos < 0 {
return nil, false
}
if !c.entryComplete(entry) {
return nil, false
}
return entry, true
}
func (c *Recurrent) entryComplete(entry *checkpointEntry) bool {
for layer := range c.convStates {
if entry.conv == nil || entry.conv[layer] == nil {
return false
}
}
for layer := range c.recurrentStates {
if entry.recurrent == nil || entry.recurrent[layer] == nil {
return false
}
}
return true
}
func (c *Recurrent) clearCheckpoints(slot int) {
if store, ok := c.checkpoints[slot]; ok {
store.reset()
}
}
func (c *Recurrent) shiftCheckpoints(slot int, beginIndex, endIndex int32) {
if store, ok := c.checkpoints[slot]; ok {
store.shiftRange(beginIndex, endIndex)
}
}
func (c *Recurrent) copyCheckpoints(ctx ml.Context, srcSlot, dstSlot int) {
if c.checkpointCount == 0 {
return
}
srcStore, ok := c.checkpoints[srcSlot]
if !ok || srcStore.size == 0 {
return
}
dstStore := c.checkpointStore(dstSlot)
dstStore.size = srcStore.size
dstStore.next = srcStore.next
dstStore.lastPos = srcStore.lastPos
for i := range srcStore.entries {
srcEntry := &srcStore.entries[i]
dstEntry := &dstStore.entries[i]
dstEntry.pos = srcEntry.pos
if srcEntry.conv != nil {
if dstEntry.conv == nil {
dstEntry.conv = make(map[int]ml.Tensor)
}
for layer, src := range srcEntry.conv {
dst := c.ensureCheckpointConv(layer, dstEntry)
ctx.Forward(src.Copy(ctx, dst))
}
}
if srcEntry.recurrent != nil {
if dstEntry.recurrent == nil {
dstEntry.recurrent = make(map[int]ml.Tensor)
}
for layer, src := range srcEntry.recurrent {
dst := c.ensureCheckpointRecurrent(layer, dstEntry)
ctx.Forward(src.Copy(ctx, dst))
}
}
}
}
func (c *Recurrent) captureConvCheckpoint(ctx ml.Context, layer int, src ml.Tensor) {
if c.checkpointCount == 0 {
return
}
if c.reserveCheckpoints {
c.reserveCheckpointConv(layer)
return
}
if len(c.curCheckpointPos) == 0 {
return
}
for i, pos := range c.curCheckpointPos {
if pos < 0 {
continue
}
slot := c.curSlots[i]
idx := c.checkpointIndexForSlot(slot, pos)
if idx < 0 {
continue
}
entry := &c.checkpoints[slot].entries[idx]
dst := c.ensureCheckpointConv(layer, entry)
seqSlice := src.Slice(ctx, 1, i, i+1, 1)
ctx.Forward(seqSlice.Copy(ctx, dst))
}
}
func (c *Recurrent) captureRecurrentCheckpoint(ctx ml.Context, layer int, src ml.Tensor) {
if c.checkpointCount == 0 {
return
}
if c.reserveCheckpoints {
c.reserveCheckpointRecurrent(layer)
return
}
if len(c.curCheckpointPos) == 0 {
return
}
for i, pos := range c.curCheckpointPos {
if pos < 0 {
continue
}
slot := c.curSlots[i]
idx := c.checkpointIndexForSlot(slot, pos)
if idx < 0 {
continue
}
entry := &c.checkpoints[slot].entries[idx]
dst := c.ensureCheckpointRecurrent(layer, entry)
seqSlice := src.Slice(ctx, 1, i, i+1, 1)
ctx.Forward(seqSlice.Copy(ctx, dst))
}
}
func (c *Recurrent) ensureCheckpointConv(layer int, entry *checkpointEntry) ml.Tensor {
if entry.conv == nil {
entry.conv = make(map[int]ml.Tensor)
}
if t, ok := entry.conv[layer]; ok {
return t
}
ctx, ok := c.checkpointConvCtxs[layer]
if !ok {
ctx = c.backend.NewContextSize(c.checkpointCtxSize).Layer(layer)
c.checkpointConvCtxs[layer] = ctx
}
t := ctx.Zeros(ml.DTypeF32, c.convDim*c.convChannels, 1)
entry.conv[layer] = t
return t
}
func (c *Recurrent) ensureCheckpointRecurrent(layer int, entry *checkpointEntry) ml.Tensor {
if entry.recurrent == nil {
entry.recurrent = make(map[int]ml.Tensor)
}
if t, ok := entry.recurrent[layer]; ok {
return t
}
ctx, ok := c.checkpointRecurCtxs[layer]
if !ok {
ctx = c.backend.NewContextSize(c.checkpointCtxSize).Layer(layer)
c.checkpointRecurCtxs[layer] = ctx
}
t := ctx.Zeros(ml.DTypeF32, c.recurrentStateSize, 1)
entry.recurrent[layer] = t
return t
}
func (c *Recurrent) reserveCheckpointConv(layer int) {
key := checkpointReserveKey(layer, 0)
if _, ok := c.checkpointReserved[key]; ok {
return
}
for slot := range c.maxSequences {
store := c.checkpointStore(slot)
for i := range store.entries {
entry := &store.entries[i]
_ = c.ensureCheckpointConv(layer, entry)
}
}
c.checkpointReserved[key] = struct{}{}
}
func (c *Recurrent) reserveCheckpointRecurrent(layer int) {
key := checkpointReserveKey(layer, 1)
if _, ok := c.checkpointReserved[key]; ok {
return
}
for slot := range c.maxSequences {
store := c.checkpointStore(slot)
for i := range store.entries {
entry := &store.entries[i]
_ = c.ensureCheckpointRecurrent(layer, entry)
}
}
c.checkpointReserved[key] = struct{}{}
}
func checkpointReserveKey(layer int, kind int) int {
return layer*2 + kind
}

View File

@@ -0,0 +1,288 @@
package kvcache
import (
"errors"
"math"
"slices"
"testing"
"github.com/ollama/ollama/ml"
)
func newTestCache() *Recurrent {
return NewRecurrentCache(RecurrentConfig{ConvDim: 1, ConvChannels: 2, RecurrentStateSize: 2})
}
func TestSlotCheckpointStoreBestIndex(t *testing.T) {
store := newSlotCheckpointStore(2)
store.record(10)
store.record(20)
_, pos, ok := store.bestIndex(15)
if !ok || pos != 10 {
t.Fatalf("expected best pos 10, got pos=%d ok=%v", pos, ok)
}
store.record(30) // overwrite oldest (10)
if _, _, ok := store.bestIndex(15); ok {
t.Fatalf("expected no checkpoint for targetPos=15 after overwrite")
}
_, pos, ok = store.bestIndex(40)
if !ok || pos != 30 {
t.Fatalf("expected best pos 30, got pos=%d ok=%v", pos, ok)
}
}
func TestCachePrepareRestore(t *testing.T) {
cache := newTestCache()
cache.checkpointCount = 3
cache.checkpoints = make(map[int]*slotCheckpointStore)
cache.pendingRestore = make(map[int]checkpointRestore)
cache.slotForSeq[1] = 0
store := cache.checkpointStore(0)
store.record(5)
store.record(9)
store.record(15)
restorePos, ok := cache.PrepareRestore(1, 12)
if !ok {
t.Fatalf("expected restore ok")
}
if restorePos != 10 {
t.Fatalf("expected restorePos 10, got %d", restorePos)
}
rest, ok := cache.pendingRestore[1]
if !ok {
t.Fatalf("expected pending restore entry")
}
if rest.pos != 9 {
t.Fatalf("expected pending restore pos 9, got %d", rest.pos)
}
}
func TestSlotCheckpointStorePruneAfter(t *testing.T) {
store := newSlotCheckpointStore(3)
store.record(10)
store.record(20)
store.record(30)
store.pruneAfter(20)
if store.lastPos != 20 {
t.Fatalf("expected lastPos 20, got %d", store.lastPos)
}
_, pos, ok := store.bestIndex(25)
if !ok || pos != 20 {
t.Fatalf("expected best pos 20 after prune, got pos=%d ok=%v", pos, ok)
}
_, pos, ok = store.bestIndex(35)
if !ok || pos != 20 {
t.Fatalf("expected pruned best pos 20 for targetPos=35, got pos=%d ok=%v", pos, ok)
}
}
func TestCacheRestoreRejectsIncompleteCheckpoint(t *testing.T) {
cache := newTestCache()
cache.checkpointCount = 3
cache.checkpoints = make(map[int]*slotCheckpointStore)
cache.pendingRestore = make(map[int]checkpointRestore)
cache.slotForSeq[1] = 0
cache.refCount = []int{1}
cache.freeSlots = nil
// Simulate layer 0 requires both conv and recurrent checkpoints.
cache.convStates[0] = nil
cache.recurrentStates[0] = nil
store := cache.checkpointStore(0)
idx := store.record(9)
entry := &store.entries[idx]
entry.conv = map[int]ml.Tensor{0: nil}
// entry.recurrent intentionally missing
cache.pendingRestore[1] = checkpointRestore{slot: 0, idx: idx, pos: 9}
err := cache.Remove(1, 10, math.MaxInt32)
if !errors.Is(err, ErrNotSupported) {
t.Fatalf("expected ErrNotSupported for incomplete checkpoint, got %v", err)
}
}
func TestCacheRestoreAcceptsCompleteCheckpoint(t *testing.T) {
cache := newTestCache()
cache.checkpointCount = 3
cache.checkpoints = make(map[int]*slotCheckpointStore)
cache.pendingRestore = make(map[int]checkpointRestore)
cache.slotForSeq[1] = 0
cache.refCount = []int{1}
cache.freeSlots = nil
store := cache.checkpointStore(0)
idx := store.record(9)
cache.pendingRestore[1] = checkpointRestore{slot: 0, idx: idx, pos: 9}
restore := cache.pendingRestore[1]
if !cache.restoreComplete(restore) {
t.Fatalf("expected restoreComplete to return true for complete checkpoint")
}
}
func TestCacheRecurrentStateShapeValidation(t *testing.T) {
cache := newTestCache()
_, err := cache.RecurrentState(nil, 0, 3)
if !errors.Is(err, ErrInvalidRecurrentShape) {
t.Fatalf("expected ErrInvalidRecurrentShape, got %v", err)
}
}
func TestSlotCheckpointStoreShiftRange(t *testing.T) {
store := newSlotCheckpointStore(5)
store.record(1)
store.record(4)
store.record(7)
store.record(10)
store.shiftRange(2, 6)
var positions []int32
for i := range store.entries {
if store.entries[i].pos >= 0 {
positions = append(positions, store.entries[i].pos)
}
}
slices.Sort(positions)
want := []int32{1, 3, 6}
if !slices.Equal(positions, want) {
t.Fatalf("unexpected shifted positions: got=%v want=%v", positions, want)
}
if store.lastPos != 6 {
t.Fatalf("expected lastPos 6, got %d", store.lastPos)
}
}
func TestCacheRemoveMiddleShiftsCheckpoints(t *testing.T) {
cache := newTestCache()
cache.slotForSeq[1] = 0
cache.refCount = []int{1}
cache.pendingRestore[1] = checkpointRestore{slot: 0, idx: 0, pos: 1}
store := cache.checkpointStore(0)
store.record(1)
store.record(4)
store.record(7)
store.record(10)
if err := cache.Remove(1, 2, 6); err != nil {
t.Fatalf("expected middle remove to succeed, got %v", err)
}
if _, ok := cache.pendingRestore[1]; ok {
t.Fatalf("expected pending restore to be cleared after middle remove")
}
var positions []int32
for i := range store.entries {
if store.entries[i].pos >= 0 {
positions = append(positions, store.entries[i].pos)
}
}
slices.Sort(positions)
want := []int32{1, 3, 6}
if !slices.Equal(positions, want) {
t.Fatalf("unexpected checkpoint positions after remove: got=%v want=%v", positions, want)
}
}
func TestSlotCheckpointStoreRingBufferWrapAround(t *testing.T) {
store := newSlotCheckpointStore(3)
store.record(10)
store.record(20)
store.record(30)
store.entries[0].conv = make(map[int]ml.Tensor)
store.entries[0].conv[0] = nil
store.entries[0].recurrent = make(map[int]ml.Tensor)
store.entries[0].recurrent[0] = nil
store.record(40)
if store.entries[0].conv == nil {
t.Fatalf("expected conv map to be preserved on reuse")
}
if store.entries[0].recurrent == nil {
t.Fatalf("expected recurrent map to be preserved on reuse")
}
if store.entries[0].pos != 40 {
t.Fatalf("expected entry 0 pos to be 40, got %d", store.entries[0].pos)
}
}
func TestSlotCheckpointStoreFullCapacity(t *testing.T) {
store := newSlotCheckpointStore(2)
idx1 := store.record(10)
idx2 := store.record(20)
if idx1 != 0 || idx2 != 1 {
t.Fatalf("expected indices 0, 1, got %d, %d", idx1, idx2)
}
if store.size != 2 {
t.Fatalf("expected size 2, got %d", store.size)
}
_, pos1, ok1 := store.bestIndex(15)
_, pos2, ok2 := store.bestIndex(25)
if !ok1 || pos1 != 10 {
t.Fatalf("expected best pos 10 for target 15, got pos=%d ok=%v", pos1, ok1)
}
if !ok2 || pos2 != 20 {
t.Fatalf("expected best pos 20 for target 25, got pos=%d ok=%v", pos2, ok2)
}
}
func TestSlotCheckpointStoreEmptyBuffer(t *testing.T) {
store := newSlotCheckpointStore(0)
idx := store.record(10)
if idx != -1 {
t.Fatalf("expected record to return -1 for empty buffer, got %d", idx)
}
_, _, ok := store.bestIndex(15)
if ok {
t.Fatalf("expected no checkpoint for empty buffer")
}
}
func TestSlotCheckpointStorePruneAfterAll(t *testing.T) {
store := newSlotCheckpointStore(3)
store.record(10)
store.record(20)
store.record(30)
store.pruneAfter(5)
if store.size != 0 {
t.Fatalf("expected size 0 after pruning all, got %d", store.size)
}
if store.lastPos != -1 {
t.Fatalf("expected lastPos -1 after pruning all, got %d", store.lastPos)
}
_, _, ok := store.bestIndex(100)
if ok {
t.Fatalf("expected no checkpoint after pruning all")
}
}

View File

@@ -0,0 +1,37 @@
From 0000000000000000000000000000000000000000 Mon Sep 17 00:00:00 2001
From: jmorganca <jmorganca@gmail.com>
Date: Sun, 22 Feb 2026 14:12:30 -0800
Subject: [PATCH] ggml-metal: guard mul_mat_id map0 and add ne20=22
specialization
---
ggml/src/ggml-metal/ggml-metal-ops.cpp | 3 ++-
ggml/src/ggml-metal/ggml-metal.metal | 1 +
2 files changed, 3 insertions(+), 1 deletion(-)
diff --git a/ggml/src/ggml-metal/ggml-metal-ops.cpp b/ggml/src/ggml-metal/ggml-metal-ops.cpp
index 4ac135603..ac5ad53db 100644
--- a/ggml/src/ggml-metal/ggml-metal-ops.cpp
+++ b/ggml/src/ggml-metal/ggml-metal-ops.cpp
@@ -1961,7 +1961,8 @@ int ggml_metal_op_mul_mat_id(ggml_metal_op_t ctx, int idx) {
// ne21 = n_rows (batch size)
const int ne21_mm_id_min = 32;
- if (props_dev->has_simdgroup_mm && ne00 >= 64 && (ne21 >= ne21_mm_id_min)) {
+ if (props_dev->has_simdgroup_mm && ne00 >= 64 && (ne21 >= ne21_mm_id_min) &&
+ (ne20 == 1 || ne20 == 2 || ne20 == 4 || ne20 == 6 || ne20 == 8 || ne20 == 10 || ne20 == 16 || ne20 == 22)) {
// some Metal matrix data types require aligned pointers
// ref: https://developer.apple.com/metal/Metal-Shading-Language-Specification.pdf (Table 2.5)
//switch (op->src[0]->type) {
diff --git a/ggml/src/ggml-metal/ggml-metal.metal b/ggml/src/ggml-metal/ggml-metal.metal
index c37447a10..4f338aa13 100644
--- a/ggml/src/ggml-metal/ggml-metal.metal
+++ b/ggml/src/ggml-metal/ggml-metal.metal
@@ -9427,6 +9427,7 @@ template [[host_name("kernel_mul_mm_id_map0_ne20_6" )]] kernel kernel_mul_mm_id_
template [[host_name("kernel_mul_mm_id_map0_ne20_8" )]] kernel kernel_mul_mm_id_map0_t kernel_mul_mm_id_map0<8>;
template [[host_name("kernel_mul_mm_id_map0_ne20_10")]] kernel kernel_mul_mm_id_map0_t kernel_mul_mm_id_map0<10>;
template [[host_name("kernel_mul_mm_id_map0_ne20_16")]] kernel kernel_mul_mm_id_map0_t kernel_mul_mm_id_map0<16>;
+template [[host_name("kernel_mul_mm_id_map0_ne20_22")]] kernel kernel_mul_mm_id_map0_t kernel_mul_mm_id_map0<22>;
template<typename S0, typename S0_4x4, typename S0_8x8, typename S1, typename S1_2x4, typename S1_8x8, typename block_q, short nl, void (*dequantize_func)(device const block_q *, short, thread S0_4x4 &), typename T0, typename T0_4x4, typename T1, typename T1_2x4>
kernel void kernel_mul_mm_id(

View File

@@ -163,6 +163,7 @@ type Tensor interface {
Conv2D(ctx Context, weight Tensor, s0, s1, p0, p1, d0, d1 int) Tensor
Conv3D(ctx Context, weight Tensor, c, s0, s1, s2, p0, p1, p2, d0, d1, d2 int) Tensor
SSMConv(ctx Context, kernel Tensor) Tensor
SSMScan(ctx Context, x, dt, A, B, C, ids Tensor) Tensor
IM2Col(ctx Context, weight Tensor, s0, s1, p0, p1, d0, d1 int) Tensor

View File

@@ -1662,6 +1662,13 @@ func (t *Tensor) SSMConv(ctx ml.Context, kernel ml.Tensor) ml.Tensor {
}
}
func (t *Tensor) SSMScan(ctx ml.Context, x, dt, A, B, C, ids ml.Tensor) ml.Tensor {
return &Tensor{
b: t.b,
t: C.ggml_ssm_scan(ctx.(*Context).ctx, t.t, x.(*Tensor).t, dt.(*Tensor).t, A.(*Tensor).t, B.(*Tensor).t, C.(*Tensor).t, ids.(*Tensor).t),
}
}
func (t *Tensor) AvgPool2D(ctx ml.Context, k, s int, p float32) ml.Tensor {
return &Tensor{
b: t.b,

View File

@@ -12249,6 +12249,7 @@ template [[host_name("kernel_mul_mm_id_map0_ne20_6" )]] kernel kernel_mul_mm_id_
template [[host_name("kernel_mul_mm_id_map0_ne20_8" )]] kernel kernel_mul_mm_id_map0_t kernel_mul_mm_id_map0<8>;
template [[host_name("kernel_mul_mm_id_map0_ne20_10")]] kernel kernel_mul_mm_id_map0_t kernel_mul_mm_id_map0<10>;
template [[host_name("kernel_mul_mm_id_map0_ne20_16")]] kernel kernel_mul_mm_id_map0_t kernel_mul_mm_id_map0<16>;
template [[host_name("kernel_mul_mm_id_map0_ne20_22")]] kernel kernel_mul_mm_id_map0_t kernel_mul_mm_id_map0<22>;
template<typename S0, typename S0_4x4, typename S0_8x8, typename S1, typename S1_2x4, typename S1_8x8, typename block_q, short nl, void (*dequantize_func)(device const block_q *, short, thread S0_4x4 &), typename T0, typename T0_4x4, typename T1, typename T1_2x4>
kernel void kernel_mul_mm_id(

View File

@@ -1961,7 +1961,8 @@ int ggml_metal_op_mul_mat_id(ggml_metal_op_t ctx, int idx) {
// ne21 = n_rows (batch size)
const int ne21_mm_id_min = 32;
if (props_dev->has_simdgroup_mm && ne00 >= 64 && (ne21 >= ne21_mm_id_min)) {
if (props_dev->has_simdgroup_mm && ne00 >= 64 && (ne21 >= ne21_mm_id_min) &&
(ne20 == 1 || ne20 == 2 || ne20 == 4 || ne20 == 6 || ne20 == 8 || ne20 == 10 || ne20 == 16 || ne20 == 22)) {
// some Metal matrix data types require aligned pointers
// ref: https://developer.apple.com/metal/Metal-Shading-Language-Specification.pdf (Table 2.5)
//switch (op->src[0]->type) {

View File

@@ -9427,6 +9427,7 @@ template [[host_name("kernel_mul_mm_id_map0_ne20_6" )]] kernel kernel_mul_mm_id_
template [[host_name("kernel_mul_mm_id_map0_ne20_8" )]] kernel kernel_mul_mm_id_map0_t kernel_mul_mm_id_map0<8>;
template [[host_name("kernel_mul_mm_id_map0_ne20_10")]] kernel kernel_mul_mm_id_map0_t kernel_mul_mm_id_map0<10>;
template [[host_name("kernel_mul_mm_id_map0_ne20_16")]] kernel kernel_mul_mm_id_map0_t kernel_mul_mm_id_map0<16>;
template [[host_name("kernel_mul_mm_id_map0_ne20_22")]] kernel kernel_mul_mm_id_map0_t kernel_mul_mm_id_map0<22>;
template<typename S0, typename S0_4x4, typename S0_8x8, typename S1, typename S1_2x4, typename S1_8x8, typename block_q, short nl, void (*dequantize_func)(device const block_q *, short, thread S0_4x4 &), typename T0, typename T0_4x4, typename T1, typename T1_2x4>
kernel void kernel_mul_mm_id(

View File

@@ -67,6 +67,7 @@ func (f *fakeTensor) Tri(ctx ml.Context, _ int) ml.Tensor
func (f *fakeTensor) Fill(ctx ml.Context, _ float32) ml.Tensor { return f }
func (f *fakeTensor) Repeat4D(ctx ml.Context, _, _, _, _ int) ml.Tensor { return f }
func (f *fakeTensor) SolveTri(ctx ml.Context, _ ml.Tensor, _, _, _ bool) ml.Tensor { return f }
func (f *fakeTensor) SSMScan(ctx ml.Context, _, _, _, _, _, _ ml.Tensor) ml.Tensor { return f }
func (m *fakeBackend) Get(name string) ml.Tensor {
if slices.Contains(m.names, name) {

View File

@@ -15,6 +15,7 @@ import (
_ "github.com/ollama/ollama/model/models/llama4"
_ "github.com/ollama/ollama/model/models/mistral3"
_ "github.com/ollama/ollama/model/models/mllama"
_ "github.com/ollama/ollama/model/models/nemotronh"
_ "github.com/ollama/ollama/model/models/nomicbert"
_ "github.com/ollama/ollama/model/models/olmo3"
_ "github.com/ollama/ollama/model/models/qwen2"

View File

@@ -0,0 +1,88 @@
package nemotronh
import (
"fmt"
"math"
"github.com/ollama/ollama/ml"
"github.com/ollama/ollama/ml/nn"
)
// Attention implements simple attention without RoPE for Nemotron-H.
// Unlike Qwen3Next, Nemotron-H attention has:
// - No RoPE (position info comes from Mamba2 layers)
// - Standard scaled dot-product attention
type Attention struct {
Query *nn.Linear `gguf:"attn_q"`
Key *nn.Linear `gguf:"attn_k"`
Value *nn.Linear `gguf:"attn_v"`
Output *nn.Linear `gguf:"attn_output"`
}
func (a *Attention) Forward(ctx ml.Context, hiddenStates ml.Tensor, cache *HybridCache, opts *Options) (ml.Tensor, error) {
hiddenDim := hiddenStates.Dim(0)
nSeqTokens := hiddenStates.Dim(1)
switch hiddenStates.Dim(2) {
case 0:
hiddenStates = hiddenStates.Reshape(ctx, hiddenDim, nSeqTokens, 1)
case 1:
default:
return nil, ErrUnsupportedBatchLayout
}
// Nemotron-H is currently clamped to num_parallel=1.
if cache != nil && cache.IsSupportedForBatch() {
if cache.numSeqs() != 1 {
return nil, ErrUnsupportedBatchLayout
}
if seqTokens := cache.seqTokens(); seqTokens > 0 && nSeqTokens != seqTokens {
return nil, ErrUnsupportedBatchLayout
}
}
batchSize := nSeqTokens
hiddenStates = hiddenStates.Reshape(ctx, hiddenDim, batchSize)
headDim := opts.getHeadDim()
if headDim <= 0 {
return nil, fmt.Errorf("nemotronh: invalid attention head dimension %d", headDim)
}
// Q projection
query := a.Query.Forward(ctx, hiddenStates)
if query.Dim(0)%headDim != 0 {
return nil, fmt.Errorf("nemotronh: query dim %d not divisible by head dim %d", query.Dim(0), headDim)
}
numHeads := query.Dim(0) / headDim
query = query.Reshape(ctx, headDim, numHeads, batchSize)
// K projection
key := a.Key.Forward(ctx, hiddenStates)
if key.Dim(0)%headDim != 0 {
return nil, fmt.Errorf("nemotronh: key dim %d not divisible by head dim %d", key.Dim(0), headDim)
}
numKVHeads := key.Dim(0) / headDim
key = key.Reshape(ctx, headDim, numKVHeads, batchSize)
// V projection
value := a.Value.Forward(ctx, hiddenStates)
if value.Dim(0)%headDim != 0 {
return nil, fmt.Errorf("nemotronh: value dim %d not divisible by head dim %d", value.Dim(0), headDim)
}
if value.Dim(0)/headDim != numKVHeads {
return nil, fmt.Errorf("nemotronh: key heads %d and value heads %d do not match", numKVHeads, value.Dim(0)/headDim)
}
value = value.Reshape(ctx, headDim, numKVHeads, batchSize)
// Standard attention computation (no RoPE)
scale := opts.attentionScale
if scale == 0 {
scale = 1.0 / math.Sqrt(float64(headDim))
}
attention := nn.Attention(ctx, query, key, value, scale, cache)
// Flatten heads
attention = attention.Reshape(ctx, headDim*numHeads, batchSize)
// Output projection
return a.Output.Forward(ctx, attention), nil
}

View File

@@ -0,0 +1,55 @@
package nemotronh
import (
"errors"
"github.com/ollama/ollama/kvcache"
"github.com/ollama/ollama/ml"
)
// ErrUnsupportedBatchLayout is returned when the batch layout is incompatible
// with the layer requirements.
var ErrUnsupportedBatchLayout = errors.New("nemotronh: unsupported batch layout")
var (
_ kvcache.Cache = (*HybridCache)(nil)
_ kvcache.CheckpointCache = (*HybridCache)(nil)
)
// HybridCache adapts the shared recurrent cache base for Nemotron-H naming.
type HybridCache struct {
*kvcache.Recurrent
}
func NewHybridCache(convDim, convChannels, ssmStateSize int) *HybridCache {
base := kvcache.NewRecurrentCache(kvcache.RecurrentConfig{
Shift: Shift,
ConvDim: convDim,
ConvChannels: convChannels,
RecurrentStateSize: ssmStateSize,
CheckpointLogPrefix: "nemotronh",
})
return &HybridCache{Recurrent: base}
}
// SSMState returns the SSM state for current batch sequences.
func (c *HybridCache) SSMState(ctx ml.Context, layer int, dState, headDim, nHead int) (ml.Tensor, error) {
return c.RecurrentState4D(ctx, layer, dState, headDim, nHead)
}
// UpdateSSMState writes a new SSM state for current batch sequences.
func (c *HybridCache) UpdateSSMState(ctx ml.Context, layer int, newState ml.Tensor) {
c.UpdateRecurrentState(ctx, layer, newState)
}
func (c *HybridCache) slotsTensor() ml.Tensor {
return c.SlotsTensor()
}
func (c *HybridCache) seqTokens() int {
return c.SeqTokens()
}
func (c *HybridCache) numSeqs() int {
return c.NumSeqs()
}

View File

@@ -0,0 +1,197 @@
package nemotronh
import (
"log/slog"
"github.com/ollama/ollama/ml"
"github.com/ollama/ollama/ml/nn"
)
// convKernel wraps the 1D convolution kernel tensor
type convKernel struct {
Weight ml.Tensor `gguf:"weight"`
}
// Mamba2 implements the Mamba2 SSM layer for Nemotron-H.
// The forward pass follows llama.cpp's build_mamba2_layer:
// 1. Input projection: zxBCdt = SSMIn @ hidden
// 2. Split: z, xBC, dt
// 3. Concat with conv state, apply SSMConv, save new conv state
// 4. Apply SiLU to convolved xBC
// 5. Split: x, B, C
// 6. Add dt bias
// 7. SSMScan: y = SSMScan(state, x, dt, A, B, C, ids)
// 8. D skip: y = y + x * D
// 9. Swiglu with z: y = z * silu(y)
// 10. Group RMSNorm
// 11. Output projection
type Mamba2 struct {
SSMIn *nn.Linear `gguf:"ssm_in"` // n_embd → d_in_proj (2*d_inner + 2*n_group*d_state + n_head)
SSMConv1D *convKernel `gguf:"ssm_conv1d"` // conv kernel
SSMConv1DB ml.Tensor `gguf:"ssm_conv1d.bias"`
SSMDtB ml.Tensor `gguf:"ssm_dt.bias"` // dt bias [n_head]
SSMA ml.Tensor `gguf:"ssm_a"` // A parameter [1, n_head]
SSMD ml.Tensor `gguf:"ssm_d"` // D skip connection [1, n_head]
SSMNorm *nn.RMSNorm `gguf:"ssm_norm"` // group norm
SSMOut *nn.Linear `gguf:"ssm_out"` // output projection
Layer int
}
func (m *Mamba2) Forward(ctx ml.Context, hiddenStates ml.Tensor, cache *HybridCache, opts *Options) (ml.Tensor, error) {
layer := m.Layer
hiddenDim := hiddenStates.Dim(0)
nSeqTokens := hiddenStates.Dim(1)
switch hiddenStates.Dim(2) {
case 0:
hiddenStates = hiddenStates.Reshape(ctx, hiddenDim, nSeqTokens, 1)
case 1:
default:
return nil, ErrUnsupportedBatchLayout
}
// Nemotron-H is currently clamped to num_parallel=1.
if cache != nil && cache.IsSupportedForBatch() {
if cache.numSeqs() != 1 {
return nil, ErrUnsupportedBatchLayout
}
if seqTokens := cache.seqTokens(); seqTokens > 0 && nSeqTokens != seqTokens {
return nil, ErrUnsupportedBatchLayout
}
}
nSeqs := 1
dConv := opts.ssmDConv
dInner := opts.ssmDInner
dState := opts.ssmDState
nHead := opts.ssmNHead
headDim := dInner / nHead
nGroup := opts.ssmNGroup
// {n_embd, n_seq_tokens, n_seqs} => {d_in_proj, n_seq_tokens, n_seqs}
// d_in_proj = 2*d_inner + 2*n_group*d_state + n_head
zxBCdt := m.SSMIn.Forward(ctx, hiddenStates)
// Split into z, xBC, dt
// z: [head_dim, n_head, n_seq_tokens, n_seqs]
z := zxBCdt.Slice(ctx, 0, 0, dInner, 1)
z = z.Reshape(ctx, headDim, nHead, nSeqTokens, nSeqs)
// xBC: [d_inner + 2*n_group*d_state, n_seq_tokens, n_seqs]
xBCSize := dInner + 2*nGroup*dState
xBC := zxBCdt.Slice(ctx, 0, dInner, dInner+xBCSize, 1)
if nSeqTokens == 1 {
xBC = xBC.Reshape(ctx, xBCSize, 1, nSeqs)
}
// dt: [n_head, n_seq_tokens, n_seqs]
dt := zxBCdt.Slice(ctx, 0, 2*dInner+2*nGroup*dState, 2*dInner+2*nGroup*dState+nHead, 1)
if nSeqTokens == 1 {
dt = dt.Reshape(ctx, nHead, 1, nSeqs)
} else {
dt = dt.Contiguous(ctx, nHead, nSeqTokens, nSeqs)
}
// Get conv state from cache
convStates, err := cache.ConvState(ctx, layer)
if err != nil {
slog.Warn("nemotronh: failed to get conv state, using zeros", "layer", layer, "error", err)
convStates = ctx.Input().Zeros(ml.DTypeF32, dConv-1, xBCSize, nSeqs)
}
// Reshape conv states: [d_conv-1, xBCSize, n_seqs]
convStates = convStates.Reshape(ctx, dConv-1, xBCSize, nSeqs)
// For decode (n_seq_tokens == 1), reshape avoids a transpose/contiguous pair.
var xBCT ml.Tensor
if nSeqTokens == 1 {
xBCT = xBC.Reshape(ctx, 1, xBCSize, nSeqs)
} else {
// Prefill path: [xBCSize, n_seq_tokens, n_seqs] -> [n_seq_tokens, xBCSize, n_seqs]
xBCT = xBC.Permute(ctx, 1, 0, 2, 3)
}
// Concatenate with conv state: [d_conv-1 + n_seq_tokens, xBCSize, n_seqs]
convInput := convStates.Concat(ctx, xBCT, 0)
// Save new conv state (last d_conv-1 columns)
lastConvStates := convInput.Slice(ctx, 0, nSeqTokens, nSeqTokens+dConv-1, 1)
cache.UpdateConvState(ctx, layer, lastConvStates)
// Apply SSM convolution
xBC = convInput.SSMConv(ctx, m.SSMConv1D.Weight)
// Add conv bias
if m.SSMConv1DB != nil {
xBC = xBC.Add(ctx, m.SSMConv1DB)
}
// Apply SiLU
xBC = xBC.SILU(ctx)
// Split xBC into x, B, C
// x: [head_dim, n_head, n_seq_tokens, n_seqs]
x := xBC.Slice(ctx, 0, 0, dInner, 1)
x = x.Reshape(ctx, headDim, nHead, nSeqTokens, nSeqs)
// B: [d_state, n_group, n_seq_tokens, n_seqs]
B := xBC.Slice(ctx, 0, dInner, dInner+nGroup*dState, 1)
B = B.Reshape(ctx, dState, nGroup, nSeqTokens, nSeqs)
// C: [d_state, n_group, n_seq_tokens, n_seqs]
C := xBC.Slice(ctx, 0, dInner+nGroup*dState, dInner+2*nGroup*dState, 1)
C = C.Reshape(ctx, dState, nGroup, nSeqTokens, nSeqs)
// Add dt bias
dt = dt.Add(ctx, m.SSMDtB)
// Get SSM state from cache
state, err := cache.SSMState(ctx, layer, dState, headDim, nHead)
if err != nil {
slog.Warn("nemotronh: failed to get SSM state, using zeros", "layer", layer, "error", err)
state = ctx.Input().Zeros(ml.DTypeF32, dState, headDim, nHead, nSeqs)
}
// SSMScan
// state: [d_state, head_dim, n_head, n_seqs]
// returns: [head_dim, n_head, n_seq_tokens, n_seqs] concatenated with new state
ySsm := state.SSMScan(ctx, x, dt, m.SSMA, B, C, cache.slotsTensor())
// ySsm is a packed 1D buffer: [y (nSeqTokens*headDim*nHead*nSeqs), newState]
yElems := headDim * nHead * nSeqTokens * nSeqs
y := ySsm.View(ctx, 0, yElems).Reshape(ctx, headDim, nHead, nSeqTokens, nSeqs)
stateOffsetBytes := yElems * x.Stride(0)
stateElems := dState * headDim * nHead * nSeqs
newState := ySsm.View(ctx, stateOffsetBytes, stateElems)
newState = newState.Reshape(ctx, dState, headDim, nHead, nSeqs)
// Update SSM state in cache
cache.UpdateSSMState(ctx, layer, newState)
// D skip connection: y = y + x * D
if m.SSMD != nil {
// SSMD shape: [1, n_head] -> broadcast to [head_dim, n_head, n_seq_tokens, n_seqs]
xD := x.Mul(ctx, m.SSMD)
y = y.Add(ctx, xD)
}
// Swiglu with z: y = z * silu(y)
y = z.SILU(ctx, y)
// Group RMSNorm
if m.SSMNorm != nil {
// Reshape for group norm: [d_inner/n_group, n_group, n_seq_tokens, n_seqs]
innerPerGroup := dInner / nGroup
y = y.Reshape(ctx, innerPerGroup, nGroup, nSeqTokens, nSeqs)
y = m.SSMNorm.Forward(ctx, y, opts.eps)
}
// Reshape back to [d_inner, n_seq_tokens, n_seqs]
y = y.Reshape(ctx, dInner, nSeqTokens, nSeqs)
// Output projection
out := m.SSMOut.Forward(ctx, y)
// Reshape to 2D for consistency with attention output
return out.Reshape(ctx, out.Dim(0), nSeqTokens*nSeqs), nil
}

View File

@@ -0,0 +1,417 @@
package nemotronh
import (
"fmt"
"math"
"github.com/ollama/ollama/fs"
"github.com/ollama/ollama/ml"
"github.com/ollama/ollama/ml/nn"
"github.com/ollama/ollama/model"
"github.com/ollama/ollama/model/input"
"github.com/ollama/ollama/tokenizer"
)
// Options contains model configuration
type Options struct {
hiddenSize int
numHeads int // attention heads
numKVHeads int // KV heads for attention layers
headDim int
eps float32
// Mamba2 SSM config
ssmDConv int // conv kernel size
ssmDInner int // inner dimension (d_inner)
ssmDState int // state dimension
ssmNHead int // number of SSM heads (dt_rank)
ssmNGroup int // number of groups for B, C
// Per-layer configuration
isRecurrent []bool // true = Mamba2, false = attention or FFN
nFF []int // n_ff per layer (0 = attention-only)
// Attention scale
attentionScale float64
// MoE config
numExperts int
numExpertsUsed int
expertWeightsNorm bool
expertWeightsScale float32
expertWeightsNormClip float32
}
func (o Options) getHeadDim() int {
if o.headDim > 0 {
return o.headDim
}
if o.numHeads <= 0 {
return 0
}
return o.hiddenSize / o.numHeads
}
// Operator is the interface for layer operators (Mamba2 or Attention)
type Operator interface {
Forward(ctx ml.Context, hiddenStates ml.Tensor, cache *HybridCache, opts *Options) (ml.Tensor, error)
}
// MLP is the interface for feedforward networks
type MLP interface {
Forward(ctx ml.Context, hiddenStates ml.Tensor, opts *Options) ml.Tensor
}
// Layer represents a single transformer layer
type Layer struct {
AttentionNorm *nn.RMSNorm `gguf:"attn_norm"`
Operator Operator // Mamba2, Attention, or nil (for FFN-only layers)
MLP MLP // Dense or MoE FFN, or nil
}
func (l *Layer) Forward(ctx ml.Context, layer int, hiddenStates, outputs ml.Tensor, cache *HybridCache, opts *Options) (ml.Tensor, error) {
residual := hiddenStates
// Pre-layer norm
hiddenStates = l.AttentionNorm.Forward(ctx, hiddenStates, opts.eps)
// Layer operator (Mamba2, Attention, or FFN)
if l.Operator != nil {
var err error
hiddenStates, err = l.Operator.Forward(ctx, hiddenStates, cache, opts)
if err != nil {
return nil, err
}
} else if l.MLP != nil {
// FFN-only layer
hiddenStates = l.MLP.Forward(ctx, hiddenStates, opts)
}
// Output projection for last layer
if outputs != nil {
hiddenStates = hiddenStates.Rows(ctx, outputs)
residual = residual.Rows(ctx, outputs)
}
// Residual connection
return hiddenStates.Add(ctx, residual), nil
}
// Model is the main Nemotron-H model
type Model struct {
model.Base
tokenizer.Tokenizer
TokenEmbedding *nn.Embedding `gguf:"token_embd"`
OutputNorm *nn.RMSNorm `gguf:"output_norm"`
Output *nn.Linear `gguf:"output,alt:token_embd"`
Layers []Layer `gguf:"blk"`
*Options
}
// Shift is used for KV cache position shifting.
// Nemotron-H attention does not apply RoPE, so keys do not need to be transformed.
func Shift(ctx ml.Context, layer int, key, shift ml.Tensor) (ml.Tensor, error) {
return key, nil
}
func (m *Model) Forward(ctx ml.Context, batch input.Batch) (ml.Tensor, error) {
hiddenStates := m.TokenEmbedding.Forward(ctx, batch.Inputs)
cache := m.Cache.(*HybridCache)
for i, layer := range m.Layers {
cache.SetLayer(i)
var outputs ml.Tensor
if i == len(m.Layers)-1 {
outputs = batch.Outputs
}
var err error
hiddenStates, err = layer.Forward(ctx, i, hiddenStates, outputs, cache, m.Options)
if err != nil {
return nil, err
}
}
hiddenStates = m.OutputNorm.Forward(ctx, hiddenStates, m.eps)
return m.Output.Forward(ctx, hiddenStates), nil
}
func New(c fs.Config) (model.Model, error) {
numLayers := int(c.Uint("block_count"))
layers := make([]Layer, numLayers)
// Get per-layer configuration from GGUF metadata
// Use the same interface pattern as qwen3next
type perLayerConfig interface {
HeadCount() []uint64
HeadCountKV() []uint64
FFNLength() []uint64
}
var headCount []uint64
var headCountKV []uint64
var ffnLength []uint64
if plc, ok := c.(perLayerConfig); ok {
headCount = plc.HeadCount()
headCountKV = plc.HeadCountKV()
ffnLength = plc.FFNLength()
}
// Build per-layer arrays with defaults
isRecurrent := make([]bool, numLayers)
nFF := make([]int, numLayers)
for i := range numLayers {
// Get per-layer values
kvHeads := uint64(1) // Default non-zero
if i < len(headCountKV) {
kvHeads = headCountKV[i]
}
ff := uint64(0)
if i < len(ffnLength) {
ff = ffnLength[i]
}
nFF[i] = int(ff)
// A layer is recurrent IFF n_head_kv == 0 AND n_ff == 0
// This matches llama.cpp behavior for Nemotron-H
isRecurrent[i] = kvHeads == 0 && ff == 0
}
// Determine if MoE
isMoE := c.Uint("expert_count") > 0
for i := range layers {
if isRecurrent[i] {
// Mamba2 layer
layers[i].Operator = &Mamba2{Layer: i}
} else if nFF[i] == 0 {
// Attention-only layer (n_head_kv > 0, n_ff == 0)
layers[i].Operator = &Attention{}
} else {
// FFN layer (n_ff > 0)
if isMoE {
layers[i].MLP = &MoESparse{}
} else {
layers[i].MLP = &Dense{}
}
}
}
// Get attention head configuration
numHeads := int(c.Uint("attention.head_count"))
if numHeads == 0 {
for i := range numLayers {
if i < len(headCount) && i < len(headCountKV) && headCount[i] > 0 && headCountKV[i] > 0 {
numHeads = int(headCount[i])
break
}
}
}
numKVHeads := int(c.Uint("attention.head_count_kv"))
if numKVHeads == 0 {
for i := range numLayers {
if i < len(headCountKV) && i < len(ffnLength) && headCountKV[i] > 0 && ffnLength[i] == 0 {
numKVHeads = int(headCountKV[i])
break
}
}
if numKVHeads == 0 {
numKVHeads = numHeads
}
}
headDim := int(c.Uint("attention.head_dim"))
if headDim == 0 {
if keyLength := int(c.Uint("attention.key_length")); keyLength > 0 {
headDim = keyLength
} else if numHeads > 0 {
headDim = int(c.Uint("embedding_length")) / numHeads
}
}
if headDim <= 0 {
return nil, fmt.Errorf("nemotronh: invalid attention head dimension")
}
if numHeads <= 0 {
// Attention layers derive per-layer head counts from projection weights.
// Keep a non-zero default to avoid invalid option math.
numHeads = 1
}
numExperts := int(c.Uint("expert_count"))
numExpertsUsed := int(c.Uint("expert_used_count"))
if numExperts > 0 {
if numExpertsUsed <= 0 || numExpertsUsed > numExperts {
return nil, fmt.Errorf("nemotronh: invalid expert_used_count=%d for expert_count=%d", numExpertsUsed, numExperts)
}
}
opts := &Options{
hiddenSize: int(c.Uint("embedding_length")),
numHeads: numHeads,
numKVHeads: numKVHeads,
headDim: headDim,
eps: c.Float("attention.layer_norm_rms_epsilon"),
ssmDConv: int(c.Uint("ssm.conv_kernel")),
ssmDInner: int(c.Uint("ssm.inner_size")),
ssmDState: int(c.Uint("ssm.state_size")),
ssmNHead: int(c.Uint("ssm.time_step_rank")),
ssmNGroup: int(c.Uint("ssm.group_count")),
isRecurrent: isRecurrent,
nFF: nFF,
attentionScale: float64(c.Float("attention.scale")),
numExperts: numExperts,
numExpertsUsed: numExpertsUsed,
expertWeightsNorm: c.Bool("expert_weights_norm", false),
expertWeightsScale: c.Float("expert_weights_scale", 1.0),
expertWeightsNormClip: c.Float("expert_weights_norm_clip", 0),
}
// Calculate cache dimensions
convDim := max(0, opts.ssmDConv-1)
convChannels := opts.ssmDInner + 2*opts.ssmNGroup*opts.ssmDState
ssmHeadDim := 0
if opts.ssmNHead > 0 {
ssmHeadDim = opts.ssmDInner / opts.ssmNHead
}
ssmStateSize := opts.ssmDState * ssmHeadDim * opts.ssmNHead
m := Model{
Tokenizer: tokenizer.NewBytePairEncoding(
&tokenizer.Vocabulary{
Values: c.Strings("tokenizer.ggml.tokens"),
Types: c.Ints("tokenizer.ggml.token_type"),
Merges: c.Strings("tokenizer.ggml.merges"),
AddBOS: c.Bool("tokenizer.ggml.add_bos_token", false),
BOS: []int32{int32(c.Uint("tokenizer.ggml.bos_token_id"))},
AddEOS: c.Bool("tokenizer.ggml.add_eos_token", false),
EOS: append(
[]int32{int32(c.Uint("tokenizer.ggml.eos_token_id"))},
c.Ints("tokenizer.ggml.eos_token_ids")...,
),
},
`(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\r\n\p{L}\p{N}]?\p{L}+|\p{N}| ?[^\s\p{L}\p{N}]+[\r\n]*|\s*[\r\n]+|\s+(?!\S)|\s+`,
),
Layers: layers,
Options: opts,
}
m.Cache = NewHybridCache(convDim, convChannels, ssmStateSize)
return &m, nil
}
func init() {
model.Register("nemotron_h", New)
model.Register("nemotron_h_moe", New)
}
// Ensure Model implements model.Model
var _ model.Model = (*Model)(nil)
// Dense implements standard feedforward with ReLU-squared activation
type Dense struct {
Up *nn.Linear `gguf:"ffn_up"`
Down *nn.Linear `gguf:"ffn_down"`
}
func (d *Dense) Forward(ctx ml.Context, x ml.Tensor, opts *Options) ml.Tensor {
// up -> ReLU-squared -> down
up := d.Up.Forward(ctx, x)
up = up.RELU(ctx)
up = up.Mul(ctx, up) // Square
return d.Down.Forward(ctx, up)
}
// MoESparse implements MoE with shared experts for Nemotron-H-MoE
type MoESparse struct {
Router *nn.Linear `gguf:"ffn_gate_inp"`
Up *nn.LinearBatch `gguf:"ffn_up_exps"`
Down *nn.LinearBatch `gguf:"ffn_down_exps"`
Bias ml.Tensor `gguf:"exp_probs_b.bias,alt:exp_probs_b"`
LatentIn *nn.Linear `gguf:"ffn_latent_in"`
LatentOut *nn.Linear `gguf:"ffn_latent_out"`
// Shared experts
SharedUp *nn.Linear `gguf:"ffn_up_shexp"`
SharedDown *nn.Linear `gguf:"ffn_down_shexp"`
}
func (m *MoESparse) Forward(ctx ml.Context, hiddenStates ml.Tensor, opts *Options) ml.Tensor {
hiddenDim := hiddenStates.Dim(0)
seqLen := hiddenStates.Dim(1)
batchSize := hiddenStates.Dim(2)
if batchSize == 0 {
batchSize = 1
}
hiddenStates2D := hiddenStates.Reshape(ctx, hiddenDim, seqLen*batchSize)
// Router logits with sigmoid gating
routerLogits := m.Router.Forward(ctx, hiddenStates2D)
// Weights come from unbiased sigmoid probabilities.
probs := routerLogits.Sigmoid(ctx)
// Selection uses optional bias.
selectionProbs := probs
if m.Bias != nil {
selectionProbs = selectionProbs.Add(ctx, m.Bias)
}
// Select top-k experts
selectedExperts := selectionProbs.TopK(ctx, opts.numExpertsUsed)
routingWeights := probs.Reshape(ctx, 1, opts.numExperts, hiddenStates2D.Dim(1)).Rows(ctx, selectedExperts)
// Normalize routing weights
if opts.expertWeightsNorm {
routingWeights = routingWeights.Reshape(ctx, opts.numExpertsUsed, hiddenStates2D.Dim(1))
weightsSum := routingWeights.SumRows(ctx)
weightsSum = weightsSum.Clamp(ctx, 6.103515625e-5, float32(math.MaxFloat32))
routingWeights = routingWeights.Div(ctx, weightsSum)
routingWeights = routingWeights.Reshape(ctx, 1, opts.numExpertsUsed, hiddenStates2D.Dim(1))
}
// Scale routing weights
if opts.expertWeightsScale != 1.0 {
routingWeights = routingWeights.Scale(ctx, float64(opts.expertWeightsScale))
}
routedInput := hiddenStates2D
if m.LatentIn != nil {
routedInput = m.LatentIn.Forward(ctx, routedInput)
}
hiddenStates3D := routedInput.Reshape(ctx, routedInput.Dim(0), 1, routedInput.Dim(1))
// Expert computation with ReLU-squared activation
upOut := m.Up.Forward(ctx, hiddenStates3D, selectedExperts)
upOut = upOut.RELU(ctx)
upOut = upOut.Mul(ctx, upOut) // Square
experts := m.Down.Forward(ctx, upOut, selectedExperts)
experts = experts.Mul(ctx, routingWeights)
// Sum over experts
moeOut := experts.View(ctx, 0, experts.Dim(0), experts.Stride(2), experts.Dim(2))
for i := 1; i < opts.numExpertsUsed; i++ {
moeOut = moeOut.Add(ctx, experts.View(ctx, i*experts.Stride(1), experts.Dim(0), experts.Stride(2), experts.Dim(2)))
}
if m.LatentOut != nil {
moeOut = m.LatentOut.Forward(ctx, moeOut)
}
// Add shared experts if present
if m.SharedUp != nil {
sharedUp := m.SharedUp.Forward(ctx, hiddenStates2D)
sharedUp = sharedUp.RELU(ctx)
sharedUp = sharedUp.Mul(ctx, sharedUp) // Square
sharedOut := m.SharedDown.Forward(ctx, sharedUp)
moeOut = moeOut.Add(ctx, sharedOut)
}
return moeOut
}

View File

@@ -447,7 +447,7 @@ func (s *Scheduler) load(req *LlmRequest, f *ggml.GGML, systemInfo ml.SystemInfo
// Some architectures are not safe with num_parallel > 1.
// ref: https://github.com/ollama/ollama/issues/4165
if slices.Contains([]string{"mllama", "qwen3vl", "qwen3vlmoe", "qwen3next", "lfm2", "lfm2moe"}, req.model.Config.ModelFamily) && numParallel != 1 {
if slices.Contains([]string{"mllama", "qwen3vl", "qwen3vlmoe", "qwen3next", "lfm2", "lfm2moe", "nemotron_h", "nemotron_h_moe"}, req.model.Config.ModelFamily) && numParallel != 1 {
numParallel = 1
slog.Warn("model architecture does not currently support parallel requests", "architecture", req.model.Config.ModelFamily)
}