[feat] Support quantization before weight resharding#2737
Conversation
📝 WalkthroughWalkthroughThese changes introduce a quantized export pathway for converting Megatron model weights to HuggingFace format. New methods are added across three files: AutoBridge exposes a public API entry point, MegatronModelBridge implements the core streaming logic with dispatch registration, and ParamMapping classes handle quantization-aware weight conversion including scale tensor management. Changes
Sequence DiagramsequenceDiagram
actor User
participant AutoBridge
participant MegatronModelBridge
participant ParamMapping
participant DispatchSystem
User->>AutoBridge: export_hf_weights_quant(model, hf_pretrained, should_quantize, quant_fn, ...)
AutoBridge->>DispatchSystem: Create dispatch instance
AutoBridge->>MegatronModelBridge: stream_weights_megatron_to_hf_quant(dispatch_instance, ...)
MegatronModelBridge->>MegatronModelBridge: Initialize bridge state<br/>(setup megatron_to_hf_quant mapping)
loop For each parameter
MegatronModelBridge->>ParamMapping: megatron_to_hf_quant(megatron_weights,<br/>should_quantize, quant_fn, quant_block_size)
ParamMapping->>ParamMapping: should_quantize(param)?
alt Quantization needed
ParamMapping->>ParamMapping: Apply quant_fn(weights)
ParamMapping->>ParamMapping: Gather scales and handle<br/>parallel distributions (TP, PP, EP)
ParamMapping->>ParamMapping: Build output dict with weights<br/>and scale_inv tensors
else Skip quantization
ParamMapping->>ParamMapping: Return weights unchanged
end
ParamMapping-->>MegatronModelBridge: Dict[HF_name → torch.Tensor]
MegatronModelBridge->>MegatronModelBridge: Yield HFWeightTuple
end
MegatronModelBridge-->>AutoBridge: Iterable[HFWeightTuple]
AutoBridge-->>User: Stream of quantized HF weights
Estimated code review effort🎯 4 (Complex) | ⏱️ ~45 minutes 🚥 Pre-merge checks | ✅ 3 | ❌ 1❌ Failed checks (1 warning)
✅ Passed checks (3 passed)
✏️ Tip: You can configure your own custom pre-merge checks in the settings. ✨ Finishing Touches🧪 Generate unit tests (beta)
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
There was a problem hiding this comment.
Actionable comments posted: 6
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Inline comments:
In `@src/megatron/bridge/models/conversion/model_bridge.py`:
- Around line 789-795: When emitting the tied lm_head weight for the quantized
path, also emit the matching lm_head scale key so FP8 tied-embedding checkpoints
remain complete: after the block that yields HFWeightTuple("lm_head.weight",
final_tensor.clone().detach()) (inside the embeddings_are_tied && hf_name ==
"model.embed_tokens.weight" branch), locate the corresponding scale tensor by
looking up the source key that mirrors hf_name with ".weight" replaced by
".weight_scale_inv" (or otherwise fetch the scale from
hf_pretrained/state/source), then yield
HFWeightTuple("lm_head.weight_scale_inv",
corresponding_scale_tensor.clone().detach()) so both the weight and its scale
are produced when tying embeddings.
- Around line 775-781: The current flow merges adapters after quantization,
which mixes BF16 adapter deltas into already-quantized tensors; to fix, either:
(A) perform adapter merging before any quantization by invoking
materialize_adapter_weights(adapter_tasks) and passing its result into
_merge_lora_adapter_weights(megatron_model, converted_weights_dict,
adapter_weights) prior to the quantization step that produces *_scale_inv
entries, or (B) explicitly reject merging on quantized exports by checking
merge_adapter_weights and the target/quantized state and raising an error (or
skipping) if merging is requested post-quantization; update the logic around
converted_weights_dict, materialize_adapter_weights, _merge_lora_adapter_weights
and the merge_adapter_weights flag accordingly.
In `@src/megatron/bridge/models/conversion/param_mapping.py`:
- Around line 953-954: The code calls should_quantize using HF-side names
(str(self.hf_param)) which violates the public API and causes incorrect
quantization decisions and errors when all sub-weights are unselected; change
these checks to evaluate should_quantize against the fused Megatron parameter
name (use self.megatron_param or the fused megatron weight identifier) and if
the predicate returns False for the fused tensor, immediately call
megatron_to_hf(megatron_weights, megatron_module) as a fallback instead of
raising; update the same pattern in the QKVMapping and GatedMLPMapping logic so
that when all sub-components are unselected (all-false), the code falls back to
megatron_to_hf rather than throwing.
- Around line 980-987: The code path in the is_expert && not is_adapter branch
can leave s_dict unassigned when full_scale is a scalar
(len(full_scale.shape)==0), causing an UnboundLocalError; update the block
around gather_from_ep_ranks_scale so that if full_scale is scalar you either
initialize s_dict = {} or broadcast the scalar into a per-expert mapping before
the for k,v in s_dict.items() loop; specifically change the logic around
full_scale, gather_from_ep_ranks_scale, and s_dict (referencing full_scale,
gather_from_ep_ranks_scale, s_dict, q_weight_dict, gather_from_ep_ranks,
full_q_weight) so s_dict is always defined (or contains per-expert entries)
prior to iterating and adding scale_name entries to q_weight_dict.
- Around line 1653-1657: This branch indexes quant_block_size[0] and [1] without
checking it; add a guard that validates quant_block_size is not None and has at
least two elements before using it (or raise a clear ValueError), e.g. check
quant_block_size is not None and len(quant_block_size) >= 2 before the
assertions that reference quant_block_size[0] and quant_block_size[1]; keep the
existing checks for head_size and hidden_size (which use
config.num_attention_heads and config.kv_channels) but ensure any early failure
reports include quant_block_size in the error message.
- Around line 944-959: The quantizers are not receiving the new quant_block_size
argument — update all megatron_to_hf_quant implementations to forward
quant_block_size into the quant_fn call (e.g., replace calls like
quant_fn(megatron_weights) with quant_fn(megatron_weights, quant_block_size)) in
ColumnParallelMapping.megatron_to_hf_quant,
RowParallelMapping.megatron_to_hf_quant, and
GatedMLPMapping.megatron_to_hf_quant so the block-size-aware quantizers receive
the parameter; ensure calls that currently check/assert quant_block_size (e.g.,
in GatedMLPMapping) pass the same quant_block_size to quant_fn and adjust any
callers expecting a 2-tuple return (q_weight, scale) accordingly.
ℹ️ Review info
⚙️ Run configuration
Configuration used: Path: .coderabbit.yaml
Review profile: CHILL
Plan: Pro
Run ID: b3bd309c-693f-4dec-8659-38dfdd2887b9
📒 Files selected for processing (3)
src/megatron/bridge/models/conversion/auto_bridge.pysrc/megatron/bridge/models/conversion/model_bridge.pysrc/megatron/bridge/models/conversion/param_mapping.py
| if merge_adapter_weights and adapter_tasks: | ||
| adapter_weights = self.materialize_adapter_weights(adapter_tasks) | ||
| converted_weights_dict = self._merge_lora_adapter_weights( | ||
| megatron_model, | ||
| converted_weights_dict, | ||
| adapter_weights, | ||
| ) |
There was a problem hiding this comment.
Merge adapters before quantizing, not after.
At this point converted_weights_dict already contains quantized payloads plus *_scale_inv entries, but materialize_adapter_weights() still gathers adapters through the full-precision path and _merge_lora_adapter_weights() iterates every dict entry. With merge_adapter_weights=True, quantized exports will silently mix BF16 deltas into FP8 tensors and even try to merge into the scale tensors. Please either reject adapter merging on the quantized path or merge in BF16 first and quantize the merged result afterward. See src/megatron/bridge/models/conversion/peft_bridge.py:560-596 and src/megatron/bridge/models/conversion/peft_bridge.py:739-800.
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In `@src/megatron/bridge/models/conversion/model_bridge.py` around lines 775 -
781, The current flow merges adapters after quantization, which mixes BF16
adapter deltas into already-quantized tensors; to fix, either: (A) perform
adapter merging before any quantization by invoking
materialize_adapter_weights(adapter_tasks) and passing its result into
_merge_lora_adapter_weights(megatron_model, converted_weights_dict,
adapter_weights) prior to the quantization step that produces *_scale_inv
entries, or (B) explicitly reject merging on quantized exports by checking
merge_adapter_weights and the target/quantized state and raising an error (or
skipping) if merging is requested post-quantization; update the logic around
converted_weights_dict, materialize_adapter_weights, _merge_lora_adapter_weights
and the merge_adapter_weights flag accordingly.
There was a problem hiding this comment.
Currently, our design does not support quantization + LoRA. We will remove the part on LoRA and add an assertion to ensure merge_adapter_weights is False.
There was a problem hiding this comment.
@hy2826, that's the right approach — an early guard is clean and explicit. Here's a suggested assertion at the top of stream_weights_megatron_to_hf_quant:
assert not merge_adapter_weights, (
"Quantized weight export (stream_weights_megatron_to_hf_quant) does not support "
"LoRA adapter merging. Call with merge_adapter_weights=False."
)Placing it right after the megatron_model = [megatron_model] normalization and before any task building will surface the misconfiguration immediately with a clear message.
✏️ Learnings added
Learnt from: hy2826
Repo: NVIDIA-NeMo/Megatron-Bridge PR: 2737
File: src/megatron/bridge/models/conversion/model_bridge.py:775-781
Timestamp: 2026-03-16T03:51:30.826Z
Learning: In `src/megatron/bridge/models/conversion/model_bridge.py`, the quantized export path `stream_weights_megatron_to_hf_quant` intentionally does not support LoRA adapter merging (quantization + LoRA is unsupported by design). The method enforces `merge_adapter_weights=False` via an assertion at the start of the function. Callers must always pass `merge_adapter_weights=False` when using the quantized export path.
Note: Learnings are effective only in the context of similar code segments. To apply general coding guidelines, please configure review instructions. You can manage existing learnings in the app.
🧠 Learnings used
Learnt from: hy2826
Repo: NVIDIA-NeMo/Megatron-Bridge PR: 2737
File: src/megatron/bridge/models/conversion/model_bridge.py:789-795
Timestamp: 2026-03-16T03:45:49.440Z
Learning: In `src/megatron/bridge/models/conversion/model_bridge.py`, embedding layers (e.g., `model.embed_tokens.weight`) and `lm_head` weights are intentionally excluded from FP8 quantization via the `should_quantize` predicate. Quantizing these layers significantly harms model performance. As a result, the tied-embedding branch in `stream_weights_megatron_to_hf_quant` does not need to emit a `lm_head.weight_scale_inv` key alongside the tied `lm_head.weight`, because these weights are never quantized and will always be in BF16.
Learnt from: CR
Repo: NVIDIA-NeMo/Megatron-Bridge PR: 0
File: CODING_GUIDELINES.md:0-0
Timestamp: 2026-03-13T21:31:52.293Z
Learning: Applies to src/megatron/bridge/models/**/*.py : Handle tensor parallel and pipeline parallel distribution correctly in weight conversion
Learnt from: CR
Repo: NVIDIA-NeMo/Megatron-Bridge PR: 0
File: CODING_GUIDELINES.md:0-0
Timestamp: 2026-03-13T21:31:52.293Z
Learning: Applies to src/megatron/bridge/models/**/*.py : Always validate tensor shapes before copying weights in model bridges
Learnt from: cuichenx
Repo: NVIDIA-NeMo/Megatron-Bridge PR: 2711
File: src/megatron/bridge/models/qwen_vl/qwen35_vl_bridge.py:378-431
Timestamp: 2026-03-11T21:23:07.051Z
Learning: Guideline: For bridge implementations under src/megatron/bridge/models/ (and similar bridge files), rely on PyTorch model.state_dict() instead of Megatron's sharded_state_dict() or dist_checkpointing paths. Ensure MTP parameter mapping keys correspond to the actual nn.Module attribute names (e.g., mtp_model_layer) and do not expect a transformer_layer prefix, which is only used by Megatron's distributed checkpointing. This pattern applies broadly to all bridge Python files in this directory.
Learnt from: HollowMan6
Repo: NVIDIA-NeMo/Megatron-Bridge PR: 2387
File: src/megatron/bridge/models/mimo/mimo_bridge.py:49-113
Timestamp: 2026-02-14T16:32:37.992Z
Learning: In `src/megatron/bridge/models/mimo/mimo_bridge.py`, the mapping_registry method intentionally iterates over both "transformer_layer" and "mtp_model_layer" prefixes for MTP layer mappings. This dual-prefix support ensures compatibility because Megatron-Core may expose the same MTP submodule under either naming convention, allowing HF MTP weights to load correctly without requiring configuration overrides.
Learnt from: aroshanghias-nvd
Repo: NVIDIA-NeMo/Megatron-Bridge PR: 2007
File: src/megatron/bridge/data/mimo/dataset.py:158-166
Timestamp: 2026-01-28T20:41:39.210Z
Learning: In Bridge Megatron-Core datasets, use torch.ones_like(input_ids) for attention masks instead of masking padding tokens. The Megatron-Core attention kernel handles causal masking internally during the forward pass, not at the dataset level. This pattern should apply broadly to Bridge datasets: SFT Dataset uses get_attention_mask_from_fusion=True (setting attention_mask=None), VLM Providers use skip_getting_attention_mask_from_dataset=True, and Qwen VL uses torch.ones_like(input_ids). Ensure dataset code consistently generates input_ids with a corresponding attention_mask of ones (or None where the kernel derives it) and verify through the dataset and forward-pass configuration that the model relies on the kernel's internal masking.
Learnt from: yaoyu-33
Repo: NVIDIA-NeMo/Megatron-Bridge PR: 2342
File: src/megatron/bridge/models/qwen_omni/thinker_model.py:326-345
Timestamp: 2026-02-13T00:22:54.549Z
Learning: In Python implementations (e.g., src/megatron/bridge/models/qwen_omni/thinker_model.py), when a feature is not supported (such as audio embeddings), raise an explicit error (e.g., NotImplementedError) instead of silently ignoring the input to fail fast with a clear message. Use a descriptive exception message that identifies the unsupported feature and the expected behavior.
If you found this review helpful, would you consider giving us a shout-out on X?
Thank you for using CodeRabbit!
| def megatron_to_hf_quant( | ||
| self, | ||
| megatron_weights: Optional[torch.Tensor], | ||
| megatron_module: Optional[nn.Module], | ||
| should_quantize: callable, | ||
| quant_fn: callable, | ||
| quant_block_size: Optional[Tuple[int, int]] = None, | ||
| ) -> Dict[str, torch.Tensor]: | ||
| """Gather from all TP ranks and concatenate with quantization before PP broadcast.""" | ||
| if not should_quantize(str(self.hf_param)): | ||
| return self.megatron_to_hf(megatron_weights, megatron_module) | ||
|
|
||
| q_weight, scale = None, None | ||
| if megatron_weights is not None: | ||
| megatron_weights = self.maybe_dequantize(megatron_weights) | ||
| q_weight, scale = quant_fn(megatron_weights) |
There was a problem hiding this comment.
🧩 Analysis chain
🏁 Script executed:
wc -l src/megatron/bridge/models/conversion/param_mapping.pyRepository: NVIDIA-NeMo/Megatron-Bridge
Length of output: 130
🏁 Script executed:
# Check the actual code at the specified lines
sed -n '944,959p' src/megatron/bridge/models/conversion/param_mapping.pyRepository: NVIDIA-NeMo/Megatron-Bridge
Length of output: 794
🏁 Script executed:
# Check line ranges 1092-1108
sed -n '1092,1108p' src/megatron/bridge/models/conversion/param_mapping.pyRepository: NVIDIA-NeMo/Megatron-Bridge
Length of output: 807
🏁 Script executed:
# Check line ranges 2361-2391
sed -n '2361,2391p' src/megatron/bridge/models/conversion/param_mapping.pyRepository: NVIDIA-NeMo/Megatron-Bridge
Length of output: 1663
🏁 Script executed:
# Look for the base megatron_to_hf_quant method definition
sed -n '293,310p' src/megatron/bridge/models/conversion/param_mapping.pyRepository: NVIDIA-NeMo/Megatron-Bridge
Length of output: 900
🏁 Script executed:
# Search for all calls to quant_fn in the file
rg "quant_fn\(" src/megatron/bridge/models/conversion/param_mapping.py -A 2Repository: NVIDIA-NeMo/Megatron-Bridge
Length of output: 573
🏁 Script executed:
# Check if there are any implementations that correctly pass quant_block_size to quant_fn
rg "quant_fn\(" src/megatron/bridge/models/conversion/param_mapping.py -B 2 -A 3 | grep -E "(quant_block_size|quant_fn)" | head -30Repository: NVIDIA-NeMo/Megatron-Bridge
Length of output: 403
🏁 Script executed:
# Look for any other method definitions that might show expected usage pattern
ast-grep --pattern 'def megatron_to_hf_quant' src/megatron/bridge/models/conversion/param_mapping.pyRepository: NVIDIA-NeMo/Megatron-Bridge
Length of output: 35784
Pass quant_block_size through to quant_fn.
The new API accepts quant_block_size as a parameter, but implementations do not pass it to quant_fn. This breaks quantizers that require block size information. Additionally, GatedMLPMapping asserts that quant_block_size is not None and validates tensor shapes against it, but then calls quant_fn without passing it, causing runtime failures.
Affected locations
- Line 959 (ColumnParallelMapping.megatron_to_hf_quant)
- Line 1107 (RowParallelMapping.megatron_to_hf_quant)
- Line 2391 (GatedMLPMapping.megatron_to_hf_quant)
Suggested fix
- q_weight, scale = quant_fn(megatron_weights)
+ q_weight, scale = (
+ quant_fn(megatron_weights, quant_block_size)
+ if quant_block_size is not None
+ else quant_fn(megatron_weights)
+ )🧰 Tools
🪛 Ruff (0.15.5)
[warning] 950-950: Unused method argument: quant_block_size
(ARG002)
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In `@src/megatron/bridge/models/conversion/param_mapping.py` around lines 944 -
959, The quantizers are not receiving the new quant_block_size argument — update
all megatron_to_hf_quant implementations to forward quant_block_size into the
quant_fn call (e.g., replace calls like quant_fn(megatron_weights) with
quant_fn(megatron_weights, quant_block_size)) in
ColumnParallelMapping.megatron_to_hf_quant,
RowParallelMapping.megatron_to_hf_quant, and
GatedMLPMapping.megatron_to_hf_quant so the block-size-aware quantizers receive
the parameter; ensure calls that currently check/assert quant_block_size (e.g.,
in GatedMLPMapping) pass the same quant_block_size to quant_fn and adjust any
callers expecting a 2-tuple return (q_weight, scale) accordingly.
| if not should_quantize(str(self.hf_param)): | ||
| return self.megatron_to_hf(megatron_weights, megatron_module) |
There was a problem hiding this comment.
Evaluate should_quantize on the fused Megatron param, and treat all-false as fallback.
The public API describes should_quantize as a predicate on Megatron weight names, but these checks use HF-side names instead. That makes selective quantization decisions drift from the caller's intent. Also, QKVMapping and GatedMLPMapping currently raise even when all sub-weights are unselected, so a valid “do not quantize this fused tensor” decision errors instead of falling back to megatron_to_hf().
Also applies to: 1101-1102, 1638-1643, 2370-2375
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In `@src/megatron/bridge/models/conversion/param_mapping.py` around lines 953 -
954, The code calls should_quantize using HF-side names (str(self.hf_param))
which violates the public API and causes incorrect quantization decisions and
errors when all sub-weights are unselected; change these checks to evaluate
should_quantize against the fused Megatron parameter name (use
self.megatron_param or the fused megatron weight identifier) and if the
predicate returns False for the fused tensor, immediately call
megatron_to_hf(megatron_weights, megatron_module) as a fallback instead of
raising; update the same pattern in the QKVMapping and GatedMLPMapping logic so
that when all sub-components are unselected (all-false), the code falls back to
megatron_to_hf rather than throwing.
| self, | ||
| megatron_model: Union[MegatronModel, List[MegatronModel]], | ||
| hf_pretrained: HFPreTrained, | ||
| should_quantize: callable, |
There was a problem hiding this comment.
why should quantize is a callable?
There was a problem hiding this comment.
Our design makes should_quantize a callable so that we can pass the module name to the should_quantize function and it can decide whether we should quantize the corresponding model weight. Usually, weights in linear modules or MoE modules are quantized while weights in layernorm, embedding layers, etc are not. Users can customize which weights to quantize.
There was a problem hiding this comment.
understand, rename to quantize checker maybe, should quantize sounds like a bool.
| @@ -1449,6 +1227,35 @@ def _megatron_to_hf_registered_impl( | |||
| merge_adapter_weights=merge_adapter_weights, | |||
| ) | |||
|
|
|||
| @stream_weights_megatron_to_hf_quant.impl((source, target)) | |||
There was a problem hiding this comment.
possible to make a quant bridge? Model bridge can inherit quant bridge similar to peft bridge. Easier to maintain code. You can move all quant methods there.
There was a problem hiding this comment.
Thanks for the suggestion. We have moved the stream_weights_megatron_to_hf_quant function to a new file called quant_bridge.py mimicing peft bridge.
|
@hy2826 please check comment |
|
@yaoyu-33 Thanks a lot for the comments! We will make modification to the codes and push the changes soon. |
|
Hi @hy2826 , kindly follow up here if there is any update to the PR, thanks for the contribution! |
|
Thanks @hy2826 for the update! Feel free to ping us again once it is ready for review. Thanks! |
|
@hy2826 can you check Claude review, sry. Overall lgtm. Already approved, but want to improve a bit since Claude commented. |
Code Review: [feat] Support quantization before weight resharding - Issues Found: (1) Bug: gather_from_ep_ranks_scale can crash with TypeError (param_mapping.py:767-771) - local_expert_number stays None if param has no .weight/.bias suffix, causing TypeError on int(). Use extract_expert_number_from_param() instead. (2) Bug: lowercase callable used as type annotation in all changed files - not a valid type hint for mypy. Use Callable from collections.abc. (3) Typo: Biaes -> Biases (param_mapping.py:1722). (4) Copyright year: quant_bridge.py uses 2025 not 2026. - Test Coverage Gaps: No EP path tests, no gather_from_ep_ranks_scale test, no split_qkv_weights_scale with attention_output_gate test, no quantization_checker=False fallback test, no bias ValueError test for QKVMapping, no stream_weights_megatron_to_hf_quant unit test, TestQKVMappingQuant requires CUDA. - Suggested test cases: No perf tests impacted. - See inline comments for specific code suggestions. |
|
Hi @yaoyu-33 , thank you very much for the further feedback! We have revised our PR according to Claude's review. |
Signed-off-by: hy2826 <hy2826@outlook.com>
|
/ok to test b0cb7fa |
Signed-off-by: yaoyu-33 <yaoyu.094@gmail.com>
|
/ok to test 0a2ffd2 |
Signed-off-by: hy2826 <hy2826@outlook.com> Signed-off-by: yaoyu-33 <yaoyu.094@gmail.com> Co-authored-by: wucong25 <412916467@qq.com> Co-authored-by: yaoyu-33 <yaoyu.094@gmail.com> Signed-off-by: Vasudevan Rengasamy <vrengasamy@nvidia.com>
What does this PR do ?
Support quantization before weight resharding.
Changelog
GitHub Actions CI
See the CI sectionin the Contributing doc for how to trigger the CI. A Nvidia developer will need to approve and trigger the CI for external contributors.
Before your PR is "Ready for review"
Pre checks:
If you haven't finished some of the above items you can still open "Draft" PR.
Additional Information
Motivation
In many current RL framework implementation such as VeRL, if we want to use low-bit quantization in rollout, the synchronization of weights between the Trainer worker and the Rollout worker follows a "Gather-then-Quantize" pattern.
The Current Workflow:$BF16/FP16$ ) are collected from different TP/PP ranks via all_gather or broadcast.
Gather: High-precision weights (
Quantize: The target inference rank receives the full high-precision weight and then performs quantization locally.
This may create communication overhead: for example, transferring$BF16$ data requires $2\times$ the bandwidth compared to $FP8$ . As model sizes grow, this sync becomes a significant latency floor for each rollout iteration.
Our Design
For blockwise quantization, we propose shifting the quantization responsibility to the source ranks. By quantizing the weight shards locally before they enter the communication collective, we can reduce the data volume by$50%$ for FP8 quantization.
Technical Workflow$i$ takes its local shard $W_i$ (in $BF16$ ) and computes:
$W_{i, fp8}$ : The quantized shard.
$S_i$ : The corresponding scaling factor (Scale).$FP8$ tensors and the associated Scales. Since $FP8$ is 1-byte, the communication volume is halved.$FP8$ shards. Since the quantization was done per-shard, the metadata (scales) are managed alongside the data to ensure numerical consistency.
Take FP8 quantization as an example, the proposed "Quantize-then-Gather" approach involves three steps:
Local Quantization: Each Rank
Low-Precision Communication: Perform all_gather on the
Shard Assembly: The target rank directly concatenates the
Mathematical Representation
Instead of:
$$
W_{global} = \text{Gather}(W_0, W_1, \dots, W_n) \implies W_{fp8} = \text{Quantize}(W_{global})
$$
$$
[W_{i, fp8}, S_i] = \text{Quantize}(W_i) \implies W_{fp8} = \text{Gather}(W_{0, fp8}, \dots, W_{n, fp8})
$$
We implement:
Thus, the result of our design and the original pipeline should be bitwise-equal.
Usage
We introduce a new function named export_hf_weights_quant in auto_bridge.py and modify subsequent dependencies to support this new function. Compared with export_hf_weights, it takes three additional arguments:
Experiment Results: reduced wall-clock time
For Qwen3-30B-A3B with FP8 quantization, on 8 * H100(80GB), after one step of warm-up, our design reduced the wall-clock time from 22.103s to 14.429s, achieving a reduction of 34.7% compared with the original "gather-then-quantize" pipeline.
Summary by CodeRabbit