My accelerate config
In which compute environment are you running? ([0] This machine, [1] AWS (Amazon SageMaker)): 0
Which type of machine are you using? ([0] No distributed training, [1] multi-CPU, [2] multi-GPU, [3] TPU): 0
How many different machines will you use (use more than 1 for multi-node training)? [1]: 1
Do you want to use DeepSpeed? [yes/NO]: yes
How many processes in total will you use? [1]: 1
Do you wish to use FP16 (mixed precision)? [yes/NO]: NO
Enviroment Info
Machine Info : V100 X 1
accerlerate version : 0.5.1
(semi-)reproducible code
model1 = torch.nn.Transformer()
model2 = torch.nn.Transformer()
opt = torch.optim.Adam(...)
loader = ...
model1, model2, opt, loader = accelerator.prepare(model1, model2, opt, loader)
Additional Explanation
Using DeepSpeed passing multiple models to prepare will fail, i.e. all the models will become the same as the last passed.
This is due to how _prepare_deepspeed handles the arguments, especially:
for obj in result:
if isinstance(obj, torch.nn.Module):
model = obj
elif isinstance(obj, (torch.optim.Optimizer, dict)):
optimizer = obj
In this way model will take only the last nn.Module object passed and therefore the engine object created later on will be wrong.
My accelerate config
In which compute environment are you running? ([0] This machine, [1] AWS (Amazon SageMaker)): 0
Which type of machine are you using? ([0] No distributed training, [1] multi-CPU, [2] multi-GPU, [3] TPU): 0
How many different machines will you use (use more than 1 for multi-node training)? [1]: 1
Do you want to use DeepSpeed? [yes/NO]: yes
How many processes in total will you use? [1]: 1
Do you wish to use FP16 (mixed precision)? [yes/NO]: NO
Enviroment Info
Machine Info : V100 X 1
accerlerate version : 0.5.1
(semi-)reproducible code
Additional Explanation
Using DeepSpeed passing multiple models to prepare will fail, i.e. all the models will become the same as the last passed.
This is due to how
_prepare_deepspeedhandles the arguments, especially:In this way
modelwill take only the lastnn.Moduleobject passed and therefore theengineobject created later on will be wrong.