adding enable_gqa in SDPA#11097
Merged
wozeparrot merged 10 commits intotinygrad:masterfrom Jul 7, 2025
Merged
Conversation
geohot
reviewed
Jul 5, 2025
test/test_ops.py
Outdated
| repeat_factor = x.shape[1] // y.shape[1] | ||
| y_repeated = y.repeat_interleave(repeat_factor, dim=1) | ||
| z_repeated = z.repeat_interleave(repeat_factor, dim=1) | ||
| return torch.nn.functional.scaled_dot_product_attention(x, y_repeated, z_repeated) |
Collaborator
There was a problem hiding this comment.
Huh, why doesn't this test the torch flag?
Contributor
Author
There was a problem hiding this comment.
Huh, why doesn't this test the torch flag?
i think its not implemented yet: https://github.com/pytorch/pytorch/blob/main/torch/nn/functional.py#L5809
in experimental phase at least?
the current stable torch 2.2.2. doesnt have enable_gqa
Collaborator
There was a problem hiding this comment.
current stable torch is 2.7? and it does have enable_gqa?
Contributor
Author
geohot
reviewed
Jul 5, 2025
tinygrad/tensor.py
Outdated
|
|
||
| def scaled_dot_product_attention(self, key:Tensor, value:Tensor, attn_mask:Tensor|None=None, dropout_p:float=0.0, is_causal:bool=False) -> Tensor: | ||
| def scaled_dot_product_attention(self, key:Tensor, value:Tensor, attn_mask:Tensor|None=None, | ||
| dropout_p:float=0.0, is_causal:bool=False, enable_gqa:bool=False) -> Tensor: |
Contributor
Changes |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
adding
enable_gqain SDPA like in pytorch.