1010import gc
1111import json
1212import os
13+ import re
1314import shutil
1415
1516import safetensors
3738)
3839
3940
40- def should_quantize (name : str , weight : torch .Tensor ) -> bool :
41+ def should_quantize (
42+ name : str ,
43+ weight : torch .Tensor ,
44+ skip_weight_substrings = SKIP_WEIGHT_SUBSTRINGS ,
45+ ) -> bool :
4146 if not name .endswith (".weight" ):
4247 return False
43- if any (substr in name for substr in SKIP_WEIGHT_SUBSTRINGS ):
48+ if any (substr in name for substr in skip_weight_substrings ):
4449 return False
4550 if weight .dtype not in (torch .float16 , torch .bfloat16 , torch .float32 ):
4651 return False
@@ -94,6 +99,8 @@ def process_file(
9499 filename : str ,
95100 result_collector : ConversionResult ,
96101 device : str ,
102+ num_hidden_layers : int ,
103+ num_layers_at_end_in_bf16 : int ,
97104) -> None :
98105 if not filename .endswith (".safetensors" ):
99106 return
@@ -106,8 +113,23 @@ def process_file(
106113 weights [key ] = f .get_tensor (key )
107114
108115 modules_to_not_convert : list [str ] = []
116+ tail_start_idx = max (0 , num_hidden_layers - num_layers_at_end_in_bf16 )
117+ num_maybe_mtp_layers = 1
118+ dynamic_skip_layer_prefixes : set [str ] = {
119+ f"model.layers.{ i } ." for i in range (tail_start_idx , num_hidden_layers + num_maybe_mtp_layers )
120+ }
121+
122+ dynamic_skip_substrings = (
123+ * SKIP_WEIGHT_SUBSTRINGS ,
124+ * sorted (dynamic_skip_layer_prefixes ),
125+ )
126+
109127 for key , tensor in weights .items ():
110- if should_quantize (key , tensor ):
128+ if should_quantize (
129+ key ,
130+ tensor ,
131+ skip_weight_substrings = dynamic_skip_substrings ,
132+ ):
111133 qweight , scale = quantize_mxfp8 (tensor )
112134 q_weights [key ] = qweight
113135 q_weights [key .replace (".weight" , ".weight_scale_inv" )] = scale
@@ -120,10 +142,19 @@ def process_file(
120142 result_collector .add_result (filename , q_weights , modules_to_not_convert )
121143
122144
123- def convert_mxfp8 (model_dir : str , save_dir : str , device : str ) -> None :
145+ def convert_mxfp8 (
146+ model_dir : str ,
147+ save_dir : str ,
148+ device : str ,
149+ num_layers_at_end_in_bf16 : int = 0 ,
150+ ) -> None :
124151 input_path = os .path .abspath (model_dir )
125152 output_path = os .path .abspath (save_dir )
126153 os .makedirs (output_path , exist_ok = True )
154+ config_path = os .path .join (input_path , "config.json" )
155+ with open (config_path ) as f :
156+ cfg = json .load (f )
157+ num_hidden_layers = int (cfg ["num_hidden_layers" ])
127158
128159 for filename in os .listdir (input_path ):
129160 if not filename .endswith (".safetensors" ) and not os .path .isdir (os .path .join (input_path , filename )):
@@ -133,7 +164,15 @@ def convert_mxfp8(model_dir: str, save_dir: str, device: str) -> None:
133164
134165 result_collector = ConversionResult ()
135166 for filename in tqdm (safetensors_files , desc = "Processing files" ):
136- process_file (input_path , output_path , filename , result_collector , device )
167+ process_file (
168+ input_path ,
169+ output_path ,
170+ filename ,
171+ result_collector ,
172+ device ,
173+ num_hidden_layers ,
174+ num_layers_at_end_in_bf16 ,
175+ )
137176 gc .collect ()
138177 if torch .cuda .is_available ():
139178 torch .cuda .empty_cache ()
@@ -146,7 +185,13 @@ def convert_mxfp8(model_dir: str, save_dir: str, device: str) -> None:
146185 "scale_fmt" : "ue8m0" ,
147186 }
148187 if len (result_collector .modules_to_not_convert ) > 0 :
149- quantization_config ["modules_to_not_convert" ] = list (set (result_collector .modules_to_not_convert ))
188+
189+ def natural_key (s ):
190+ return [int (t ) if t .isdigit () else t for t in re .findall (r"\d+|\D+" , s )]
191+
192+ quantization_config ["modules_to_not_convert" ] = sorted (
193+ list (set (result_collector .modules_to_not_convert )), key = natural_key
194+ )
150195
151196 config_path = os .path .join (input_path , "config.json" )
152197 if os .path .exists (config_path ):
@@ -175,6 +220,12 @@ def main() -> None:
175220 default = "cuda" ,
176221 help = "Torch device to run quantization on (default: cuda)." ,
177222 )
223+ parser .add_argument (
224+ "--num-layers-at-end-in-bf16" ,
225+ type = int ,
226+ default = 0 ,
227+ help = "Keep last N decoder layers in BF16 and do not quantize them." ,
228+ )
178229 args = parser .parse_args ()
179230
180231 if not torch .cuda .is_available ():
@@ -198,7 +249,12 @@ def main() -> None:
198249 elif not os .path .isdir (args .save_dir ):
199250 raise ValueError ("The save_dir should be a directory." )
200251
201- convert_mxfp8 (args .model_dir , args .save_dir , str (device ))
252+ convert_mxfp8 (
253+ args .model_dir ,
254+ args .save_dir ,
255+ str (device ),
256+ num_layers_at_end_in_bf16 = args .num_layers_at_end_in_bf16 ,
257+ )
202258
203259
204260if __name__ == "__main__" :
0 commit comments