Skip to content

[Refactor Attention mask handling] Moves attention mask processing to the Attention class#28132

Closed
ArthurZucker wants to merge 20 commits intomainfrom
refactor-attention-converesion
Closed

[Refactor Attention mask handling] Moves attention mask processing to the Attention class#28132
ArthurZucker wants to merge 20 commits intomainfrom
refactor-attention-converesion

Conversation

@ArthurZucker
Copy link
Copy Markdown
Collaborator

@ArthurZucker ArthurZucker commented Dec 19, 2023

What does this PR do?

This is more aligned with our philosophy, but also simplifies and will simplify things.
Will help a lot with the static cache.

The only way to share the mask is to call LlamaAttention but if you have a better way I'll update it!
This makes the attention class self contained, which is also pretty convenient for testing.
Ran the slow test without fa2 will run them again on dgx once approved.

cc @patrickvonplaten for visibility

@ArthurZucker ArthurZucker marked this pull request as ready for review December 20, 2023 10:18
Copy link
Copy Markdown
Contributor

@gante gante left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM (yay, fewer config-dependent if/elses 🙌 )

BTW, for retrocompatibility, we may want to check whether the attention masks are 4D before expanding in the attention classes. As we've learned with the Cache refactor, other repos might rely on the interface of these internal classes, and this is technically an interface change (4D attention mask input -> 2D attention mask input).

@ArthurZucker
Copy link
Copy Markdown
Collaborator Author

We added support for 4d attention mask inside the converter so should be alright but yeah will check related issues!

class LlamaAttention(nn.Module):
"""Multi-headed attention from 'Attention Is All You Need' paper"""

cached_mask = None
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

as seen with you offline, this cannot work as-is due to sharing across model instances

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes will either not cache it at the class level but instance level a tril. Or pass it as kwargs. Jax does not seem to care so should not be too bad

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why can't we pass the attention_mask just into the cache.update(...) function?

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'll check that as well, but that doesn't help for all type of attention we have which need a pre-processed mask, will work after the pre-processing tho

@patrickvonplaten
Copy link
Copy Markdown
Contributor

Can you specify how this helps with the static cache?

The static cache should also work with the attention_mask being passed at every forward call (it'll always have the same shape). I don't think it's a good idea to have the attention_mask be a class variable.

@ArthurZucker
Copy link
Copy Markdown
Collaborator Author

It will not be a class variable forgot to update but I'll follow what we do with jax.
This will help as the cache length is different from the number of tokens that are seen which you get when you are in the attention layer.

@ArthurZucker
Copy link
Copy Markdown
Collaborator Author

Can give more details but basically new cache + attention was not behaving properly. This is gonna be my priority this week anyway!

@github-actions
Copy link
Copy Markdown
Contributor

This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread.

Please note that issues that do not follow the contributing guidelines are likely to be ignored.

@github-actions github-actions bot closed this Feb 6, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants