mlxrunner: fuse top-P and top-K into a single sort pass

When both filters are active, avoid paying for a full sort in top-P
and a partial sort in top-K. Single-filter paths are unchanged.
Improves generation throughput on gemma4:e4b by 1.5%.
This commit is contained in:
Jesse Gross
2026-04-16 13:42:39 -07:00
parent ef8c885bd7
commit 149f45800d

View File

@@ -54,18 +54,23 @@ func New(opts Options) *Sampler {
transforms = append(transforms, penalty)
}
if opts.TopP > 0 && opts.TopP < 1 {
transforms = append(transforms, topP)
hasTopP := opts.TopP > 0 && opts.TopP < 1
hasTopK := opts.TopK > 0
switch {
case hasTopP:
// topKTopP always does a full descending sort for the top-P
// cumulative mask and opportunistically masks top-K during the
// same pass when it is also configured.
transforms = append(transforms, topKTopP)
case hasTopK:
// Argpartition (partial sort) is cheaper than a full sort.
transforms = append(transforms, topK)
}
if opts.MinP != 0 {
transforms = append(transforms, minP)
}
if opts.TopK > 0 {
transforms = append(transforms, topK)
}
if opts.Temperature == 0 {
transforms = append(transforms, greedy)
} else {
@@ -173,18 +178,38 @@ func temperature(s *Sampler, scores *mlx.Array) *mlx.Array {
return mlx.DivScalar(scores, s.Temperature).Categorical(-1)
}
func topP(s *Sampler, scores *mlx.Array) *mlx.Array {
if s.TopP <= 0 || s.TopP >= 1 {
return scores
}
// topKTopP applies top-P in a descending sort pass and, when top-K is also
// configured, masks any surviving value below the K-th largest in the same
// pass. Callers dispatch here whenever top-P is enabled — the top-K-only
// case uses a cheaper partial sort via the topK transform.
func topKTopP(s *Sampler, scores *mlx.Array) *mlx.Array {
vocab := scores.Dim(scores.NumDims() - 1)
applyTopK := s.TopK > 0 && s.TopK < vocab
order := scores.Negative().ArgsortAxis(-1)
sortedScores := scores.TakeAlongAxis(order, -1)
sortedProbs := mlx.SoftmaxAxis(sortedScores, -1, true)
prevCumProbs := sortedProbs.Cumsum(-1, false, true).Subtract(sortedProbs)
sorted := scores.TakeAlongAxis(order, -1)
negInf := mlx.FromValue(float32(math.Inf(-1)))
// Top-P: in descending order, keep tokens whose exclusive cumulative
// probability is still below s.TopP.
probs := mlx.SoftmaxAxis(sorted, -1, true)
prevCumProbs := probs.Cumsum(-1, false, true).Subtract(probs)
keep := prevCumProbs.Less(mlx.FromValue(s.TopP))
filtered := mlx.Where(keep, sortedScores, mlx.FromValue(float32(math.Inf(-1))))
return scores.PutAlongAxis(order, filtered, -1)
sorted = mlx.Where(keep, sorted, negInf)
out := scores.PutAlongAxis(order, sorted, -1)
// Top-K: sorted is already in descending order, so positions [K, V)
// are the ones to drop. Scatter -inf through their original-layout
// indices (order[K:]). Positional (not value-based) so exactly K
// tokens survive — ties at the K-th logit get broken by the sort
// order rather than promoted through the filter.
if applyTopK {
dropOrder := order.Slice(mlx.Slice(), mlx.Slice(s.TopK, mlx.End))
out = out.PutAlongAxis(dropOrder, negInf, -1)
}
return out
}
func minP(s *Sampler, scores *mlx.Array) *mlx.Array {