Implement nvfp4#546
Conversation
Summary of ChangesHello @zianglih, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed! This pull request significantly expands the model quantization capabilities by integrating NVIDIA's NVFP4 and MXFP8 formats. It provides the necessary processing logic, conversion tools for Hugging Face models, and updates to the training and inference pipelines to leverage these new, more efficient data types, particularly for MoE architectures and Blackwell hardware. The changes aim to improve performance and memory footprint for large language models. Highlights
Using Gemini Code AssistThe full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips. Invoking Gemini You can request assistance from Gemini at any point by creating a comment using either
Customization To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a Limitations & Feedback Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here. You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension. Footnotes
|
There was a problem hiding this comment.
Code Review
This pull request introduces support for NVFP4 and MXFP8 quantization, adding new quantization processors, conversion scripts, and updating run scripts. However, it introduces a critical command injection vulnerability in several scripts used for training and conversion. These scripts interpolate user-controlled input, specifically model_name, directly into shell command strings without proper sanitization, which could lead to arbitrary code execution. Beyond this security concern, there is significant code duplication of the NVFP4 quantization logic between the processor module and the conversion script, which should be refactored. Opportunities also exist to improve maintainability by refactoring repeated logic in the MXFP8 quantizer and the run scripts, and parts of the NVFP4 implementation could be optimized for better performance.
| for converted_name, param in converted_named_params: | ||
| base, role = _split_gated_pair_name(converted_name) | ||
| if base is None or role is None: | ||
| continue | ||
| if _should_quantize_param(converted_name, param, group_size): | ||
| gated_candidates.setdefault(base, {})[role] = param | ||
|
|
||
| for base, roles in gated_candidates.items(): | ||
| if "gate" in roles and "up" in roles and _should_share_gated_pair_amax(args, base): | ||
| gate_amax = roles["gate"].abs().max().to(torch.float32) | ||
| up_amax = roles["up"].abs().max().to(torch.float32) | ||
| shared_global_amax[base] = torch.max(gate_amax, up_amax) | ||
|
|
||
| quantize_named_params = [] | ||
| for converted_name, param in converted_named_params: | ||
| if not _should_quantize_param(converted_name, param, group_size): | ||
| quantize_named_params.append((converted_name, param)) | ||
| continue | ||
| base, _role = _split_gated_pair_name(converted_name) | ||
| global_amax = shared_global_amax.get(base) if base else None | ||
| qweight, block_scale, weight_scale_2 = quantize_nvfp4(param, global_amax=global_amax, group_size=group_size) | ||
| quantize_named_params.append((converted_name, qweight)) | ||
| quantize_named_params.append((converted_name.replace(".weight", ".weight_scale"), block_scale)) | ||
| quantize_named_params.append((converted_name.replace(".weight", ".weight_scale_2"), weight_scale_2)) | ||
| quantize_named_params.append( | ||
| (converted_name.replace(".weight", ".input_scale"), torch.ones_like(weight_scale_2, dtype=torch.float32)) | ||
| ) | ||
|
|
||
| return quantize_named_params | ||
|
|
||
|
|
||
| def _should_quantize_param(name, weight, group_size): | ||
| if not name.endswith(".weight"): | ||
| return False | ||
| if weight.dtype not in (torch.float16, torch.bfloat16, torch.float32): | ||
| return False | ||
| if weight.dim() < 2: | ||
| return False | ||
| if weight.shape[-1] % group_size != 0: | ||
| raise ValueError(f"Last dim {weight.shape[-1]} must be divisible by {group_size} for NVFP4 ({name}).") | ||
| return True | ||
|
|
||
|
|
||
| def _split_gated_pair_name(name: str): | ||
| for suffix, role in GATED_PAIR_SUFFIXES.items(): | ||
| if name.endswith(suffix): | ||
| return name[: -len(suffix)], role | ||
| return None, None | ||
|
|
||
|
|
||
| def cast_to_fp4x2(x: torch.Tensor) -> torch.Tensor: | ||
| """Quantize a tensor to FP4 E2M1 and pack two values per byte.""" | ||
| result = torch.zeros_like(x, dtype=torch.uint8) | ||
| result[(x >= 0.0) & (x <= 0.25)] = 0 | ||
| result[(x > 0.25) & (x < 0.75)] = 1 | ||
| result[(x >= 0.75) & (x <= 1.25)] = 2 | ||
| result[(x > 1.25) & (x < 1.75)] = 3 | ||
| result[(x >= 1.75) & (x <= 2.5)] = 4 | ||
| result[(x > 2.5) & (x < 3.5)] = 5 | ||
| result[(x >= 3.5) & (x <= 5.0)] = 6 | ||
| result[x > 5.0] = 7 | ||
|
|
||
| result[(x >= -0.25) & (x < -0.0)] = 8 | ||
| result[(x < -0.25) & (x > -0.75)] = 9 | ||
| result[(x <= -0.75) & (x >= -1.25)] = 10 | ||
| result[(x < -1.25) & (x > -1.75)] = 11 | ||
| result[(x <= -1.75) & (x >= -2.5)] = 12 | ||
| result[(x < -2.5) & (x > -3.5)] = 13 | ||
| result[(x <= -3.5) & (x >= -5.0)] = 14 | ||
| result[x < -5.0] = 15 | ||
|
|
||
| return result[:, ::2] + result[:, 1::2] * 16 | ||
|
|
||
|
|
||
| def _quantize_nvfp4_1d( | ||
| weight: torch.Tensor, | ||
| global_amax: torch.Tensor | None = None, | ||
| group_size: int = NVFP4_GROUP_SIZE, | ||
| ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: | ||
| """ | ||
| NVFP4 1D quantization (tile shape = 1x16), adapted from | ||
| TransformerEngine NVFP4QuantizerRef._quantize_blockwise_reference. | ||
|
|
||
| Returns: | ||
| qweight: uint8 packed fp4, shape (M, K // 2) | ||
| block_scale: float8_e4m3fn, shape (M, K // group_size) | ||
| global_scale: float32 scalar tensor | ||
| """ | ||
| weight = weight.contiguous() | ||
| m, n = weight.shape | ||
| if n % group_size != 0: | ||
| raise ValueError(f"NVFP4 requires K divisible by {group_size}, got {n}.") | ||
|
|
||
| weight_f = weight.to(torch.float32) | ||
| if global_amax is None: | ||
| global_amax = torch.max(torch.abs(weight_f)) | ||
| else: | ||
| global_amax = global_amax.to(device=weight.device, dtype=torch.float32) | ||
| if global_amax.item() == 0.0: | ||
| qweight = torch.zeros((m, n // 2), dtype=torch.uint8, device=weight.device) | ||
| block_scale = torch.zeros( | ||
| (m, n // group_size), | ||
| dtype=torch.float8_e4m3fn, | ||
| device=weight.device, | ||
| ) | ||
| global_scale = torch.tensor(1.0, device=weight.device, dtype=torch.float32) | ||
| return qweight, block_scale, global_scale | ||
|
|
There was a problem hiding this comment.
The functions cast_to_fp4x2, _quantize_nvfp4_1d, and quantize_nvfp4 are identical to those in the new tools/convert_hf_to_nvfp4.py script. This significant code duplication will make future maintenance difficult and error-prone, as any change would need to be manually synchronized between both files.
Please extract this shared quantization logic into a common utility module (e.g., within miles/utils/) and import it in both this file and the conversion script.
| def cast_to_fp4x2(x: torch.Tensor) -> torch.Tensor: | ||
| """Quantize a tensor to FP4 E2M1 and pack two values per byte.""" | ||
| result = torch.zeros_like(x, dtype=torch.uint8) | ||
| result[(x >= 0.0) & (x <= 0.25)] = 0 | ||
| result[(x > 0.25) & (x < 0.75)] = 1 | ||
| result[(x >= 0.75) & (x <= 1.25)] = 2 | ||
| result[(x > 1.25) & (x < 1.75)] = 3 | ||
| result[(x >= 1.75) & (x <= 2.5)] = 4 | ||
| result[(x > 2.5) & (x < 3.5)] = 5 | ||
| result[(x >= 3.5) & (x <= 5.0)] = 6 | ||
| result[x > 5.0] = 7 | ||
|
|
||
| result[(x >= -0.25) & (x < -0.0)] = 8 | ||
| result[(x < -0.25) & (x > -0.75)] = 9 | ||
| result[(x <= -0.75) & (x >= -1.25)] = 10 | ||
| result[(x < -1.25) & (x > -1.75)] = 11 | ||
| result[(x <= -1.75) & (x >= -2.5)] = 12 | ||
| result[(x < -2.5) & (x > -3.5)] = 13 | ||
| result[(x <= -3.5) & (x >= -5.0)] = 14 | ||
| result[x < -5.0] = 15 | ||
|
|
||
| return result[:, ::2] + result[:, 1::2] * 16 | ||
|
|
||
|
|
||
| def _quantize_nvfp4_1d( | ||
| weight: torch.Tensor, | ||
| global_amax: torch.Tensor | None = None, | ||
| ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: | ||
| """ | ||
| NVFP4 1D quantization (tile shape = 1x16), adapted from | ||
| TransformerEngine NVFP4QuantizerRef._quantize_blockwise_reference. | ||
|
|
||
| Returns: | ||
| qweight: uint8 packed fp4, shape (M, K // 2) | ||
| block_scale: float8_e4m3fn, shape (M, K // 16) | ||
| global_scale: float32 scalar tensor | ||
| """ | ||
| weight = weight.contiguous() | ||
| m, n = weight.shape | ||
| if n % NVFP4_GROUP_SIZE != 0: | ||
| raise ValueError(f"NVFP4 requires K divisible by {NVFP4_GROUP_SIZE}, got {n}.") | ||
|
|
||
| weight_f = weight.to(torch.float32) | ||
| if global_amax is None: | ||
| global_amax = torch.max(torch.abs(weight_f)) | ||
| else: | ||
| global_amax = global_amax.to(device=weight.device, dtype=torch.float32) | ||
| if global_amax.item() == 0.0: | ||
| qweight = torch.zeros((m, n // 2), dtype=torch.uint8, device=weight.device) | ||
| block_scale = torch.zeros( | ||
| (m, n // NVFP4_GROUP_SIZE), | ||
| dtype=torch.float8_e4m3fn, | ||
| device=weight.device, | ||
| ) | ||
| global_scale = torch.tensor(1.0, device=weight.device, dtype=torch.float32) | ||
| return qweight, block_scale, global_scale | ||
|
|
||
| fp4_max = torch.tensor(FP4_E2M1_MAX, device=weight.device, dtype=torch.float32) | ||
| fp8_max = torch.tensor(FP8_E4M3_MAX, device=weight.device, dtype=torch.float32) | ||
|
|
||
| global_encode_scale = torch.div(fp8_max * fp4_max, global_amax) | ||
| global_encode_scale = torch.min( | ||
| global_encode_scale, | ||
| torch.tensor(torch.finfo(torch.float32).max, device=weight.device, dtype=torch.float32), | ||
| ) | ||
| if global_encode_scale.item() == 0.0: | ||
| global_encode_scale = torch.tensor(1.0, device=weight.device, dtype=torch.float32) | ||
| global_decode_scale = torch.div(1.0, global_encode_scale) | ||
|
|
||
| weight_blocks = weight_f.view(m, n // NVFP4_GROUP_SIZE, NVFP4_GROUP_SIZE) | ||
| vec_max = torch.amax(torch.abs(weight_blocks), dim=-1, keepdim=True) | ||
| decode_scale = torch.div(vec_max, fp4_max) * global_encode_scale | ||
| decode_scale = torch.clamp(decode_scale, min=-fp8_max, max=fp8_max).to(torch.float8_e4m3fn) | ||
|
|
||
| encode_scale = torch.div(1.0, decode_scale.to(torch.float32) * global_decode_scale) | ||
| scaled = weight_blocks * encode_scale | ||
| clipped = torch.clamp(scaled, -fp4_max, fp4_max).reshape(m, n) | ||
|
|
||
| qweight = cast_to_fp4x2(clipped) | ||
| block_scale = decode_scale.squeeze(-1) | ||
| return qweight, block_scale, global_decode_scale | ||
|
|
||
|
|
||
| def quantize_nvfp4( | ||
| weight: torch.Tensor, | ||
| global_amax: torch.Tensor | None = None, | ||
| ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: | ||
| if weight.dim() == 2: | ||
| return _quantize_nvfp4_1d(weight, global_amax=global_amax) | ||
| if weight.dim() == 3: | ||
| if global_amax is not None: | ||
| raise ValueError("global_amax override is only supported for 2D weights.") | ||
| qweights = [] | ||
| block_scales = [] | ||
| global_scales = [] | ||
| for idx in range(weight.shape[0]): | ||
| qweight, block_scale, global_scale = _quantize_nvfp4_1d(weight[idx]) | ||
| qweights.append(qweight) | ||
| block_scales.append(block_scale) | ||
| global_scales.append(global_scale) | ||
| return ( | ||
| torch.stack(qweights, dim=0), | ||
| torch.stack(block_scales, dim=0), | ||
| torch.stack(global_scales, dim=0), | ||
| ) | ||
| raise ValueError(f"Unsupported weight rank {weight.dim()} for NVFP4 quantization.") |
There was a problem hiding this comment.
The functions cast_to_fp4x2, _quantize_nvfp4_1d, and quantize_nvfp4 are identical to those in miles/backends/megatron_utils/megatron_to_hf/processors/quantizer_nvfp4.py. This significant code duplication is problematic for maintenance.
Please extract this shared quantization logic into a common utility module to ensure consistency and simplify future updates.
| ckpt_args = ( | ||
| f"--hf-checkpoint /root/models/{args.model_name}{'-FP8' if args.rollout_fp8 else ''} " | ||
| f"--load {load_save_path} " | ||
| f"--hf-checkpoint {hf_checkpoint}/ " |
| U.exec_command( | ||
| f"python tools/convert_hf_to_mxfp8.py --model-dir /root/models/{args.model_name} --save-dir {mxfp8_path}" | ||
| ) | ||
|
|
||
| if use_nvfp4: | ||
| nvfp4_path = f"/root/models/{args.model_name}-NVFP4" | ||
| if not os.path.isdir(nvfp4_path): | ||
| U.exec_command( | ||
| f"python tools/convert_hf_to_nvfp4.py --model-dir /root/models/{args.model_name} --save-dir {nvfp4_path}" | ||
| ) |
There was a problem hiding this comment.
The script is vulnerable to command injection because it constructs shell commands using the unsanitized user-controlled input args.model_name. An attacker could provide a malicious model name containing shell metacharacters (e.g., ; rm -rf /) to execute arbitrary commands in the environment where this script is run. This input flows into U.exec_command on lines 56 and 63, and into the training command via hf_checkpoint on line 95.
| U.exec_command( | ||
| f"python tools/convert_hf_to_mxfp8.py --model-dir /root/models/{args.model_name} --save-dir {mxfp8_path}" | ||
| ) |
There was a problem hiding this comment.
|
|
||
| # experts | ||
| expert_pattern = r"mlp.experts\.(.+)\.weight(\d+)" | ||
| match = re.match(expert_pattern, rest) | ||
| if match: | ||
| rest, expert_idx = match.groups() | ||
| if rest in [ | ||
| "linear_fc1", | ||
| "linear_fc2", | ||
| ]: | ||
| return _quantize_moe_params(args, converted_named_params, group_size) | ||
|
|
||
| # shared expert | ||
| shared_expert_pattern = r"mlp.shared_experts\.(.+)" | ||
| match = re.match(shared_expert_pattern, rest) | ||
| if match: | ||
| rest = match.groups()[0] | ||
| if rest in [ | ||
| "linear_fc1.weight", | ||
| "linear_fc2.weight", |
There was a problem hiding this comment.
The logic to identify which parameters to quantize is spread across multiple conditional blocks based on regex matching. This makes the control flow complex and hard to follow.
To improve readability and maintainability, consider consolidating the decision logic. You could define the patterns for quantizable layers and then have a single block that performs the quantization if a match is found. This would also make it easier to add new quantizable layers in the future.
| use_blackwell_fp8 = args.hardware in ("GB200", "GB300") and (args.rollout_fp8 or args.train_fp8) | ||
| use_nvfp4 = args.rollout_nvfp4 | ||
| if use_nvfp4: | ||
| hf_checkpoint = f"/root/models/{args.model_name}-NVFP4" | ||
| elif use_blackwell_fp8: | ||
| hf_checkpoint = f"/root/models/{args.model_name}-MXFP8" | ||
| elif args.rollout_fp8: | ||
| hf_checkpoint = f"/root/models/{args.model_name}-FP8" | ||
| else: | ||
| hf_checkpoint = f"/root/models/{args.model_name}" |
There was a problem hiding this comment.
The logic to determine the Hugging Face checkpoint path based on quantization flags is verbose. The helper variables use_blackwell_fp8 and use_nvfp4 are also recalculated in both the prepare and execute functions.
To improve clarity and reduce repetition, you could move the calculation of these flags into the __post_init__ method of ScriptArgs and add a helper method to the class to resolve the checkpoint path.
| use_blackwell_fp8 = args.hardware in ("GB200", "GB300") and (args.rollout_fp8 or args.train_fp8) | ||
| if use_blackwell_fp8: | ||
| hf_checkpoint = f"/root/models/{args.model_name}-MXFP8" | ||
| elif args.rollout_fp8: | ||
| hf_checkpoint = f"/root/models/{args.model_name}-FP8" | ||
| else: | ||
| hf_checkpoint = f"/root/models/{args.model_name}" |
There was a problem hiding this comment.
| def cast_to_fp4x2(x: torch.Tensor) -> torch.Tensor: | ||
| """Quantize a tensor to FP4 E2M1 and pack two values per byte.""" | ||
| result = torch.zeros_like(x, dtype=torch.uint8) | ||
| result[(x >= 0.0) & (x <= 0.25)] = 0 | ||
| result[(x > 0.25) & (x < 0.75)] = 1 | ||
| result[(x >= 0.75) & (x <= 1.25)] = 2 | ||
| result[(x > 1.25) & (x < 1.75)] = 3 | ||
| result[(x >= 1.75) & (x <= 2.5)] = 4 | ||
| result[(x > 2.5) & (x < 3.5)] = 5 | ||
| result[(x >= 3.5) & (x <= 5.0)] = 6 | ||
| result[x > 5.0] = 7 | ||
|
|
||
| result[(x >= -0.25) & (x < -0.0)] = 8 | ||
| result[(x < -0.25) & (x > -0.75)] = 9 | ||
| result[(x <= -0.75) & (x >= -1.25)] = 10 | ||
| result[(x < -1.25) & (x > -1.75)] = 11 | ||
| result[(x <= -1.75) & (x >= -2.5)] = 12 | ||
| result[(x < -2.5) & (x > -3.5)] = 13 | ||
| result[(x <= -3.5) & (x >= -5.0)] = 14 | ||
| result[x < -5.0] = 15 | ||
|
|
||
| return result[:, ::2] + result[:, 1::2] * 16 |
| fp4_max = torch.tensor(FP4_E2M1_MAX, device=weight.device, dtype=torch.float32) | ||
| fp8_max = torch.tensor(FP8_E4M3_MAX, device=weight.device, dtype=torch.float32) | ||
|
|
||
| global_encode_scale = torch.div(fp8_max * fp4_max, global_amax) |
There was a problem hiding this comment.
|
My current NVFP4 dev workflow:
All commands: |
@HumansAnd
WIP.
NVFP4 RL requires the following SGLang PR:
NVFP4 QAT requires this TE PR:
NVTE_BACKWARD_OVERRIDE=high_precision|dequantizedNVIDIA/TransformerEngine#2644NVFP4 expert-oonly training requires this Megatron commit (#567):