Enabling Transformer fast path for not batch_first (MHA, TE, TEL)#106668
Enabling Transformer fast path for not batch_first (MHA, TE, TEL)#106668mikekgfb wants to merge 2 commits intopytorch:mainfrom
Conversation
|
This pull request was exported from Phabricator. Differential Revision: D48095703 |
1 similar comment
|
This pull request was exported from Phabricator. Differential Revision: D48095703 |
63203ec to
3a59a69
Compare
3a59a69 to
3719ea0
Compare
|
This pull request was exported from Phabricator. Differential Revision: D48095703 |
|
@wconstab any suggestion how to handle the FSDP fail? I think (i.e., conjecture) the underlying cause is a numerical difference between fastpath vs. standard execution -- I think this is becawe avoided running into this so far because this test ran with batch_first=False, and until this diff we used the fastpath with batch_first=True only. (fastpath is only on for inference with no_grad(), so eval() vs eval+no_grad() exercises two different paths) |
|
missed this comment before, but it sounds like you've found out that its an issue of eval vs train accuracy for this model, is that right? |
3719ea0 to
79fee71
Compare
|
This pull request was exported from Phabricator. Differential Revision: D48095703 |
The underlying issue is that model.eval() vs model.eval() with no_grad() triggers different computational kernels. The only way to control this in a sane way is via a backend context manager for choosing between these backends. I added this in #107014 |
79fee71 to
d8ee9bc
Compare
|
This pull request was exported from Phabricator. Differential Revision: D48095703 |
Yep - the way that "validation" is performed is that the test runs with train mode, and with eval/no_grad -- the latter triggers fastpath with the present bathc, but because we;re looking at different implementations, we can't expect bit-exact answers. One possible solution is a context manager that gives more control over the kernel chosen, similar for what we do for SDPA - #107163 is an implementation of this context manager for backend selection |
d8ee9bc to
c7ecb35
Compare
|
This pull request was exported from Phabricator. Differential Revision: D48095703 |
c7ecb35 to
21ea697
Compare
|
This pull request was exported from Phabricator. Differential Revision: D48095703 |
|
Looks like this PR hasn't been updated in a while so we're going to go ahead and mark this as |
21ea697 to
ddd757f
Compare
|
This pull request was exported from Phabricator. Differential Revision: D48095703 |
1 similar comment
|
This pull request was exported from Phabricator. Differential Revision: D48095703 |
ddd757f to
abdabaf
Compare
…nd manager (pytorch#107163) Summary: Create fastpath backend context manager, similar to SDPA kernel backend manager Test Plan: sandcastle, github Differential Revision: D48325593
…torch#106668) Summary: The fast path for the `forward()` method in `MultiheadAttention`, `TE`, `TEL` only accepted `batch_first = True`. This diff enables fast path for `batch_first=False` as well. Test Plan: sandcastle, github CI/CD Differential Revision: D48095703
abdabaf to
05eb0fe
Compare
|
This pull request was exported from Phabricator. Differential Revision: D48095703 |
Summary: The fast path for the
forward()method inMultiheadAttention,TE,TELonly acceptedbatch_first = True. This diff enables fast path forbatch_first=Falseas well.Differential Revision: D48095703
cc @mrshenli @pritamdamania87 @zhaojuanmao @satgera @rohan-varma @gqchen @aazzolini @osalpekar @jiayisuse @H-Huang @kwen2501 @awgu @penguinwu @fegin @XilunWu @wanchaol @fduwjj @wz337 @tianyu-l @wconstab @yf225 @kiukchung @d4l3k @LucasLLC