Skip to content

Dynamo/AOTAutograd traceable flash attention#8654

Merged
zpcore merged 4 commits intomasterfrom
piz/autograde_trace
Feb 1, 2025
Merged

Dynamo/AOTAutograd traceable flash attention#8654
zpcore merged 4 commits intomasterfrom
piz/autograde_trace

Conversation

@zpcore
Copy link
Copy Markdown
Member

@zpcore zpcore commented Jan 30, 2025

Resolves #8633

@zpcore zpcore force-pushed the piz/autograde_trace branch from f82f373 to a8c8f47 Compare January 31, 2025 10:31
@zpcore zpcore changed the title backward with spmd issue Dynamo/AOTAutograd traceable flash attention Jan 31, 2025
@zpcore zpcore requested a review from tengyifei January 31, 2025 10:38
Copy link
Copy Markdown
Collaborator

@tengyifei tengyifei left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is really great

Comment thread test/test_pallas.py
Comment thread torch_xla/experimental/custom_kernel.py Outdated
Comment thread torch_xla/experimental/custom_kernel.py
Comment thread torch_xla/experimental/custom_kernel.py Outdated
Comment thread torch_xla/experimental/custom_kernel.py
Comment thread torch_xla/experimental/custom_kernel.py
Comment thread test/test_pallas_spmd.py Outdated
Comment thread test/test_pallas_spmd.py
Comment thread test/test_pallas.py Outdated
Comment thread test/test_pallas.py Outdated
@zpcore zpcore merged commit 9ae017e into master Feb 1, 2025
@zpcore zpcore deleted the piz/autograde_trace branch February 1, 2025 04:22
tengyifei added a commit to AI-Hypercomputer/torchprime that referenced this pull request Mar 17, 2025
We replace the `for` loop in both Llama and Mixtral with an equivalent
`HomogenousSequential` layer, which can be either run a for loop or use
`torch_xla`'s scan operator. This is a clean-ish way to turn scan on/off
without cluttering the modeling code.

I also adjusted Mixtral slightly so that we can even run `scan` in
Mixtral with its static MoE implementation. Scanning over GMM on the
other hand won't work until GMM forward/backward is wrapped in a custom
op similar to pytorch/xla#8654.

Test: added unit test. Next PR will change the trainer to apply scan.
tengyifei added a commit to AI-Hypercomputer/torchprime that referenced this pull request Mar 18, 2025
* Make models amenable to scan

We replace the `for` loop in both Llama and Mixtral with an equivalent
`HomogenousSequential` layer, which can be either run a for loop or use
`torch_xla`'s scan operator. This is a clean-ish way to turn scan on/off
without cluttering the modeling code.

I also adjusted Mixtral slightly so that we can even run `scan` in
Mixtral with its static MoE implementation. Scanning over GMM on the
other hand won't work until GMM forward/backward is wrapped in a custom
op similar to pytorch/xla#8654.

Test: added unit test. Next PR will change the trainer to apply scan.

* Address comments
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

Make flash_attention Dynamo/AOTAutograd traceable

2 participants