Files
ollama/x/imagegen/vae/tiling.go
Daniel Hiltgen 10e51c5177 MLX: add header vendoring and remove go build tag (#14642)
* prefer rocm v6 on windows

Avoid building with v7 - more changes are needed

* MLX: add header vendoring and remove go build tag

This switches to using a vendoring approach for the mlx-c headers so that Go
can build without requiring a cmake first.  This enables building the new MLX
based code by default.  Every time cmake runs, the headers are refreshed, so we
can easily keep them in sync when we bump mlx versions.  Basic Windows
and Linux support are verified.

* ci: harden for flaky choco repo servers

CI sometimes fails due to choco not actually installing cache.  Since it just speeds up the build, we can proceed without.

* review comments
2026-03-09 17:24:45 -07:00

214 lines
6.1 KiB
Go

// Package vae provides shared utilities for VAE (Variational Autoencoder) operations.
package vae
import (
"github.com/ollama/ollama/x/imagegen/mlx"
)
// TilingConfig holds configuration for tiled VAE decoding.
// This is a general technique to reduce memory usage when decoding large latents.
type TilingConfig struct {
TileSize int32 // Tile size in latent space (e.g., 64 latent → 512 pixels for 8x VAE)
Overlap int32 // Overlap in latent space (e.g., 16 latent = 25% of 64)
}
// DefaultTilingConfig returns reasonable defaults matching diffusers.
// tile_latent_min_size=64, tile_overlap_factor=0.25
func DefaultTilingConfig() *TilingConfig {
return &TilingConfig{
TileSize: 64, // 64 latent pixels
Overlap: 16, // 25% overlap
}
}
// decodedTile holds a decoded tile's pixel data and dimensions
type decodedTile struct {
data []float32
height int32
width int32
}
// DecodeTiled decodes latents using tiled processing with overlap blending.
// This reduces memory usage for large images by processing in overlapping tiles.
//
// Parameters:
// - latents: [1, H, W, C] latent tensor in NHWC format
// - cfg: tiling configuration (tile size and overlap)
// - decoder: function to decode a single tile [1, H, W, C] -> [1, H*scale, W*scale, 3]
//
// Returns: [1, 3, H*scale, W*scale] decoded image in NCHW format
func DecodeTiled(latents *mlx.Array, cfg *TilingConfig, decoder func(*mlx.Array) *mlx.Array) *mlx.Array {
shape := latents.Shape()
H := shape[1] // latent height
W := shape[2] // latent width
C := shape[3]
tileLatentSize := cfg.TileSize
overlapLatent := cfg.Overlap
// If image is small enough, just decode normally
if H <= tileLatentSize && W <= tileLatentSize {
decoded := decoder(latents)
decoded = mlx.AsType(decoded, mlx.DtypeFloat32)
decoded = mlx.ClipScalar(decoded, 0.0, 1.0, true, true)
decoded = mlx.Transpose(decoded, 0, 3, 1, 2) // NHWC -> NCHW
return decoded
}
// Calculate tiling parameters (matching diffusers)
overlapSize := tileLatentSize - overlapLatent // stride in latent space
// Blend extent in pixel space (assumes 8x upscale, adjust if needed)
// For other scale factors, this could be made configurable
tileSampleSize := tileLatentSize * 8 // tile size in pixels after 8x upscale
blendExtent := overlapLatent * 8 // blend region in pixels
rowLimit := tileSampleSize - blendExtent // non-overlapping region per tile
// Phase 1: Decode all tiles and store in 2D grid
var rows [][]decodedTile
for i := int32(0); i < H; i += overlapSize {
var row []decodedTile
for j := int32(0); j < W; j += overlapSize {
// Extract tile (may be smaller at edges)
i2 := min(i+tileLatentSize, H)
j2 := min(j+tileLatentSize, W)
tile := mlx.Slice(latents, []int32{0, i, j, 0}, []int32{1, i2, j2, C})
decoded := decoder(tile)
decoded = mlx.AsType(decoded, mlx.DtypeFloat32)
mlx.Eval(decoded)
decodedShape := decoded.Shape()
tileH := decodedShape[1]
tileW := decodedShape[2]
tileData := decoded.Data()
decoded.Free()
row = append(row, decodedTile{data: tileData, height: tileH, width: tileW})
}
rows = append(rows, row)
}
// Phase 2: Blend adjacent tiles (modifies in place)
for i := range rows {
for j := range rows[i] {
tile := &rows[i][j]
// Blend with tile above
if i > 0 {
above := &rows[i-1][j]
blendV(above, tile, blendExtent)
}
// Blend with tile to the left
if j > 0 {
left := &rows[i][j-1]
blendH(left, tile, blendExtent)
}
}
}
// Phase 3: Calculate crop dimensions for each tile
colWidths := make([]int32, len(rows[0]))
for j := range rows[0] {
keepW := rowLimit
if int32(j+1)*overlapSize >= W {
keepW = rows[0][j].width
}
colWidths[j] = keepW
}
rowHeights := make([]int32, len(rows))
for i := range rows {
keepH := rowLimit
if int32(i+1)*overlapSize >= H {
keepH = rows[i][0].height
}
rowHeights[i] = keepH
}
// Calculate total dimensions
var totalW, totalH int32
for _, w := range colWidths {
totalW += w
}
for _, h := range rowHeights {
totalH += h
}
// Phase 4: Assemble final image by interleaving tiles row-by-row
finalData := make([]float32, totalH*totalW*3)
dstY := int32(0)
for i, row := range rows {
keepH := rowHeights[i]
for y := int32(0); y < keepH; y++ {
dstX := int32(0)
for j, tile := range row {
keepW := colWidths[j]
for x := int32(0); x < keepW; x++ {
for c := int32(0); c < 3; c++ {
srcIdx := (y*tile.width + x) * 3 + c
dstIdx := ((dstY + y) * totalW + (dstX + x)) * 3 + c
finalData[dstIdx] = tile.data[srcIdx]
}
}
dstX += keepW
}
}
dstY += keepH
}
// Create mlx array [1, H, W, 3] then transpose to NCHW [1, 3, H, W]
result := mlx.NewArray(finalData, []int32{1, totalH, totalW, 3})
result = mlx.Transpose(result, 0, 3, 1, 2)
result = mlx.ClipScalar(result, 0.0, 1.0, true, true)
return result
}
// blendV blends the bottom of 'above' tile into top of 'current' tile (vertical blend)
// Matches diffusers blend_v formula
func blendV(above, current *decodedTile, blendExtent int32) {
blend := min(blendExtent, min(above.height, current.height))
if blend <= 0 {
return
}
w := min(above.width, current.width)
for y := int32(0); y < blend; y++ {
alpha := float32(y) / float32(blend)
for x := int32(0); x < w; x++ {
for c := int32(0); c < 3; c++ {
aboveIdx := ((above.height - blend + y) * above.width + x) * 3 + c
currIdx := (y * current.width + x) * 3 + c
current.data[currIdx] = above.data[aboveIdx]*(1-alpha) + current.data[currIdx]*alpha
}
}
}
}
// blendH blends the right of 'left' tile into left of 'current' tile (horizontal blend)
// Matches diffusers blend_h formula
func blendH(left, current *decodedTile, blendExtent int32) {
blend := min(blendExtent, min(left.width, current.width))
if blend <= 0 {
return
}
h := min(left.height, current.height)
for y := int32(0); y < h; y++ {
for x := int32(0); x < blend; x++ {
alpha := float32(x) / float32(blend)
for c := int32(0); c < 3; c++ {
leftIdx := (y * left.width + (left.width - blend + x)) * 3 + c
currIdx := (y * current.width + x) * 3 + c
current.data[currIdx] = left.data[leftIdx]*(1-alpha) + current.data[currIdx]*alpha
}
}
}
}