mirror of
https://github.com/ollama/ollama.git
synced 2026-04-18 09:03:35 -04:00
mlx: fuse sigmoid router head in glm4_moe_lite
DeepSeek-V2-style aux-loss-free routing computes sigmoid(gates) once but needs it twice: the raw sigmoid output is gathered after top-k, while the post-bias negation is the argpartition key. Fuse into a single multi-output Compiled kernel returning both, saving two launches on the routing path per token. Exposed as a general SigmoidRouter since the same pattern is shared across DeepSeek-V2 descendants. Improves glm4.7 generation performance by approximately 1%.
This commit is contained in:
@@ -62,3 +62,25 @@ var LogitSoftcap = Compile2(
|
||||
},
|
||||
Shapeless(),
|
||||
)
|
||||
|
||||
// sigmoidRouterFused traces the DeepSeek-V2 / GLM-MoE aux-loss-free router
|
||||
// head. Two outputs are returned so the pre-bias sigmoid (used to gather
|
||||
// per-expert scores after top-k) and the post-bias negation (used as the
|
||||
// argpartition key for top-k) share a single kernel.
|
||||
var sigmoidRouterFused = Compile(
|
||||
"SigmoidRouter",
|
||||
func(in ...*Array) []*Array {
|
||||
gates, bias := in[0], in[1]
|
||||
orig := gates.Sigmoid()
|
||||
neg := orig.Add(bias).Negative()
|
||||
return []*Array{orig, neg}
|
||||
},
|
||||
Shapeless(),
|
||||
)
|
||||
|
||||
// SigmoidRouter returns (sigmoid(gates), -(sigmoid(gates)+bias)) as a fused
|
||||
// kernel — the DeepSeek-V2 / GLM-MoE aux-loss-free router head.
|
||||
func SigmoidRouter(gates, bias *Array) (origScores, negScores *Array) {
|
||||
out := sigmoidRouterFused(gates, bias)
|
||||
return out[0], out[1]
|
||||
}
|
||||
|
||||
@@ -161,21 +161,21 @@ type MoEGate struct {
|
||||
func (g *MoEGate) Forward(x *mlx.Array, cfg *Config) (*mlx.Array, *mlx.Array) {
|
||||
gates := g.Gate.Forward(x)
|
||||
|
||||
scores := mlx.Sigmoid(gates)
|
||||
origScores := scores
|
||||
|
||||
var origScores, negScores *mlx.Array
|
||||
if g.EScoreCorrectionBias != nil {
|
||||
scores = mlx.Add(scores, g.EScoreCorrectionBias)
|
||||
origScores, negScores = mlx.SigmoidRouter(gates, g.EScoreCorrectionBias)
|
||||
} else {
|
||||
origScores = mlx.Sigmoid(gates)
|
||||
negScores = mlx.Neg(origScores)
|
||||
}
|
||||
|
||||
topK := cfg.NumExpertsPerTok
|
||||
negScores := mlx.Neg(scores)
|
||||
inds := mlx.Argpartition(negScores, int(topK)-1, -1)
|
||||
|
||||
dims := inds.Dims()
|
||||
inds = mlx.SliceStartStop(inds, []int32{0, 0, 0}, []int32{int32(dims[0]), int32(dims[1]), topK})
|
||||
|
||||
scores = mlx.TakeAlongAxis(origScores, inds, -1)
|
||||
scores := mlx.TakeAlongAxis(origScores, inds, -1)
|
||||
|
||||
if topK > 1 && cfg.NormTopKProb {
|
||||
sumScores := mlx.Sum(scores, -1, true)
|
||||
|
||||
Reference in New Issue
Block a user