Skip to content

[New Feature Enhancement] Generalize DPO with different f-divergeneces #1259

@1485840691

Description

@1485840691

Why

There is a paper discussing generalizing DPO with different f-divergences (present implementation using log() is one of its kind: reverse kL w/ a = 0) to help model better balance alignment performance and generation diversity.

According to the paper,
"
Empirically, adopting these f-divergences ensures a balance
between alignment performance and generation diversity. Importantly, f-DPO
outperforms PPO-based methods in divergence efficiency, and divergence con
straints directly influence expected calibration error (ECE).
"

How

Would like to work out a PR to add these f-divergences besides the current supported (reverse KL)

image

image

Implementation should be straightforward.
A simple code update for illustration purposes
`

   rejected_rewards = policy_chosen_logps.to(self.accelerator.device) - reference_chosen_logps.to(self.accelerator.device)
   rejected_rewards_exp = torch.exp(rejected_rewards)

   chosen_rewards = policy_chosen_logps.to(self.accelerator.device) - reference_chosen_logps.to(self.accelerator.device)
    chosen_rewards_exp = torch.exp(chosen_rewards)

   # Forward KL
   logits = -1/chosen_rewards_exp - (-1 / rejected_rewards_exp)

   # Js-divergence
   logits = (chosen_rewards - torch.log(1 + chosen_rewards_exp)) - (rejected_rewards - torch.log(1 + rejected_rewards_exp))

   # alpha-divergence
   logits = (1 - chosen_rewards_exp ** (-self.alpha_div)) / self.alpha_div -  (1 - rejected_rewards_exp ** (-self.alpha_div)) / self.alpha_div `

Possible updates to existing class
class DPOTrainer(Trainer): def __init__( ... **f_divergence_kwargs: Optional[Dict] = None,** ... )

Any concerns, please let me know.

Metadata

Metadata

Assignees

No one assigned

    Labels

    Type

    No type
    No fields configured for issues without a type.

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions