[Draft] [WIP] Support --fp8-param-gather for mxfp8#919
[Draft] [WIP] Support --fp8-param-gather for mxfp8#919zianglih wants to merge 1 commit intoradixark:mainfrom
--fp8-param-gather for mxfp8#919Conversation
There was a problem hiding this comment.
Code Review
This pull request introduces logic to reconcile Transformer Engine (TE) precision overrides by dequantizing primary weights for modules forced into high-precision compute. It also adds support for dequantizing FP8 and MXFP8 tensors during weight export and updates regex patterns for decoder layers. Feedback highlights that removing the fp8_model_init context manager may prevent weights from initializing in FP8, rendering the reconciliation logic ineffective. Additionally, suggestions were made to optimize weight export by gathering quantized tensors before dequantization to reduce communication overhead and to consolidate redundant module iterations during model initialization.
| model = GPTModel(**kwargs) | ||
|
|
||
| if args.fp8_param_gather: | ||
| _reconcile_te_precision_overrides_for_fp8_param_gather(model, config) |
There was a problem hiding this comment.
The removal of the fp8_model_init context manager is a significant change. This context manager is typically required by TransformerEngine to ensure that weights are initialized in FP8 when --fp8-param-gather is enabled. Without it, TE modules will likely initialize their primary weights in high precision (e.g., BF16), which would make the subsequent _reconcile_te_precision_overrides_for_fp8_param_gather call ineffective as it won't find any FP8 weights to dequantize. Please verify if the initialization path still correctly produces quantized weights when expected.
| export_param = _dequantize_for_export(info.name, param) | ||
| # Prepare async all_gather | ||
| if "expert_bias" in info.name: | ||
| gather_tasks.append((info, param, None, None, None, None)) | ||
| gather_tasks.append((info, export_param, None, None, None, None)) | ||
| handles.append(None) | ||
| elif not param.tensor_model_parallel or getattr(param, "parallel_mode", None) == "duplicated": | ||
| gather_tasks.append((info, param.data, None, None, None, None)) | ||
| gather_tasks.append((info, export_param, None, None, None, None)) | ||
| handles.append(None) | ||
| else: |
There was a problem hiding this comment.
Dequantizing parameters before the all_gather operation increases the communication volume by 2x (e.g., from 1 byte in FP8 to 2 bytes in BF16/FP16). For large models, this can significantly slow down weight transfer or export and increase peak memory usage on the source rank. It would be more efficient to perform the all_gather on the quantized tensors (and their associated metadata) and then dequantize the gathered result on the destination rank.
| for module_name, module in model.named_modules(): | ||
| if not _is_high_precision_te_recipe(_get_active_te_recipe(module)): | ||
| continue | ||
|
|
||
| has_fp8_primary = any( | ||
| is_float8tensor(p) or is_float8tensor(getattr(p, "data", p)) | ||
| for p in module.parameters(recurse=False) | ||
| if p is not None | ||
| ) | ||
| if has_fp8_primary: | ||
| incompatible_modules.append(module_name) |
There was a problem hiding this comment.
--fp8-param-gather for mxfp8--fp8-param-gather for mxfp8
--fp8-param-gather for mxfp8--fp8-param-gather for mxfp8
@HumansAnd