Skip to content

[MLX] Add native MLX execution backend for Apple Silicon Mac#20342

Merged
hnyls2002 merged 18 commits intosgl-project:mainfrom
yeahdongcn:xd/mlx_lm
Mar 26, 2026
Merged

[MLX] Add native MLX execution backend for Apple Silicon Mac#20342
hnyls2002 merged 18 commits intosgl-project:mainfrom
yeahdongcn:xd/mlx_lm

Conversation

@yeahdongcn
Copy link
Copy Markdown
Collaborator

@yeahdongcn yeahdongcn commented Mar 11, 2026

Motivation

Introduces MlxModelRunner and MlxTpModelWorker under python/sglang/srt/hardware_backend/mlx, enabling end-to-end model inference via MLX on Apple Silicon. Activated with SGLANG_USE_MLX=1.

Modifications

  • MlxModelRunner replaces the entire PyTorch model pipeline with native MLX execution, bridging only final logits back to PyTorch for sampling.
  • MlxTpModelWorker subclasses TpModelWorker, keeping the base worker and scheduler free of MLX-specific code. Stale request cleanup is handled automatically.
  • bench_one_batch.py uses a runner abstraction (_BenchRunner / _MlxBenchRunner) to unify the benchmark loop.

Accuracy Tests

2

Benchmarking and Profiling

With SGLANG_USE_MLX_ATTENTION=0 (default):

> uv run python -m sglang.bench_one_batch --model-path /Users/yexiaodong/.cache/modelscope/hub/models/Qwen/Qwen3-0.6B --trust-remote-code --disable-radix-cache --disable-cuda-graph --tp-size 1 --batch-size 1 --input-len 60 --output-len 100 --port 43440
W0311 16:28:24.133000 95531 sglang-diffusion/lib/python3.11/site-packages/torch/distributed/elastic/multiprocessing/redirects.py:29] NOTE: Redirects are currently not supported in Windows or MacOs.
/Users/yexiaodong/go/src/github.com/yeahdongcn/sglang/python/sglang/srt/layers/attention/fla/utils.py:212: UserWarning: Triton is not supported on current platform, roll back to CPU.
  warnings.warn(
/Users/yexiaodong/go/src/github.com/yeahdongcn/sglang/python/sglang/srt/layers/quantization/awq.py:87: UserWarning: Only CUDA, HIP and XPU support AWQ currently.
  warnings.warn(f"Only CUDA, HIP and XPU support AWQ currently.")
/Users/yexiaodong/go/src/github.com/yeahdongcn/sglang/python/sglang/srt/layers/quantization/gguf.py:48: UserWarning: Only CUDA and MUSA support GGUF quantization currently.
  warnings.warn(f"Only CUDA and MUSA support GGUF quantization currently.")
[2026-03-11 16:28:25] INFO server_args.py:2140: Attention backend not specified. Use torch_native backend by default.
[2026-03-11 16:28:25] WARNING server_args.py:2146: Cuda graph is disabled because of using torch native attention backend
[2026-03-11 16:28:25] WARNING common.py:1221: Fail to set RLIMIT_STACK: current limit exceeds maximum limit
[2026-03-11 16:28:25 TP0] Init torch distributed begin.
[2026-03-11 16:28:25 TP0] Init torch distributed ends. elapsed=0.05 s, mem usage=0.00 GB
[2026-03-11 16:28:25 TP0] Ignore import error when loading sglang.srt.models.bailing_moe_linear: No module named 'vllm'
[2026-03-11 16:28:25 TP0] Ignore import error when loading sglang.srt.models.bailing_moe_nextn: No module named 'vllm'
[2026-03-11 16:28:26 TP0] Ignore import error when loading sglang.srt.models.glm_ocr: No module named 'transformers.models.glm_ocr'
[2026-03-11 16:28:26 TP0] Ignore import error when loading sglang.srt.models.glm_ocr_nextn: No module named 'transformers.models.glm_ocr'
[2026-03-11 16:28:26 TP0] Ignore import error when loading sglang.srt.models.glmasr: cannot import name 'GlmAsrConfig' from 'transformers' (/Users/yexiaodong/go/src/github.com/yeahdongcn/sglang/sglang-diffusion/lib/python3.11/site-packages/transformers/__init__.py)
/Users/yexiaodong/go/src/github.com/yeahdongcn/sglang/python/sglang/srt/layers/attention/fla/utils.py:212: UserWarning: Triton is not supported on current platform, roll back to CPU.
  warnings.warn(
[2026-03-11 16:28:26 TP0] Load weight begin. avail mem=4.10 GB
Loading safetensors checkpoint shards:   0% Completed | 0/1 [00:00<?, ?it/s]
[2026-03-11 16:28:26 TP0] Parameter lm_head.weight not found in params_dict
Loading safetensors checkpoint shards: 100% Completed | 1/1 [00:02<00:00,  2.24s/it]
Loading safetensors checkpoint shards: 100% Completed | 1/1 [00:02<00:00,  2.25s/it]

[2026-03-11 16:28:28 TP0] Load weight end. elapsed=2.67 s, type=Qwen3ForCausalLM, avail mem=2.22 GB, mem usage=1.88 GB.
[2026-03-11 16:28:28 TP0] Using KV cache dtype: torch.bfloat16
[2026-03-11 16:28:29 TP0] KV Cache is allocated. #tokens: 21290, K size: 1.14 GB, V size: 1.14 GB
[2026-03-11 16:28:29 TP0] Memory pool end. avail mem=2.77 GB
[2026-03-11 16:28:29 TP0] Disable piecewise CUDA graph because --disable-piecewise-cuda-graph is set
max_total_num_tokens=21290
Warmup ...
Prefill. latency: 0.56373 s, throughput:    106.43 token/s
Decode 0. Batch size: 1, latency: 1.12418 s, throughput:      0.89 token/s
Decode 1. Batch size: 1, latency: 0.13345 s, throughput:      7.49 token/s
Decode 2. Batch size: 1, latency: 0.16514 s, throughput:      6.06 token/s
Decode 3. Batch size: 1, latency: 0.12836 s, throughput:      7.79 token/s
Decode 4. Batch size: 1, latency: 0.11836 s, throughput:      8.45 token/s
Decode.  median latency: 0.12508 s, median throughput:      8.00 token/s
Total. latency:  5.513 s, throughput:     16.69 token/s
Benchmark ...
Prefill. latency: 0.25097 s, throughput:    239.07 token/s
Decode 0. Batch size: 1, latency: 0.10205 s, throughput:      9.80 token/s
Decode 1. Batch size: 1, latency: 0.10695 s, throughput:      9.35 token/s
Decode 2. Batch size: 1, latency: 0.11190 s, throughput:      8.94 token/s
Decode 3. Batch size: 1, latency: 0.10398 s, throughput:      9.62 token/s
Decode 4. Batch size: 1, latency: 0.10563 s, throughput:      9.47 token/s
Decode.  median latency: 0.13053 s, median throughput:      7.66 token/s
Total. latency: 13.187 s, throughput:     12.13 token/s

With SGLANG_USE_MLX_ATTENTION=1:

> export SGLANG_USE_MLX=1
> uv run python -m sglang.bench_one_batch --model-path /Users/yexiaodong/.cache/modelscope/hub/models/Qwen/Qwen3-0.6B --trust-remote-code --disable-radix-cache --disable-cuda-graph --tp-size 1 --batch-size 1 --input-len 60 --output-len 100 --port 43440
W0311 16:29:29.661000 96241 sglang-diffusion/lib/python3.11/site-packages/torch/distributed/elastic/multiprocessing/redirects.py:29] NOTE: Redirects are currently not supported in Windows or MacOs.
/Users/yexiaodong/go/src/github.com/yeahdongcn/sglang/python/sglang/srt/layers/attention/fla/utils.py:212: UserWarning: Triton is not supported on current platform, roll back to CPU.
  warnings.warn(
/Users/yexiaodong/go/src/github.com/yeahdongcn/sglang/python/sglang/srt/layers/quantization/awq.py:87: UserWarning: Only CUDA, HIP and XPU support AWQ currently.
  warnings.warn(f"Only CUDA, HIP and XPU support AWQ currently.")
/Users/yexiaodong/go/src/github.com/yeahdongcn/sglang/python/sglang/srt/layers/quantization/gguf.py:48: UserWarning: Only CUDA and MUSA support GGUF quantization currently.
  warnings.warn(f"Only CUDA and MUSA support GGUF quantization currently.")
[2026-03-11 16:29:31] INFO server_args.py:2140: Attention backend not specified. Use torch_native backend by default.
[2026-03-11 16:29:31] WARNING server_args.py:2146: Cuda graph is disabled because of using torch native attention backend
[2026-03-11 16:29:31] WARNING common.py:1221: Fail to set RLIMIT_STACK: current limit exceeds maximum limit
[2026-03-11 16:29:31 TP0] Init torch distributed begin.
[2026-03-11 16:29:31 TP0] Init torch distributed ends. elapsed=0.08 s, mem usage=0.02 GB
[2026-03-11 16:29:31 TP0] Ignore import error when loading sglang.srt.models.bailing_moe_linear: No module named 'vllm'
[2026-03-11 16:29:31 TP0] Ignore import error when loading sglang.srt.models.bailing_moe_nextn: No module named 'vllm'
[2026-03-11 16:29:31 TP0] Ignore import error when loading sglang.srt.models.glm_ocr: No module named 'transformers.models.glm_ocr'
[2026-03-11 16:29:31 TP0] Ignore import error when loading sglang.srt.models.glm_ocr_nextn: No module named 'transformers.models.glm_ocr'
[2026-03-11 16:29:31 TP0] Ignore import error when loading sglang.srt.models.glmasr: cannot import name 'GlmAsrConfig' from 'transformers' (/Users/yexiaodong/go/src/github.com/yeahdongcn/sglang/sglang-diffusion/lib/python3.11/site-packages/transformers/__init__.py)
/Users/yexiaodong/go/src/github.com/yeahdongcn/sglang/python/sglang/srt/layers/attention/fla/utils.py:212: UserWarning: Triton is not supported on current platform, roll back to CPU.
  warnings.warn(
[2026-03-11 16:29:31 TP0] Load weight begin. avail mem=5.35 GB
Loading safetensors checkpoint shards:   0% Completed | 0/1 [00:00<?, ?it/s]
[2026-03-11 16:29:31 TP0] Parameter lm_head.weight not found in params_dict
Loading safetensors checkpoint shards: 100% Completed | 1/1 [00:01<00:00,  1.45s/it]
Loading safetensors checkpoint shards: 100% Completed | 1/1 [00:01<00:00,  1.45s/it]

[2026-03-11 16:29:33 TP0] Load weight end. elapsed=1.74 s, type=Qwen3ForCausalLM, avail mem=2.74 GB, mem usage=2.61 GB.
[2026-03-11 16:29:33 TP0] Using KV cache dtype: torch.bfloat16
[2026-03-11 16:29:33 TP0] KV Cache is allocated. #tokens: 21471, K size: 1.15 GB, V size: 1.15 GB
[2026-03-11 16:29:33 TP0] Memory pool end. avail mem=1.81 GB
[2026-03-11 16:29:33 TP0] Disable piecewise CUDA graph because --disable-piecewise-cuda-graph is set
max_total_num_tokens=21471
Initializing MlxModelRunner for end-to-end MLX inference
[2026-03-11 16:29:33 TP0] Loading MLX model: /Users/yexiaodong/.cache/modelscope/hub/models/Qwen/Qwen3-0.6B
[2026-03-11 16:29:34 TP0] MLX model loaded in 0.79s
Warmup ...
Prefill. latency: 0.35995 s, throughput:    166.69 token/s
Decode 0. Batch size: 1, latency: 0.03179 s, throughput:     31.45 token/s
Decode 1. Batch size: 1, latency: 0.02628 s, throughput:     38.06 token/s
Decode 2. Batch size: 1, latency: 0.02691 s, throughput:     37.16 token/s
Decode 3. Batch size: 1, latency: 0.02608 s, throughput:     38.35 token/s
Decode 4. Batch size: 1, latency: 0.02762 s, throughput:     36.21 token/s
Decode.  median latency: 0.02608 s, median throughput:     38.35 token/s
Total. latency:  1.176 s, throughput:     78.21 token/s
Benchmark ...
Prefill. latency: 0.05681 s, throughput:   1056.14 token/s
Decode 0. Batch size: 1, latency: 0.02790 s, throughput:     35.85 token/s
Decode 1. Batch size: 1, latency: 0.02710 s, throughput:     36.90 token/s
Decode 2. Batch size: 1, latency: 0.02440 s, throughput:     40.99 token/s
Decode 3. Batch size: 1, latency: 0.02513 s, throughput:     39.79 token/s
Decode 4. Batch size: 1, latency: 0.02408 s, throughput:     41.53 token/s
Decode.  median latency: 0.02832 s, median throughput:     35.31 token/s
Total. latency:  2.928 s, throughput:     54.65 token/s

Checklist

Review Process

  1. Ping Merge Oncalls to start the PR flow. See the PR Merge Process.
  2. Get approvals from CODEOWNERS and other reviewers.
  3. Trigger CI tests with comments or contact authorized users to do so.
    • /tag-run-ci-label, /rerun-failed-ci, /tag-and-rerun-ci
  4. After green CI and required approvals, ask Merge Oncalls to merge.

@yeahdongcn yeahdongcn requested a review from alexnails March 11, 2026 08:31
@yeahdongcn yeahdongcn marked this pull request as ready for review March 11, 2026 08:31
@yeahdongcn
Copy link
Copy Markdown
Collaborator Author

Perf data for ref: #20221

@gemini-code-assist
Copy link
Copy Markdown
Contributor

Summary of Changes

Hello, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

This pull request significantly enhances SGLang's performance on Apple Silicon Macs by integrating a native MLX execution backend. The changes allow models to run entirely within MLX, leveraging its optimizations for Apple hardware, while maintaining compatibility with SGLang's existing sampling pipeline by selectively bridging logits to PyTorch. This provides a more efficient inference path for users on Apple Silicon, as demonstrated by the provided benchmarks showing substantial throughput improvements.

Highlights

  • Native MLX Execution Backend: Introduced a native MLX execution backend for Apple Silicon Macs, enabling end-to-end model inference via MLX, bypassing PyTorch MPS for core computations.
  • MlxModelRunner Implementation: Added MlxModelRunner to manage MLX model loading, prefill, and batched decode operations, bridging only final logits back to PyTorch for sampling compatibility.
  • MlxTpModelWorker Integration: Created MlxTpModelWorker which subclasses TpModelWorker to integrate the MLX model runner into the existing SGLang worker architecture, handling request state cleanup automatically.
  • Benchmarking Abstraction: Refactored bench_one_batch.py with _BenchRunner and _MlxBenchRunner classes to provide a unified abstraction for benchmarking both PyTorch and MLX execution paths.
  • Centralized Tensor Bridging: Implemented sglang.srt.utils.tensor_bridge.py to centralize PyTorch-MLX tensor conversions, improving code reusability and handling MPS memory considerations.
  • Activation via Environment Variable: Enabled the MLX backend to be activated by setting the SGLANG_USE_MLX=1 environment variable.

🧠 New Feature in Public Preview: You can now enable Memory to help Gemini Code Assist learn from your team's feedback. This makes future code reviews more consistent and personalized to your project's style. Click here to enable Memory in your admin console.

Changelog
  • python/sglang/bench_one_batch.py
    • Imported the envs module for environment variable access.
    • Integrated MlxModelRunner when SGLANG_USE_MLX is enabled.
    • Introduced _BenchRunner and _MlxBenchRunner classes to abstract model execution for benchmarking.
    • Updated correctness_test and latency_test_run_once functions to utilize the new runner abstraction.
    • Removed the device parameter from latency_test_run_once and latency_test as device handling is now encapsulated within the runner.
  • python/sglang/jit_kernel/diffusion/triton/mps_fallback.py
    • Imported mlx_to_torch and torch_to_mlx from the new sglang.srt.utils.tensor_bridge module.
    • Removed redundant internal _torch_to_mlx and _mlx_to_torch functions.
    • Updated norm_infer_native, triton_one_pass_rms_norm_native, and rms_norm_fn_native to use the new centralized tensor bridging utilities.
  • python/sglang/srt/hardware_backend/mlx/model_runner.py
    • Added a new file defining the MlxModelRunner class for end-to-end MLX inference.
    • Implemented MlxRequestState for managing per-request state during MLX inference.
    • Provided utility functions _merge_kv_caches and _extract_kv_cache for batched KV cache management.
    • Implemented methods for MLX model loading, prefill, single and batched decode, logits extraction, and request state management.
  • python/sglang/srt/hardware_backend/mlx/tp_worker.py
    • Added a new file defining the MlxTpModelWorker class.
    • Subclassed TpModelWorker to override the forward pass for MLX execution.
    • Implemented _forward_batch_generation_mlx to handle prefill and decode operations using the MlxModelRunner.
    • Included logic for automatic cleanup of stale MLX request states.
  • python/sglang/srt/managers/scheduler.py
    • Modified init_tp_model_worker to conditionally initialize MlxTpModelWorker if the SGLANG_USE_MLX environment variable is set, otherwise defaulting to TpModelWorker.
  • python/sglang/srt/utils/tensor_bridge.py
    • Added a new file providing a centralized module for converting tensors between PyTorch and MLX.
    • Defined MLX_TO_TORCH_DTYPE and TORCH_TO_MLX_DTYPE mappings for various data types.
    • Implemented torch_to_mlx and mlx_to_torch functions, including considerations for MPS device memory limits.
    • Included sync_mlx and sync_torch functions for explicit synchronization between MLX and PyTorch.
Activity
  • The author introduced MlxModelRunner and MlxTpModelWorker to enable MLX execution on Apple Silicon.
  • Benchmarking results were provided, showing significant throughput improvements when SGLANG_USE_MLX=1 is enabled, particularly in decode operations (e.g., decode throughput increased from ~9 token/s to ~35 token/s in benchmark runs).
  • The author detailed modifications across several files to integrate the new MLX backend and abstract benchmarking logic.
Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point by creating a comment using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands on the current page.

Feature Command Description
Code Review /gemini review Performs a code review for the current pull request in its current state.
Pull Request Summary /gemini summary Provides a summary of the current pull request in its current state.
Comment @gemini-code-assist Responds in comments when explicitly tagged, both in pull request comments and review comments.
Help /gemini help Displays a list of available commands.

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for GitHub and other Google products, sign up here.

You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution.

Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces a native MLX execution backend for Apple Silicon, which is a significant and well-implemented feature. The changes are well-structured, including a new MlxModelRunner for end-to-end MLX inference, a MlxTpModelWorker for integration with the scheduler, and updates to the benchmark script to support both PyTorch and MLX backends through a clean abstraction. The refactoring of tensor conversion logic into a centralized tensor_bridge.py is also a good improvement.

I have a few suggestions to enhance the code quality and robustness. Specifically, I've pointed out a local import that could be moved, some duplicated code that could be refactored into a helper method, and a more critical issue regarding the lack of support for ForwardMode.MIXED in the new MlxTpModelWorker, which could lead to errors with features like chunked prefill.

Comment thread python/sglang/srt/hardware_backend/mlx/tp_worker.py
Comment thread python/sglang/srt/hardware_backend/mlx/model_runner.py Outdated
Comment thread python/sglang/srt/hardware_backend/mlx/model_runner.py Outdated
Signed-off-by: Xiaodong Ye <yeahdongcn@gmail.com>
@yeahdongcn
Copy link
Copy Markdown
Collaborator Author

yeahdongcn commented Mar 11, 2026

@gemini-code-assist review

Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces a native MLX execution backend for Apple Silicon, which is a significant and valuable feature for improving performance on Macs. The implementation is well-structured, using abstractions to support both the existing PyTorch path and the new MLX path. The core logic for the MLX runner and its integration into the scheduler seems solid. My review focuses on a few key areas for improvement: memory efficiency, benchmark accuracy, and type hint correctness. Overall, this is a strong contribution that will benefit users on Apple Silicon.

Note: Security Review did not run due to the size of the PR.

Comment thread python/sglang/srt/hardware_backend/mlx/tp_worker.py Outdated
Comment thread python/sglang/bench_one_batch.py
Comment thread python/sglang/bench_one_batch.py Outdated
Comment thread python/sglang/srt/hardware_backend/mlx/model_runner.py Outdated
Signed-off-by: Xiaodong Ye <yeahdongcn@gmail.com>
@mickqian
Copy link
Copy Markdown
Collaborator

mickqian commented Mar 11, 2026

reproduced on my MacBook:

SGLANG_USE_MLX=1 uv run python -m sglang.bench_one_batch --model-path /Users/xxx/.cache/huggingface/hub/models--Qwen--Qwen3-0.6B/snapshots/c1899de289a04d12100db370d81485cdf75e47ca/ --trust-remote-code --disable-radix-cache --disable-cuda-graph --tp-size 1 --batch-size 1 --input-len 60 --output-len 100 --port 30000

Warmup ...
Prefill. latency: 0.11839 s, throughput:    506.82 token/s
Decode 0. Batch size: 1, latency: 0.00890 s, throughput:    112.41 token/s
Decode 1. Batch size: 1, latency: 0.00776 s, throughput:    128.79 token/s
Decode 2. Batch size: 1, latency: 0.00754 s, throughput:    132.54 token/s
Decode 3. Batch size: 1, latency: 0.00774 s, throughput:    129.17 token/s
Decode 4. Batch size: 1, latency: 0.00726 s, throughput:    137.79 token/s
Decode.  median latency: 0.00741 s, median throughput:    134.90 token/s
Total. latency:  0.350 s, throughput:    263.00 token/s
Benchmark ...
Prefill. latency: 0.01814 s, throughput:   3306.82 token/s
Decode 0. Batch size: 1, latency: 0.00722 s, throughput:    138.50 token/s
Decode 1. Batch size: 1, latency: 0.00737 s, throughput:    135.67 token/s
Decode 2. Batch size: 1, latency: 0.00710 s, throughput:    140.89 token/s
Decode 3. Batch size: 1, latency: 0.00711 s, throughput:    140.58 token/s
Decode 4. Batch size: 1, latency: 0.00715 s, throughput:    139.78 token/s
Decode.  median latency: 0.00692 s, median throughput:    144.48 token/s
Total. latency:  0.707 s, throughput:    226.24 token/s

…location

Signed-off-by: Xiaodong Ye <yeahdongcn@gmail.com>
Signed-off-by: Xiaodong Ye <yeahdongcn@gmail.com>
@changminbark
Copy link
Copy Markdown
Contributor

@yeahdongcn After this gets merged, I can start looking into performance optimizations related to the scheduler. How does that sound?

@yeahdongcn
Copy link
Copy Markdown
Collaborator Author

/gemini summary

@gemini-code-assist
Copy link
Copy Markdown
Contributor

Warning

You have reached your daily quota limit. Please wait up to 24 hours and I will start processing your requests again!

Comment thread python/sglang/bench_one_batch.py
Comment thread python/sglang/srt/hardware_backend/mlx/model_runner.py Outdated
Comment thread python/sglang/jit_kernel/diffusion/triton/mps_fallback.py Outdated
Comment thread python/sglang/bench_one_batch.py
Signed-off-by: Xiaodong Ye <yeahdongcn@gmail.com>
@yeahdongcn
Copy link
Copy Markdown
Collaborator Author

yeahdongcn commented Mar 13, 2026

test_mps_fallback_norms.py (AI-generated):

"""Tests for MPS fallback norm functions — verifies MLX-accelerated versions
match the pure-PyTorch reference for various input ranks (2D, 3D, 4D)."""

import pytest
import torch


def _requires_mlx():
    try:
        import mlx.core as mx  # noqa: F401

        return True
    except ImportError:
        return False


pytestmark = pytest.mark.skipif(not _requires_mlx(), reason="MLX not available")

# ── Reference implementations (pure math, no reshaping) ──────────────────────


def _ref_rms_norm(x: torch.Tensor, w: torch.Tensor, eps: float) -> torch.Tensor:
    x_f = x.float()
    variance = x_f.pow(2).mean(dim=-1, keepdim=True)
    x_hat = x_f * torch.rsqrt(variance + eps)
    return (x_hat * w.float()).to(x.dtype)


def _ref_rms_norm_fn(x, weight, bias, residual, eps, residual_in_fp32, zero_centered_weight, out_dtype):
    orig_dtype = x.dtype
    x_f = x.float()
    if residual is not None:
        x_f = x_f + residual.float()
        residual_out_val = x_f.to(torch.float32 if residual_in_fp32 else orig_dtype)
    else:
        residual_out_val = None
    variance = x_f.pow(2).mean(dim=-1, keepdim=True)
    x_hat = x_f * torch.rsqrt(variance + eps)
    if weight is not None:
        w = weight.float()
        if zero_centered_weight:
            w = w + 1.0
        x_hat = x_hat * w
    if bias is not None:
        x_hat = x_hat + bias.float()
    final_dtype = out_dtype if out_dtype is not None else orig_dtype
    y = x_hat.to(final_dtype)
    if residual is not None and residual_out_val is not None:
        return y, residual_out_val
    return y


# ── Test shapes ───────────────────────────────────────────────────────────────

SHAPES_2D = [(4, 64), (16, 128), (1, 256)]
SHAPES_3D = [(2, 8, 64), (1, 16, 128), (4, 4, 256)]
SHAPES_4D = [(1, 2, 8, 64), (2, 2, 4, 128)]
ALL_SHAPES = SHAPES_2D + SHAPES_3D + SHAPES_4D

EPS = 1e-6
ATOL = 1e-4
RTOL = 1e-3


# ── Tests for triton_one_pass_rms_norm_native (MLX) ──────────────────────────


@pytest.mark.parametrize("shape", ALL_SHAPES)
def test_triton_one_pass_rms_norm_native_shapes(shape):
    """MLX triton_one_pass_rms_norm_native matches reference for N-D inputs."""
    from sglang.jit_kernel.diffusion.triton.mps_fallback import (
        triton_one_pass_rms_norm_native,
    )

    dim = shape[-1]
    x = torch.randn(shape, dtype=torch.float32)
    w = torch.randn(dim, dtype=torch.float32)

    result = triton_one_pass_rms_norm_native(x, w, eps=EPS)
    expected = _ref_rms_norm(x, w, EPS)

    assert result.shape == expected.shape, f"Shape mismatch: {result.shape} vs {expected.shape}"
    torch.testing.assert_close(result, expected, atol=ATOL, rtol=RTOL)


# ── Tests for rms_norm_fn_native (MLX) ───────────────────────────────────────


@pytest.mark.parametrize("shape", ALL_SHAPES)
def test_rms_norm_fn_native_basic(shape):
    """MLX rms_norm_fn_native matches reference for N-D inputs (no residual)."""
    from sglang.jit_kernel.diffusion.triton.mps_fallback import rms_norm_fn_native

    dim = shape[-1]
    x = torch.randn(shape, dtype=torch.float32)
    w = torch.randn(dim, dtype=torch.float32)

    result = rms_norm_fn_native(x, w, bias=None, eps=EPS)
    expected = _ref_rms_norm_fn(x, w, None, None, EPS, False, False, None)

    assert result.shape == expected.shape
    torch.testing.assert_close(result, expected, atol=ATOL, rtol=RTOL)


@pytest.mark.parametrize("shape", ALL_SHAPES)
def test_rms_norm_fn_native_with_residual(shape):
    """MLX rms_norm_fn_native matches reference with residual for N-D inputs."""
    from sglang.jit_kernel.diffusion.triton.mps_fallback import rms_norm_fn_native

    dim = shape[-1]
    x = torch.randn(shape, dtype=torch.float32)
    w = torch.randn(dim, dtype=torch.float32)
    residual = torch.randn(shape, dtype=torch.float32)

    result_y, result_res = rms_norm_fn_native(x, w, bias=None, residual=residual, eps=EPS)
    expected_y, expected_res = _ref_rms_norm_fn(x, w, None, residual, EPS, False, False, None)

    assert result_y.shape == expected_y.shape
    assert result_res.shape == expected_res.shape
    torch.testing.assert_close(result_y, expected_y, atol=ATOL, rtol=RTOL)
    torch.testing.assert_close(result_res, expected_res, atol=ATOL, rtol=RTOL)


@pytest.mark.parametrize("shape", SHAPES_3D)
def test_rms_norm_fn_native_with_bias_and_zero_centered(shape):
    """MLX rms_norm_fn_native with bias and zero_centered_weight."""
    from sglang.jit_kernel.diffusion.triton.mps_fallback import rms_norm_fn_native

    dim = shape[-1]
    x = torch.randn(shape, dtype=torch.float32)
    w = torch.randn(dim, dtype=torch.float32)
    b = torch.randn(dim, dtype=torch.float32)

    result = rms_norm_fn_native(x, w, bias=b, eps=EPS, zero_centered_weight=True)
    expected = _ref_rms_norm_fn(x, w, b, None, EPS, False, True, None)

    assert result.shape == expected.shape
    torch.testing.assert_close(result, expected, atol=ATOL, rtol=RTOL)


# ── Tests for norm_infer_native (MLX) ────────────────────────────────────────


@pytest.mark.parametrize("shape", ALL_SHAPES)
@pytest.mark.parametrize("is_rms_norm", [True, False])
def test_norm_infer_native(shape, is_rms_norm):
    """MLX norm_infer_native matches PyTorch reference for N-D inputs."""
    from sglang.jit_kernel.diffusion.triton.mps_fallback import norm_infer_native

    dim = shape[-1]
    x = torch.randn(shape, dtype=torch.float32)
    w = torch.randn(dim, dtype=torch.float32)
    b = None if is_rms_norm else torch.randn(dim, dtype=torch.float32)

    result = norm_infer_native(x, w, b, EPS, is_rms_norm=is_rms_norm)

    # Compute reference
    x_f = x.float()
    if is_rms_norm:
        variance = x_f.pow(2).mean(dim=-1, keepdim=True)
        x_hat = x_f * torch.rsqrt(variance + EPS)
        expected = (x_hat * w.float()).to(x.dtype)
    else:
        mean = x_f.mean(dim=-1, keepdim=True)
        variance = (x_f - mean).pow(2).mean(dim=-1, keepdim=True)
        x_hat = (x_f - mean) * torch.rsqrt(variance + EPS)
        expected = (x_hat * w.float() + b.float()).to(x.dtype)

    assert result.shape == expected.shape
    torch.testing.assert_close(result, expected, atol=ATOL, rtol=RTOL)


if __name__ == "__main__":
    pytest.main([__file__, "-v", "-s"])
> uv run python python/sglang/jit_kernel/tests/test_mps_fallback_norms.py
================================================================ test session starts =================================================================
platform darwin -- Python 3.11.15, pytest-9.0.2, pluggy-1.6.0 -- /Users/yexiaodong/go/src/github.com/yeahdongcn/sglang/sglang-diffusion/bin/python3
cachedir: .pytest_cache
rootdir: /Users/yexiaodong/go/src/github.com/yeahdongcn/sglang/python
configfile: pyproject.toml
plugins: anyio-4.12.1
collected 43 items                                                                                                                                   

python/sglang/jit_kernel/tests/test_mps_fallback_norms.py::test_triton_one_pass_rms_norm_native_shapes[shape0] PASSED
python/sglang/jit_kernel/tests/test_mps_fallback_norms.py::test_triton_one_pass_rms_norm_native_shapes[shape1] PASSED
python/sglang/jit_kernel/tests/test_mps_fallback_norms.py::test_triton_one_pass_rms_norm_native_shapes[shape2] PASSED
python/sglang/jit_kernel/tests/test_mps_fallback_norms.py::test_triton_one_pass_rms_norm_native_shapes[shape3] PASSED
python/sglang/jit_kernel/tests/test_mps_fallback_norms.py::test_triton_one_pass_rms_norm_native_shapes[shape4] PASSED
python/sglang/jit_kernel/tests/test_mps_fallback_norms.py::test_triton_one_pass_rms_norm_native_shapes[shape5] PASSED
python/sglang/jit_kernel/tests/test_mps_fallback_norms.py::test_triton_one_pass_rms_norm_native_shapes[shape6] PASSED
python/sglang/jit_kernel/tests/test_mps_fallback_norms.py::test_triton_one_pass_rms_norm_native_shapes[shape7] PASSED
python/sglang/jit_kernel/tests/test_mps_fallback_norms.py::test_rms_norm_fn_native_basic[shape0] PASSED
python/sglang/jit_kernel/tests/test_mps_fallback_norms.py::test_rms_norm_fn_native_basic[shape1] PASSED
python/sglang/jit_kernel/tests/test_mps_fallback_norms.py::test_rms_norm_fn_native_basic[shape2] PASSED
python/sglang/jit_kernel/tests/test_mps_fallback_norms.py::test_rms_norm_fn_native_basic[shape3] PASSED
python/sglang/jit_kernel/tests/test_mps_fallback_norms.py::test_rms_norm_fn_native_basic[shape4] PASSED
python/sglang/jit_kernel/tests/test_mps_fallback_norms.py::test_rms_norm_fn_native_basic[shape5] PASSED
python/sglang/jit_kernel/tests/test_mps_fallback_norms.py::test_rms_norm_fn_native_basic[shape6] PASSED
python/sglang/jit_kernel/tests/test_mps_fallback_norms.py::test_rms_norm_fn_native_basic[shape7] PASSED
python/sglang/jit_kernel/tests/test_mps_fallback_norms.py::test_rms_norm_fn_native_with_residual[shape0] PASSED
python/sglang/jit_kernel/tests/test_mps_fallback_norms.py::test_rms_norm_fn_native_with_residual[shape1] PASSED
python/sglang/jit_kernel/tests/test_mps_fallback_norms.py::test_rms_norm_fn_native_with_residual[shape2] PASSED
python/sglang/jit_kernel/tests/test_mps_fallback_norms.py::test_rms_norm_fn_native_with_residual[shape3] PASSED
python/sglang/jit_kernel/tests/test_mps_fallback_norms.py::test_rms_norm_fn_native_with_residual[shape4] PASSED
python/sglang/jit_kernel/tests/test_mps_fallback_norms.py::test_rms_norm_fn_native_with_residual[shape5] PASSED
python/sglang/jit_kernel/tests/test_mps_fallback_norms.py::test_rms_norm_fn_native_with_residual[shape6] PASSED
python/sglang/jit_kernel/tests/test_mps_fallback_norms.py::test_rms_norm_fn_native_with_residual[shape7] PASSED
python/sglang/jit_kernel/tests/test_mps_fallback_norms.py::test_rms_norm_fn_native_with_bias_and_zero_centered[shape0] PASSED
python/sglang/jit_kernel/tests/test_mps_fallback_norms.py::test_rms_norm_fn_native_with_bias_and_zero_centered[shape1] PASSED
python/sglang/jit_kernel/tests/test_mps_fallback_norms.py::test_rms_norm_fn_native_with_bias_and_zero_centered[shape2] PASSED
python/sglang/jit_kernel/tests/test_mps_fallback_norms.py::test_norm_infer_native[True-shape0] PASSED
python/sglang/jit_kernel/tests/test_mps_fallback_norms.py::test_norm_infer_native[True-shape1] PASSED
python/sglang/jit_kernel/tests/test_mps_fallback_norms.py::test_norm_infer_native[True-shape2] PASSED
python/sglang/jit_kernel/tests/test_mps_fallback_norms.py::test_norm_infer_native[True-shape3] PASSED
python/sglang/jit_kernel/tests/test_mps_fallback_norms.py::test_norm_infer_native[True-shape4] PASSED
python/sglang/jit_kernel/tests/test_mps_fallback_norms.py::test_norm_infer_native[True-shape5] PASSED
python/sglang/jit_kernel/tests/test_mps_fallback_norms.py::test_norm_infer_native[True-shape6] PASSED
python/sglang/jit_kernel/tests/test_mps_fallback_norms.py::test_norm_infer_native[True-shape7] PASSED
python/sglang/jit_kernel/tests/test_mps_fallback_norms.py::test_norm_infer_native[False-shape0] PASSED
python/sglang/jit_kernel/tests/test_mps_fallback_norms.py::test_norm_infer_native[False-shape1] PASSED
python/sglang/jit_kernel/tests/test_mps_fallback_norms.py::test_norm_infer_native[False-shape2] PASSED
python/sglang/jit_kernel/tests/test_mps_fallback_norms.py::test_norm_infer_native[False-shape3] PASSED
python/sglang/jit_kernel/tests/test_mps_fallback_norms.py::test_norm_infer_native[False-shape4] PASSED
python/sglang/jit_kernel/tests/test_mps_fallback_norms.py::test_norm_infer_native[False-shape5] PASSED
python/sglang/jit_kernel/tests/test_mps_fallback_norms.py::test_norm_infer_native[False-shape6] PASSED
python/sglang/jit_kernel/tests/test_mps_fallback_norms.py::test_norm_infer_native[False-shape7] PASSED

================================================================= 43 passed in 0.78s =================================================================

@yeahdongcn
Copy link
Copy Markdown
Collaborator Author

/gemini summary

@yeahdongcn
Copy link
Copy Markdown
Collaborator Author

yeahdongcn commented Mar 16, 2026

It looks like we need CI coverage for macOS @Kangyan-Zhou. #19997 appears to have broken macOS functionality (just fixed in the latest commit).

@yeahdongcn
Copy link
Copy Markdown
Collaborator Author

Compared with vllm-metal on my MBP (M1+16G), the performance looks similar for Qwen3-0.6B (vllm: 27.09 tok/s; sglang with SGLANG_USE_MLX=1: 28.21 tok/s).

cuttini added a commit to cuttini/sglang that referenced this pull request Mar 18, 2026
Adds native MLX execution backend for Apple Silicon:
- MlxModelRunner: full model inference in MLX, bridges logits to PyTorch for sampling
- MlxModelRunnerStub: skips PyTorch weight loading when using MLX
- MlxTpModelWorker: subclasses TpModelWorker for MLX-specific cleanup
- tensor_bridge.py: SGLANG_USE_MLX=1 env var to activate

Source: sgl-project#20342
Combined with our --pp-layer-start/--pp-layer-end for layer sharding.

Usage on Apple Silicon:
  SGLANG_USE_MLX=1 python -m sglang.launch_server \
    --model Qwen/Qwen2.5-0.5B-Instruct \
    --pp-layer-start 0 --pp-layer-end 8
Comment thread python/sglang/srt/hardware_backend/mlx/model_runner_stub.py
Comment thread python/sglang/srt/hardware_backend/mlx/model_runner_stub.py
@alexnails
Copy link
Copy Markdown
Collaborator

Other than those LGTM!

[with caveats that we will need to dedup this for multi hardware for Metal (whenever that effort starts in full)]

@hnyls2002
Copy link
Copy Markdown
Collaborator

/rerun-failed-ci

Signed-off-by: Xiaodong Ye <yeahdongcn@gmail.com>
Signed-off-by: Xiaodong Ye <yeahdongcn@gmail.com>
@yeahdongcn
Copy link
Copy Markdown
Collaborator Author

/rerun-failed-ci

2 similar comments
@mickqian
Copy link
Copy Markdown
Collaborator

/rerun-failed-ci

@yeahdongcn
Copy link
Copy Markdown
Collaborator Author

/rerun-failed-ci

@hnyls2002 hnyls2002 merged commit a305964 into sgl-project:main Mar 26, 2026
483 of 555 checks passed
satyamk7054 pushed a commit to satyamk7054/sglang that referenced this pull request Apr 3, 2026
JustinTong0323 pushed a commit to JustinTong0323/sglang that referenced this pull request Apr 7, 2026
yhyang201 pushed a commit to yhyang201/sglang that referenced this pull request Apr 22, 2026
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

dependencies Pull requests that update a dependency file jit-kernel macos run-ci

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants