Fix Megatron-FSDP optimizer CPU offload and checkpointing#4623
Conversation
|
/ok to test |
@wplf, there was an error processing your request: See the following link for more information: https://docs.gha-runners.nvidia.com/cpr/e/1/ |
|
/ok to test 5da7d67 |
Handle Megatron-FSDP DTensor parameters and gradients by operating on local shards before CPU optimizer offload copies. This avoids dispatching pin_memory/is_pinned through DTensor and lets pin_cpu_params control CPU parameter pinning.
eff7bfb to
cdaac89
Compare
cdaac89 to
2f2442b
Compare
2f2442b to
21ba0b9
Compare
cspades
left a comment
There was a problem hiding this comment.
Some important nits, basically if we can move most of this code into MFSDP source it would be much cleaner and better.
INFO: Could you also add more commentary explaining the lifecycle of the offloaded DTensors here? Just want to make sure if this is the right way to implement this.
| if isinstance(self.optimizer, HybridDeviceOptimizer): | ||
| packed_state = self._pack_hybrid_optimizer_fsdp_state_dict() | ||
| else: | ||
| packed_state = { | ||
| (self._param_name(k) if isinstance(k, torch.Tensor) else k): v | ||
| for k, v in self.state.items() | ||
| } |
There was a problem hiding this comment.
This optimizer isn't just used for FSDP right? We should check that the parameters are FSDP parameters before doing this.
| def _pack_hybrid_optimizer_fsdp_state( | ||
| self, param: torch.nn.Parameter, param_name: str, state: dict[str, Any] | ||
| ) -> dict[str, Any]: | ||
| """Convert HybridDeviceOptimizer FSDP-local tensor state to DTensors.""" |
There was a problem hiding this comment.
Why not just use preprocess_state_dict_for_uneven_dtensor or update_uneven_dtensor_chunk_metadata instead of re-implementing it here? (Also, we should have all of these types of utilities inside the Megatron-FSDP source directory unless it absolutely needs to be in the training loop / optimizer code!)
| local_tensor = data.to_local() | ||
| else: | ||
| assert data.numel() == dist_param.numel(), ( | ||
| f"DTensor shape mismatch: data.shape={data.shape} vs " | ||
| f"dist_param.shape={global_shape}" | ||
| ) | ||
| local_tensor = data.to_local().view(-1) |
There was a problem hiding this comment.
These can also be DTensor._local_tensor BTW! Plus more instances of this in this same file.
| value = make_fsdp_dtensor( | ||
| value.data.view(-1), | ||
| flat_param, | ||
| dist_index=param.megatron_fsdp_dist_index, | ||
| is_expert_param=is_expert_param, | ||
| run_check=False, | ||
| update_uneven_dtensor_chunk_meta=False, | ||
| ) | ||
| self._set_flat_fsdp_dtensor_chunk_metadata( | ||
| value, param.megatron_fsdp_slice | ||
| ) |
There was a problem hiding this comment.
So TL;DR I assume the packed_state is our offloaded state, and we want them to be DTensors as well?
| def _set_flat_fsdp_dtensor_chunk_metadata(tensor: DTensor, fsdp_slice: slice) -> None: | ||
| """Attach DCP chunk metadata for a flat FSDP-local shard without collectives.""" |
There was a problem hiding this comment.
We can just re-factor the existing function in uneven_dtensor.py to skip the metadata AG update! Less code outside MFSDP the better!
Summary
HybridDeviceOptimizerstate forfsdp_dtensorcheckpoints as deterministic FSDP DTensors, skipping duplicatemaster_paramentries and attaching local DCP chunk metadata without metadata collectives.Fixes #4910.
Tests
uv run isort megatron/core/optimizer/distrib_optimizer.pypython -m py_compile megatron/core/optimizer/distrib_optimizer.py megatron/core/optimizer/cpu_offloading/hybrid_optimizer.py megatron/core/transformer/fsdp_dtensor_checkpoint.pyPYTHONDONTWRITEBYTECODE=1 PYTHONPATH=. uv run python -m pytest tests/unit_tests/test_optimizer_cpu_offloading.py -q(72 passed)optimizer_cpu_offload=True,optimizer_offload_fraction=1,overlap_cpu_optimizer_d2h_h2d=True,use_precision_aware_optimizer=True, andfsdp_dtensoroptimizer checkpointing: checkpoint saves at iterations 10 and 12 completed successfully.