In the DSA implementation proposed by DeepSeek-32, the RoPE-related split for Q and K is performed as:
q_pe, q_nope = torch.split(
q,
[self.rope_head_dim, self.head_dim - self.rope_head_dim],
dim=-1
)
However, in your implementation, the split order appears to be:
x_nope, x_pe = torch.split(
x, [self.index_head_dim - self.qk_pos_emb_head_dim, self.qk_pos_emb_head_dim], dim=-1
)
This means the RoPE-applied and non-RoPE parts are arranged in a different order compared to DeepSeek-32.
Could this difference in the split/order of RoPE and non-RoPE components for Q/K introduce any numerical or accuracy discrepancies during training or inference?
In the DSA implementation proposed by DeepSeek-32, the RoPE-related split for Q and K is performed as:
q_pe, q_nope = torch.split(
q,
[self.rope_head_dim, self.head_dim - self.rope_head_dim],
dim=-1
)
However, in your implementation, the split order appears to be:
x_nope, x_pe = torch.split(
x, [self.index_head_dim - self.qk_pos_emb_head_dim, self.qk_pos_emb_head_dim], dim=-1
)
This means the RoPE-applied and non-RoPE parts are arranged in a different order compared to DeepSeek-32.
Could this difference in the split/order of RoPE and non-RoPE components for Q/K introduce any numerical or accuracy discrepancies during training or inference?