Skip to content

Commit 2cc52d8

Browse files
authored
feat: Support MXFP4 quantized dense models on AMD CDNA2/CDNA3 GPUs (#19143)
1 parent f639425 commit 2cc52d8

8 files changed

Lines changed: 648 additions & 282 deletions

File tree

python/pyproject_other.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -93,7 +93,7 @@ tracing = [
9393
srt_hip = [
9494
"sglang[runtime_common]",
9595
"torch",
96-
"petit_kernel==0.0.2",
96+
"petit_kernel==0.0.3",
9797
"wave-lang==3.8.2",
9898
]
9999

python/sglang/srt/configs/model_config.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -946,6 +946,7 @@ def _verify_quantization(self) -> None:
946946
"fbgemm_fp8",
947947
"w8a8_fp8",
948948
"petit_nvfp4",
949+
"petit_mxfp4",
949950
"quark",
950951
"mxfp4",
951952
"auto-round",
@@ -970,6 +971,7 @@ def _verify_quantization(self) -> None:
970971
"qoq",
971972
"w4afp8",
972973
"petit_nvfp4",
974+
"petit_mxfp4",
973975
"quark",
974976
"modelslim",
975977
]
@@ -978,6 +980,7 @@ def _verify_quantization(self) -> None:
978980
"modelopt_fp4": ["modelopt"],
979981
"modelopt_mixed": ["modelopt"],
980982
"petit_nvfp4": ["modelopt"],
983+
"petit_mxfp4": ["mxfp4", "quark"],
981984
"w8a8_int8": ["compressed-tensors", "compressed_tensors"],
982985
"w8a8_fp8": ["compressed-tensors", "compressed_tensors"],
983986
}

python/sglang/srt/layers/quantization/__init__.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,8 @@ def override_quantization_method(self, *args, **kwargs):
3636
from sglang.srt.layers.quantization.modelslim.modelslim import ModelSlimConfig
3737
from sglang.srt.layers.quantization.moe_wna16 import MoeWNA16Config
3838
from sglang.srt.layers.quantization.mxfp4 import Mxfp4Config
39-
from sglang.srt.layers.quantization.petit import PetitNvFp4Config
39+
from sglang.srt.layers.quantization.petit_mxfp4 import PetitMxfp4Config
40+
from sglang.srt.layers.quantization.petit_nvfp4 import PetitNvFp4Config
4041
from sglang.srt.layers.quantization.qoq import QoQConfig
4142
from sglang.srt.layers.quantization.quark.quark import QuarkConfig
4243
from sglang.srt.layers.quantization.quark_int4fp8_moe import QuarkInt4Fp8Config
@@ -72,6 +73,7 @@ def override_quantization_method(self, *args, **kwargs):
7273
"qoq": QoQConfig,
7374
"w4afp8": W4AFp8Config,
7475
"petit_nvfp4": PetitNvFp4Config,
76+
"petit_mxfp4": PetitMxfp4Config,
7577
"fbgemm_fp8": FBGEMMFp8Config,
7678
"quark": QuarkConfig,
7779
"auto-round": AutoRoundConfig,
Lines changed: 7 additions & 249 deletions
Original file line numberDiff line numberDiff line change
@@ -1,253 +1,11 @@
1-
# Adapted from https://github.com/vllm-project/vllm/blob/main/vllm/model_executor/layers/quantization/modelopt.py
1+
"""Backward-compatible import shim.
22
3+
Use `petit_nvfp4.py` for implementation. Keep this module for existing imports.
4+
"""
35

4-
import logging
5-
from typing import Any, Dict, List, Optional
6-
7-
import regex as re
8-
import torch
9-
from torch.nn.parameter import Parameter
10-
11-
from sglang.srt.layers.linear import LinearBase
12-
from sglang.srt.layers.parameter import ModelWeightParameter, PerTensorScaleParameter
13-
from sglang.srt.layers.quantization.base_config import (
14-
LinearMethodBase,
15-
QuantizationConfig,
16-
QuantizeMethodBase,
17-
)
18-
from sglang.srt.layers.quantization.petit_utils import (
19-
apply_petit_nvfp4_linear,
20-
prepare_nvfp4_layer_for_petit,
21-
verify_petit_nvfp4_supported,
6+
from sglang.srt.layers.quantization.petit_nvfp4 import ( # noqa: F401
7+
PetitNvFp4Config,
8+
PetitNvFp4LinearMethod,
229
)
23-
from sglang.srt.layers.quantization.unquant import UnquantizedLinearMethod
24-
from sglang.srt.layers.quantization.utils import is_layer_skipped
25-
from sglang.srt.utils import is_hip
26-
27-
_is_hip = is_hip()
28-
29-
# Initialize logger for the module
30-
logger = logging.getLogger(__name__)
31-
32-
33-
# Configuration class to support the NVFP4 quantized model generated by the ModelOpt quantization tool
34-
class PetitNvFp4Config(QuantizationConfig):
35-
"""Config class for Petit FP4."""
36-
37-
def __init__(
38-
self,
39-
is_checkpoint_nvfp4_serialized: bool = False,
40-
kv_cache_quant_algo: str = None,
41-
group_size: int = None,
42-
exclude_modules: List[str] = None,
43-
) -> None:
44-
self.is_checkpoint_nvfp4_serialized = is_checkpoint_nvfp4_serialized
45-
if is_checkpoint_nvfp4_serialized:
46-
logger.warning(
47-
"Detected nvfp4 checkpoint. Please note that the "
48-
"format is experimental and subject to change."
49-
)
50-
self.group_size = group_size
51-
self.kv_cache_quant_algo = kv_cache_quant_algo
52-
self.exclude_modules = exclude_modules
53-
54-
@classmethod
55-
def get_name(cls) -> str:
56-
return "petit_nvfp4"
57-
58-
@classmethod
59-
def get_supported_act_dtypes(cls) -> List[torch.dtype]:
60-
return [torch.bfloat16, torch.half]
61-
62-
@classmethod
63-
def get_min_capability(cls) -> int:
64-
# Petit supports the gfx90a and gfx942 GPUs
65-
return 90
66-
67-
@classmethod
68-
def get_config_filenames(cls) -> List[str]:
69-
return ["hf_quant_config.json"]
70-
71-
@classmethod
72-
def from_config(cls, config: Dict[str, Any]) -> "PetitNvFp4Config":
73-
quant_config = cls.get_from_keys(config, ["quantization"])
74-
quant_method = quant_config["quant_algo"]
75-
group_size = quant_config.get("group_size", None)
76-
verify_petit_nvfp4_supported(quant_method, group_size)
77-
78-
is_checkpoint_nvfp4_serialized = "NVFP4" in quant_method
79-
kv_cache_quant_algo = quant_config["kv_cache_quant_algo"]
80-
if not kv_cache_quant_algo:
81-
kv_cache_quant_algo = "auto"
82-
exclude_modules = quant_config.get("exclude_modules", None)
83-
if not (group_size and kv_cache_quant_algo and (exclude_modules is not None)):
84-
logger.warning(
85-
f"group_size: {group_size},"
86-
f"kv_cache_quant_algo: {kv_cache_quant_algo},"
87-
f"exclude_modules: {exclude_modules}"
88-
)
89-
raise ValueError(
90-
"NVFP4 quantization requires group size and "
91-
"kv_cache_quant_algo specified in "
92-
"hf_quant_config.json"
93-
)
94-
return cls(
95-
is_checkpoint_nvfp4_serialized,
96-
kv_cache_quant_algo,
97-
group_size,
98-
exclude_modules,
99-
)
100-
101-
@classmethod
102-
def override_quantization_method(cls, hf_quant_cfg, user_quant) -> Optional[str]:
103-
can_convert = cls.is_petit_nvfp4_compatible(hf_quant_cfg)
104-
if can_convert:
105-
return cls.get_name()
106-
return None
107-
108-
@classmethod
109-
def is_petit_nvfp4_compatible(cls, quant_config: Dict[str, Any]) -> bool:
110-
quant_method = quant_config.get("quant_method", "").lower()
111-
return _is_hip and quant_method == "modelopt"
112-
113-
def is_layer_excluded(self, prefix: str, exclude_modules: list):
114-
for pattern in exclude_modules:
115-
regex_str = pattern.replace(".", r"\.").replace("*", r".*")
116-
if re.fullmatch(regex_str, prefix):
117-
return True
118-
return False
119-
120-
def get_quant_method(
121-
self, layer: torch.nn.Module, prefix: str
122-
) -> Optional["QuantizeMethodBase"]:
123-
if isinstance(layer, LinearBase):
124-
if is_layer_skipped(prefix, self.exclude_modules) or self.is_layer_excluded(
125-
prefix, self.exclude_modules
126-
):
127-
return UnquantizedLinearMethod()
128-
return PetitNvFp4LinearMethod(self)
129-
return None
130-
131-
def get_scaled_act_names(self) -> List[str]:
132-
return []
133-
134-
135-
class PetitNvFp4LinearMethod(LinearMethodBase):
136-
"""Linear method for NVFP4.
137-
Supports loading NVFP4 checkpoints with the following structure:
138-
139-
|Tensor Name | datatype | shape |
140-
|----------------------------------------------------|
141-
|input_scale | torch.float32 | scalar |
142-
|weight | NVFP4(SE2M1) | [1, X, y/2] |
143-
|weight_scale | FP8-E4M3 | [X, Y] |
144-
|weight_scale_2 | torch.float32 | scalar |
145-
146-
The weights are quantized per block of 16 elements.
147-
Args: quant_config: The ModelOpt quantization config.
148-
"""
149-
150-
def __init__(self, quant_config: PetitNvFp4Config):
151-
self.quant_config = quant_config
152-
153-
def create_weights(
154-
self,
155-
layer: torch.nn.Module,
156-
input_size_per_partition: int,
157-
output_partition_sizes: List[int],
158-
input_size: int,
159-
output_size: int,
160-
params_dtype: torch.dtype,
161-
**extra_weight_attrs,
162-
):
163-
del input_size, output_size
164-
if not self.quant_config.is_checkpoint_nvfp4_serialized:
165-
raise ValueError(
166-
"NVFP4 quantization was selected, "
167-
" dynamic quantization is not supported."
168-
)
169-
170-
output_size_per_partition = sum(output_partition_sizes)
171-
weight_loader = extra_weight_attrs.get("weight_loader")
172-
173-
layer.logical_widths = output_partition_sizes
174-
175-
layer.input_size_per_partition = input_size_per_partition
176-
layer.output_size_per_partition = output_size_per_partition
177-
if input_size_per_partition % 16 != 0:
178-
raise ValueError(
179-
"Unsupported model when in features size is " "not multiple of 16"
180-
)
181-
182-
weight_dtype = (
183-
torch.float8_e4m3fn
184-
if self.quant_config.is_checkpoint_nvfp4_serialized
185-
else params_dtype
186-
)
187-
188-
weight = ModelWeightParameter(
189-
data=torch.empty(
190-
# 2 fp4 data is packed in one uint8 in the input dimension
191-
output_size_per_partition,
192-
input_size_per_partition // 2,
193-
dtype=torch.uint8,
194-
),
195-
input_dim=1,
196-
output_dim=0,
197-
weight_loader=weight_loader,
198-
)
199-
layer.register_parameter("weight", weight)
200-
201-
input_scale = PerTensorScaleParameter(
202-
data=torch.empty(len(output_partition_sizes), dtype=torch.float32),
203-
weight_loader=weight_loader,
204-
)
205-
206-
layer.register_parameter("input_scale", input_scale)
207-
208-
weight_scale_2 = PerTensorScaleParameter(
209-
data=torch.empty(len(output_partition_sizes), dtype=torch.float32),
210-
weight_loader=weight_loader,
211-
)
212-
layer.register_parameter("weight_scale_2", weight_scale_2)
213-
214-
weight_scale = ModelWeightParameter(
215-
data=torch.empty(
216-
output_size_per_partition,
217-
input_size_per_partition // self.quant_config.group_size,
218-
dtype=weight_dtype,
219-
),
220-
input_dim=1,
221-
output_dim=0,
222-
weight_loader=weight_loader,
223-
)
224-
225-
layer.register_parameter("weight_scale", weight_scale)
226-
227-
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
228-
input_scale_2 = layer.input_scale.max().to(torch.float32)
229-
weight_scale_2 = layer.weight_scale_2.max().to(torch.float32)
230-
layer.input_scale = Parameter(input_scale_2, requires_grad=False)
231-
layer.weight_scale_2 = Parameter(weight_scale_2, requires_grad=False)
232-
layer.alpha = Parameter(
233-
layer.input_scale * layer.weight_scale_2, requires_grad=False
234-
)
235-
236-
prepare_nvfp4_layer_for_petit(layer)
237-
del layer.input_scale
23810

239-
def apply(
240-
self,
241-
layer: torch.nn.Module,
242-
x: torch.Tensor,
243-
bias: Optional[torch.Tensor] = None,
244-
) -> torch.Tensor:
245-
return apply_petit_nvfp4_linear(
246-
input=x,
247-
weight=layer.weight,
248-
weight_scale=layer.weight_scale,
249-
weight_scale_2=layer.weight_scale_2,
250-
size_n=layer.output_size_per_partition,
251-
size_k=layer.input_size_per_partition,
252-
bias=bias,
253-
)
11+
__all__ = ["PetitNvFp4Config", "PetitNvFp4LinearMethod"]

0 commit comments

Comments
 (0)