Skip to content

Experimental support for fairscale ShardedDDP#9139

Merged
sgugger merged 5 commits intomasterfrom
sharded_ddp
Dec 16, 2020
Merged

Experimental support for fairscale ShardedDDP#9139
sgugger merged 5 commits intomasterfrom
sharded_ddp

Conversation

@sgugger
Copy link
Copy Markdown
Collaborator

@sgugger sgugger commented Dec 15, 2020

What does this PR do?

This PR adds support for FairScale's shared DDP training to save GPU memory when training distributed models. Initial tests see a nice reduction of GPU memory used indeed!

This follows the steps of the main example provided on the FairScale repo, integrating them in our Trainer API. To activate training with shared DDP, one must pass along the flag --sharded_ddp in a distributed launch command.

Benchmarks tried:

  • a fine-tuning on MRPC with bert_base_uncased -> goes from 5GB per GPU to 4GB per GPU with no hurt on accuracy
  • a fine-tuning on SQUAD v2 with xlnet_large-cased -> goes from 11.5GB per GPU to 8GB per GPU (didn't go until the end so didn't check if the accuracy was the same. Training loss seemed equivalent.)

@stas00
Copy link
Copy Markdown
Contributor

stas00 commented Dec 15, 2020

wrt your notes on GPU memory consumption improvements - from what I have seen checking GPU allocation often doesn't show the real difference, as pytorch tends to use more than it absolutely needs if there is spare memory - or rather it can go with less when the memory is tight - so to get the best improvements stats it's the best to try to push instead the BS until it OOMs, and then you get a more precise difference - which usually leads to more precise improvement numbers than just comparing memory allocation. This is just in my experience.

All I'm saying is that probably the improvements are even better than what they seem.

Comment thread src/transformers/trainer.py Outdated
Copy link
Copy Markdown
Member

@LysandreJik LysandreJik left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That's cool, very clean!

Comment thread src/transformers/trainer.py Outdated
@stas00 stas00 changed the title Experimental stupport for fairscale ShardedDDP Experimental support for fairscale ShardedDDP Dec 15, 2020
@stas00
Copy link
Copy Markdown
Contributor

stas00 commented Dec 15, 2020

finetune_trainer crashes with this option:

export BS=4; rm -r output_dir; CUDA_VISIBLE_DEVICES=0,1 PYTHONPATH=../../src USE_TF=0   python -m torch.distributed.launch --nproc_per_node=2  ./finetune_trainer.py --model_name_or_path sshleifer/distill-mbart-en-ro-12-4 --output_dir output_dir --adam_eps 1e-06 --data_dir wmt_en_ro --do_train --fp16 --freeze_embeds --label_smoothing 0.1 --learning_rate 3e-5 --logging_first_step --logging_steps 1000 --max_source_length 128 --max_target_length 128 --num_train_epochs 1 --overwrite_output_dir --per_device_train_batch_size $BS --sortish_sampler --src_lang en_XX --task translation --tgt_lang ro_RO --val_max_target_length 128 --warmup_steps 500 --n_train 500 --sharded_ddp
Traceback (most recent call last):
 File "./finetune_trainer.py", line 379, in <module>
   main()
 File "./finetune_trainer.py", line 315, in main
   trainer.train(
 File "/mnt/nvme1/code/huggingface/transformers-master/src/transformers/trainer.py", line 677, in train
   model = ShardedDDP(model, self.optimizer)
 File "/home/stas/anaconda3/envs/main-38/lib/python3.8/site-packages/fairscale/nn/data_parallel/sharded_ddp.py", line 96, in __init__
   self._param_iterator = chain(*[optim.should_bucket_param.keys() for optim in self.sharded_optimizers])
TypeError: 'AdamW' object is not iterable

could probably extend test_finetune_trainer.py to deploy this option if fairscale is available? but CIs won't have it - and it's quite slow to build

@sgugger
Copy link
Copy Markdown
Collaborator Author

sgugger commented Dec 15, 2020

Oh it's just because it overrides the create_optimizer_and_scheduler method. Will fix that method.

@stas00
Copy link
Copy Markdown
Contributor

stas00 commented Dec 15, 2020

OK, next we have this:

Traceback (most recent call last):
  File "./finetune_trainer.py", line 379, in <module>
    main()
  File "./finetune_trainer.py", line 315, in main
    trainer.train(
  File "/mnt/nvme1/code/huggingface/transformers-master/src/transformers/trainer.py", line 818, in train
    self.scaler.step(self.optimizer)
  File "/home/stas/anaconda3/envs/main-38/lib/python3.8/site-packages/torch/cuda/amp/grad_scaler.py", line 330, in step
    assert len(optimizer_state["found_inf_per_device"]) > 0, "No inf checks were recorded for this optimizer."
AssertionError: No inf checks were recorded for this optimizer.

Coincidentally I have just had the same issue with deepspeed integration when I enable its internal fp16 handling. Didn't get to the root of it yet, but removing --fp16 arg and thus disabling all the fp16 handling trainer does removed this error.

note: I'm switching to deepspeed fp16 handling there...

@sgugger
Copy link
Copy Markdown
Collaborator Author

sgugger commented Dec 15, 2020

Is it FP16 with AMP or with apex? I don't believe fairscale is compatible with apex.

@stas00
Copy link
Copy Markdown
Contributor

stas00 commented Dec 15, 2020

native amp

See the command line I'm testing with at:
#9139 (comment)

@stas00
Copy link
Copy Markdown
Contributor

stas00 commented Dec 16, 2020

If you're joining in and discovered you can't build fairscale, please see this and perhaps that.

other choices will force the requested backend.
sharded_ddp (:obj:`bool`, `optional`, defaults to :obj:`False`):
Use Sharded DDP training from `FairScale <https://github.com/facebookresearch/fairscale>`__ (in distributed
training only). This is an experimental feature.
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice!



if is_fairscale_available():
from fairscale.optim import OSS
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

OSS is a bit cryptic to me, but I think it's still better to use the "real" name instead of import OSS as OptimizerStateSharding -> so good for me!

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah I'm using the same convention they do too, to not surprise any user.

Copy link
Copy Markdown
Contributor

@patrickvonplaten patrickvonplaten left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Clean!

@sgugger sgugger merged commit 9a67185 into master Dec 16, 2020
@sgugger sgugger deleted the sharded_ddp branch December 16, 2020 18:47
@blefaudeux
Copy link
Copy Markdown

blefaudeux commented Dec 17, 2020

OK, next we have this:

Traceback (most recent call last):
  File "./finetune_trainer.py", line 379, in <module>
    main()
  File "./finetune_trainer.py", line 315, in main
    trainer.train(
  File "/mnt/nvme1/code/huggingface/transformers-master/src/transformers/trainer.py", line 818, in train
    self.scaler.step(self.optimizer)
  File "/home/stas/anaconda3/envs/main-38/lib/python3.8/site-packages/torch/cuda/amp/grad_scaler.py", line 330, in step
    assert len(optimizer_state["found_inf_per_device"]) > 0, "No inf checks were recorded for this optimizer."
AssertionError: No inf checks were recorded for this optimizer.

Coincidentally I have just had the same issue with deepspeed integration when I enable its internal fp16 handling. Didn't get to the root of it yet, but removing --fp16 arg and thus disabling all the fp16 handling trainer does removed this error.

note: I'm switching to deepspeed fp16 handling there...

hey there, a bit late, but one of the fairscale/shardedDDP author. The issue with Apex (and vanilla Torch) grad scaler is that it does not know about the gradient sharding, so not all the ranks will have the same behaviour. Torch AMP is supported though, you just have to pass in the ShardedGradScaler as defined here https://github.com/facebookresearch/fairscale/blob/master/fairscale/optim/grad_scaler.py#L24

@sgugger
Copy link
Copy Markdown
Collaborator Author

sgugger commented Dec 17, 2020

Yes, we're passing that scaler :-) The issue was with AMP not Apex. It looks like there is a problem with or without FP16 with one of models.
Ah reading more, I see there is a lot on the issue I posted so will look there. Thanks for coming helping us!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants