Skip to content

Update non_xla attention to properly support paged_attention dynamo code path#7022

Merged
wonjoo-wj merged 2 commits intomasterfrom
wonjoo/paged-attention/dynamo
May 3, 2024
Merged

Update non_xla attention to properly support paged_attention dynamo code path#7022
wonjoo-wj merged 2 commits intomasterfrom
wonjoo/paged-attention/dynamo

Conversation

@wonjoo-wj
Copy link
Copy Markdown
Collaborator

@wonjoo-wj wonjoo-wj commented May 2, 2024

  • Update non_xla attention to properly support paged_attention dynamo code path
  • Fix the original broken dynamo unit tests with paged_attention

Test plan:

root@1fdc3324aeef:/pytorch/xla# python test/test_pallas.py PallasTest.test_paged_attention_wrapper_with_dynamo
.
----------------------------------------------------------------------
Ran 1 test in 1.798s

OK

+ TPU CI

@wonjoo-wj wonjoo-wj requested review from JackCaoG and alanwaketan May 2, 2024 23:29
@JackCaoG JackCaoG added the tpuci label May 2, 2024
Copy link
Copy Markdown
Collaborator

@alanwaketan alanwaketan left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM.

attn_output = attn_weight @ v
return attn_output
# Return orignal shape of q.
return torch.empty_like(q)
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do you know if this actually initialize anything? I hope not.

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yea it is worth checking out, in above we have a warning about the q should be on meta device, I think running ops on meta_tensor will not allocate any device memory.

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

According to https://pytorch.org/docs/stable/generated/torch.empty_like.html, it seems like emtpy_like returns an uninitialized tensor. Along with the meta tensor check above, I think we should be good.

@wonjoo-wj
Copy link
Copy Markdown
Collaborator Author

wonjoo-wj commented May 3, 2024

Thanks for the reviews, I'll go ahead and merge this as the CIs (including TPU CI) are all green.

@wonjoo-wj wonjoo-wj merged commit 2bce3f8 into master May 3, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants