Skip to content

[GDN] Keep A_log and out_norm in FP32 for numerical stability#3634

Closed
cuichenx wants to merge 1 commit intoNVIDIA:mainfrom
cuichenx:chcui/qwen3.5_precision
Closed

[GDN] Keep A_log and out_norm in FP32 for numerical stability#3634
cuichenx wants to merge 1 commit intoNVIDIA:mainfrom
cuichenx:chcui/qwen3.5_precision

Conversation

@cuichenx
Copy link
Contributor

What does this PR do ?

Summary

Qwen3.5 (the successor to Qwen3-Next) stores the GDN A_log and out_norm (output layernorm) parameters in FP32 in its released checkpoints, while the rest of the model is BF16. The current GatedDeltaNet implementation allocates these in config.params_dtype, causing dtype mismatches during checkpoint conversion. This PR aligns the Megatron implementation with the upstream precision convention.

  • Store A_log parameter in FP32 instead of config.params_dtype
  • Cast out_norm (output layernorm) to FP32 after construction
  • Keep A_log initialization in FP32 in reset_parameters
  • No changes to dt_bias (remains at config.params_dtype)

Motivation

The Qwen3.5 family (both the 27B dense and three MoE variants) releases A_log and linear_attn.norm.weight in FP32, even though the rest of the checkpoint is BF16. This is a deliberate precision choice by the model authors for numerical stability.

The GDN forward pass already computes the decay gate g in FP32:

g = -self.A_log.exp() * F.softplus(alpha.float() + self.dt_bias)  # In fp32

However, the A_log parameter and out_norm weights were previously allocated in config.params_dtype (typically BF16). This creates two problems:

  1. Precision loss in A_log: A_log stores log(A) where A is a per-head decay rate. The exponential A_log.exp() amplifies quantization error from BF16 storage, even though the compute is done in FP32.

  2. Precision loss in out_norm: The output RMSNorm applied after the recurrence operates on per-head-dim scale factors. Storing these in BF16 introduces unnecessary rounding.

Both parameters are small (per-head or per-head-dim vectors), so the FP32 memory overhead is negligible.

Changes

megatron/core/ssm/gated_delta_net.py:

  • __init__: A_log allocated with dtype=torch.float32 (was config.params_dtype)
  • __init__: out_norm cast to FP32 via self.out_norm.to(torch.float32) after build_module
  • reset_parameters: A_log initialization uses dtype=torch.float32 (was config.params_dtype)

Precedent

This follows the existing convention in MambaMixer, where A_log is already unconditionally stored in FP32:

# mamba_mixer.py, line 336-341
A = torch.empty(self.nheads_local_tp, dtype=torch.float32, ...)
A_log = torch.log(A)  # Keep A_log in fp32
self.A_log = nn.Parameter(A_log)

Alternative designs considered

Option A: Configurable via TransformerConfig flags

Add two boolean fields (linear_A_log_in_fp32, linear_out_norm_in_fp32) to TransformerConfig, defaulting to True, and have GatedDeltaNet read them.

  • Pro: Models that need BF16 for these params (e.g., Qwen3-Next) can set the flags to False for bitwise-identical checkpoint roundtrips.
  • Con: TransformerConfig already has ~197 fields. Adding per-parameter precision knobs for tiny vectors (one scalar per head) is hard to justify given the negligible memory difference. Every GDN-based model bridge must also remember to set these flags correctly.

Why we chose hardcoded FP32 instead:

  • The memory overhead is negligible (these are per-head vectors, not weight matrices).
  • The forward pass already computes in FP32 regardless.
  • MambaMixer already hardcodes A_log to FP32 — this aligns the two SSM-family modules.
  • For models that previously stored these in BF16 (e.g., Qwen3-Next), the BF16-to-FP32 upcast is lossless (all BF16 values are exactly representable in FP32), so checkpoint loading is unaffected. Export will produce FP32 values that are mathematically identical to the original BF16 values.
  • Qwen3-Next is superseded by Qwen3.5, which already uses FP32 for these params.
    ⚠️ For major changes (either in lines of code or in its impact), please make sure to first share a design doc with the team. If you're unsure what's the best way to do so, contact the @mcore-oncall.

Contribution process

flowchart LR
    A[Pre-checks] --> B[PR Tests]
    subgraph Code Review/Approval
        C1[Expert Review] --> C2[Final Review]
    end
    B --> C1
    C2 --> D[Merge]
Loading

Pre-checks

  • I want this PR in a versioned release and have added the appropriate Milestone (e.g., Core 0.8)
  • I have added relevant unit tests
  • I have added relevant functional tests
  • I have added proper typing to my code Typing guidelines
  • I have added relevant documentation
  • I have run the autoformatter.sh on my PR

Code review

The following process is enforced via the CODEOWNERS file for changes into megatron/core. For changes outside of megatron/core, it is up to the PR author whether or not to tag the Final Reviewer team.

For MRs into `main` branch

Feel free to message or comment the @mcore-oncall to help accelerate your merge into main. The less complex your PR is, the faster it will be approved and merged!

(Step 1): Add PR label Expert Review

(Step 2): Collect the expert reviewers reviews

  1. Attach the Expert Review label when your PR is ready for review.
  2. GitHub auto-assigns expert reviewers based on your changes. They will get notified and pick up your PR soon.

⚠️ Only proceed to the next step once all reviewers have approved, merge-conflict are resolved and the CI is passing.
Final Review might get declined if these requirements are not fulfilled.

(Step 3): Final Review

  1. Add Final Review label
  2. GitHub auto-assigns final reviewers based on your changes. They will get notified and pick up your PR soon.

(Optional Step 4): Cherry-pick into release branch

If this PR also needs to be merged into core_r* release branches, after this PR has been merged, select Cherry-pick to open a new PR into the release branch.

For MRs into `dev` branch The proposed review process for `dev` branch is under active discussion.

MRs are mergable after one approval by either eharper@nvidia.com or zijiey@nvidia.com.

Merging your PR

Any member of core-adlr and core-nemo will be able to merge your PR.

Signed-off-by: Chen Cui <chcui@nvidia.com>
@cuichenx cuichenx requested review from a team as code owners February 27, 2026 00:51
@svcnvidia-nemo-ci svcnvidia-nemo-ci added this to the Core 0.16 milestone Feb 27, 2026
@svcnvidia-nemo-ci svcnvidia-nemo-ci requested a review from a team February 27, 2026 00:52
@yuzhongw-nvidia
Copy link
Contributor

yuzhongw-nvidia commented Feb 27, 2026

Hi @cuichenx , thanks for your work, it is a very good catch. But I'm afraid that your change may not work because the model will be wrapped with Float16Module here so the params will be converted back to bfloat16.

Please refer to moe expert bias (impl, usage) to refine this PR. And please also help double-check whether the two dtypes are correct in an E2E training loop. Thanks!

@yuzhongw-nvidia
Copy link
Contributor

Also, please help double-check whether the two changes are correct. My observation is that, although the tensor dtype of GDN out norm and A_log are fp32 in the checkpoint file, they are still constructed and maintained as bf16 tensors in the code of HF (code).

I run an example to verify it as below.

>>> model = transformers.AutoModelForCausalLM.from_pretrained("Qwen/Qwen3.5-35B-A3B")
>>> model.model.layers[0].linear_attn.A_log.dtype
torch.bfloat16
>>> model.model.layers[0].linear_attn.norm.weight.dtype
torch.bfloat16

I'm not sure which one is correct. Could you please help find out what data type does qwen3.5 use here on earth? Thanks!

@cuichenx
Copy link
Contributor Author

Thanks @yuzhongw-nvidia, marking this as draft for now

@cuichenx cuichenx marked this pull request as draft February 27, 2026 17:27
@copy-pr-bot
Copy link

copy-pr-bot bot commented Feb 27, 2026

Auto-sync is disabled for draft pull requests in this repository. Workflows must be run manually.

Contributors can view more details about this message here.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants