Skip to content

[Draft] [WIP] Support --fp8-param-gather for mxfp8#919

Draft
zianglih wants to merge 1 commit intoradixark:mainfrom
zianglih:mxfp8-param-gather
Draft

[Draft] [WIP] Support --fp8-param-gather for mxfp8#919
zianglih wants to merge 1 commit intoradixark:mainfrom
zianglih:mxfp8-param-gather

Conversation

@zianglih
Copy link
Copy Markdown
Contributor

@zianglih zianglih commented Apr 6, 2026

Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces logic to reconcile Transformer Engine (TE) precision overrides by dequantizing primary weights for modules forced into high-precision compute. It also adds support for dequantizing FP8 and MXFP8 tensors during weight export and updates regex patterns for decoder layers. Feedback highlights that removing the fp8_model_init context manager may prevent weights from initializing in FP8, rendering the reconciliation logic ineffective. Additionally, suggestions were made to optimize weight export by gathering quantized tensors before dequantization to reduce communication overhead and to consolidate redundant module iterations during model initialization.

Comment on lines +379 to +382
model = GPTModel(**kwargs)

if args.fp8_param_gather:
_reconcile_te_precision_overrides_for_fp8_param_gather(model, config)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The removal of the fp8_model_init context manager is a significant change. This context manager is typically required by TransformerEngine to ensure that weights are initialized in FP8 when --fp8-param-gather is enabled. Without it, TE modules will likely initialize their primary weights in high precision (e.g., BF16), which would make the subsequent _reconcile_te_precision_overrides_for_fp8_param_gather call ineffective as it won't find any FP8 weights to dequantize. Please verify if the initialization path still correctly produces quantized weights when expected.

Comment on lines +122 to 130
export_param = _dequantize_for_export(info.name, param)
# Prepare async all_gather
if "expert_bias" in info.name:
gather_tasks.append((info, param, None, None, None, None))
gather_tasks.append((info, export_param, None, None, None, None))
handles.append(None)
elif not param.tensor_model_parallel or getattr(param, "parallel_mode", None) == "duplicated":
gather_tasks.append((info, param.data, None, None, None, None))
gather_tasks.append((info, export_param, None, None, None, None))
handles.append(None)
else:
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

Dequantizing parameters before the all_gather operation increases the communication volume by 2x (e.g., from 1 byte in FP8 to 2 bytes in BF16/FP16). For large models, this can significantly slow down weight transfer or export and increase peak memory usage on the source rank. It would be more efficient to perform the all_gather on the quantized tensors (and their associated metadata) and then dequantize the gathered result on the destination rank.

Comment on lines +144 to +154
for module_name, module in model.named_modules():
if not _is_high_precision_te_recipe(_get_active_te_recipe(module)):
continue

has_fp8_primary = any(
is_float8tensor(p) or is_float8tensor(getattr(p, "data", p))
for p in module.parameters(recurse=False)
if p is not None
)
if has_fp8_primary:
incompatible_modules.append(module_name)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

This second pass over all modules to check for remaining FP8 primary weights is redundant and adds overhead during model initialization. This safety check could be integrated into the first loop (lines 129-141) to avoid iterating over the entire module tree twice.

@zianglih zianglih changed the title Support --fp8-param-gather for mxfp8 [WIP] Support --fp8-param-gather for mxfp8 Apr 7, 2026
@zianglih zianglih changed the title [WIP] Support --fp8-param-gather for mxfp8 [Draft] [WIP] Support --fp8-param-gather for mxfp8 Apr 9, 2026
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant