Skip to content

Enabling Transformer fast path for not batch_first (MHA, TE, TEL)#106668

Closed
mikekgfb wants to merge 2 commits intopytorch:mainfrom
mikekgfb:export-D48095703
Closed

Enabling Transformer fast path for not batch_first (MHA, TE, TEL)#106668
mikekgfb wants to merge 2 commits intopytorch:mainfrom
mikekgfb:export-D48095703

Conversation

@mikekgfb
Copy link
Contributor

@mikekgfb mikekgfb commented Aug 5, 2023

Summary: The fast path for the forward() method in MultiheadAttention, TE, TEL only accepted batch_first = True. This diff enables fast path for batch_first=False as well.

Differential Revision: D48095703

cc @mrshenli @pritamdamania87 @zhaojuanmao @satgera @rohan-varma @gqchen @aazzolini @osalpekar @jiayisuse @H-Huang @kwen2501 @awgu @penguinwu @fegin @XilunWu @wanchaol @fduwjj @wz337 @tianyu-l @wconstab @yf225 @kiukchung @d4l3k @LucasLLC

@pytorch-bot
Copy link

pytorch-bot bot commented Aug 5, 2023

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/106668

Note: Links to docs will display an error until the docs builds have been completed.

❌ 36 New Failures

As of commit 05eb0fe with merge base 81adbb6 (image):

NEW FAILURES - The following jobs have failed:

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@facebook-github-bot
Copy link
Contributor

This pull request was exported from Phabricator. Differential Revision: D48095703

1 similar comment
@facebook-github-bot
Copy link
Contributor

This pull request was exported from Phabricator. Differential Revision: D48095703

@facebook-github-bot
Copy link
Contributor

This pull request was exported from Phabricator. Differential Revision: D48095703

@mikekgfb
Copy link
Contributor Author

mikekgfb commented Aug 6, 2023

@wconstab any suggestion how to handle the FSDP fail? I think (i.e., conjecture) the underlying cause is a numerical difference between fastpath vs. standard execution -- I think this is becawe avoided running into this so far because this test ran with batch_first=False, and until this diff we used the fastpath with batch_first=True only. (fastpath is only on for inference with no_grad(), so eval() vs eval+no_grad() exercises two different paths)

cc: @rohan-varma @awgu @mrshenli @drisspg

@wconstab
Copy link
Contributor

wconstab commented Aug 8, 2023

missed this comment before, but it sounds like you've found out that its an issue of eval vs train accuracy for this model, is that right?

@facebook-github-bot
Copy link
Contributor

This pull request was exported from Phabricator. Differential Revision: D48095703

@mikekgfb
Copy link
Contributor Author

missed this comment before, but it sounds like you've found out that its an issue of eval vs train accuracy for this model, is that right?

The underlying issue is that model.eval() vs model.eval() with no_grad() triggers different computational kernels. The only way to control this in a sane way is via a backend context manager for choosing between these backends. I added this in #107014

@facebook-github-bot
Copy link
Contributor

This pull request was exported from Phabricator. Differential Revision: D48095703

@mikekgfb mikekgfb added enhancement Not as big of a feature, but technically not a bug. Should be easy to fix release notes: nn release notes category labels Sep 8, 2023
@mikekgfb
Copy link
Contributor Author

mikekgfb commented Sep 8, 2023

missed this comment before, but it sounds like you've found out that its an issue of eval vs train accuracy for this model, is that right?

Yep - the way that "validation" is performed is that the test runs with train mode, and with eval/no_grad -- the latter triggers fastpath with the present bathc, but because we;re looking at different implementations, we can't expect bit-exact answers.

One possible solution is a context manager that gives more control over the kernel chosen, similar for what we do for SDPA - #107163 is an implementation of this context manager for backend selection

@facebook-github-bot
Copy link
Contributor

This pull request was exported from Phabricator. Differential Revision: D48095703

@facebook-github-bot
Copy link
Contributor

This pull request was exported from Phabricator. Differential Revision: D48095703

@github-actions
Copy link
Contributor

github-actions bot commented Nov 8, 2023

Looks like this PR hasn't been updated in a while so we're going to go ahead and mark this as Stale.
Feel free to remove the Stale label if you feel this was a mistake.
If you are unable to remove the Stale label please contact a maintainer in order to do so.
If you want the bot to never mark this PR stale again, add the no-stale label.
Stale pull requests will automatically be closed after 30 days of inactivity.

@facebook-github-bot
Copy link
Contributor

This pull request was exported from Phabricator. Differential Revision: D48095703

1 similar comment
@facebook-github-bot
Copy link
Contributor

This pull request was exported from Phabricator. Differential Revision: D48095703

Michael Gschwind added 2 commits November 30, 2023 12:18
…nd manager (pytorch#107163)

Summary:

Create fastpath backend context manager, similar to SDPA kernel backend manager

Test Plan: sandcastle, github

Differential Revision: D48325593
…torch#106668)

Summary:

The fast path for the `forward()` method in `MultiheadAttention`, `TE`, `TEL` only accepted `batch_first = True`. This diff enables fast path for `batch_first=False` as well.

Test Plan: sandcastle, github CI/CD

Differential Revision: D48095703
@facebook-github-bot
Copy link
Contributor

This pull request was exported from Phabricator. Differential Revision: D48095703

@albanD albanD added oncall: distributed Add this issue/PR to distributed oncall triage queue and removed module: distributed labels Dec 8, 2023
@github-actions github-actions bot closed this Jan 7, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

enhancement Not as big of a feature, but technically not a bug. Should be easy to fix fb-exported oncall: distributed Add this issue/PR to distributed oncall triage queue release notes: nn release notes category Stale

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants