[GDN] Keep A_log and out_norm in FP32 for numerical stability#3634
[GDN] Keep A_log and out_norm in FP32 for numerical stability#3634cuichenx wants to merge 1 commit intoNVIDIA:mainfrom
A_log and out_norm in FP32 for numerical stability#3634Conversation
Signed-off-by: Chen Cui <chcui@nvidia.com>
|
Hi @cuichenx , thanks for your work, it is a very good catch. But I'm afraid that your change may not work because the model will be wrapped with Please refer to moe expert bias (impl, usage) to refine this PR. And please also help double-check whether the two dtypes are correct in an E2E training loop. Thanks! |
|
Also, please help double-check whether the two changes are correct. My observation is that, although the tensor dtype of GDN out norm and A_log are fp32 in the checkpoint file, they are still constructed and maintained as bf16 tensors in the code of HF (code). I run an example to verify it as below. I'm not sure which one is correct. Could you please help find out what data type does qwen3.5 use here on earth? Thanks! |
|
Thanks @yuzhongw-nvidia, marking this as draft for now |
|
Auto-sync is disabled for draft pull requests in this repository. Workflows must be run manually. Contributors can view more details about this message here. |
What does this PR do ?
Summary
Qwen3.5 (the successor to Qwen3-Next) stores the GDN
A_logandout_norm(output layernorm) parameters in FP32 in its released checkpoints, while the rest of the model is BF16. The current GatedDeltaNet implementation allocates these inconfig.params_dtype, causing dtype mismatches during checkpoint conversion. This PR aligns the Megatron implementation with the upstream precision convention.A_logparameter in FP32 instead ofconfig.params_dtypeout_norm(output layernorm) to FP32 after constructionA_loginitialization in FP32 inreset_parametersdt_bias(remains atconfig.params_dtype)Motivation
The Qwen3.5 family (both the 27B dense and three MoE variants) releases
A_logandlinear_attn.norm.weightin FP32, even though the rest of the checkpoint is BF16. This is a deliberate precision choice by the model authors for numerical stability.The GDN forward pass already computes the decay gate
gin FP32:However, the
A_logparameter andout_normweights were previously allocated inconfig.params_dtype(typically BF16). This creates two problems:Precision loss in
A_log:A_logstoreslog(A)whereAis a per-head decay rate. The exponentialA_log.exp()amplifies quantization error from BF16 storage, even though the compute is done in FP32.Precision loss in
out_norm: The output RMSNorm applied after the recurrence operates on per-head-dim scale factors. Storing these in BF16 introduces unnecessary rounding.Both parameters are small (per-head or per-head-dim vectors), so the FP32 memory overhead is negligible.
Changes
megatron/core/ssm/gated_delta_net.py:__init__:A_logallocated withdtype=torch.float32(wasconfig.params_dtype)__init__:out_normcast to FP32 viaself.out_norm.to(torch.float32)afterbuild_modulereset_parameters:A_loginitialization usesdtype=torch.float32(wasconfig.params_dtype)Precedent
This follows the existing convention in
MambaMixer, whereA_logis already unconditionally stored in FP32:Alternative designs considered
Option A: Configurable via TransformerConfig flags
Add two boolean fields (
linear_A_log_in_fp32,linear_out_norm_in_fp32) toTransformerConfig, defaulting toTrue, and haveGatedDeltaNetread them.Falsefor bitwise-identical checkpoint roundtrips.TransformerConfigalready has ~197 fields. Adding per-parameter precision knobs for tiny vectors (one scalar per head) is hard to justify given the negligible memory difference. Every GDN-based model bridge must also remember to set these flags correctly.Why we chose hardcoded FP32 instead:
MambaMixeralready hardcodesA_logto FP32 — this aligns the two SSM-family modules.Contribution process
flowchart LR A[Pre-checks] --> B[PR Tests] subgraph Code Review/Approval C1[Expert Review] --> C2[Final Review] end B --> C1 C2 --> D[Merge]Pre-checks
Core 0.8)Code review
The following process is enforced via the CODEOWNERS file for changes into
megatron/core. For changes outside ofmegatron/core, it is up to the PR author whether or not to tag the Final Reviewer team.For MRs into `main` branch
Feel free to message or comment the @mcore-oncall to help accelerate your merge into main. The less complex your PR is, the faster it will be approved and merged!
(Step 1): Add PR label
Expert Review(Step 2): Collect the expert reviewers reviews
Expert Reviewlabel when your PR is ready for review.Final Review might get declined if these requirements are not fulfilled.
(Step 3): Final Review
Final Reviewlabel(Optional Step 4): Cherry-pick into release branch
If this PR also needs to be merged into
core_r*release branches, after this PR has been merged, selectCherry-pickto open a new PR into the release branch.For MRs into `dev` branch
The proposed review process for `dev` branch is under active discussion.MRs are mergable after one approval by either
eharper@nvidia.comorzijiey@nvidia.com.Merging your PR
Any member of core-adlr and
core-nemowill be able to merge your PR.