|
1 | | -# Adapted from https://github.com/vllm-project/vllm/blob/main/vllm/model_executor/layers/quantization/modelopt.py |
| 1 | +"""Backward-compatible import shim. |
2 | 2 |
|
| 3 | +Use `petit_nvfp4.py` for implementation. Keep this module for existing imports. |
| 4 | +""" |
3 | 5 |
|
4 | | -import logging |
5 | | -from typing import Any, Dict, List, Optional |
6 | | - |
7 | | -import regex as re |
8 | | -import torch |
9 | | -from torch.nn.parameter import Parameter |
10 | | - |
11 | | -from sglang.srt.layers.linear import LinearBase |
12 | | -from sglang.srt.layers.parameter import ModelWeightParameter, PerTensorScaleParameter |
13 | | -from sglang.srt.layers.quantization.base_config import ( |
14 | | - LinearMethodBase, |
15 | | - QuantizationConfig, |
16 | | - QuantizeMethodBase, |
17 | | -) |
18 | | -from sglang.srt.layers.quantization.petit_utils import ( |
19 | | - apply_petit_nvfp4_linear, |
20 | | - prepare_nvfp4_layer_for_petit, |
21 | | - verify_petit_nvfp4_supported, |
| 6 | +from sglang.srt.layers.quantization.petit_nvfp4 import ( # noqa: F401 |
| 7 | + PetitNvFp4Config, |
| 8 | + PetitNvFp4LinearMethod, |
22 | 9 | ) |
23 | | -from sglang.srt.layers.quantization.unquant import UnquantizedLinearMethod |
24 | | -from sglang.srt.layers.quantization.utils import is_layer_skipped |
25 | | -from sglang.srt.utils import is_hip |
26 | | - |
27 | | -_is_hip = is_hip() |
28 | | - |
29 | | -# Initialize logger for the module |
30 | | -logger = logging.getLogger(__name__) |
31 | | - |
32 | | - |
33 | | -# Configuration class to support the NVFP4 quantized model generated by the ModelOpt quantization tool |
34 | | -class PetitNvFp4Config(QuantizationConfig): |
35 | | - """Config class for Petit FP4.""" |
36 | | - |
37 | | - def __init__( |
38 | | - self, |
39 | | - is_checkpoint_nvfp4_serialized: bool = False, |
40 | | - kv_cache_quant_algo: str = None, |
41 | | - group_size: int = None, |
42 | | - exclude_modules: List[str] = None, |
43 | | - ) -> None: |
44 | | - self.is_checkpoint_nvfp4_serialized = is_checkpoint_nvfp4_serialized |
45 | | - if is_checkpoint_nvfp4_serialized: |
46 | | - logger.warning( |
47 | | - "Detected nvfp4 checkpoint. Please note that the " |
48 | | - "format is experimental and subject to change." |
49 | | - ) |
50 | | - self.group_size = group_size |
51 | | - self.kv_cache_quant_algo = kv_cache_quant_algo |
52 | | - self.exclude_modules = exclude_modules |
53 | | - |
54 | | - @classmethod |
55 | | - def get_name(cls) -> str: |
56 | | - return "petit_nvfp4" |
57 | | - |
58 | | - @classmethod |
59 | | - def get_supported_act_dtypes(cls) -> List[torch.dtype]: |
60 | | - return [torch.bfloat16, torch.half] |
61 | | - |
62 | | - @classmethod |
63 | | - def get_min_capability(cls) -> int: |
64 | | - # Petit supports the gfx90a and gfx942 GPUs |
65 | | - return 90 |
66 | | - |
67 | | - @classmethod |
68 | | - def get_config_filenames(cls) -> List[str]: |
69 | | - return ["hf_quant_config.json"] |
70 | | - |
71 | | - @classmethod |
72 | | - def from_config(cls, config: Dict[str, Any]) -> "PetitNvFp4Config": |
73 | | - quant_config = cls.get_from_keys(config, ["quantization"]) |
74 | | - quant_method = quant_config["quant_algo"] |
75 | | - group_size = quant_config.get("group_size", None) |
76 | | - verify_petit_nvfp4_supported(quant_method, group_size) |
77 | | - |
78 | | - is_checkpoint_nvfp4_serialized = "NVFP4" in quant_method |
79 | | - kv_cache_quant_algo = quant_config["kv_cache_quant_algo"] |
80 | | - if not kv_cache_quant_algo: |
81 | | - kv_cache_quant_algo = "auto" |
82 | | - exclude_modules = quant_config.get("exclude_modules", None) |
83 | | - if not (group_size and kv_cache_quant_algo and (exclude_modules is not None)): |
84 | | - logger.warning( |
85 | | - f"group_size: {group_size}," |
86 | | - f"kv_cache_quant_algo: {kv_cache_quant_algo}," |
87 | | - f"exclude_modules: {exclude_modules}" |
88 | | - ) |
89 | | - raise ValueError( |
90 | | - "NVFP4 quantization requires group size and " |
91 | | - "kv_cache_quant_algo specified in " |
92 | | - "hf_quant_config.json" |
93 | | - ) |
94 | | - return cls( |
95 | | - is_checkpoint_nvfp4_serialized, |
96 | | - kv_cache_quant_algo, |
97 | | - group_size, |
98 | | - exclude_modules, |
99 | | - ) |
100 | | - |
101 | | - @classmethod |
102 | | - def override_quantization_method(cls, hf_quant_cfg, user_quant) -> Optional[str]: |
103 | | - can_convert = cls.is_petit_nvfp4_compatible(hf_quant_cfg) |
104 | | - if can_convert: |
105 | | - return cls.get_name() |
106 | | - return None |
107 | | - |
108 | | - @classmethod |
109 | | - def is_petit_nvfp4_compatible(cls, quant_config: Dict[str, Any]) -> bool: |
110 | | - quant_method = quant_config.get("quant_method", "").lower() |
111 | | - return _is_hip and quant_method == "modelopt" |
112 | | - |
113 | | - def is_layer_excluded(self, prefix: str, exclude_modules: list): |
114 | | - for pattern in exclude_modules: |
115 | | - regex_str = pattern.replace(".", r"\.").replace("*", r".*") |
116 | | - if re.fullmatch(regex_str, prefix): |
117 | | - return True |
118 | | - return False |
119 | | - |
120 | | - def get_quant_method( |
121 | | - self, layer: torch.nn.Module, prefix: str |
122 | | - ) -> Optional["QuantizeMethodBase"]: |
123 | | - if isinstance(layer, LinearBase): |
124 | | - if is_layer_skipped(prefix, self.exclude_modules) or self.is_layer_excluded( |
125 | | - prefix, self.exclude_modules |
126 | | - ): |
127 | | - return UnquantizedLinearMethod() |
128 | | - return PetitNvFp4LinearMethod(self) |
129 | | - return None |
130 | | - |
131 | | - def get_scaled_act_names(self) -> List[str]: |
132 | | - return [] |
133 | | - |
134 | | - |
135 | | -class PetitNvFp4LinearMethod(LinearMethodBase): |
136 | | - """Linear method for NVFP4. |
137 | | - Supports loading NVFP4 checkpoints with the following structure: |
138 | | -
|
139 | | - |Tensor Name | datatype | shape | |
140 | | - |----------------------------------------------------| |
141 | | - |input_scale | torch.float32 | scalar | |
142 | | - |weight | NVFP4(SE2M1) | [1, X, y/2] | |
143 | | - |weight_scale | FP8-E4M3 | [X, Y] | |
144 | | - |weight_scale_2 | torch.float32 | scalar | |
145 | | -
|
146 | | - The weights are quantized per block of 16 elements. |
147 | | - Args: quant_config: The ModelOpt quantization config. |
148 | | - """ |
149 | | - |
150 | | - def __init__(self, quant_config: PetitNvFp4Config): |
151 | | - self.quant_config = quant_config |
152 | | - |
153 | | - def create_weights( |
154 | | - self, |
155 | | - layer: torch.nn.Module, |
156 | | - input_size_per_partition: int, |
157 | | - output_partition_sizes: List[int], |
158 | | - input_size: int, |
159 | | - output_size: int, |
160 | | - params_dtype: torch.dtype, |
161 | | - **extra_weight_attrs, |
162 | | - ): |
163 | | - del input_size, output_size |
164 | | - if not self.quant_config.is_checkpoint_nvfp4_serialized: |
165 | | - raise ValueError( |
166 | | - "NVFP4 quantization was selected, " |
167 | | - " dynamic quantization is not supported." |
168 | | - ) |
169 | | - |
170 | | - output_size_per_partition = sum(output_partition_sizes) |
171 | | - weight_loader = extra_weight_attrs.get("weight_loader") |
172 | | - |
173 | | - layer.logical_widths = output_partition_sizes |
174 | | - |
175 | | - layer.input_size_per_partition = input_size_per_partition |
176 | | - layer.output_size_per_partition = output_size_per_partition |
177 | | - if input_size_per_partition % 16 != 0: |
178 | | - raise ValueError( |
179 | | - "Unsupported model when in features size is " "not multiple of 16" |
180 | | - ) |
181 | | - |
182 | | - weight_dtype = ( |
183 | | - torch.float8_e4m3fn |
184 | | - if self.quant_config.is_checkpoint_nvfp4_serialized |
185 | | - else params_dtype |
186 | | - ) |
187 | | - |
188 | | - weight = ModelWeightParameter( |
189 | | - data=torch.empty( |
190 | | - # 2 fp4 data is packed in one uint8 in the input dimension |
191 | | - output_size_per_partition, |
192 | | - input_size_per_partition // 2, |
193 | | - dtype=torch.uint8, |
194 | | - ), |
195 | | - input_dim=1, |
196 | | - output_dim=0, |
197 | | - weight_loader=weight_loader, |
198 | | - ) |
199 | | - layer.register_parameter("weight", weight) |
200 | | - |
201 | | - input_scale = PerTensorScaleParameter( |
202 | | - data=torch.empty(len(output_partition_sizes), dtype=torch.float32), |
203 | | - weight_loader=weight_loader, |
204 | | - ) |
205 | | - |
206 | | - layer.register_parameter("input_scale", input_scale) |
207 | | - |
208 | | - weight_scale_2 = PerTensorScaleParameter( |
209 | | - data=torch.empty(len(output_partition_sizes), dtype=torch.float32), |
210 | | - weight_loader=weight_loader, |
211 | | - ) |
212 | | - layer.register_parameter("weight_scale_2", weight_scale_2) |
213 | | - |
214 | | - weight_scale = ModelWeightParameter( |
215 | | - data=torch.empty( |
216 | | - output_size_per_partition, |
217 | | - input_size_per_partition // self.quant_config.group_size, |
218 | | - dtype=weight_dtype, |
219 | | - ), |
220 | | - input_dim=1, |
221 | | - output_dim=0, |
222 | | - weight_loader=weight_loader, |
223 | | - ) |
224 | | - |
225 | | - layer.register_parameter("weight_scale", weight_scale) |
226 | | - |
227 | | - def process_weights_after_loading(self, layer: torch.nn.Module) -> None: |
228 | | - input_scale_2 = layer.input_scale.max().to(torch.float32) |
229 | | - weight_scale_2 = layer.weight_scale_2.max().to(torch.float32) |
230 | | - layer.input_scale = Parameter(input_scale_2, requires_grad=False) |
231 | | - layer.weight_scale_2 = Parameter(weight_scale_2, requires_grad=False) |
232 | | - layer.alpha = Parameter( |
233 | | - layer.input_scale * layer.weight_scale_2, requires_grad=False |
234 | | - ) |
235 | | - |
236 | | - prepare_nvfp4_layer_for_petit(layer) |
237 | | - del layer.input_scale |
238 | 10 |
|
239 | | - def apply( |
240 | | - self, |
241 | | - layer: torch.nn.Module, |
242 | | - x: torch.Tensor, |
243 | | - bias: Optional[torch.Tensor] = None, |
244 | | - ) -> torch.Tensor: |
245 | | - return apply_petit_nvfp4_linear( |
246 | | - input=x, |
247 | | - weight=layer.weight, |
248 | | - weight_scale=layer.weight_scale, |
249 | | - weight_scale_2=layer.weight_scale_2, |
250 | | - size_n=layer.output_size_per_partition, |
251 | | - size_k=layer.input_size_per_partition, |
252 | | - bias=bias, |
253 | | - ) |
| 11 | +__all__ = ["PetitNvFp4Config", "PetitNvFp4LinearMethod"] |
0 commit comments