Skip to content

[MPS] Triton stub sub-module resolution broken on macOS with torch 2.9+ / Python 3.12+ #21548

@karanb192

Description

@karanb192

Checklist

Motivation

On macOS, SGLang installs a mock triton module via _triton_stub.py so that code importing triton doesn't crash. However, with torch 2.9+, the stub is missing triton.compiler / triton.compiler.compiler sub-modules, causing bench_one_batch and server startup to fail.

Problem

The import chain that triggers the crash:

torchvision → torch._dynamo → torch._inductor.runtime.hints → import triton.compiler.compiler

Because the triton stub makes has_triton_package() return True (since triton exists in sys.modules), torch proceeds to import triton.compiler.compiler — which doesn't exist in the mock, causing ModuleNotFoundError.

ModuleNotFoundError: No module named 'triton.compiler'

The stub currently only registers:

  • triton
  • triton.backends
  • triton.backends.compiler

But not triton.compiler or triton.compiler.compiler, which torch 2.9+ expects.

Additional Issue

The _TritonFinder meta-path finder uses the deprecated find_module/load_module protocol, which doesn't work reliably in Python 3.12+. Should be updated to find_spec (PEP 451).

Suggested Fix

  1. Register triton.compiler and triton.compiler.compiler as mock sub-modules in _triton_stub.py
  2. Update _TritonFinder to use the modern find_spec API

Environment

  • macOS (Apple Silicon M3 Pro)
  • Python 3.12.11
  • torch 2.9.1
  • torchvision 0.24.0
  • SGLang main branch (commit e8d46f1)

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions