Experimental support for fairscale ShardedDDP#9139
Conversation
|
wrt your notes on GPU memory consumption improvements - from what I have seen checking GPU allocation often doesn't show the real difference, as pytorch tends to use more than it absolutely needs if there is spare memory - or rather it can go with less when the memory is tight - so to get the best improvements stats it's the best to try to push instead the BS until it OOMs, and then you get a more precise difference - which usually leads to more precise improvement numbers than just comparing memory allocation. This is just in my experience. All I'm saying is that probably the improvements are even better than what they seem. |
|
finetune_trainer crashes with this option: could probably extend |
|
Oh it's just because it overrides the |
|
OK, next we have this: Coincidentally I have just had the same issue with deepspeed integration when I enable its internal fp16 handling. Didn't get to the root of it yet, but removing note: I'm switching to deepspeed fp16 handling there... |
|
Is it FP16 with AMP or with apex? I don't believe fairscale is compatible with apex. |
|
native amp See the command line I'm testing with at: |
| other choices will force the requested backend. | ||
| sharded_ddp (:obj:`bool`, `optional`, defaults to :obj:`False`): | ||
| Use Sharded DDP training from `FairScale <https://github.com/facebookresearch/fairscale>`__ (in distributed | ||
| training only). This is an experimental feature. |
|
|
||
|
|
||
| if is_fairscale_available(): | ||
| from fairscale.optim import OSS |
There was a problem hiding this comment.
OSS is a bit cryptic to me, but I think it's still better to use the "real" name instead of import OSS as OptimizerStateSharding -> so good for me!
There was a problem hiding this comment.
Yeah I'm using the same convention they do too, to not surprise any user.
hey there, a bit late, but one of the fairscale/shardedDDP author. The issue with Apex (and vanilla Torch) grad scaler is that it does not know about the gradient sharding, so not all the ranks will have the same behaviour. Torch AMP is supported though, you just have to pass in the ShardedGradScaler as defined here https://github.com/facebookresearch/fairscale/blob/master/fairscale/optim/grad_scaler.py#L24 |
|
Yes, we're passing that scaler :-) The issue was with AMP not Apex. It looks like there is a problem with or without FP16 with one of models. |
What does this PR do?
This PR adds support for FairScale's shared DDP training to save GPU memory when training distributed models. Initial tests see a nice reduction of GPU memory used indeed!
This follows the steps of the main example provided on the FairScale repo, integrating them in our Trainer API. To activate training with shared DDP, one must pass along the flag
--sharded_ddpin a distributed launch command.Benchmarks tried:
bert_base_uncased-> goes from 5GB per GPU to 4GB per GPU with no hurt on accuracyxlnet_large-cased-> goes from 11.5GB per GPU to 8GB per GPU (didn't go until the end so didn't check if the accuracy was the same. Training loss seemed equivalent.)