System Info
PyTorch 2.2.1
DeepSpeed 0.13.4
Information
Tasks
Reproduction
Abstract
We are considering supporting multiple models with DeepSpeed when using Accelerate. We will be using the term model and DeepSpeed engine interchangeably.
Motivation and Background
Currently, when using Accelerate's integration of DeepSpeed, only a single model is supported. This limits the use cases such as RLHF, GANs, Knowledge Distillation etc which involve multiple models. We also have interest in this feature as per the below feature requests:
- https://huggingface.slack.com/archives/C06CEE9C1M4/p1706821695816299
- Passing multiple models with DeepSpeed will fail · Issue #253 · huggingface/accelerate (github.com)
- Passing multiple models with DeepSpeed will fail · Issue #1388 · huggingface/accelerate (github.com)
The reasons for restricting to only a single model support is given below:
- The user can only provide a single DeepSpeed config plugin/DeepSpeed config file corresponding to a single model. Ideally, the user would have different DeepSpeed configs for different models.
- DeepSpeed needs to keep track of the model, its optimizer and scheduler. Therefore, we currently have only one global DeepSpeed engine wrapper to control the backward and optimizer/scheduler step.
Proposal
The aim would be to solve the 2 challenges above. This would need:
- Support multiple DeepSpeed configurations. I believe the questionnaire with minimal DeepSpeed plugin shouldn’t be changed and should continue to support a single model. This should act like the default config to be used by all the models.
- Flexibility to use different deepspeed config should be part of
prepare method. For example, given 4 models in rlhf scenario, I should be able to do the below:
# rlhf
...
model_1 = actor_model()
model_2 = critic_model()
model_3 = reference_model()
model_4 = reward_model()
optimizer_1 = torch.optim.AdamW(model_1.parameters(), lr=lr_1)
optimizer_2 = torch.optim.AdamW(model_2.parameters(), lr=lr_2)
scheduler_1 = get_scheduler("cosine_with_warmup", optimizer_1, warmup_steps=w_1, total_steps=n_1)
scheduler_2 = get_scheduler("cosine_with_warmup", optimizer_2, warmup_steps=w_2, total_steps=n_2)
model_1, optimizer_1, scheduler_1 = accelerator.prepare(model_1, optimizer_1, scheduler_1) # uses the default DeepSpeed config passed via Accelerate config
model_2, optimizer_2, scheduler_2 = accelerator.prepare(model_2, optimizer_2, scheduler_2, deepspeed_config="path_or_dict_to_deepspeed_config_json")
model_3 = accelerator.prepare(model_3, deepspeed_config="path_or_dict_to_deepspeed_config_json")
model_4 = accelerator.prepare(model_4, deepspeed_config="path_or_dict_to_deepspeed_config_json")
for batch in train_dataloader:
prompts = batch["prompts"]
generations = model_1.generate(prompts) #outputs prompts+answers
log_probs = model_1(generations)
ref_log_probs = model_3(generations)
reward_scores = model_4(generations)
values = model_2(generations)
for ppo_step in range(ppo_steps):
old_rewards = compute_rewards(prompts, log_probs,ref_log_probs, reward_score) # reward-kl_divergence
batch = {'input_ids': seq, "attention_mask": attention_mask}
advantages, returns = get_advantages_and_returns(values, old_rewards)
new_log_probs = model_1(**batch, use_cache=False).logits
model_1_loss = compute_actor_loss(new_log_probs, log_probs, advantages)
accelerator.backward(model_1_loss) # challenge - need to know which deepspeed engine to use
optimizer_1.step()
scheduler_1.step()
optimizer_1.zero_grad()
new_value = model_2(**batch)
model_2_loss = critic_loss_fn(new_value, values, returns)
accelerator.backward(model_2_loss) # challenge - need to know which deepspeed engine to use
optimizer_2.step()
scheduler_2.step()
optimizer_2.zero_grad()
...
Challenges for which user would need to do extra work:
- Now, the issue here happens when
accelerator.backward(model_1_loss) or accelerator.backward(model_2_loss) is called. Behind the scenes, currently self.deepspeed_engine_wrapped.backward(loss, **kwargs) is called as we currently support only 1 DeepSpeed engine. Now, if we have multiple DeepSpeed engines, how do we know which deepspeed engine’s backward to call? Should a kwarg such as accelerator.backward(model_1_loss, model=model_1) be passed and internally have a mapping between the model and the respective DeepSpeed engine? However, passing such a kwarg deviates from the minimal API of Accelerate.
- How do we handle zero_init if, for example, 2 models are using ZeRO-3 while remaining 2 are using ZeRO-2? If the default DeepSpeed config passed by user is ZeRO-3 with
zero_init=True, the user is then tasked with disabling it when loading the models which use ZeRO-2 via with zero3_init_context_manager(enabled=False) context manager.
Compatibility
This feature needs to be backwards compatible with Accelerate as well as Trainer. The Trainer API will have no changes.
Alternatives Considered
- At present, if only a single model needs to be trained while the remaining models are only used for inference and are smaller models which can fit in GPU memory, then the user can simply avoid passing it to
accelerator.prepare() method.
- They can manually use the DeepSpeed API directly to create DeepSpeed engines for the remaining models as one in the TRl library when using DPO algo for using Stage 3 for frozen reference model to share it across GPUs.
- Creating a super-model encapsulating the different models in a single class.
Dependencies
- PyTorch
- DeepSpeed
Expected behavior
Enabling usecases involving multiple models with Accelerate's DeepSpeed integration.
System Info
Information
Tasks
no_trainerscript in theexamplesfolder of thetransformersrepo (such asrun_no_trainer_glue.py)Reproduction
Abstract
We are considering supporting multiple models with DeepSpeed when using Accelerate. We will be using the term model and DeepSpeed engine interchangeably.
Motivation and Background
Currently, when using Accelerate's integration of DeepSpeed, only a single model is supported. This limits the use cases such as RLHF, GANs, Knowledge Distillation etc which involve multiple models. We also have interest in this feature as per the below feature requests:
The reasons for restricting to only a single model support is given below:
Proposal
The aim would be to solve the 2 challenges above. This would need:
preparemethod. For example, given 4 models in rlhf scenario, I should be able to do the below:Challenges for which user would need to do extra work:
accelerator.backward(model_1_loss)oraccelerator.backward(model_2_loss)is called. Behind the scenes, currentlyself.deepspeed_engine_wrapped.backward(loss, **kwargs)is called as we currently support only 1 DeepSpeed engine. Now, if we have multiple DeepSpeed engines, how do we know which deepspeed engine’s backward to call? Should akwargsuch asaccelerator.backward(model_1_loss, model=model_1)be passed and internally have a mapping between the model and the respective DeepSpeed engine? However, passing such akwargdeviates from the minimal API of Accelerate.zero_init=True, the user is then tasked with disabling it when loading the models which use ZeRO-2 viawith zero3_init_context_manager(enabled=False)context manager.Compatibility
This feature needs to be backwards compatible with Accelerate as well as Trainer. The Trainer API will have no changes.
Alternatives Considered
accelerator.prepare()method.Dependencies
Expected behavior
Enabling usecases involving multiple models with Accelerate's DeepSpeed integration.