Skip to content

llama 4 scout fp8 not work on sglang #203

@zhyncs

Description

@zhyncs

Describe the bug

Steps/Code to reproduce bug

python3 hf_ptq.py --pyt_ckpt_path meta-llama/Llama-4-Scout-17B-16E-Instruct --qformat fp8 --export_fmt hf --export_path Llama-4-Scout-17B-16E-Instruct-FP8 --trust_remote_code
python3 -m sglang.launch_server --model Llama-4-Scout-17B-16E-Instruct-FP8 --tp 8
[2025-05-27 21:20:09 TP7] Scheduler hit an exception: Traceback (most recent call last):
  File "/sgl-workspace/sglang/python/sglang/srt/managers/scheduler.py", line 2322, in run_scheduler_process
    scheduler = Scheduler(server_args, port_args, gpu_id, tp_rank, pp_rank, dp_rank)
  File "/sgl-workspace/sglang/python/sglang/srt/managers/scheduler.py", line 280, in __init__
    self.tp_worker = TpWorkerClass(
  File "/sgl-workspace/sglang/python/sglang/srt/managers/tp_worker_overlap_thread.py", line 64, in __init__
    self.worker = TpModelWorker(
  File "/sgl-workspace/sglang/python/sglang/srt/managers/tp_worker.py", line 78, in __init__
    self.model_runner = ModelRunner(
  File "/sgl-workspace/sglang/python/sglang/srt/model_executor/model_runner.py", line 233, in __init__
    self.initialize(min_per_gpu_memory)
  File "/sgl-workspace/sglang/python/sglang/srt/model_executor/model_runner.py", line 274, in initialize
    self.load_model()
  File "/sgl-workspace/sglang/python/sglang/srt/model_executor/model_runner.py", line 541, in load_model
    self.model = get_model(
  File "/sgl-workspace/sglang/python/sglang/srt/model_loader/__init__.py", line 22, in get_model
    return loader.load_model(
  File "/sgl-workspace/sglang/python/sglang/srt/model_loader/loader.py", line 376, in load_model
    model = _initialize_model(
  File "/sgl-workspace/sglang/python/sglang/srt/model_loader/loader.py", line 153, in _initialize_model
    return model_class(
  File "/sgl-workspace/sglang/python/sglang/srt/models/llama4.py", line 526, in __init__
    super().__init__(config, quant_config, prefix)
  File "/sgl-workspace/sglang/python/sglang/srt/models/llama.py", line 413, in __init__
    self.model = self._init_model(config, quant_config, add_prefix("model", prefix))
  File "/sgl-workspace/sglang/python/sglang/srt/models/llama4.py", line 537, in _init_model
    return Llama4Model(config, quant_config=quant_config, prefix=prefix)
  File "/sgl-workspace/sglang/python/sglang/srt/models/llama4.py", line 470, in __init__
    self.layers = make_layers(
  File "/sgl-workspace/sglang/python/sglang/srt/utils.py", line 465, in make_layers
    + [
  File "/sgl-workspace/sglang/python/sglang/srt/utils.py", line 466, in <listcomp>
    maybe_offload_to_cpu(layer_fn(idx=idx, prefix=add_prefix(idx, prefix)))
  File "/sgl-workspace/sglang/python/sglang/srt/models/llama4.py", line 472, in <lambda>
    lambda idx, prefix: Llama4DecoderLayer(
  File "/sgl-workspace/sglang/python/sglang/srt/models/llama4.py", line 372, in __init__
    self.feed_forward = Llama4MoE(
  File "/sgl-workspace/sglang/python/sglang/srt/models/llama4.py", line 107, in __init__
    self.experts = FusedMoE(
  File "/sgl-workspace/sglang/python/sglang/srt/layers/moe/fused_moe_triton/layer.py", line 340, in __init__
    assert self.quant_method is not None
AssertionError

Expected behavior

System information

  • Container used (if applicable): ?
  • OS (e.g., Ubuntu 22.04, CentOS 7, Windows 10): ?
  • CPU architecture (x86_64, aarch64): ?
  • GPU name (e.g. H100, A100, L40S): ?
  • GPU memory size: ?
  • Number of GPUs: ?
  • Library versions (if applicable):
    • Python: ?
    • ModelOpt version or commit hash: ?
    • CUDA: ?
    • PyTorch: ?
    • Transformers: ?
    • TensorRT-LLM: ?
    • ONNXRuntime: ?
    • TensorRT: ?
  • Any other details that may help: ?
Click to expand: Python script to automatically collect system information
import platform
import re
import subprocess


def get_nvidia_gpu_info():
    try:
        nvidia_smi = (
            subprocess.check_output(
                "nvidia-smi --query-gpu=name,memory.total,count --format=csv,noheader,nounits",
                shell=True,
            )
            .decode("utf-8")
            .strip()
            .split("\n")
        )
        if len(nvidia_smi) > 0:
            gpu_name = nvidia_smi[0].split(",")[0].strip()
            gpu_memory = round(float(nvidia_smi[0].split(",")[1].strip()) / 1024, 1)
            gpu_count = len(nvidia_smi)
            return gpu_name, f"{gpu_memory} GB", gpu_count
    except Exception:
        return "?", "?", "?"


def get_cuda_version():
    try:
        nvcc_output = subprocess.check_output("nvcc --version", shell=True).decode("utf-8")
        match = re.search(r"release (\d+\.\d+)", nvcc_output)
        if match:
            return match.group(1)
    except Exception:
        return "?"


def get_package_version(package):
    try:
        return getattr(__import__(package), "__version__", "?")
    except Exception:
        return "?"


# Get system info
os_info = f"{platform.system()} {platform.release()}"
if platform.system() == "Linux":
    try:
        os_info = (
            subprocess.check_output("cat /etc/os-release | grep PRETTY_NAME | cut -d= -f2", shell=True)
            .decode("utf-8")
            .strip()
            .strip('"')
        )
    except Exception:
        pass
elif platform.system() == "Windows":
    print("Please add the `windows` label to the issue.")

cpu_arch = platform.machine()
gpu_name, gpu_memory, gpu_count = get_nvidia_gpu_info()
cuda_version = get_cuda_version()

# Print system information in the format required for the issue template
print("=" * 70)
print("- Container used (if applicable): " + "?")
print("- OS (e.g., Ubuntu 22.04, CentOS 7, Windows 10): " + os_info)
print("- CPU architecture (x86_64, aarch64): " + cpu_arch)
print("- GPU name (e.g. H100, A100, L40S): " + gpu_name)
print("- GPU memory size: " + gpu_memory)
print("- Number of GPUs: " + str(gpu_count))
print("- Library versions (if applicable):")
print("  - Python: " + platform.python_version())
print("  - ModelOpt version or commit hash: " + get_package_version("modelopt"))
print("  - CUDA: " + cuda_version)
print("  - PyTorch: " + get_package_version("torch"))
print("  - Transformers: " + get_package_version("transformers"))
print("  - TensorRT-LLM: " + get_package_version("tensorrt_llm"))
print("  - ONNXRuntime: " + get_package_version("onnxruntime"))
print("  - TensorRT: " + get_package_version("tensorrt"))
print("- Any other details that may help: " + "?")
print("=" * 70)

Metadata

Metadata

Assignees

Labels

bugSomething isn't workingroadmap

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions