Consider the fast rope code:
|
def backward(ctx, dQ, dK): |
|
batch, _, _, head_dim = dQ.shape |
|
|
|
rope_ptr = ( |
|
ctx.rope_indices |
|
if ctx.has_indices |
|
else ctx.cos.new_empty(1, dtype = torch.int32) |
|
) |
|
|
|
Q_batch_stride, Q_head_stride, Q_seq_stride = ( |
|
dQ.stride(0), |
|
dQ.stride(1), |
|
dQ.stride(2), |
|
) |
|
K_batch_stride, K_head_stride, K_seq_stride = ( |
|
dK.stride(0), |
|
dK.stride(1), |
|
dK.stride(2), |
|
) |
|
|
|
# Inplace rotary embedding is generally fine |
|
dQ_out = dQ.clone() if not dQ.is_contiguous else dQ |
|
dK_out = dK.clone() if not dK.is_contiguous else dK |
|
|
When a zero-strided tensor dQ or dK comes in the strides
Q_batch_stride, Q_head_stride, Q_seq_stride, and
K_batch_stride, K_head_stride, K_seq_stride are all set as zero.
To my knowledge, this is a bug that can happen with debugging losses. For example,
out = fast_rope_embedding(x.clone(), x.clone(), cos, sin)
(out[0].sum() + out[1].sum()).backward()
This code gives the backward function a zero-strided tensor, which should be fully materialized.
Furthermore,
Each forward/backward pass was never cloning the given Q, K, dQ, and dK tensors.
See,
|
# Inplace rotary embedding is generally fine |
|
dQ_out = dQ.clone() if not dQ.is_contiguous else dQ |
|
dK_out = dK.clone() if not dK.is_contiguous else dK |
and,
|
# Inplace rotary embedding is generally fine |
|
Q_out = Q.clone() if not Q.is_contiguous else Q |
|
K_out = K.clone() if not K.is_contiguous else K |
X.is_contiguous is a method, and it should be used as X.is_contiguous().
I have a PR with the fix I will submit briefly.
Consider the fast rope code:
unsloth/unsloth/kernels/rope_embedding.py
Lines 377 to 400 in d83fbf6
When a zero-strided tensor
dQordKcomes in the stridesQ_batch_stride,Q_head_stride,Q_seq_stride, andK_batch_stride,K_head_stride,K_seq_strideare all set as zero.To my knowledge, this is a bug that can happen with debugging losses. For example,
This code gives the backward function a zero-strided tensor, which should be fully materialized.
Furthermore,
Each forward/backward pass was never cloning the given Q, K, dQ, and dK tensors.
See,
unsloth/unsloth/kernels/rope_embedding.py
Lines 397 to 399 in d83fbf6
and,
unsloth/unsloth/kernels/rope_embedding.py
Lines 314 to 316 in d83fbf6
X.is_contiguousis a method, and it should be used asX.is_contiguous().I have a PR with the fix I will submit briefly.