Skip to content

Pix2Struct -- mismatched output of cross attention weights #25175

@leitro

Description

@leitro

System Info

Hi huggingface team!

The output of cross attention weights is mismatched as shown in https://github.com/huggingface/transformers/blob/05cda5df3405e6a2ee4ecf8f7e1b2300ebda472e/src/transformers/models/pix2struct/modeling_pix2struct.py#L1551C22-L1551C22.

In the code: all_cross_attentions = all_cross_attentions + (layer_outputs[3],)

where layer_outputs[3] is still the self attention weights, the REAL cross attention weights should be layer_outputs[5].

Please correct me if I made some mistakes. Looking forward to the updated version. Thank you! @amyeroberts @ArthurZucker @younesbelkada

Who can help?

No response

Information

  • The official example scripts
  • My own modified scripts

Tasks

  • An officially supported task in the examples folder (such as GLUE/SQuAD, ...)
  • My own task or dataset (give details below)

Reproduction

Hightlight of the training code:

model = Pix2StructForConditionalGeneration.from_pretrained('google/pix2struct-docvqa-base')
outputs = model.forward(**inputs, labels=labels, output_attentions=True)

Turn on the attention output button by output_attentions=True, and then get the cross attention weights by outputs.cross_attentions where the bug exists.

Expected behavior

Change the index from 3 to 5 for selecting the correct cross attention weights, then everything's done hopefully.

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions