Skip to content

Commit 83ba755

Browse files
committed
Add initial num_layers_at_end_in_bf16
1 parent 3ac8a67 commit 83ba755

3 files changed

Lines changed: 85 additions & 8 deletions

File tree

miles/backends/megatron_utils/megatron_to_hf/processors/quantizer_mxfp8.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,16 @@ def quantize_params_mxfp8(args, megatron_name, converted_named_params, quantizat
2020
else:
2121
layer_idx, rest = match.groups()
2222

23+
# Skip quantization for BF16 tail of main decoder layers.
24+
if getattr(args, "first_last_layers_bf16", False):
25+
assert hasattr(args, "num_layers"), "Expected args.num_layers to be set for BF16 tail."
26+
num_layers = int(args.num_layers)
27+
assert num_layers > 0, "Expected args.num_layers > 0 for BF16 tail."
28+
num_layers_at_end_in_bf16 = int(getattr(args, "num_layers_at_end_in_bf16", 0))
29+
tail_start_idx = max(0, num_layers - num_layers_at_end_in_bf16)
30+
if int(layer_idx) >= tail_start_idx:
31+
return converted_named_params
32+
2333
# experts
2434
expert_pattern = r"mlp.experts\.(.+)\.weight(\d+)"
2535
match = re.match(expert_pattern, rest)

scripts/run_qwen3_30b_a3b.py

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@ class ScriptArgs(U.ExecuteTrainConfig):
2222
train_fp8: bool = False
2323
enable_megatron_bridge: bool = False
2424
enable_mis: bool = False
25+
num_layers_at_end_in_bf16: int = 0
2526
# TODO improve, should be able to override more easily
2627
tis_use_rs: bool = True
2728

@@ -46,7 +47,10 @@ def prepare(args: ScriptArgs):
4647
mxfp8_path = f"/root/models/{args.model_name}-MXFP8"
4748
if not os.path.isdir(mxfp8_path):
4849
U.exec_command(
49-
f"python tools/convert_hf_to_mxfp8.py --model-dir /root/models/{args.model_name} --save-dir {mxfp8_path}"
50+
"python tools/convert_hf_to_mxfp8.py "
51+
f"--model-dir /root/models/{args.model_name} "
52+
f"--save-dir {mxfp8_path} "
53+
f"--num-layers-at-end-in-bf16 {args.num_layers_at_end_in_bf16}"
5054
)
5155

5256
if not args.enable_megatron_bridge:
@@ -184,6 +188,13 @@ def execute(args: ScriptArgs):
184188
"NVTE_FP8_BLOCK_SCALING_FP32_SCALES": "1",
185189
}
186190

191+
if args.train_fp8 and args.num_layers_at_end_in_bf16 > 0:
192+
misc_args += (
193+
"--first-last-layers-bf16 "
194+
"--num-layers-at-start-in-bf16 0 "
195+
f"--num-layers-at-end-in-bf16 {args.num_layers_at_end_in_bf16} "
196+
)
197+
187198
if args.enable_megatron_bridge:
188199
misc_args += "--megatron-to-hf-mode bridge "
189200

tools/convert_hf_to_mxfp8.py

Lines changed: 63 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
import gc
1111
import json
1212
import os
13+
import re
1314
import shutil
1415

1516
import safetensors
@@ -37,10 +38,14 @@
3738
)
3839

3940

40-
def should_quantize(name: str, weight: torch.Tensor) -> bool:
41+
def should_quantize(
42+
name: str,
43+
weight: torch.Tensor,
44+
skip_weight_substrings=SKIP_WEIGHT_SUBSTRINGS,
45+
) -> bool:
4146
if not name.endswith(".weight"):
4247
return False
43-
if any(substr in name for substr in SKIP_WEIGHT_SUBSTRINGS):
48+
if any(substr in name for substr in skip_weight_substrings):
4449
return False
4550
if weight.dtype not in (torch.float16, torch.bfloat16, torch.float32):
4651
return False
@@ -94,6 +99,8 @@ def process_file(
9499
filename: str,
95100
result_collector: ConversionResult,
96101
device: str,
102+
num_hidden_layers: int,
103+
num_layers_at_end_in_bf16: int,
97104
) -> None:
98105
if not filename.endswith(".safetensors"):
99106
return
@@ -106,8 +113,23 @@ def process_file(
106113
weights[key] = f.get_tensor(key)
107114

108115
modules_to_not_convert: list[str] = []
116+
tail_start_idx = max(0, num_hidden_layers - num_layers_at_end_in_bf16)
117+
num_maybe_mtp_layers = 1
118+
dynamic_skip_layer_prefixes: set[str] = {
119+
f"model.layers.{i}." for i in range(tail_start_idx, num_hidden_layers + num_maybe_mtp_layers)
120+
}
121+
122+
dynamic_skip_substrings = (
123+
*SKIP_WEIGHT_SUBSTRINGS,
124+
*sorted(dynamic_skip_layer_prefixes),
125+
)
126+
109127
for key, tensor in weights.items():
110-
if should_quantize(key, tensor):
128+
if should_quantize(
129+
key,
130+
tensor,
131+
skip_weight_substrings=dynamic_skip_substrings,
132+
):
111133
qweight, scale = quantize_mxfp8(tensor)
112134
q_weights[key] = qweight
113135
q_weights[key.replace(".weight", ".weight_scale_inv")] = scale
@@ -120,10 +142,19 @@ def process_file(
120142
result_collector.add_result(filename, q_weights, modules_to_not_convert)
121143

122144

123-
def convert_mxfp8(model_dir: str, save_dir: str, device: str) -> None:
145+
def convert_mxfp8(
146+
model_dir: str,
147+
save_dir: str,
148+
device: str,
149+
num_layers_at_end_in_bf16: int = 0,
150+
) -> None:
124151
input_path = os.path.abspath(model_dir)
125152
output_path = os.path.abspath(save_dir)
126153
os.makedirs(output_path, exist_ok=True)
154+
config_path = os.path.join(input_path, "config.json")
155+
with open(config_path) as f:
156+
cfg = json.load(f)
157+
num_hidden_layers = int(cfg["num_hidden_layers"])
127158

128159
for filename in os.listdir(input_path):
129160
if not filename.endswith(".safetensors") and not os.path.isdir(os.path.join(input_path, filename)):
@@ -133,7 +164,15 @@ def convert_mxfp8(model_dir: str, save_dir: str, device: str) -> None:
133164

134165
result_collector = ConversionResult()
135166
for filename in tqdm(safetensors_files, desc="Processing files"):
136-
process_file(input_path, output_path, filename, result_collector, device)
167+
process_file(
168+
input_path,
169+
output_path,
170+
filename,
171+
result_collector,
172+
device,
173+
num_hidden_layers,
174+
num_layers_at_end_in_bf16,
175+
)
137176
gc.collect()
138177
if torch.cuda.is_available():
139178
torch.cuda.empty_cache()
@@ -146,7 +185,13 @@ def convert_mxfp8(model_dir: str, save_dir: str, device: str) -> None:
146185
"scale_fmt": "ue8m0",
147186
}
148187
if len(result_collector.modules_to_not_convert) > 0:
149-
quantization_config["modules_to_not_convert"] = list(set(result_collector.modules_to_not_convert))
188+
189+
def natural_key(s):
190+
return [int(t) if t.isdigit() else t for t in re.findall(r"\d+|\D+", s)]
191+
192+
quantization_config["modules_to_not_convert"] = sorted(
193+
list(set(result_collector.modules_to_not_convert)), key=natural_key
194+
)
150195

151196
config_path = os.path.join(input_path, "config.json")
152197
if os.path.exists(config_path):
@@ -175,6 +220,12 @@ def main() -> None:
175220
default="cuda",
176221
help="Torch device to run quantization on (default: cuda).",
177222
)
223+
parser.add_argument(
224+
"--num-layers-at-end-in-bf16",
225+
type=int,
226+
default=0,
227+
help="Keep last N decoder layers in BF16 and do not quantize them.",
228+
)
178229
args = parser.parse_args()
179230

180231
if not torch.cuda.is_available():
@@ -198,7 +249,12 @@ def main() -> None:
198249
elif not os.path.isdir(args.save_dir):
199250
raise ValueError("The save_dir should be a directory.")
200251

201-
convert_mxfp8(args.model_dir, args.save_dir, str(device))
252+
convert_mxfp8(
253+
args.model_dir,
254+
args.save_dir,
255+
str(device),
256+
num_layers_at_end_in_bf16=args.num_layers_at_end_in_bf16,
257+
)
202258

203259

204260
if __name__ == "__main__":

0 commit comments

Comments
 (0)