Skip to content

Speed up ZeRO-3 generation with DPO #1543

@sngdng

Description

@sngdng

Hi, a recent PR brought large improvements (x10) to PPO generation with ZeRO-3.
@lewtun, you mention on the PR that it can be adapted for other trainers. I gave it a quick shot and it seems that naive applying the context manager to trainers like DPO does not work:

in remove_hooks
    if model.optimizer is not None and hasattr(
       ^^^^^^^^^^^^^^^^^^^^
    raise AttributeError(f"'{type(self).__name__}' object has no attribute '{name}'")
AttributeError: 'GPTNeoXForCausalLM' object has no attribute 'optimizer'

There seems to be an inconsistency between the base classes. Is there a reason why DPO is based on Trainer from transformers and PPO on BaseTrainer ? What would be the easy way to add this feature to other trainers ? Thanks !

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type
    No fields configured for issues without a type.

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions