Skip to content

Conversation

@kailums
Copy link
Contributor

@kailums kailums commented Nov 15, 2023

Description

change RotaryEmbeddings op implementation, add support for 4D input tensor that is with shape of [batch, num_heads, seq_len, head_size].

Motivation and Context

Current RotaryEmbedding op only support 3d input tensor with shape [batch, seq_len, hidden_size]

For llamav2 model, when using FusionRotaryEmbeddings to only fuse RotaryEmbeddings op, there will be a transpose operation for query and key, and then the input tensor of RotaryEmbeddings becomes 4D [batch, num_heads, seq_len, head_size].

This scenario can't be supported by current RotaryEmbeddings implementation. So it needs to support 4D input tensor.

@kailums kailums merged commit 1a29460 into main Nov 17, 2023
@kailums kailums deleted the kailums/rope-4d branch November 17, 2023 12:38
kleiti pushed a commit to kleiti/onnxruntime that referenced this pull request Mar 22, 2024
### Description
<!-- Describe your changes. -->

change RotaryEmbeddings op implementation, add support for 4D input
tensor that is with shape of [batch, num_heads, seq_len, head_size].

### Motivation and Context
<!-- - Why is this change required? What problem does it solve?
- If it fixes an open issue, please link to the issue here. -->
Current RotaryEmbedding op only support 3d input tensor with shape
[batch, seq_len, hidden_size]

For llamav2 model, when using FusionRotaryEmbeddings to only fuse
RotaryEmbeddings op, there will be a transpose operation for query and
key, and then the input tensor of RotaryEmbeddings becomes 4D [batch,
num_heads, seq_len, head_size].

This scenario can't be supported by current RotaryEmbeddings
implementation. So it needs to support 4D input tensor.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants