[fix] Add detach() to fp32 param shard for leaf-tensor consistency#23
Merged
guapisolo merged 1 commit intomiles-mainfrom Apr 14, 2026
Merged
[fix] Add detach() to fp32 param shard for leaf-tensor consistency#23guapisolo merged 1 commit intomiles-mainfrom
guapisolo merged 1 commit intomiles-mainfrom
Conversation
| # fp32 params. | ||
| elif model_param.type() == 'torch.cuda.FloatTensor': | ||
| shard_model_param = model_param.view(-1)[param_range.start : param_range.end] | ||
| # Keep shard tensors as leaf tensors for torch Optimizer. |
Author
There was a problem hiding this comment.
Why do this?
Attach reason to the pr description. BF16 already fix this but fp32 not
yueming-yuan
approved these changes
Apr 14, 2026
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Context
The float16/bf16 branch (line 372) already uses detach().
This change makes the fp32 branch consistent. The fp32 branch becomes relevant when parameters are intentionally kept in fp32 (e.g., Qwen3.5's A_log via enforce_marked_param_dtypes).
Why it won't break anything
detach() only disconnects the shard from the autograd graph (making it a leaf tensor). It does not copy data — the shard still shares the same underlying storage as model_param. Every downstream consumer of shard_fp32_groups uses manual data operations, never autograd:
Gradient flow (_copy_model_grads_to_main_grads): Reads from model_param.main_grad, slices it, and assigns to shard.grad via direct attribute assignment — autograd is not involved.
Parameter writeback (copy_main_params_to_model_params): Re-slices from the param buffer and calls shard_model_param.data.copy() — autograd is not involved.
Checkpoint load (copy_model_params_to_main_params): Re-slices from model_param.view(-1)[...] and calls shard_main_param.data.copy() — autograd is not involved.
zero_grad: Clears .grad attribute or sets it to None — autograd is not involved.
Optimizer step: PyTorch optimizer operates on shard.data and shard.grad directly — in fact, it prefers leaf tensors and may warn or error on non-leaf params.
In short: Megatron's DistributedOptimizer completely bypasses PyTorch autograd for all gradient and parameter movement. detach() is a no-op in terms of data and behavior — it only satisfies PyTorch's expectation that optimizer params are leaf tensors.