Skip to content

[Feature] MLX Backend Implementation for SGLang #17846

@dougyster

Description

@dougyster

Checklist

Motivation

[Feature] MLX Backend for Apple Silicon Support

Summary

Add Apple Silicon (MLX + Metal) backend support to SGLang, enabling native GPU-accelerated inference on M1/M2/M3/M4 Macs.

Motivation

  • vLLM has community-maintained Apple Silicon support via [vllm-metal](https://github.com/vllm-project/vllm-metal)
  • SGLang currently has no Apple Silicon support
  • MLX provides native Metal GPU acceleration with unified memory benefits
  • Enables local development/testing on Mac hardware

Project Structure

python/sglang/srt/hardware_backend/mlx/
├── __init__.py
├── mlx_attention.py          # Attention Backend (REQUIRED for MVP)
├── mlx_graph_runner.py       # Graph Runner (REQUIRED for MVP)
├── mlx_model_loader.py       # Model loading (REQUIRED for MVP)
├── quantization/             # Quantization (Phase 2)
│   └── mlx_quant_method.py
└── moe/                      # MOE support (Phase 3)
    └── mlx_moe_method.py

Registration Entry Points

1. Device Config

File: python/sglang/srt/configs/device_config.py - line 14

# CHANGE THIS LINE:
if device in ["cuda", "xpu", "hpu", "cpu", "npu"]:

# TO:
if device in ["cuda", "xpu", "hpu", "cpu", "npu", "mlx"]:
    self.device_type = device

2. Graph Runner Registration

File: python/sglang/srt/model_executor/model_runner.py - lines 2018-2025

# ADD IMPORT at ~line 111:
from sglang.srt.hardware_backend.mlx.mlx_graph_runner import MLXGraphRunner

# MODIFY at line 2018-2023:
graph_runners = defaultdict(
    lambda: CudaGraphRunner,
    {
        "cpu": CPUGraphRunner,
        "npu": NPUGraphRunner,
        "mlx": MLXGraphRunner,  # ← ADD THIS LINE
    },
)
self.graph_runner = graph_runners[self.device](self)

3. Attention Backend Registration

File: python/sglang/srt/layers/attention/attention_registry.py - line 45

@register_attention_backend("mlx")
def create_mlx_backend(runner):
    from sglang.srt.hardware_backend.mlx.mlx_attention import MLXAttentionBackend
    return MLXAttentionBackend(runner)

MLX-Specific Considerations

1. Model Runner / Device Placement

  • Intercept input tensors early to ensure they aren't placed on CUDA devices
  • Attention backend and graph runner can assume tensors are in the right place since MLX uses unified memory (no CPU↔GPU transfers needed)

2. Attention Backend

Core function: mlx.fast.scaled_dot_product_attention(q, k, v, scale, mask)

  • Metal-accelerated attention for forward_decode() and forward_extend()
output = mlx.fast.scaled_dot_product_attention(
    q, k, v,
    scale=1.0/math.sqrt(head_dim),
    mask=attention_mask
)

Tensor conversion: mlx.array(numpy_array, dtype=mx.float16)

  • Bridges PyTorch (SGLang) and MLX
  • MVP: Convert on each forward pass
  • Future: Keep tensors in MLX format throughout, or use mlx-community models directly
q_mlx = mlx.array(q.cpu().numpy(), dtype=mx.float16)

Lazy evaluation: mlx.eval(*arrays)

  • MLX uses lazy evaluation; force computation before returning to PyTorch
output = mx.fast.scaled_dot_product_attention(q, k, v)
mx.eval(output)  # Force computation
return torch.from_numpy(np.array(output))

3. Graph Runner

JIT compilation: mlx.compile(function)

class MLXGraphRunner:
    def __init__(self, model_runner):
        self.model = model_runner.model
        self.compiled_fn = mlx.compile(self.model.forward)

    def forward(self, input_ids, positions, batch):
        return self.compiled_fn(input_ids, positions)

Note: This can be a thin wrapper. CUDA graph runners exist to batch CPU→GPU kernel launches, but MLX's unified memory eliminates this overhead. Capture/replay optimizations are not needed.

4. Quantization (Phase 2)

# Quantize model weights
model = load_model(path)
quantized = mlx.nn.quantize(model, bits=4, group_size=64)
# Memory: 14GB FP16 → 3.5GB INT4

# Pre-quantized linear layer
layer = mlx.nn.QuantizedLinear(
    input_dims=4096,
    output_dims=4096,
    bits=4,
    group_size=64
)

Roadmap

Phase 1: MVP

Goal: Basic inference working on Apple Silicon

Required Files

File Change
configs/device_config.py Add "mlx" to device list
model_runner.py Register MLXGraphRunner
attention_registry.py Register mlx attention backend
hardware_backend/mlx/__init__.py New file
hardware_backend/mlx/mlx_attention.py New file
hardware_backend/mlx/mlx_graph_runner.py New file
hardware_backend/mlx/mlx_model_loader.py New file

Core Interfaces

# mlx_attention.py
class MLXAttentionBackend(AttentionBackend):
    def __init__(self, model_runner):
        """Initialize MLX attention backend"""

    def init_forward_metadata(self, forward_batch):
        """Setup metadata before forward pass"""

    def forward_decode(self, q, k, v, layer, forward_batch):
        """Single-token generation (decode)"""

    def forward_extend(self, q, k, v, layer, forward_batch):
        """Multi-token prefill (extend)"""
# mlx_graph_runner.py
class MLXGraphRunner:
    def __init__(self, model_runner):
        """Initialize MLX graph runner"""

    def forward(self, input_ids, positions, forward_batch):
        """Execute model forward pass"""
# mlx_model_loader.py
def load_mlx_model(model_path, config):
    """
    Load model in MLX format.
    
    MVP: Expects mlx-community models (e.g., mlx-community/TinyLlama-1.1B-Chat-v1.0-4bit)
    Future: Support PyTorch weight conversion
    """

Testing

# Test basic inference
python -m sglang.launch_server \
  --model-path mlx-community/TinyLlama-1.1B-Chat-v1.0-4bit \
  --device mlx \
  --attention-backend mlx \
  --port 30000
# Send test request
curl http://localhost:30000/v1/completions \
  -H "Content-Type: application/json" \
  -d '{"model": "TinyLlama", "prompt": "Hello, how are you?"}'

Phase 2: Optimization

Goal: Improve performance and memory usage

Features

  • Quantization support (mlx_4bit, mlx_8bit)
  • mx.compile() optimization in graph runner
  • Pre-quantized model loading from HuggingFace
  • KV cache optimization for unified memory

Files to Add

hardware_backend/mlx/quantization/__init__.py
hardware_backend/mlx/quantization/mlx_quant_method.py

Testing

python -m sglang.launch_server \
  --model-path mlx-community/Llama-3.2-3B-Instruct-4bit \
  --device mlx \
  --attention-backend mlx \
  --quantization mlx_4bit

Phase 3: Advanced Features

Goal: Support advanced models and optimizations

Features

  • MOE support (Mixtral, DeepSeek-V2, Qwen2-MoE)
  • Custom Metal shaders for attention
  • Speculative decoding support
  • Multi-device support (if MLX adds support)

Files to Add

hardware_backend/mlx/moe/__init__.py
hardware_backend/mlx/moe/mlx_moe_method.py

Testing

python -m sglang.launch_server \
  --model-path mlx-community/Mixtral-8x7B-Instruct-4bit \
  --device mlx \
  --attention-backend mlx

Related resources

References

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions