Skip to content

Offload lp_grads and lp_params as well for deepspeed >= 0.16.5#947

Merged
hijkzzz merged 1 commit intoOpenRLHF:mainfrom
HollowMan6:deepspeed_offload
Apr 8, 2025
Merged

Offload lp_grads and lp_params as well for deepspeed >= 0.16.5#947
hijkzzz merged 1 commit intoOpenRLHF:mainfrom
HollowMan6:deepspeed_offload

Conversation

@HollowMan6
Copy link
Copy Markdown
Member

Those offload types are fixed in deepspeedai/DeepSpeed#7050, and they have already been released in 0.16.5.

Those offload types are fixed in
deepspeedai/DeepSpeed#7050, and they
have already been released in 0.16.5.

Signed-off-by: Hollow Man <hollowman@opensuse.org>
@hijkzzz
Copy link
Copy Markdown
Collaborator

hijkzzz commented Apr 7, 2025

One issue is that when calling update_weights of vllm, the actor's parameters must be on the GPU.

@HollowMan6
Copy link
Copy Markdown
Member Author

One issue is that when calling update_weights of vllm, the actor's parameters must be on the GPU.

That's interesting. I didn't experience any errors with this PR, maybe those offloads are not successful even when specified? I will double-check.

@HollowMan6
Copy link
Copy Markdown
Member Author

HollowMan6 commented Apr 7, 2025

@hijkzzz Took a closer look and we did have an additional GPU memory reduction with lp_grads and lp_params added (offload should be successful). The reason why it works when calling update weights of the vLLM now is that, since the model parameters are reconstructed via deepspeed.zero.GatheredParameters, the replicated tensor will still be stored in the original device (GPU).

Reference from the deepspeed code:

@hijkzzz
Copy link
Copy Markdown
Collaborator

hijkzzz commented Apr 8, 2025

Thank you for your discovery. I just thought of another issue — when enabling reward ref offload together with DeepSpeed sleep, it seems like we should also offload the parameters of ref /reward in this way.
Also, we could support allowing users to choose whether to offload just the optimizer or include the parameters as well. For example, when deepspeed_sleep_level = 1, we offload only the optimizer; when it's set to 2, we offload everything.

@hijkzzz hijkzzz merged commit b412819 into OpenRLHF:main Apr 8, 2025
1 check passed
@HollowMan6
Copy link
Copy Markdown
Member Author

When enabling reward ref offload together with DeepSpeed sleep, it seems like we should also offload the parameters of ref /reward in this way.

Just did another investigation and it looks like DeepSpeed hasn't added support for the dynamic offload in pure inference mode (when the optimizer is not specified).

When the optimizer is not specified, the optimizer will be type DeepSpeedZeRoOffload instead of DeepSpeedZeroOptimizer_Stage3. Reference: https://github.com/deepspeedai/DeepSpeed/blob/56005d2b256eb81a88cba0a1984375f9663a3110/deepspeed/runtime/engine.py#L1684-L1707

ray.exceptions.RayTaskError(AttributeError): ray::ReferenceModelRayActor.init_model_from_pretrained()
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "openrlhf/trainer/ray/launcher.py", line 134, in init_model_from_pretrained
    self.offload_states()
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "openrlhf/trainer/ray/launcher.py", line 106, in offload_states
    offload_deepspeed_states(model)
  File "openrlhf/utils/deepspeed/deepspeed_utils.py", line 143, in offload_deepspeed_states
    model.offload_states(
  File "deepspeed/runtime/engine.py", line 3904, in offload_states
    self.optimizer.offload_states(include=include, device=device, pin_memory=pin_memory, non_blocking=non_blocking)
    ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
AttributeError: 'DeepSpeedZeRoOffload' object has no attribute 'offload_states'

Also, we could support allowing users to choose whether to offload just the optimizer or include the parameters as well. For example, when deepspeed_sleep_level = 1, we offload only the optimizer; when it's set to 2, we offload everything.

It looks like parameters offload is disabled ac2f0a2, so I guess this is not needed now. If not, this should be straightforward to implement.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants