Skip to content

[fix] Add detach() to fp32 param shard for leaf-tensor consistency#23

Merged
guapisolo merged 1 commit intomiles-mainfrom
fix/fp32
Apr 14, 2026
Merged

[fix] Add detach() to fp32 param shard for leaf-tensor consistency#23
guapisolo merged 1 commit intomiles-mainfrom
fix/fp32

Conversation

@guapisolo
Copy link
Copy Markdown

@guapisolo guapisolo commented Apr 14, 2026

Context

The float16/bf16 branch (line 372) already uses detach().

shard_model_param = model_param.detach().view(-1)[
    param_range.start : param_range.end
]

This change makes the fp32 branch consistent. The fp32 branch becomes relevant when parameters are intentionally kept in fp32 (e.g., Qwen3.5's A_log via enforce_marked_param_dtypes).

Why it won't break anything

detach() only disconnects the shard from the autograd graph (making it a leaf tensor). It does not copy data — the shard still shares the same underlying storage as model_param. Every downstream consumer of shard_fp32_groups uses manual data operations, never autograd:

Gradient flow (_copy_model_grads_to_main_grads): Reads from model_param.main_grad, slices it, and assigns to shard.grad via direct attribute assignment — autograd is not involved.

Parameter writeback (copy_main_params_to_model_params): Re-slices from the param buffer and calls shard_model_param.data.copy() — autograd is not involved.

Checkpoint load (copy_model_params_to_main_params): Re-slices from model_param.view(-1)[...] and calls shard_main_param.data.copy() — autograd is not involved.

zero_grad: Clears .grad attribute or sets it to None — autograd is not involved.

Optimizer step: PyTorch optimizer operates on shard.data and shard.grad directly — in fact, it prefers leaf tensors and may warn or error on non-leaf params.

In short: Megatron's DistributedOptimizer completely bypasses PyTorch autograd for all gradient and parameter movement. detach() is a no-op in terms of data and behavior — it only satisfies PyTorch's expectation that optimizer params are leaf tensors.

# fp32 params.
elif model_param.type() == 'torch.cuda.FloatTensor':
shard_model_param = model_param.view(-1)[param_range.start : param_range.end]
# Keep shard tensors as leaf tensors for torch Optimizer.
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why do this?

Copy link
Copy Markdown
Author

@guapisolo guapisolo Apr 14, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why do this?

Attach reason to the pr description. BF16 already fix this but fp32 not

@guapisolo guapisolo changed the title fix small issue when pytorch has bug as optimizer param [fix] fp32 gradient flow warning Apr 14, 2026
@guapisolo guapisolo changed the title [fix] fp32 gradient flow warning [fix] Add detach() to fp32 param shard for leaf-tensor consistency Apr 14, 2026
@guapisolo guapisolo merged commit 32dbe9f into miles-main Apr 14, 2026
3 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants