Skip to content

go-mlx

forge.lthn.ai/core/go-mlx provides native Apple Metal GPU inference and LoRA fine-tuning for Go. It wraps Apple's MLX framework through the mlx-c C API, implementing the inference.Backend interface from forge.lthn.ai/core/go-inference.

Platform: darwin/arm64 only (Apple Silicon M1-M4). A stub provides MetalAvailable() bool returning false on all other platforms.

Quick Start

import (
    "context"
    "fmt"

    "forge.lthn.ai/core/go-inference"
    _ "forge.lthn.ai/core/go-mlx" // registers "metal" backend via init()
)

func main() {
    m, err := inference.LoadModel("/path/to/safetensors/model/")
    if err != nil {
        panic(err)
    }
    defer m.Close()

    ctx := context.Background()
    for tok := range m.Generate(ctx, "What is 2+2?", inference.WithMaxTokens(128)) {
        fmt.Print(tok.Text)
    }
    if err := m.Err(); err != nil {
        panic(err)
    }
}

The blank import (_ "forge.lthn.ai/core/go-mlx") auto-registers the Metal backend. All interaction goes through the go-inference interfaces -- go-mlx itself exports only Metal-specific memory controls.

Features

  • Streaming inference -- token-by-token generation via iter.Seq[Token] (range-over-func)
  • Multi-turn chat -- native chat templates for Gemma 3, Qwen 2/3, and Llama 3
  • Batch inference -- Classify (prefill-only) and BatchGenerate (autoregressive) for multiple prompts
  • LoRA fine-tuning -- low-rank adaptation with AdamW optimiser and gradient checkpointing
  • Quantisation -- transparent support for 4-bit and 8-bit quantised models via QuantizedMatmul
  • Attention inspection -- extract post-RoPE K vectors from the KV cache for analysis
  • Performance metrics -- prefill/decode tokens per second, GPU memory usage

Supported Models

Models must be in HuggingFace safetensors format (not GGUF). Architecture is auto-detected from config.json:

Architecture model_type values Tested sizes
Gemma 3 gemma3, gemma3_text, gemma2 1B, 4B, 27B
Qwen 3 qwen3, qwen2 8B+
Llama 3 llama 8B+

Package Layout

Package Purpose
Root (mlx) Public API: Metal backend registration, memory controls, training type exports
internal/metal/ All CGO code: array ops, model loaders, generation, training primitives
mlxlm/ Alternative subprocess backend via Python's mlx-lm (no CGO required)

Metal Memory Controls

These control the Metal allocator directly, not individual models:

import mlx "forge.lthn.ai/core/go-mlx"

mlx.SetCacheLimit(4 << 30)   // 4 GB cache limit
mlx.SetMemoryLimit(32 << 30) // 32 GB hard limit
mlx.ClearCache()              // release cached memory between chat turns

fmt.Printf("active: %d MB, peak: %d MB\n",
    mlx.GetActiveMemory()/1024/1024,
    mlx.GetPeakMemory()/1024/1024)
Function Purpose
SetCacheLimit(bytes) Soft limit on the allocator cache
SetMemoryLimit(bytes) Hard ceiling on Metal memory
SetWiredLimit(bytes) Wired memory limit
GetActiveMemory() Current live allocations in bytes
GetPeakMemory() High-water mark since last reset
GetCacheMemory() Cached (not yet freed) memory
ClearCache() Release cached memory to the OS
ResetPeakMemory() Reset the high-water mark
GetDeviceInfo() Metal GPU hardware information

Performance Baseline

Measured on M3 Ultra (60-core GPU, 96 GB unified memory):

Operation Throughput
Gemma3-1B 4-bit prefill 246 tok/s
Gemma3-1B 4-bit decode 82 tok/s
Gemma3-1B 4-bit classify (4 prompts) 152 prompts/s
DeepSeek R1 7B 4-bit decode 27 tok/s
Llama 3.1 8B 4-bit decode 30 tok/s

Documentation

  • Architecture -- CGO binding layer, lazy evaluation, memory model, attention, KV cache
  • Models -- model loading, supported architectures, tokenisation, chat templates
  • Training -- LoRA fine-tuning, gradient computation, AdamW optimiser, loss functions
  • Build Guide -- prerequisites, CMake setup, build tags, testing

Downstream Consumers

Package Role
forge.lthn.ai/core/go-ml Imports go-inference + go-mlx for the Metal backend training loop
forge.lthn.ai/core/go-i18n Gemma3-1B domain classification (Phase 2a)
forge.lthn.ai/core/go-rocm Sibling AMD GPU backend, same go-inference interfaces

Licence

EUPL-1.2