Skip to content

Introduce a GRU module implemented with scan#8777

Merged
qihqi merged 6 commits intomasterfrom
yifeit/rnn-scan
Mar 3, 2025
Merged

Introduce a GRU module implemented with scan#8777
qihqi merged 6 commits intomasterfrom
yifeit/rnn-scan

Conversation

@tengyifei
Copy link
Copy Markdown
Collaborator

@tengyifei tengyifei commented Mar 2, 2025

Fixes #8655

Given that the experimental launch of scan operator that lowers to XLA's WhileOp, we should leverage it to implement performant RNN layers. This PR adds support for a common RNN: Gated Recurrent Unit. It's mostly API compatible with the GRU module found in PyTorch upstream except we only support uni-directional RNN for now.

It should great to leverage it in place of the for loop that loops throught the time dimension, which could be large.

@tengyifei tengyifei requested a review from qihqi March 3, 2025 04:51
@tengyifei tengyifei marked this pull request as ready for review March 3, 2025 04:51
@tengyifei tengyifei requested a review from bhavya01 March 3, 2025 04:54
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

RNN / GRU / LSTM implementation for torch_xla

2 participants