Hi, in flashmla_backend,
when I print self.num_q_heads & self.num_kv_heads, the numbers are all 16.
However, in the original flashmla repo, the num_kv_heads should be 1.
Could you clarify this difference?
Thanks.
mla_metadata, num_splits = get_mla_metadata(
forward_batch.seq_lens.to(torch.int32),
Q_LEN * self.num_q_heads // self.num_kv_heads,
self.num_kv_heads,
)
Hi, in flashmla_backend,
when I print self.num_q_heads & self.num_kv_heads, the numbers are all 16.
However, in the original flashmla repo, the num_kv_heads should be 1.
Could you clarify this difference?
Thanks.
mla_metadata, num_splits = get_mla_metadata( forward_batch.seq_lens.to(torch.int32), Q_LEN * self.num_q_heads // self.num_kv_heads, self.num_kv_heads, )