Skip to content

error when using PPO in Gemma #1663

@mostafamdy

Description

@mostafamdy

System Info

Hi,
I tried using ppo with gemma model but I get this error
I think the issue is here is_encoder_decoder

---------------------------------------------------------------------------
TypeError                                 Traceback (most recent call last)
Cell In[26], line 68
     66 print(response_tensors)
     67 #### Run PPO step
---> 68 stats = ppo_trainer.step(query_tensors, response_tensors, rewards)
     69 ppo_trainer.log_stats(stats, batch, rewards)
     70 break

File /opt/conda/lib/python3.10/contextlib.py:79, in ContextDecorator.__call__.<locals>.inner(*args, **kwds)
     76 @wraps(func)
     77 def inner(*args, **kwds):
     78     with self._recreate_cm():
---> 79         return func(*args, **kwds)

File /opt/conda/lib/python3.10/site-packages/trl/trainer/ppo_trainer.py:721, in PPOTrainer.step(self, queries, responses, scores, response_masks)
    718 full_kl_penalty = self.config.kl_penalty == "full"
    720 with torch.no_grad():
--> 721     all_logprobs, logits_or_none, values, masks = self.batched_forward_pass(
    722         self.model,
    723         queries,
    724         responses,
    725         model_inputs,
    726         response_masks=response_masks,
    727         return_logits=full_kl_penalty,
    728     )
    729     with self.optional_peft_ctx():
    730         ref_logprobs, ref_logits_or_none, _, _ = self.batched_forward_pass(
    731             self.model if self.is_peft_model else self.ref_model,
    732             queries,
   (...)
    735             return_logits=full_kl_penalty,
    736         )

File /opt/conda/lib/python3.10/contextlib.py:79, in ContextDecorator.__call__.<locals>.inner(*args, **kwds)
     76 @wraps(func)
     77 def inner(*args, **kwds):
     78     with self._recreate_cm():
---> 79         return func(*args, **kwds)

File /opt/conda/lib/python3.10/site-packages/trl/trainer/ppo_trainer.py:994, in PPOTrainer.batched_forward_pass(self, model, queries, responses, model_inputs, return_logits, response_masks)
    992 if response_masks is not None:
    993     response_masks_batch = response_masks[i * fbs : (i + 1) * fbs]
--> 994 logits, _, values = model(**input_kwargs)
    996 if self.is_encoder_decoder:
    997     input_ids = input_kwargs["decoder_input_ids"]

File /opt/conda/lib/python3.10/site-packages/torch/nn/modules/module.py:1518, in Module._wrapped_call_impl(self, *args, **kwargs)
   1516     return self._compiled_call_impl(*args, **kwargs)  # type: ignore[misc]
   1517 else:
-> 1518     return self._call_impl(*args, **kwargs)

File /opt/conda/lib/python3.10/site-packages/torch/nn/modules/module.py:1568, in Module._call_impl(self, *args, **kwargs)
   1565     bw_hook = hooks.BackwardHook(self, full_backward_hooks, backward_pre_hooks)
   1566     args = bw_hook.setup_input_hook(args)
-> 1568 result = forward_call(*args, **kwargs)
   1569 if _global_forward_hooks or self._forward_hooks:
   1570     for hook_id, hook in (
   1571         *_global_forward_hooks.items(),
   1572         *self._forward_hooks.items(),
   1573     ):
   1574         # mark that always called hook is run

File /opt/conda/lib/python3.10/site-packages/trl/models/modeling_value_head.py:171, in AutoModelForCausalLMWithValueHead.forward(self, input_ids, past_key_values, attention_mask, **kwargs)
    168 if self.is_peft_model and self.pretrained_model.active_peft_config.peft_type == "PREFIX_TUNING":
    169     kwargs.pop("past_key_values")
--> 171 base_model_output = self.pretrained_model(
    172     input_ids=input_ids,
    173     attention_mask=attention_mask,
    174     **kwargs,
    175 )
    177 last_hidden_state = base_model_output.hidden_states[-1]
    178 lm_logits = base_model_output.logits

File /opt/conda/lib/python3.10/site-packages/torch/nn/modules/module.py:1518, in Module._wrapped_call_impl(self, *args, **kwargs)
   1516     return self._compiled_call_impl(*args, **kwargs)  # type: ignore[misc]
   1517 else:
-> 1518     return self._call_impl(*args, **kwargs)

File /opt/conda/lib/python3.10/site-packages/torch/nn/modules/module.py:1527, in Module._call_impl(self, *args, **kwargs)
   1522 # If we don't have any hooks, we want to skip the rest of the logic in
   1523 # this function, and just call forward.
   1524 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
   1525         or _global_backward_pre_hooks or _global_backward_hooks
   1526         or _global_forward_hooks or _global_forward_pre_hooks):
-> 1527     return forward_call(*args, **kwargs)
   1529 try:
   1530     result = None

File /opt/conda/lib/python3.10/site-packages/peft/peft_model.py:1326, in PeftModelForSeq2SeqLM.forward(self, input_ids, attention_mask, inputs_embeds, decoder_input_ids, decoder_attention_mask, decoder_inputs_embeds, labels, output_attentions, output_hidden_states, return_dict, task_ids, **kwargs)
   1324     with self._enable_peft_forward_hooks(**kwargs):
   1325         kwargs = {k: v for k, v in kwargs.items() if k not in self.special_peft_forward_args}
-> 1326         return self.base_model(
   1327             input_ids=input_ids,
   1328             attention_mask=attention_mask,
   1329             inputs_embeds=inputs_embeds,
   1330             decoder_input_ids=decoder_input_ids,
   1331             decoder_attention_mask=decoder_attention_mask,
   1332             decoder_inputs_embeds=decoder_inputs_embeds,
   1333             labels=labels,
   1334             output_attentions=output_attentions,
   1335             output_hidden_states=output_hidden_states,
   1336             return_dict=return_dict,
   1337             **kwargs,
   1338         )
   1340 batch_size = _get_batch_size(input_ids, inputs_embeds)
   1341 if decoder_attention_mask is not None:
   1342     # concat prompt attention mask

File /opt/conda/lib/python3.10/site-packages/torch/nn/modules/module.py:1518, in Module._wrapped_call_impl(self, *args, **kwargs)
   1516     return self._compiled_call_impl(*args, **kwargs)  # type: ignore[misc]
   1517 else:
-> 1518     return self._call_impl(*args, **kwargs)

File /opt/conda/lib/python3.10/site-packages/torch/nn/modules/module.py:1527, in Module._call_impl(self, *args, **kwargs)
   1522 # If we don't have any hooks, we want to skip the rest of the logic in
   1523 # this function, and just call forward.
   1524 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
   1525         or _global_backward_pre_hooks or _global_backward_hooks
   1526         or _global_forward_hooks or _global_forward_pre_hooks):
-> 1527     return forward_call(*args, **kwargs)
   1529 try:
   1530     result = None

File /opt/conda/lib/python3.10/site-packages/peft/tuners/tuners_utils.py:161, in BaseTuner.forward(self, *args, **kwargs)
    160 def forward(self, *args: Any, **kwargs: Any):
--> 161     return self.model.forward(*args, **kwargs)

File /opt/conda/lib/python3.10/site-packages/accelerate/hooks.py:166, in add_hook_to_module.<locals>.new_forward(module, *args, **kwargs)
    164         output = module._old_forward(*args, **kwargs)
    165 else:
--> 166     output = module._old_forward(*args, **kwargs)
    167 return module._hf_hook.post_forward(module, output)

TypeError: GemmaForCausalLM.forward() got an unexpected keyword argument 'decoder_input_ids'

Who can help?

No response

Information

  • The official example scripts
  • My own modified scripts

Tasks

  • An officially supported task in the examples folder (such as GLUE/SQuAD, ...)
  • My own task or dataset (give details below)

Reproduction

.

Expected behavior

.

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type
    No fields configured for issues without a type.

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions