mirror of
https://github.com/ollama/ollama.git
synced 2026-04-18 09:03:35 -04:00
mlxrunner: fuse top-P and top-K into a single sort pass
When both filters are active, avoid paying for a full sort in top-P and a partial sort in top-K. Single-filter paths are unchanged. Improves generation throughput on gemma4:e4b by 1.5%.
This commit is contained in:
@@ -54,18 +54,23 @@ func New(opts Options) *Sampler {
|
||||
transforms = append(transforms, penalty)
|
||||
}
|
||||
|
||||
if opts.TopP > 0 && opts.TopP < 1 {
|
||||
transforms = append(transforms, topP)
|
||||
hasTopP := opts.TopP > 0 && opts.TopP < 1
|
||||
hasTopK := opts.TopK > 0
|
||||
switch {
|
||||
case hasTopP:
|
||||
// topKTopP always does a full descending sort for the top-P
|
||||
// cumulative mask and opportunistically masks top-K during the
|
||||
// same pass when it is also configured.
|
||||
transforms = append(transforms, topKTopP)
|
||||
case hasTopK:
|
||||
// Argpartition (partial sort) is cheaper than a full sort.
|
||||
transforms = append(transforms, topK)
|
||||
}
|
||||
|
||||
if opts.MinP != 0 {
|
||||
transforms = append(transforms, minP)
|
||||
}
|
||||
|
||||
if opts.TopK > 0 {
|
||||
transforms = append(transforms, topK)
|
||||
}
|
||||
|
||||
if opts.Temperature == 0 {
|
||||
transforms = append(transforms, greedy)
|
||||
} else {
|
||||
@@ -173,18 +178,38 @@ func temperature(s *Sampler, scores *mlx.Array) *mlx.Array {
|
||||
return mlx.DivScalar(scores, s.Temperature).Categorical(-1)
|
||||
}
|
||||
|
||||
func topP(s *Sampler, scores *mlx.Array) *mlx.Array {
|
||||
if s.TopP <= 0 || s.TopP >= 1 {
|
||||
return scores
|
||||
}
|
||||
// topKTopP applies top-P in a descending sort pass and, when top-K is also
|
||||
// configured, masks any surviving value below the K-th largest in the same
|
||||
// pass. Callers dispatch here whenever top-P is enabled — the top-K-only
|
||||
// case uses a cheaper partial sort via the topK transform.
|
||||
func topKTopP(s *Sampler, scores *mlx.Array) *mlx.Array {
|
||||
vocab := scores.Dim(scores.NumDims() - 1)
|
||||
applyTopK := s.TopK > 0 && s.TopK < vocab
|
||||
|
||||
order := scores.Negative().ArgsortAxis(-1)
|
||||
sortedScores := scores.TakeAlongAxis(order, -1)
|
||||
sortedProbs := mlx.SoftmaxAxis(sortedScores, -1, true)
|
||||
prevCumProbs := sortedProbs.Cumsum(-1, false, true).Subtract(sortedProbs)
|
||||
sorted := scores.TakeAlongAxis(order, -1)
|
||||
negInf := mlx.FromValue(float32(math.Inf(-1)))
|
||||
|
||||
// Top-P: in descending order, keep tokens whose exclusive cumulative
|
||||
// probability is still below s.TopP.
|
||||
probs := mlx.SoftmaxAxis(sorted, -1, true)
|
||||
prevCumProbs := probs.Cumsum(-1, false, true).Subtract(probs)
|
||||
keep := prevCumProbs.Less(mlx.FromValue(s.TopP))
|
||||
filtered := mlx.Where(keep, sortedScores, mlx.FromValue(float32(math.Inf(-1))))
|
||||
return scores.PutAlongAxis(order, filtered, -1)
|
||||
sorted = mlx.Where(keep, sorted, negInf)
|
||||
|
||||
out := scores.PutAlongAxis(order, sorted, -1)
|
||||
|
||||
// Top-K: sorted is already in descending order, so positions [K, V)
|
||||
// are the ones to drop. Scatter -inf through their original-layout
|
||||
// indices (order[K:]). Positional (not value-based) so exactly K
|
||||
// tokens survive — ties at the K-th logit get broken by the sort
|
||||
// order rather than promoted through the filter.
|
||||
if applyTopK {
|
||||
dropOrder := order.Slice(mlx.Slice(), mlx.Slice(s.TopK, mlx.End))
|
||||
out = out.PutAlongAxis(dropOrder, negInf, -1)
|
||||
}
|
||||
|
||||
return out
|
||||
}
|
||||
|
||||
func minP(s *Sampler, scores *mlx.Array) *mlx.Array {
|
||||
|
||||
Reference in New Issue
Block a user