Skip to content

adding enable_gqa in SDPA#11097

Merged
wozeparrot merged 10 commits intotinygrad:masterfrom
NinoRisteski:enable_gqa
Jul 7, 2025
Merged

adding enable_gqa in SDPA#11097
wozeparrot merged 10 commits intotinygrad:masterfrom
NinoRisteski:enable_gqa

Conversation

@NinoRisteski
Copy link
Contributor

adding enable_gqa in SDPA like in pytorch.

test/test_ops.py Outdated
repeat_factor = x.shape[1] // y.shape[1]
y_repeated = y.repeat_interleave(repeat_factor, dim=1)
z_repeated = z.repeat_interleave(repeat_factor, dim=1)
return torch.nn.functional.scaled_dot_product_attention(x, y_repeated, z_repeated)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Huh, why doesn't this test the torch flag?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Huh, why doesn't this test the torch flag?

i think its not implemented yet: https://github.com/pytorch/pytorch/blob/main/torch/nn/functional.py#L5809
in experimental phase at least?

the current stable torch 2.2.2. doesnt have enable_gqa

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

current stable torch is 2.7? and it does have enable_gqa?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

current stable torch is 2.7? and it does have enable_gqa?

Thx! I had trouble installing 2.7, but managed 2.6 and it has enable_gqa.
I updated the test -> 0989de6 @geohot


def scaled_dot_product_attention(self, key:Tensor, value:Tensor, attn_mask:Tensor|None=None, dropout_p:float=0.0, is_causal:bool=False) -> Tensor:
def scaled_dot_product_attention(self, key:Tensor, value:Tensor, attn_mask:Tensor|None=None,
dropout_p:float=0.0, is_causal:bool=False, enable_gqa:bool=False) -> Tensor:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Weird spacing

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@github-actions
Copy link
Contributor

github-actions bot commented Jul 6, 2025

Changes

Name                  Lines    Diff    Tokens/Line    Diff
------------------  -------  ------  -------------  ------
tinygrad/tensor.py     1444      +4           20.7    -0.0


total lines changes: +4

@wozeparrot wozeparrot merged commit a1a146a into tinygrad:master Jul 7, 2025
38 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants