Skip to content

[feat] Support quantization before weight resharding#2737

Merged
yaoyu-33 merged 14 commits into
NVIDIA-NeMo:mainfrom
wucong25:low_precision_resharding
May 31, 2026
Merged

[feat] Support quantization before weight resharding#2737
yaoyu-33 merged 14 commits into
NVIDIA-NeMo:mainfrom
wucong25:low_precision_resharding

Conversation

@hy2826

@hy2826 hy2826 commented Mar 11, 2026

Copy link
Copy Markdown
Contributor

What does this PR do ?

Support quantization before weight resharding.

Changelog

  • Add specific line by line info of high level changes in this PR.

GitHub Actions CI

See the CI sectionin the Contributing doc for how to trigger the CI. A Nvidia developer will need to approve and trigger the CI for external contributors.

Before your PR is "Ready for review"

Pre checks:

  • Make sure you read and followed Contributor guidelines
  • Did you write any new necessary tests?
  • Did you add or update any necessary documentation?
  • [ No ] Does the PR affect components that are optional to install? (Ex: Numba, Pynini, Apex etc)
    • Reviewer: Does the PR have correct import guards for all optional libraries?

If you haven't finished some of the above items you can still open "Draft" PR.

Additional Information

Motivation

In many current RL framework implementation such as VeRL, if we want to use low-bit quantization in rollout, the synchronization of weights between the Trainer worker and the Rollout worker follows a "Gather-then-Quantize" pattern.

The Current Workflow:
Gather: High-precision weights ($BF16/FP16$) are collected from different TP/PP ranks via all_gather or broadcast.
Quantize: The target inference rank receives the full high-precision weight and then performs quantization locally.

image

This may create communication overhead: for example, transferring $BF16$ data requires $2\times$ the bandwidth compared to $FP8$. As model sizes grow, this sync becomes a significant latency floor for each rollout iteration.

Our Design

For blockwise quantization, we propose shifting the quantization responsibility to the source ranks. By quantizing the weight shards locally before they enter the communication collective, we can reduce the data volume by $50%$ for FP8 quantization.

Technical Workflow
Take FP8 quantization as an example, the proposed "Quantize-then-Gather" approach involves three steps:
Local Quantization: Each Rank $i$ takes its local shard $W_i$ (in $BF16$) and computes:
$W_{i, fp8}$: The quantized shard.
$S_i$: The corresponding scaling factor (Scale).
Low-Precision Communication: Perform all_gather on the $FP8$ tensors and the associated Scales. Since $FP8$ is 1-byte, the communication volume is halved.
Shard Assembly: The target rank directly concatenates the $FP8$ shards. Since the quantization was done per-shard, the metadata (scales) are managed alongside the data to ensure numerical consistency.

image

Mathematical Representation

Instead of:
$$
W_{global} = \text{Gather}(W_0, W_1, \dots, W_n) \implies W_{fp8} = \text{Quantize}(W_{global})
$$
We implement:
$$
[W_{i, fp8}, S_i] = \text{Quantize}(W_i) \implies W_{fp8} = \text{Gather}(W_{0, fp8}, \dots, W_{n, fp8})
$$

Thus, the result of our design and the original pipeline should be bitwise-equal.

Usage

We introduce a new function named export_hf_weights_quant in auto_bridge.py and modify subsequent dependencies to support this new function. Compared with export_hf_weights, it takes three additional arguments:

  1. quant_block_size: a tuple stating the quantization block size, such as (1,32).
  2. should_quantize: a function that takes a string that is the Megatron weight name and decides whether or not to quantize the corresponding weight.
  3. quant_fn: a function to quantize the weight, and takes an additional optional parameter of quantization block size.

Experiment Results: reduced wall-clock time

For Qwen3-30B-A3B with FP8 quantization, on 8 * H100(80GB), after one step of warm-up, our design reduced the wall-clock time from 22.103s to 14.429s, achieving a reduction of 34.7% compared with the original "gather-then-quantize" pipeline.

Summary by CodeRabbit

  • New Features
    • Added quantized weight export workflow from Megatron to HuggingFace format, including support for custom quantization functions and block size configuration.
    • Extended export capabilities with quantization parameters while maintaining backward compatibility with existing non-quantized export paths.

@copy-pr-bot

copy-pr-bot Bot commented Mar 11, 2026

Copy link
Copy Markdown

This pull request requires additional validation before any workflows can run on NVIDIA's runners.

Pull request vetters can view their responsibilities here.

Contributors can view more details about this message here.

@coderabbitai

coderabbitai Bot commented Mar 11, 2026

Copy link
Copy Markdown
Contributor
📝 Walkthrough

Walkthrough

These changes introduce a quantized export pathway for converting Megatron model weights to HuggingFace format. New methods are added across three files: AutoBridge exposes a public API entry point, MegatronModelBridge implements the core streaming logic with dispatch registration, and ParamMapping classes handle quantization-aware weight conversion including scale tensor management.

Changes

Cohort / File(s) Summary
AutoBridge Public API
src/megatron/bridge/models/conversion/auto_bridge.py
Added export_hf_weights_quant method that mirrors the existing export path but delegates to quantized streaming with quantization parameters (should_quantize, quant_fn, quant_block_size).
MegatronModelBridge Core Logic
src/megatron/bridge/models/conversion/model_bridge.py
Implemented stream_weights_megatron_to_hf_quant method with quantization-aware weight mapping and adapter merging. Extended dispatch system to route quantized export through registered implementations, including dispatch-wrapped entry points and bridge construction logic.
ParamMapping Quantization Support
src/megatron/bridge/models/conversion/param_mapping.py
Added megatron_to_hf_quant method to MegatronParamMapping base class and all concrete mapping implementations (ColumnParallelMapping, RowParallelMapping, ReplicatedMapping, QKVMapping, GatedMLPMapping, etc.). Includes quantization decision-making, per-tensor quantization, scale tensor handling, and expert-parallel (EP) gathering. Added helper methods for EP-level scale gathering and QKV weight splitting.

Sequence Diagram

sequenceDiagram
    actor User
    participant AutoBridge
    participant MegatronModelBridge
    participant ParamMapping
    participant DispatchSystem

    User->>AutoBridge: export_hf_weights_quant(model, hf_pretrained, should_quantize, quant_fn, ...)
    AutoBridge->>DispatchSystem: Create dispatch instance
    AutoBridge->>MegatronModelBridge: stream_weights_megatron_to_hf_quant(dispatch_instance, ...)
    
    MegatronModelBridge->>MegatronModelBridge: Initialize bridge state<br/>(setup megatron_to_hf_quant mapping)
    
    loop For each parameter
        MegatronModelBridge->>ParamMapping: megatron_to_hf_quant(megatron_weights,<br/>should_quantize, quant_fn, quant_block_size)
        
        ParamMapping->>ParamMapping: should_quantize(param)?
        alt Quantization needed
            ParamMapping->>ParamMapping: Apply quant_fn(weights)
            ParamMapping->>ParamMapping: Gather scales and handle<br/>parallel distributions (TP, PP, EP)
            ParamMapping->>ParamMapping: Build output dict with weights<br/>and scale_inv tensors
        else Skip quantization
            ParamMapping->>ParamMapping: Return weights unchanged
        end
        
        ParamMapping-->>MegatronModelBridge: Dict[HF_name → torch.Tensor]
        MegatronModelBridge->>MegatronModelBridge: Yield HFWeightTuple
    end
    
    MegatronModelBridge-->>AutoBridge: Iterable[HFWeightTuple]
    AutoBridge-->>User: Stream of quantized HF weights
Loading

Estimated code review effort

🎯 4 (Complex) | ⏱️ ~45 minutes

🚥 Pre-merge checks | ✅ 3 | ❌ 1

❌ Failed checks (1 warning)

Check name Status Explanation Resolution
Test Results For Major Changes ⚠️ Warning PR introduces major quantization feature with claimed 34.7% performance improvement but provides no documented test results for new export methods or formal performance benchmarks. Add comprehensive tests for export_hf_weights_quant and stream_weights_megatron_to_hf_quant, provide formal performance benchmarks, resolve identified review bugs, and document test results.
✅ Passed checks (3 passed)
Check name Status Explanation
Description Check ✅ Passed Check skipped - CodeRabbit’s high-level summary is enabled.
Title check ✅ Passed The title '[feat] Support quantization before weight resharding' directly and clearly describes the main objective of the PR: adding support for performing quantization before weight resharding to reduce communication volume.
Docstring Coverage ✅ Passed Docstring coverage is 84.62% which is sufficient. The required threshold is 80.00%.

✏️ Tip: You can configure your own custom pre-merge checks in the settings.

✨ Finishing Touches
🧪 Generate unit tests (beta)
  • Create PR with unit tests
  • Post copyable unit tests in a comment

Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share

Comment @coderabbitai help to get the list of available commands and usage tips.

@coderabbitai coderabbitai Bot left a comment

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 6

🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Inline comments:
In `@src/megatron/bridge/models/conversion/model_bridge.py`:
- Around line 789-795: When emitting the tied lm_head weight for the quantized
path, also emit the matching lm_head scale key so FP8 tied-embedding checkpoints
remain complete: after the block that yields HFWeightTuple("lm_head.weight",
final_tensor.clone().detach()) (inside the embeddings_are_tied && hf_name ==
"model.embed_tokens.weight" branch), locate the corresponding scale tensor by
looking up the source key that mirrors hf_name with ".weight" replaced by
".weight_scale_inv" (or otherwise fetch the scale from
hf_pretrained/state/source), then yield
HFWeightTuple("lm_head.weight_scale_inv",
corresponding_scale_tensor.clone().detach()) so both the weight and its scale
are produced when tying embeddings.
- Around line 775-781: The current flow merges adapters after quantization,
which mixes BF16 adapter deltas into already-quantized tensors; to fix, either:
(A) perform adapter merging before any quantization by invoking
materialize_adapter_weights(adapter_tasks) and passing its result into
_merge_lora_adapter_weights(megatron_model, converted_weights_dict,
adapter_weights) prior to the quantization step that produces *_scale_inv
entries, or (B) explicitly reject merging on quantized exports by checking
merge_adapter_weights and the target/quantized state and raising an error (or
skipping) if merging is requested post-quantization; update the logic around
converted_weights_dict, materialize_adapter_weights, _merge_lora_adapter_weights
and the merge_adapter_weights flag accordingly.

In `@src/megatron/bridge/models/conversion/param_mapping.py`:
- Around line 953-954: The code calls should_quantize using HF-side names
(str(self.hf_param)) which violates the public API and causes incorrect
quantization decisions and errors when all sub-weights are unselected; change
these checks to evaluate should_quantize against the fused Megatron parameter
name (use self.megatron_param or the fused megatron weight identifier) and if
the predicate returns False for the fused tensor, immediately call
megatron_to_hf(megatron_weights, megatron_module) as a fallback instead of
raising; update the same pattern in the QKVMapping and GatedMLPMapping logic so
that when all sub-components are unselected (all-false), the code falls back to
megatron_to_hf rather than throwing.
- Around line 980-987: The code path in the is_expert && not is_adapter branch
can leave s_dict unassigned when full_scale is a scalar
(len(full_scale.shape)==0), causing an UnboundLocalError; update the block
around gather_from_ep_ranks_scale so that if full_scale is scalar you either
initialize s_dict = {} or broadcast the scalar into a per-expert mapping before
the for k,v in s_dict.items() loop; specifically change the logic around
full_scale, gather_from_ep_ranks_scale, and s_dict (referencing full_scale,
gather_from_ep_ranks_scale, s_dict, q_weight_dict, gather_from_ep_ranks,
full_q_weight) so s_dict is always defined (or contains per-expert entries)
prior to iterating and adding scale_name entries to q_weight_dict.
- Around line 1653-1657: This branch indexes quant_block_size[0] and [1] without
checking it; add a guard that validates quant_block_size is not None and has at
least two elements before using it (or raise a clear ValueError), e.g. check
quant_block_size is not None and len(quant_block_size) >= 2 before the
assertions that reference quant_block_size[0] and quant_block_size[1]; keep the
existing checks for head_size and hidden_size (which use
config.num_attention_heads and config.kv_channels) but ensure any early failure
reports include quant_block_size in the error message.
- Around line 944-959: The quantizers are not receiving the new quant_block_size
argument — update all megatron_to_hf_quant implementations to forward
quant_block_size into the quant_fn call (e.g., replace calls like
quant_fn(megatron_weights) with quant_fn(megatron_weights, quant_block_size)) in
ColumnParallelMapping.megatron_to_hf_quant,
RowParallelMapping.megatron_to_hf_quant, and
GatedMLPMapping.megatron_to_hf_quant so the block-size-aware quantizers receive
the parameter; ensure calls that currently check/assert quant_block_size (e.g.,
in GatedMLPMapping) pass the same quant_block_size to quant_fn and adjust any
callers expecting a 2-tuple return (q_weight, scale) accordingly.

ℹ️ Review info
⚙️ Run configuration

Configuration used: Path: .coderabbit.yaml

Review profile: CHILL

Plan: Pro

Run ID: b3bd309c-693f-4dec-8659-38dfdd2887b9

📥 Commits

Reviewing files that changed from the base of the PR and between de93536 and 8bc5c8c.

📒 Files selected for processing (3)
  • src/megatron/bridge/models/conversion/auto_bridge.py
  • src/megatron/bridge/models/conversion/model_bridge.py
  • src/megatron/bridge/models/conversion/param_mapping.py

Comment on lines +775 to +781
if merge_adapter_weights and adapter_tasks:
adapter_weights = self.materialize_adapter_weights(adapter_tasks)
converted_weights_dict = self._merge_lora_adapter_weights(
megatron_model,
converted_weights_dict,
adapter_weights,
)

@coderabbitai coderabbitai Bot Mar 11, 2026

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🔴 Critical

Merge adapters before quantizing, not after.

At this point converted_weights_dict already contains quantized payloads plus *_scale_inv entries, but materialize_adapter_weights() still gathers adapters through the full-precision path and _merge_lora_adapter_weights() iterates every dict entry. With merge_adapter_weights=True, quantized exports will silently mix BF16 deltas into FP8 tensors and even try to merge into the scale tensors. Please either reject adapter merging on the quantized path or merge in BF16 first and quantize the merged result afterward. See src/megatron/bridge/models/conversion/peft_bridge.py:560-596 and src/megatron/bridge/models/conversion/peft_bridge.py:739-800.

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@src/megatron/bridge/models/conversion/model_bridge.py` around lines 775 -
781, The current flow merges adapters after quantization, which mixes BF16
adapter deltas into already-quantized tensors; to fix, either: (A) perform
adapter merging before any quantization by invoking
materialize_adapter_weights(adapter_tasks) and passing its result into
_merge_lora_adapter_weights(megatron_model, converted_weights_dict,
adapter_weights) prior to the quantization step that produces *_scale_inv
entries, or (B) explicitly reject merging on quantized exports by checking
merge_adapter_weights and the target/quantized state and raising an error (or
skipping) if merging is requested post-quantization; update the logic around
converted_weights_dict, materialize_adapter_weights, _merge_lora_adapter_weights
and the merge_adapter_weights flag accordingly.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Currently, our design does not support quantization + LoRA. We will remove the part on LoRA and add an assertion to ensure merge_adapter_weights is False.

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@hy2826, that's the right approach — an early guard is clean and explicit. Here's a suggested assertion at the top of stream_weights_megatron_to_hf_quant:

assert not merge_adapter_weights, (
    "Quantized weight export (stream_weights_megatron_to_hf_quant) does not support "
    "LoRA adapter merging. Call with merge_adapter_weights=False."
)

Placing it right after the megatron_model = [megatron_model] normalization and before any task building will surface the misconfiguration immediately with a clear message.


✏️ Learnings added
Learnt from: hy2826
Repo: NVIDIA-NeMo/Megatron-Bridge PR: 2737
File: src/megatron/bridge/models/conversion/model_bridge.py:775-781
Timestamp: 2026-03-16T03:51:30.826Z
Learning: In `src/megatron/bridge/models/conversion/model_bridge.py`, the quantized export path `stream_weights_megatron_to_hf_quant` intentionally does not support LoRA adapter merging (quantization + LoRA is unsupported by design). The method enforces `merge_adapter_weights=False` via an assertion at the start of the function. Callers must always pass `merge_adapter_weights=False` when using the quantized export path.

Note: Learnings are effective only in the context of similar code segments. To apply general coding guidelines, please configure review instructions. You can manage existing learnings in the app.


🧠 Learnings used
Learnt from: hy2826
Repo: NVIDIA-NeMo/Megatron-Bridge PR: 2737
File: src/megatron/bridge/models/conversion/model_bridge.py:789-795
Timestamp: 2026-03-16T03:45:49.440Z
Learning: In `src/megatron/bridge/models/conversion/model_bridge.py`, embedding layers (e.g., `model.embed_tokens.weight`) and `lm_head` weights are intentionally excluded from FP8 quantization via the `should_quantize` predicate. Quantizing these layers significantly harms model performance. As a result, the tied-embedding branch in `stream_weights_megatron_to_hf_quant` does not need to emit a `lm_head.weight_scale_inv` key alongside the tied `lm_head.weight`, because these weights are never quantized and will always be in BF16.

Learnt from: CR
Repo: NVIDIA-NeMo/Megatron-Bridge PR: 0
File: CODING_GUIDELINES.md:0-0
Timestamp: 2026-03-13T21:31:52.293Z
Learning: Applies to src/megatron/bridge/models/**/*.py : Handle tensor parallel and pipeline parallel distribution correctly in weight conversion

Learnt from: CR
Repo: NVIDIA-NeMo/Megatron-Bridge PR: 0
File: CODING_GUIDELINES.md:0-0
Timestamp: 2026-03-13T21:31:52.293Z
Learning: Applies to src/megatron/bridge/models/**/*.py : Always validate tensor shapes before copying weights in model bridges

Learnt from: cuichenx
Repo: NVIDIA-NeMo/Megatron-Bridge PR: 2711
File: src/megatron/bridge/models/qwen_vl/qwen35_vl_bridge.py:378-431
Timestamp: 2026-03-11T21:23:07.051Z
Learning: Guideline: For bridge implementations under src/megatron/bridge/models/ (and similar bridge files), rely on PyTorch model.state_dict() instead of Megatron's sharded_state_dict() or dist_checkpointing paths. Ensure MTP parameter mapping keys correspond to the actual nn.Module attribute names (e.g., mtp_model_layer) and do not expect a transformer_layer prefix, which is only used by Megatron's distributed checkpointing. This pattern applies broadly to all bridge Python files in this directory.

Learnt from: HollowMan6
Repo: NVIDIA-NeMo/Megatron-Bridge PR: 2387
File: src/megatron/bridge/models/mimo/mimo_bridge.py:49-113
Timestamp: 2026-02-14T16:32:37.992Z
Learning: In `src/megatron/bridge/models/mimo/mimo_bridge.py`, the mapping_registry method intentionally iterates over both "transformer_layer" and "mtp_model_layer" prefixes for MTP layer mappings. This dual-prefix support ensures compatibility because Megatron-Core may expose the same MTP submodule under either naming convention, allowing HF MTP weights to load correctly without requiring configuration overrides.

Learnt from: aroshanghias-nvd
Repo: NVIDIA-NeMo/Megatron-Bridge PR: 2007
File: src/megatron/bridge/data/mimo/dataset.py:158-166
Timestamp: 2026-01-28T20:41:39.210Z
Learning: In Bridge Megatron-Core datasets, use torch.ones_like(input_ids) for attention masks instead of masking padding tokens. The Megatron-Core attention kernel handles causal masking internally during the forward pass, not at the dataset level. This pattern should apply broadly to Bridge datasets: SFT Dataset uses get_attention_mask_from_fusion=True (setting attention_mask=None), VLM Providers use skip_getting_attention_mask_from_dataset=True, and Qwen VL uses torch.ones_like(input_ids). Ensure dataset code consistently generates input_ids with a corresponding attention_mask of ones (or None where the kernel derives it) and verify through the dataset and forward-pass configuration that the model relies on the kernel's internal masking.

Learnt from: yaoyu-33
Repo: NVIDIA-NeMo/Megatron-Bridge PR: 2342
File: src/megatron/bridge/models/qwen_omni/thinker_model.py:326-345
Timestamp: 2026-02-13T00:22:54.549Z
Learning: In Python implementations (e.g., src/megatron/bridge/models/qwen_omni/thinker_model.py), when a feature is not supported (such as audio embeddings), raise an explicit error (e.g., NotImplementedError) instead of silently ignoring the input to fail fast with a clear message. Use a descriptive exception message that identifies the unsupported feature and the expected behavior.

If you found this review helpful, would you consider giving us a shout-out on X?

Thank you for using CodeRabbit!

Comment thread src/megatron/bridge/models/conversion/model_bridge.py Outdated
Comment on lines +944 to +959
def megatron_to_hf_quant(
self,
megatron_weights: Optional[torch.Tensor],
megatron_module: Optional[nn.Module],
should_quantize: callable,
quant_fn: callable,
quant_block_size: Optional[Tuple[int, int]] = None,
) -> Dict[str, torch.Tensor]:
"""Gather from all TP ranks and concatenate with quantization before PP broadcast."""
if not should_quantize(str(self.hf_param)):
return self.megatron_to_hf(megatron_weights, megatron_module)

q_weight, scale = None, None
if megatron_weights is not None:
megatron_weights = self.maybe_dequantize(megatron_weights)
q_weight, scale = quant_fn(megatron_weights)

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🔴 Critical

🧩 Analysis chain

🏁 Script executed:

wc -l src/megatron/bridge/models/conversion/param_mapping.py

Repository: NVIDIA-NeMo/Megatron-Bridge

Length of output: 130


🏁 Script executed:

# Check the actual code at the specified lines
sed -n '944,959p' src/megatron/bridge/models/conversion/param_mapping.py

Repository: NVIDIA-NeMo/Megatron-Bridge

Length of output: 794


🏁 Script executed:

# Check line ranges 1092-1108
sed -n '1092,1108p' src/megatron/bridge/models/conversion/param_mapping.py

Repository: NVIDIA-NeMo/Megatron-Bridge

Length of output: 807


🏁 Script executed:

# Check line ranges 2361-2391
sed -n '2361,2391p' src/megatron/bridge/models/conversion/param_mapping.py

Repository: NVIDIA-NeMo/Megatron-Bridge

Length of output: 1663


🏁 Script executed:

# Look for the base megatron_to_hf_quant method definition
sed -n '293,310p' src/megatron/bridge/models/conversion/param_mapping.py

Repository: NVIDIA-NeMo/Megatron-Bridge

Length of output: 900


🏁 Script executed:

# Search for all calls to quant_fn in the file
rg "quant_fn\(" src/megatron/bridge/models/conversion/param_mapping.py -A 2

Repository: NVIDIA-NeMo/Megatron-Bridge

Length of output: 573


🏁 Script executed:

# Check if there are any implementations that correctly pass quant_block_size to quant_fn
rg "quant_fn\(" src/megatron/bridge/models/conversion/param_mapping.py -B 2 -A 3 | grep -E "(quant_block_size|quant_fn)" | head -30

Repository: NVIDIA-NeMo/Megatron-Bridge

Length of output: 403


🏁 Script executed:

# Look for any other method definitions that might show expected usage pattern
ast-grep --pattern 'def megatron_to_hf_quant' src/megatron/bridge/models/conversion/param_mapping.py

Repository: NVIDIA-NeMo/Megatron-Bridge

Length of output: 35784


Pass quant_block_size through to quant_fn.

The new API accepts quant_block_size as a parameter, but implementations do not pass it to quant_fn. This breaks quantizers that require block size information. Additionally, GatedMLPMapping asserts that quant_block_size is not None and validates tensor shapes against it, but then calls quant_fn without passing it, causing runtime failures.

Affected locations
  • Line 959 (ColumnParallelMapping.megatron_to_hf_quant)
  • Line 1107 (RowParallelMapping.megatron_to_hf_quant)
  • Line 2391 (GatedMLPMapping.megatron_to_hf_quant)
Suggested fix
-            q_weight, scale = quant_fn(megatron_weights)
+            q_weight, scale = (
+                quant_fn(megatron_weights, quant_block_size)
+                if quant_block_size is not None
+                else quant_fn(megatron_weights)
+            )
🧰 Tools
🪛 Ruff (0.15.5)

[warning] 950-950: Unused method argument: quant_block_size

(ARG002)

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@src/megatron/bridge/models/conversion/param_mapping.py` around lines 944 -
959, The quantizers are not receiving the new quant_block_size argument — update
all megatron_to_hf_quant implementations to forward quant_block_size into the
quant_fn call (e.g., replace calls like quant_fn(megatron_weights) with
quant_fn(megatron_weights, quant_block_size)) in
ColumnParallelMapping.megatron_to_hf_quant,
RowParallelMapping.megatron_to_hf_quant, and
GatedMLPMapping.megatron_to_hf_quant so the block-size-aware quantizers receive
the parameter; ensure calls that currently check/assert quant_block_size (e.g.,
in GatedMLPMapping) pass the same quant_block_size to quant_fn and adjust any
callers expecting a 2-tuple return (q_weight, scale) accordingly.

Comment on lines +953 to +954
if not should_quantize(str(self.hf_param)):
return self.megatron_to_hf(megatron_weights, megatron_module)

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

Evaluate should_quantize on the fused Megatron param, and treat all-false as fallback.

The public API describes should_quantize as a predicate on Megatron weight names, but these checks use HF-side names instead. That makes selective quantization decisions drift from the caller's intent. Also, QKVMapping and GatedMLPMapping currently raise even when all sub-weights are unselected, so a valid “do not quantize this fused tensor” decision errors instead of falling back to megatron_to_hf().

Also applies to: 1101-1102, 1638-1643, 2370-2375

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@src/megatron/bridge/models/conversion/param_mapping.py` around lines 953 -
954, The code calls should_quantize using HF-side names (str(self.hf_param))
which violates the public API and causes incorrect quantization decisions and
errors when all sub-weights are unselected; change these checks to evaluate
should_quantize against the fused Megatron parameter name (use
self.megatron_param or the fused megatron weight identifier) and if the
predicate returns False for the fused tensor, immediately call
megatron_to_hf(megatron_weights, megatron_module) as a fallback instead of
raising; update the same pattern in the QKVMapping and GatedMLPMapping logic so
that when all sub-components are unselected (all-false), the code falls back to
megatron_to_hf rather than throwing.

Comment thread src/megatron/bridge/models/conversion/param_mapping.py
Comment thread src/megatron/bridge/models/conversion/param_mapping.py Outdated
@yaoyu-33 yaoyu-33 added feature New capabilities, enhancements, or enablement work area:ckpt Checkpoint conversion, loading, export, and save paths needs-review PR is ready for code review and waiting on a reviewer and removed community-request feature New capabilities, enhancements, or enablement work labels Mar 11, 2026
Comment thread src/megatron/bridge/models/conversion/auto_bridge.py
self,
megatron_model: Union[MegatronModel, List[MegatronModel]],
hf_pretrained: HFPreTrained,
should_quantize: callable,

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why should quantize is a callable?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Our design makes should_quantize a callable so that we can pass the module name to the should_quantize function and it can decide whether we should quantize the corresponding model weight. Usually, weights in linear modules or MoE modules are quantized while weights in layernorm, embedding layers, etc are not. Users can customize which weights to quantize.

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

understand, rename to quantize checker maybe, should quantize sounds like a bool.

Comment thread src/megatron/bridge/models/conversion/model_bridge.py Outdated
@@ -1449,6 +1227,35 @@ def _megatron_to_hf_registered_impl(
merge_adapter_weights=merge_adapter_weights,
)

@stream_weights_megatron_to_hf_quant.impl((source, target))

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

possible to make a quant bridge? Model bridge can inherit quant bridge similar to peft bridge. Easier to maintain code. You can move all quant methods there.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the suggestion. We have moved the stream_weights_megatron_to_hf_quant function to a new file called quant_bridge.py mimicing peft bridge.

@yaoyu-33

Copy link
Copy Markdown
Contributor

@hy2826 please check comment

@hy2826

hy2826 commented Mar 16, 2026

Copy link
Copy Markdown
Contributor Author

@yaoyu-33 Thanks a lot for the comments! We will make modification to the codes and push the changes soon.

@yaoyu-33 yaoyu-33 added needs-follow-up and removed needs-review PR is ready for code review and waiting on a reviewer labels Mar 23, 2026
@suiyoubi

Copy link
Copy Markdown
Contributor

Hi @hy2826 , kindly follow up here if there is any update to the PR, thanks for the contribution!

@hy2826

hy2826 commented Mar 31, 2026

Copy link
Copy Markdown
Contributor Author

Hi @hy2826 , kindly follow up here if there is any update to the PR, thanks for the contribution!

Hi @suiyoubi , thanks for the reminder! The change to the code has been finished. We are finding the machines to test our changes. Sorry for the wait!

@suiyoubi

Copy link
Copy Markdown
Contributor

Thanks @hy2826 for the update! Feel free to ping us again once it is ready for review. Thanks!

@yaoyu-33 yaoyu-33 added waiting-on-customer Waiting on the original author to respond and removed ready-to-merge PR is approved, current, and only waiting for CI to pass before merge labels May 13, 2026
@yaoyu-33

Copy link
Copy Markdown
Contributor

@hy2826 can you check Claude review, sry. Overall lgtm. Already approved, but want to improve a bit since Claude commented.

@claude

claude Bot commented May 13, 2026

Copy link
Copy Markdown
Contributor

Code Review: [feat] Support quantization before weight resharding - Issues Found: (1) Bug: gather_from_ep_ranks_scale can crash with TypeError (param_mapping.py:767-771) - local_expert_number stays None if param has no .weight/.bias suffix, causing TypeError on int(). Use extract_expert_number_from_param() instead. (2) Bug: lowercase callable used as type annotation in all changed files - not a valid type hint for mypy. Use Callable from collections.abc. (3) Typo: Biaes -> Biases (param_mapping.py:1722). (4) Copyright year: quant_bridge.py uses 2025 not 2026. - Test Coverage Gaps: No EP path tests, no gather_from_ep_ranks_scale test, no split_qkv_weights_scale with attention_output_gate test, no quantization_checker=False fallback test, no bias ValueError test for QKVMapping, no stream_weights_megatron_to_hf_quant unit test, TestQKVMappingQuant requires CUDA. - Suggested test cases: No perf tests impacted. - See inline comments for specific code suggestions.

@hy2826

hy2826 commented May 19, 2026

Copy link
Copy Markdown
Contributor Author

Hi @yaoyu-33 , thank you very much for the further feedback! We have revised our PR according to Claude's review.

Signed-off-by: hy2826 <hy2826@outlook.com>
@yaoyu-33 yaoyu-33 added ready-to-merge PR is approved, current, and only waiting for CI to pass before merge high-priority and removed waiting-on-customer Waiting on the original author to respond labels May 28, 2026
@yaoyu-33

Copy link
Copy Markdown
Contributor

/ok to test b0cb7fa

Signed-off-by: yaoyu-33 <yaoyu.094@gmail.com>
@yaoyu-33

Copy link
Copy Markdown
Contributor

/ok to test 0a2ffd2

@yaoyu-33 yaoyu-33 disabled auto-merge May 31, 2026 22:35
@yaoyu-33 yaoyu-33 merged commit 6e522fd into NVIDIA-NeMo:main May 31, 2026
72 of 73 checks passed
vasunvidia pushed a commit to vasunvidia/Megatron-Bridge that referenced this pull request Jun 10, 2026
Signed-off-by: hy2826 <hy2826@outlook.com>
Signed-off-by: yaoyu-33 <yaoyu.094@gmail.com>
Co-authored-by: wucong25 <412916467@qq.com>
Co-authored-by: yaoyu-33 <yaoyu.094@gmail.com>
Signed-off-by: Vasudevan Rengasamy <vrengasamy@nvidia.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

area:ckpt Checkpoint conversion, loading, export, and save paths community-request feature New capabilities, enhancements, or enablement work high-priority ready-to-merge PR is approved, current, and only waiting for CI to pass before merge

Projects

None yet

Development

Successfully merging this pull request may close these issues.

6 participants