[MLX] Add native MLX execution backend for Apple Silicon Mac#20342
[MLX] Add native MLX execution backend for Apple Silicon Mac#20342hnyls2002 merged 18 commits intosgl-project:mainfrom
Conversation
|
Perf data for ref: #20221 |
Summary of ChangesHello, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed! This pull request significantly enhances SGLang's performance on Apple Silicon Macs by integrating a native MLX execution backend. The changes allow models to run entirely within MLX, leveraging its optimizations for Apple hardware, while maintaining compatibility with SGLang's existing sampling pipeline by selectively bridging logits to PyTorch. This provides a more efficient inference path for users on Apple Silicon, as demonstrated by the provided benchmarks showing substantial throughput improvements. Highlights
🧠 New Feature in Public Preview: You can now enable Memory to help Gemini Code Assist learn from your team's feedback. This makes future code reviews more consistent and personalized to your project's style. Click here to enable Memory in your admin console. Changelog
Activity
Using Gemini Code AssistThe full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips. Invoking Gemini You can request assistance from Gemini at any point by creating a comment using either
Customization To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a Limitations & Feedback Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for GitHub and other Google products, sign up here. You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension. Footnotes
|
There was a problem hiding this comment.
Code Review
This pull request introduces a native MLX execution backend for Apple Silicon, which is a significant and well-implemented feature. The changes are well-structured, including a new MlxModelRunner for end-to-end MLX inference, a MlxTpModelWorker for integration with the scheduler, and updates to the benchmark script to support both PyTorch and MLX backends through a clean abstraction. The refactoring of tensor conversion logic into a centralized tensor_bridge.py is also a good improvement.
I have a few suggestions to enhance the code quality and robustness. Specifically, I've pointed out a local import that could be moved, some duplicated code that could be refactored into a helper method, and a more critical issue regarding the lack of support for ForwardMode.MIXED in the new MlxTpModelWorker, which could lead to errors with features like chunked prefill.
Signed-off-by: Xiaodong Ye <yeahdongcn@gmail.com>
|
@gemini-code-assist review |
There was a problem hiding this comment.
Code Review
This pull request introduces a native MLX execution backend for Apple Silicon, which is a significant and valuable feature for improving performance on Macs. The implementation is well-structured, using abstractions to support both the existing PyTorch path and the new MLX path. The core logic for the MLX runner and its integration into the scheduler seems solid. My review focuses on a few key areas for improvement: memory efficiency, benchmark accuracy, and type hint correctness. Overall, this is a strong contribution that will benefit users on Apple Silicon.
Note: Security Review did not run due to the size of the PR.
Signed-off-by: Xiaodong Ye <yeahdongcn@gmail.com>
|
reproduced on my MacBook: |
…location Signed-off-by: Xiaodong Ye <yeahdongcn@gmail.com>
Signed-off-by: Xiaodong Ye <yeahdongcn@gmail.com>
|
@yeahdongcn After this gets merged, I can start looking into performance optimizations related to the scheduler. How does that sound? |
|
/gemini summary |
|
Warning You have reached your daily quota limit. Please wait up to 24 hours and I will start processing your requests again! |
Signed-off-by: Xiaodong Ye <yeahdongcn@gmail.com>
|
"""Tests for MPS fallback norm functions — verifies MLX-accelerated versions
match the pure-PyTorch reference for various input ranks (2D, 3D, 4D)."""
import pytest
import torch
def _requires_mlx():
try:
import mlx.core as mx # noqa: F401
return True
except ImportError:
return False
pytestmark = pytest.mark.skipif(not _requires_mlx(), reason="MLX not available")
# ── Reference implementations (pure math, no reshaping) ──────────────────────
def _ref_rms_norm(x: torch.Tensor, w: torch.Tensor, eps: float) -> torch.Tensor:
x_f = x.float()
variance = x_f.pow(2).mean(dim=-1, keepdim=True)
x_hat = x_f * torch.rsqrt(variance + eps)
return (x_hat * w.float()).to(x.dtype)
def _ref_rms_norm_fn(x, weight, bias, residual, eps, residual_in_fp32, zero_centered_weight, out_dtype):
orig_dtype = x.dtype
x_f = x.float()
if residual is not None:
x_f = x_f + residual.float()
residual_out_val = x_f.to(torch.float32 if residual_in_fp32 else orig_dtype)
else:
residual_out_val = None
variance = x_f.pow(2).mean(dim=-1, keepdim=True)
x_hat = x_f * torch.rsqrt(variance + eps)
if weight is not None:
w = weight.float()
if zero_centered_weight:
w = w + 1.0
x_hat = x_hat * w
if bias is not None:
x_hat = x_hat + bias.float()
final_dtype = out_dtype if out_dtype is not None else orig_dtype
y = x_hat.to(final_dtype)
if residual is not None and residual_out_val is not None:
return y, residual_out_val
return y
# ── Test shapes ───────────────────────────────────────────────────────────────
SHAPES_2D = [(4, 64), (16, 128), (1, 256)]
SHAPES_3D = [(2, 8, 64), (1, 16, 128), (4, 4, 256)]
SHAPES_4D = [(1, 2, 8, 64), (2, 2, 4, 128)]
ALL_SHAPES = SHAPES_2D + SHAPES_3D + SHAPES_4D
EPS = 1e-6
ATOL = 1e-4
RTOL = 1e-3
# ── Tests for triton_one_pass_rms_norm_native (MLX) ──────────────────────────
@pytest.mark.parametrize("shape", ALL_SHAPES)
def test_triton_one_pass_rms_norm_native_shapes(shape):
"""MLX triton_one_pass_rms_norm_native matches reference for N-D inputs."""
from sglang.jit_kernel.diffusion.triton.mps_fallback import (
triton_one_pass_rms_norm_native,
)
dim = shape[-1]
x = torch.randn(shape, dtype=torch.float32)
w = torch.randn(dim, dtype=torch.float32)
result = triton_one_pass_rms_norm_native(x, w, eps=EPS)
expected = _ref_rms_norm(x, w, EPS)
assert result.shape == expected.shape, f"Shape mismatch: {result.shape} vs {expected.shape}"
torch.testing.assert_close(result, expected, atol=ATOL, rtol=RTOL)
# ── Tests for rms_norm_fn_native (MLX) ───────────────────────────────────────
@pytest.mark.parametrize("shape", ALL_SHAPES)
def test_rms_norm_fn_native_basic(shape):
"""MLX rms_norm_fn_native matches reference for N-D inputs (no residual)."""
from sglang.jit_kernel.diffusion.triton.mps_fallback import rms_norm_fn_native
dim = shape[-1]
x = torch.randn(shape, dtype=torch.float32)
w = torch.randn(dim, dtype=torch.float32)
result = rms_norm_fn_native(x, w, bias=None, eps=EPS)
expected = _ref_rms_norm_fn(x, w, None, None, EPS, False, False, None)
assert result.shape == expected.shape
torch.testing.assert_close(result, expected, atol=ATOL, rtol=RTOL)
@pytest.mark.parametrize("shape", ALL_SHAPES)
def test_rms_norm_fn_native_with_residual(shape):
"""MLX rms_norm_fn_native matches reference with residual for N-D inputs."""
from sglang.jit_kernel.diffusion.triton.mps_fallback import rms_norm_fn_native
dim = shape[-1]
x = torch.randn(shape, dtype=torch.float32)
w = torch.randn(dim, dtype=torch.float32)
residual = torch.randn(shape, dtype=torch.float32)
result_y, result_res = rms_norm_fn_native(x, w, bias=None, residual=residual, eps=EPS)
expected_y, expected_res = _ref_rms_norm_fn(x, w, None, residual, EPS, False, False, None)
assert result_y.shape == expected_y.shape
assert result_res.shape == expected_res.shape
torch.testing.assert_close(result_y, expected_y, atol=ATOL, rtol=RTOL)
torch.testing.assert_close(result_res, expected_res, atol=ATOL, rtol=RTOL)
@pytest.mark.parametrize("shape", SHAPES_3D)
def test_rms_norm_fn_native_with_bias_and_zero_centered(shape):
"""MLX rms_norm_fn_native with bias and zero_centered_weight."""
from sglang.jit_kernel.diffusion.triton.mps_fallback import rms_norm_fn_native
dim = shape[-1]
x = torch.randn(shape, dtype=torch.float32)
w = torch.randn(dim, dtype=torch.float32)
b = torch.randn(dim, dtype=torch.float32)
result = rms_norm_fn_native(x, w, bias=b, eps=EPS, zero_centered_weight=True)
expected = _ref_rms_norm_fn(x, w, b, None, EPS, False, True, None)
assert result.shape == expected.shape
torch.testing.assert_close(result, expected, atol=ATOL, rtol=RTOL)
# ── Tests for norm_infer_native (MLX) ────────────────────────────────────────
@pytest.mark.parametrize("shape", ALL_SHAPES)
@pytest.mark.parametrize("is_rms_norm", [True, False])
def test_norm_infer_native(shape, is_rms_norm):
"""MLX norm_infer_native matches PyTorch reference for N-D inputs."""
from sglang.jit_kernel.diffusion.triton.mps_fallback import norm_infer_native
dim = shape[-1]
x = torch.randn(shape, dtype=torch.float32)
w = torch.randn(dim, dtype=torch.float32)
b = None if is_rms_norm else torch.randn(dim, dtype=torch.float32)
result = norm_infer_native(x, w, b, EPS, is_rms_norm=is_rms_norm)
# Compute reference
x_f = x.float()
if is_rms_norm:
variance = x_f.pow(2).mean(dim=-1, keepdim=True)
x_hat = x_f * torch.rsqrt(variance + EPS)
expected = (x_hat * w.float()).to(x.dtype)
else:
mean = x_f.mean(dim=-1, keepdim=True)
variance = (x_f - mean).pow(2).mean(dim=-1, keepdim=True)
x_hat = (x_f - mean) * torch.rsqrt(variance + EPS)
expected = (x_hat * w.float() + b.float()).to(x.dtype)
assert result.shape == expected.shape
torch.testing.assert_close(result, expected, atol=ATOL, rtol=RTOL)
if __name__ == "__main__":
pytest.main([__file__, "-v", "-s"])> uv run python python/sglang/jit_kernel/tests/test_mps_fallback_norms.py
================================================================ test session starts =================================================================
platform darwin -- Python 3.11.15, pytest-9.0.2, pluggy-1.6.0 -- /Users/yexiaodong/go/src/github.com/yeahdongcn/sglang/sglang-diffusion/bin/python3
cachedir: .pytest_cache
rootdir: /Users/yexiaodong/go/src/github.com/yeahdongcn/sglang/python
configfile: pyproject.toml
plugins: anyio-4.12.1
collected 43 items
python/sglang/jit_kernel/tests/test_mps_fallback_norms.py::test_triton_one_pass_rms_norm_native_shapes[shape0] PASSED
python/sglang/jit_kernel/tests/test_mps_fallback_norms.py::test_triton_one_pass_rms_norm_native_shapes[shape1] PASSED
python/sglang/jit_kernel/tests/test_mps_fallback_norms.py::test_triton_one_pass_rms_norm_native_shapes[shape2] PASSED
python/sglang/jit_kernel/tests/test_mps_fallback_norms.py::test_triton_one_pass_rms_norm_native_shapes[shape3] PASSED
python/sglang/jit_kernel/tests/test_mps_fallback_norms.py::test_triton_one_pass_rms_norm_native_shapes[shape4] PASSED
python/sglang/jit_kernel/tests/test_mps_fallback_norms.py::test_triton_one_pass_rms_norm_native_shapes[shape5] PASSED
python/sglang/jit_kernel/tests/test_mps_fallback_norms.py::test_triton_one_pass_rms_norm_native_shapes[shape6] PASSED
python/sglang/jit_kernel/tests/test_mps_fallback_norms.py::test_triton_one_pass_rms_norm_native_shapes[shape7] PASSED
python/sglang/jit_kernel/tests/test_mps_fallback_norms.py::test_rms_norm_fn_native_basic[shape0] PASSED
python/sglang/jit_kernel/tests/test_mps_fallback_norms.py::test_rms_norm_fn_native_basic[shape1] PASSED
python/sglang/jit_kernel/tests/test_mps_fallback_norms.py::test_rms_norm_fn_native_basic[shape2] PASSED
python/sglang/jit_kernel/tests/test_mps_fallback_norms.py::test_rms_norm_fn_native_basic[shape3] PASSED
python/sglang/jit_kernel/tests/test_mps_fallback_norms.py::test_rms_norm_fn_native_basic[shape4] PASSED
python/sglang/jit_kernel/tests/test_mps_fallback_norms.py::test_rms_norm_fn_native_basic[shape5] PASSED
python/sglang/jit_kernel/tests/test_mps_fallback_norms.py::test_rms_norm_fn_native_basic[shape6] PASSED
python/sglang/jit_kernel/tests/test_mps_fallback_norms.py::test_rms_norm_fn_native_basic[shape7] PASSED
python/sglang/jit_kernel/tests/test_mps_fallback_norms.py::test_rms_norm_fn_native_with_residual[shape0] PASSED
python/sglang/jit_kernel/tests/test_mps_fallback_norms.py::test_rms_norm_fn_native_with_residual[shape1] PASSED
python/sglang/jit_kernel/tests/test_mps_fallback_norms.py::test_rms_norm_fn_native_with_residual[shape2] PASSED
python/sglang/jit_kernel/tests/test_mps_fallback_norms.py::test_rms_norm_fn_native_with_residual[shape3] PASSED
python/sglang/jit_kernel/tests/test_mps_fallback_norms.py::test_rms_norm_fn_native_with_residual[shape4] PASSED
python/sglang/jit_kernel/tests/test_mps_fallback_norms.py::test_rms_norm_fn_native_with_residual[shape5] PASSED
python/sglang/jit_kernel/tests/test_mps_fallback_norms.py::test_rms_norm_fn_native_with_residual[shape6] PASSED
python/sglang/jit_kernel/tests/test_mps_fallback_norms.py::test_rms_norm_fn_native_with_residual[shape7] PASSED
python/sglang/jit_kernel/tests/test_mps_fallback_norms.py::test_rms_norm_fn_native_with_bias_and_zero_centered[shape0] PASSED
python/sglang/jit_kernel/tests/test_mps_fallback_norms.py::test_rms_norm_fn_native_with_bias_and_zero_centered[shape1] PASSED
python/sglang/jit_kernel/tests/test_mps_fallback_norms.py::test_rms_norm_fn_native_with_bias_and_zero_centered[shape2] PASSED
python/sglang/jit_kernel/tests/test_mps_fallback_norms.py::test_norm_infer_native[True-shape0] PASSED
python/sglang/jit_kernel/tests/test_mps_fallback_norms.py::test_norm_infer_native[True-shape1] PASSED
python/sglang/jit_kernel/tests/test_mps_fallback_norms.py::test_norm_infer_native[True-shape2] PASSED
python/sglang/jit_kernel/tests/test_mps_fallback_norms.py::test_norm_infer_native[True-shape3] PASSED
python/sglang/jit_kernel/tests/test_mps_fallback_norms.py::test_norm_infer_native[True-shape4] PASSED
python/sglang/jit_kernel/tests/test_mps_fallback_norms.py::test_norm_infer_native[True-shape5] PASSED
python/sglang/jit_kernel/tests/test_mps_fallback_norms.py::test_norm_infer_native[True-shape6] PASSED
python/sglang/jit_kernel/tests/test_mps_fallback_norms.py::test_norm_infer_native[True-shape7] PASSED
python/sglang/jit_kernel/tests/test_mps_fallback_norms.py::test_norm_infer_native[False-shape0] PASSED
python/sglang/jit_kernel/tests/test_mps_fallback_norms.py::test_norm_infer_native[False-shape1] PASSED
python/sglang/jit_kernel/tests/test_mps_fallback_norms.py::test_norm_infer_native[False-shape2] PASSED
python/sglang/jit_kernel/tests/test_mps_fallback_norms.py::test_norm_infer_native[False-shape3] PASSED
python/sglang/jit_kernel/tests/test_mps_fallback_norms.py::test_norm_infer_native[False-shape4] PASSED
python/sglang/jit_kernel/tests/test_mps_fallback_norms.py::test_norm_infer_native[False-shape5] PASSED
python/sglang/jit_kernel/tests/test_mps_fallback_norms.py::test_norm_infer_native[False-shape6] PASSED
python/sglang/jit_kernel/tests/test_mps_fallback_norms.py::test_norm_infer_native[False-shape7] PASSED
================================================================= 43 passed in 0.78s ================================================================= |
|
/gemini summary |
Signed-off-by: Xiaodong Ye <yeahdongcn@gmail.com>
|
It looks like we need CI coverage for macOS @Kangyan-Zhou. #19997 appears to have broken macOS functionality (just fixed in the latest commit). |
|
Compared with |
Adds native MLX execution backend for Apple Silicon: - MlxModelRunner: full model inference in MLX, bridges logits to PyTorch for sampling - MlxModelRunnerStub: skips PyTorch weight loading when using MLX - MlxTpModelWorker: subclasses TpModelWorker for MLX-specific cleanup - tensor_bridge.py: SGLANG_USE_MLX=1 env var to activate Source: sgl-project#20342 Combined with our --pp-layer-start/--pp-layer-end for layer sharding. Usage on Apple Silicon: SGLANG_USE_MLX=1 python -m sglang.launch_server \ --model Qwen/Qwen2.5-0.5B-Instruct \ --pp-layer-start 0 --pp-layer-end 8
|
Other than those LGTM! [with caveats that we will need to dedup this for multi hardware for Metal (whenever that effort starts in full)] |
|
/rerun-failed-ci |
Signed-off-by: Xiaodong Ye <yeahdongcn@gmail.com>
|
/rerun-failed-ci |
2 similar comments
|
/rerun-failed-ci |
|
/rerun-failed-ci |
…ject#20342) Signed-off-by: Xiaodong Ye <yeahdongcn@gmail.com>
…ject#20342) Signed-off-by: Xiaodong Ye <yeahdongcn@gmail.com>
…ject#20342) Signed-off-by: Xiaodong Ye <yeahdongcn@gmail.com>
Motivation
Introduces
MlxModelRunnerandMlxTpModelWorkerunderpython/sglang/srt/hardware_backend/mlx, enabling end-to-end model inference via MLX on Apple Silicon. Activated withSGLANG_USE_MLX=1.Modifications
MlxModelRunnerreplaces the entire PyTorch model pipeline with native MLX execution, bridging only final logits back to PyTorch for sampling.MlxTpModelWorkersubclassesTpModelWorker, keeping the base worker and scheduler free of MLX-specific code. Stale request cleanup is handled automatically.bench_one_batch.pyuses a runner abstraction (_BenchRunner/_MlxBenchRunner) to unify the benchmark loop.Accuracy Tests
Benchmarking and Profiling
With
SGLANG_USE_MLX_ATTENTION=0(default):With
SGLANG_USE_MLX_ATTENTION=1:Checklist
Review Process
/tag-run-ci-label,/rerun-failed-ci,/tag-and-rerun-ci