[GPT-NeoX] Add SDPA support#31031
Conversation
There was a problem hiding this comment.
Bigger Models needed to match in generation. Same as in Llama, not sure if that's an issue.
There was a problem hiding this comment.
Exchanged the old attention mask implementation here. Can revert but thought it kept things cleaner.
|
Ran [
{
"task_env": "padding_side=left, use_mask=True, batch_size=5, enable_kernels=True, torch_atol=0.01, torch_rtol=0.03, cuda=Yes",
"mean_of_mean_differences": 0.02271,
"std_of_mean_differences": 0.03539,
"min_of_mean_differences": 0.00616,
"max_of_mean_differences": 0.3262,
"total_fails": 87
},
{
"task_env": "padding_side=right, use_mask=True, batch_size=5, enable_kernels=True, torch_atol=0.01, torch_rtol=0.03, cuda=Yes",
"mean_of_mean_differences": 0.02072,
"std_of_mean_differences": 0.01623,
"min_of_mean_differences": 0.00693,
"max_of_mean_differences": 0.1079,
"total_fails": 86
},
{
"task_env": "padding_side=left, use_mask=False, batch_size=5, enable_kernels=True, torch_atol=0.01, torch_rtol=0.03, cuda=Yes",
"mean_of_mean_differences": 0.02172,
"std_of_mean_differences": 0.01623,
"min_of_mean_differences": 0.01025,
"max_of_mean_differences": 0.09912,
"total_fails": 80
},
{
"task_env": "padding_side=right, use_mask=False, batch_size=5, enable_kernels=True, torch_atol=0.01, torch_rtol=0.03, cuda=Yes",
"mean_of_mean_differences": 0.02172,
"std_of_mean_differences": 0.01623,
"min_of_mean_differences": 0.01025,
"max_of_mean_differences": 0.09912,
"total_fails": 80
},
{
"task_env": "padding_side=left, use_mask=False, batch_size=1, enable_kernels=True, torch_atol=0.01, torch_rtol=0.03, cuda=Yes",
"mean_of_mean_differences": 0.03034,
"std_of_mean_differences": 0.05975,
"min_of_mean_differences": 0.00812,
"max_of_mean_differences": 0.2949,
"total_fails": 21
},
{
"task_env": "padding_side=right, use_mask=False, batch_size=1, enable_kernels=True, torch_atol=0.01, torch_rtol=0.03, cuda=Yes",
"mean_of_mean_differences": 0.03034,
"std_of_mean_differences": 0.05975,
"min_of_mean_differences": 0.00812,
"max_of_mean_differences": 0.2949,
"total_fails": 21
},
{
"task_env": "padding_side=left, use_mask=False, batch_size=5, enable_kernels=False, torch_atol=0.01, torch_rtol=0.01, cuda=Maybe",
"mean_of_mean_differences": 0.00205,
"std_of_mean_differences": 0.0008,
"min_of_mean_differences": 0.00095,
"max_of_mean_differences": 0.00397,
"total_fails": 17
},
{
"task_env": "padding_side=right, use_mask=False, batch_size=5, enable_kernels=False, torch_atol=0.01, torch_rtol=0.01, cuda=Maybe",
"mean_of_mean_differences": 0.00205,
"std_of_mean_differences": 0.0008,
"min_of_mean_differences": 0.00095,
"max_of_mean_differences": 0.00397,
"total_fails": 17
},
{
"task_env": "padding_side=right, use_mask=True, batch_size=1, enable_kernels=True, torch_atol=0.01, torch_rtol=0.03, cuda=Yes",
"mean_of_mean_differences": 0.01624,
"std_of_mean_differences": 0.00679,
"min_of_mean_differences": 0.00653,
"max_of_mean_differences": 0.02954,
"total_fails": 17
},
{
"task_env": "padding_side=left, use_mask=True, batch_size=5, enable_kernels=False, torch_atol=0.01, torch_rtol=0.01, cuda=Maybe",
"mean_of_mean_differences": 0.0026,
"std_of_mean_differences": 0.0014,
"min_of_mean_differences": 0.00119,
"max_of_mean_differences": 0.00635,
"total_fails": 15
},
{
"task_env": "padding_side=right, use_mask=True, batch_size=5, enable_kernels=False, torch_atol=0.01, torch_rtol=0.01, cuda=Maybe",
"mean_of_mean_differences": 0.00245,
"std_of_mean_differences": 0.00108,
"min_of_mean_differences": 0.00119,
"max_of_mean_differences": 0.00473,
"total_fails": 15
},
{
"task_env": "padding_side=left, use_mask=True, batch_size=1, enable_kernels=True, torch_atol=0.01, torch_rtol=0.03, cuda=Yes",
"mean_of_mean_differences": 0.01135,
"std_of_mean_differences": 0.00861,
"min_of_mean_differences": 0.00629,
"max_of_mean_differences": 0.03516,
"total_fails": 9
},
{
"task_env": "padding_side=left, use_mask=False, batch_size=1, enable_kernels=False, torch_atol=0.01, torch_rtol=0.01, cuda=Maybe",
"mean_of_mean_differences": 0.00454,
"std_of_mean_differences": 0.00205,
"min_of_mean_differences": 0.002,
"max_of_mean_differences": 0.00702,
"total_fails": 3
},
{
"task_env": "padding_side=right, use_mask=False, batch_size=1, enable_kernels=False, torch_atol=0.01, torch_rtol=0.01, cuda=Maybe",
"mean_of_mean_differences": 0.00454,
"std_of_mean_differences": 0.00205,
"min_of_mean_differences": 0.002,
"max_of_mean_differences": 0.00702,
"total_fails": 3
},
{
"task_env": "padding_side=left, use_mask=True, batch_size=1, enable_kernels=False, torch_atol=0.01, torch_rtol=0.01, cuda=Maybe",
"mean_of_mean_differences": 0.002,
"std_of_mean_differences": 0.0,
"min_of_mean_differences": 0.002,
"max_of_mean_differences": 0.002,
"total_fails": 1
}
]The tests seem pretty flaky but |
|
@fxmarty Could you do a first review? |
There was a problem hiding this comment.
I've generalised the projections and rope application. A lot of stuff is being repeated among all the attn implementations. Can revert if needed.
There was a problem hiding this comment.
padding_mask is never used in any attn implementation but to keep things common between the different attns, I added it to the method signature.
There was a problem hiding this comment.
There was a problem hiding this comment.
I've removed it from the function signature 👍
|
The new commits are all cosmetic/refactoring stuff which doesn't affect the logic except for the additional |
|
Further removed a wrong artefact in the code copied from llama.
|
|
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update. |
added head mask check to sdpa mask creation handle sdpa memory backend bug via own version flag
fix flash_attn_2 stuff
amyeroberts
left a comment
There was a problem hiding this comment.
Great - thanks for all the work adding this!
All LGTM, would just like a second look from @ArthurZucker regarding the ROPE scaling logic
|
Before merge, we'll need to do a run on all the slow tests for the model. Pushing a commit with the message |
|
Should be committed with the msg. |
|
@vasqu Thanks for pushing and again for adding this capability. All looks good - let's merge 🤗 |
What does this PR do?
Adds torch's SDPA to the
GPT-NeoXmodel architecture (as another attention module). Possibly relevant #28005.Added benchmarks based on @fxmarty scripts @ training and inference. Setup: rtx3080ti-16GB, PyTorch 2.2.1, OS Ubuntu 22.04 using
float16with pythia-410m-deduped.Training results:
Inference results:
Remaining relevant issues:
RUN_SLOW=True pytest tests/models/gpt_neox -k "test_eager_matches_sdpa_inference" -s -vvvvvleads to failures onbf16--> not sure if it's an implementation issue on my side or because the models' rope is naturally stored in fp32? Help/Advice appreciated. Edit: See comment below for a summary of captured tests that fail.Before submitting
Pull Request section?
to it if that's the case.
documentation guidelines, and
here are tips on formatting docstrings.
Who can review?
Anyone in the community is free to review the PR once the tests have passed. Feel free to tag
members/contributors who may be interested in your PR.
@fxmarty @ArthurZucker @amyeroberts