Skip to content
This repository was archived by the owner on Mar 3, 2026. It is now read-only.

Make models amenable to scan#157

Merged
tengyifei merged 2 commits intomainfrom
yifeit/scan-1
Mar 18, 2025
Merged

Make models amenable to scan#157
tengyifei merged 2 commits intomainfrom
yifeit/scan-1

Conversation

@tengyifei
Copy link
Copy Markdown
Contributor

@tengyifei tengyifei commented Mar 17, 2025

We replace the for loop in both Llama and Mixtral with an equivalent HomogenousSequential layer, which can be either run a for loop or use torch_xla's scan operator. This is a clean-ish way to turn scan on/off without cluttering the modeling code.

I also adjusted Mixtral slightly so that we can even run scan in Mixtral with its static MoE implementation. In order to integrate with scan, we need to refactor the Mixtral decoder for loop into a format where results from the previous iteration feed into the next iteration. Scanning over GMM on the other hand won't work until GMM forward/backward is wrapped in a custom op similar to pytorch/xla#8654.

Cleanup the README that got jumbled in #111 while I'm here.

Test: added unit test. Next PR will change the trainer to apply scan.

We replace the `for` loop in both Llama and Mixtral with an equivalent
`HomogenousSequential` layer, which can be either run a for loop or use
`torch_xla`'s scan operator. This is a clean-ish way to turn scan on/off
without cluttering the modeling code.

I also adjusted Mixtral slightly so that we can even run `scan` in
Mixtral with its static MoE implementation. Scanning over GMM on the
other hand won't work until GMM forward/backward is wrapped in a custom
op similar to pytorch/xla#8654.

Test: added unit test. Next PR will change the trainer to apply scan.
@tengyifei tengyifei marked this pull request as ready for review March 17, 2025 08:24
@tengyifei tengyifei requested review from bhavya01, qihqi and zpcore March 17, 2025 17:14
Comment thread torchprime/torch_xla_models/mixtral/model.py
Comment thread torchprime/torch_xla_models/mixtral/model.py Outdated
Comment thread torchprime/torch_xla_models/scan_layers.py
@tengyifei tengyifei requested a review from bhavya01 March 17, 2025 22:40
Comment thread README.md
Comment thread torchprime/torch_xla_models/tests/test_llama.py
Comment thread torchprime/torch_xla_models/tests/test_mixtral.py
Comment thread torchprime/torch_xla_models/mixtral/model.py
Comment thread torchprime/layers/sequential.py
@tengyifei tengyifei requested a review from zpcore March 18, 2025 00:08
@tengyifei
Copy link
Copy Markdown
Contributor Author

@zpcore i saw you added a number of comments but didn't press "Request Changes" or "Approve" -- let me know if you would like to request changes or approve.

Copy link
Copy Markdown
Collaborator

@zpcore zpcore left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM!

@tengyifei tengyifei enabled auto-merge (squash) March 18, 2025 00:35
@tengyifei tengyifei merged commit d6f2452 into main Mar 18, 2025
@tengyifei tengyifei deleted the yifeit/scan-1 branch March 18, 2025 01:36
Sign up for free to subscribe to this conversation on GitHub. Already have an account? Sign in.

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants