Is your feature request related to a problem? Please describe.
The current implementation of LayerWiseOptimizer shards by params after sorting them by numel. In typical settings where the embedding and output layers are very significantly larger than any other params, it means that the last two ranks will be respectively storing the entirety of the embedding/output layers, on top of their previous assigned layers. In many cases, especially with high PP, the memory overhead will be so large than it will negate most of the benefit of distributing the optimizer states, because two ranks will have to store a large fraction of the optimizer states for their PP group.
Describe the solution you'd like
The ideal solution would be to have Muon params sharded by params as is currently the case, but keep Adam parameters distributed "the old way", e.g., dim-0 sharding. Though likely optimal, I understand that the current structure of Megatron optimizers make it very difficult. Nonetheless, I think it is primordial to find a way to shard the gigantic embedding/output states if we want to truly benefit from optim state distribution, especially since they're not affected by Muon computation and so there is no real reason to keep the matrix complete.
Describe alternatives you've considered
Perhaps a more accessible short term improvement would be to at least process params in decreasing number of numels, and instead of assigning to ranks in order, always assign to the rank with the least total numel currently assigned. At least the ranks storing embeddings and output would not be storing any other param on top of it. However, this would also cause issues, because it would lead to some ranks having a lot of small params, and others having a few large params, and Muon computation cost is not proportional to the number of parameters and would thus be imbalanced. I think there is no 'easy fix' for this and some form of heterogeneous sharding will be necessary at some point.
Is your feature request related to a problem? Please describe.
The current implementation of LayerWiseOptimizer shards by params after sorting them by numel. In typical settings where the embedding and output layers are very significantly larger than any other params, it means that the last two ranks will be respectively storing the entirety of the embedding/output layers, on top of their previous assigned layers. In many cases, especially with high PP, the memory overhead will be so large than it will negate most of the benefit of distributing the optimizer states, because two ranks will have to store a large fraction of the optimizer states for their PP group.
Describe the solution you'd like
The ideal solution would be to have Muon params sharded by params as is currently the case, but keep Adam parameters distributed "the old way", e.g., dim-0 sharding. Though likely optimal, I understand that the current structure of Megatron optimizers make it very difficult. Nonetheless, I think it is primordial to find a way to shard the gigantic embedding/output states if we want to truly benefit from optim state distribution, especially since they're not affected by Muon computation and so there is no real reason to keep the matrix complete.
Describe alternatives you've considered
Perhaps a more accessible short term improvement would be to at least process params in decreasing number of numels, and instead of assigning to ranks in order, always assign to the rank with the least total numel currently assigned. At least the ranks storing embeddings and output would not be storing any other param on top of it. However, this would also cause issues, because it would lead to some ranks having a lot of small params, and others having a few large params, and Muon computation cost is not proportional to the number of parameters and would thus be imbalanced. I think there is no 'easy fix' for this and some form of heterogeneous sharding will be necessary at some point.