fix(megatron-fsdp): compute SWiGLU/GDN split in item coordinates for non-DTensor optimizer states#4423
fix(megatron-fsdp): compute SWiGLU/GDN split in item coordinates for non-DTensor optimizer states#4423xuwchen wants to merge 2 commits into
Conversation
|
This PR has been automatically converted to draft because all PRs must start as drafts. When you are ready for review, click Ready for Review to begin the review process. This will:
See the contribution guide for more details. |
| assert data.shape[swiglu_shard_axis] % 2 == 0, ( | ||
| f"SWiGLU weights must have an even size along the shard axis {swiglu_shard_axis}, " | ||
| f"got {data.shape[swiglu_shard_axis]}" | ||
| # Use dist_param (always a DTensor) for global shape/numel, | ||
| # as data may be a regular Tensor (e.g., optimizer states). | ||
| global_shape = dist_param.shape | ||
| if isinstance(data, DTensor): | ||
| assert data.shape == global_shape, ( | ||
| f"DTensor shape mismatch: data.shape={data.shape} vs " | ||
| f"dist_param.shape={global_shape}" | ||
| ) | ||
| assert global_shape[swiglu_shard_axis] % 2 == 0, ( | ||
| f"SWiGLU FC1 must have even global size along axis {swiglu_shard_axis}, " | ||
| f"got {global_shape[swiglu_shard_axis]} (global_shape={list(global_shape)})" | ||
| ) |
There was a problem hiding this comment.
So basically, if data is a DTensor, we should ensure it matches the state DTensor. And then, we always use the global shape. (I've hit this error before when FusedAdam was broken, which I fixed here: NVIDIA/TransformerEngine#2795)
When would data not be a DTensor in the context of this function? 👀 Is it a bug or just so this function can be used outside of MLM / MBridge checkpointing?
There was a problem hiding this comment.
Good point! Your fix in NVIDIA/TransformerEngine#2795 would indeed resolve this issue by making FusedAdam optimizer states (exp_avg, exp_avg_sq) DTensors, so data would always be a DTensor and the original code would work correctly.
The case we've hit so far is specifically FusedAdam producing plain Tensor optimizer states. I'm not 100% sure whether other code paths could also pass plain Tensors to this function, but since split_swiglu_linear_fc1 doesn't enforce any type constraint on data, I think it's safer to keep this fix on the MCore side as defensive programming. The DTensor case is the "easy" case, we just assert shape consistency there. The real fix ensures that even if data is a plain Tensor, the W/V split boundaries are computed in the correct (global) coordinate system.
I've attached a visualization based on a deepseek proxy model experiment to help illustrate this. Hope this clarifies the motivation behind the fix.
|
/ok to test e6aa004 |
What does this PR do ?
split_swiglu_linear_fc1andsplit_gdn_fusedoriginally useddata.numel() // tp_mesh.mesh.numel()to compute component slices (w_slice/v_slicefor SWiGLU,comp_slicefor GDN), then intersected them withfsdp_sliceto determine each rank's portion. Whendatais a DTensor,data.numel()returns the global element count (=param.numel()), so component slices andfsdp_sliceare both in item coordinates and the intersection is valid. However, whendatais a plain Tensor (optimizer states created by FusedAdam),data.numel()returns the local shard size on this rank, making the component slices's range far smaller than the item coordinate space and resulting in an intersection across incompatible coordinate systems.The fix replaces
data.numel()withdist_param.numel()(dist_paramis always a DTensor and therefore reflects the global element count), ensuringfsdp_sliceand all component slices are in item coordinates.Contribution process
Pre-checks
Code review
Feel free to message or comment the @mcore-oncall to help accelerate your merge into main. The less complex your PR is, the faster it will be approved and merged!
All PRs start as draft. If you open a non-draft PR, it will be automatically converted to draft.
Step 1: Mark PR as "Ready for Review"
.github/CODEOWNERS.Final Review might get declined if these requirements are not fulfilled.
Step 2: Final Review
For PRs that change
megatron/core, once all expert reviewers have approved, theFinal Reviewlabel is applied automatically and final reviewers are assigned.For PRs outside
megatron/core, this step is skipped.Step 3: Approved
Once all required reviewers have approved, the
Approvedlabel is applied automatically.Merge
Any member of mcore-engineers will be able to merge your PR.
For MRs into `dev` branch
The proposed review process for `dev` branch is under active discussion.MRs are mergable after one approval by either
eharper@nvidia.comorzijiey@nvidia.com.