Skip to content

Add support for Pix2Struct #796

@NielsRogge

Description

@NielsRogge

Feature request

I wanted to apply PEFT to Pix2Struct, however since the model expects flattened_patches rather than input_ids, I'm getting the following error:

[/usr/local/lib/python3.10/dist-packages/peft/peft_model.py](https://localhost:8080/#) in forward(self, input_ids, attention_mask, inputs_embeds, decoder_input_ids, decoder_attention_mask, decoder_inputs_embeds, labels, output_attentions, output_hidden_states, return_dict, **kwargs)
   1078         peft_config = self.active_peft_config
   1079         if not isinstance(peft_config, PromptLearningConfig):
-> 1080             return self.base_model(
   1081                 input_ids=input_ids,
   1082                 attention_mask=attention_mask,

[/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py](https://localhost:8080/#) in _call_impl(self, *args, **kwargs)
   1499                 or _global_backward_pre_hooks or _global_backward_hooks
   1500                 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1501             return forward_call(*args, **kwargs)
   1502         # Do not call functions when jit is used
   1503         full_backward_hooks, non_full_backward_hooks = [], []

TypeError: Pix2StructForConditionalGeneration.forward() got an unexpected keyword argument 'input_ids'

The flattened_patches is a pretty unique input name, not sure any other model will ever get that name as input. cc @younesbelkada

A notebook to reproduce is here.

Motivation

Pix2Struct is a pretty heavy model, hence leveraging LoRa/QLoRa instead of full fine-tuning would greatly benefit the community.

Your contribution

Not sure I can help here

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions