🐛 Bug
Input tensors to attention must be in format [B, M, H, K], where B is the batch size, M the sequence length, H the number of heads, and K the embedding size per head as documented here.
Hence positional embedding (e.g., rotary embedding) should be applied to dim=1. However, in the RotaryEmbedding class, dim=-2 is being passed, which corresponds to dim=2 as seen here.
def forward(
self, q: torch.Tensor, k: torch.Tensor
) -> Tuple[torch.Tensor, torch.Tensor]:
self._cos_cached, self._sin_cached = self._update_cos_sin_tables(
k, seq_dimension=-2 # should be seq_dimension=1 or no argument should be passed as the default value is correct
)
return (
apply_rotary_pos_emb(q, self._cos_cached, self._sin_cached),
apply_rotary_pos_emb(k, self._cos_cached, self._sin_cached),
)
Additional context
Thanks to @jmercat who found symptoms of this problem downstream of xformers!
🐛 Bug
Input tensors to attention must be in format
[B, M, H, K], whereBis the batch size,Mthe sequence length,Hthe number of heads, andKthe embedding size per head as documented here.Hence positional embedding (e.g., rotary embedding) should be applied to
dim=1. However, in theRotaryEmbeddingclass,dim=-2is being passed, which corresponds todim=2as seen here.Additional context
Thanks to @jmercat who found symptoms of this problem downstream of xformers!