Add FusedLinearJSD#300
Conversation
| out=grad_weight, | ||
| ) | ||
|
|
||
| loss = torch.sum(loss_1d) / BT |
There was a problem hiding this comment.
I just noticed that torch.sum might overflow when BT size is large. It can be fixed by either
- do division first before sum, i.e. torch.sum(loss_1d / BT)
- put division inside JSD kernel
I prefre the second solution more, but I'm not sure if it is ok to modify another kernel in this PR.
There was a problem hiding this comment.
can the existing division in JSD be able to handle this? 🤔 I saw you change the n_row for each chunk to be BT now.
There was a problem hiding this comment.
currently n_row parameter in JSD is only for gradients calculation.
it can be modified to something like cross_entropy does, which can calculate loss wrt expected reduction method
I saw you change the
n_rowfor each chunk to be BT now.
simply passing BT can perform correct calculations without further alpha tweaking like flce does
There was a problem hiding this comment.
can the existing division in JSD be able to handle this?
it has potential overflow issue too, thats one of the reason why i think moving division into kernel is better too
There was a problem hiding this comment.
Gotcha. I'm good with moving it inside the JSD kernel. For ce/flce, let's keep them for now. Thanks!
qingquansong
left a comment
There was a problem hiding this comment.
LGTM! We can add the ignore index later similar to here so it can be easily used for the SFT context. Great work!
| out=grad_weight, | ||
| ) | ||
|
|
||
| loss = torch.sum(loss_1d) / BT |
There was a problem hiding this comment.
can the existing division in JSD be able to handle this? 🤔 I saw you change the n_row for each chunk to be BT now.
|
|
||
|
|
||
| @triton.jit | ||
| def element_mul_kernel( |
There was a problem hiding this comment.
maybe we can delete the one in the original ce and share this
There was a problem hiding this comment.
yeah, I can add the new ce file
Summary
similar to the fuse linear CE.
It handles the forward and backward pass of the final linear layer via JSD by avoiding the materialization of the large logits tensor. Since JSD is the last layer, we can compute the gradient at the forward pass.
Testing Done
Hidden size: 4096, Vocab size: 128256


make testto ensure correctnessmake checkstyleto ensure code stylemake test-convergenceto ensure convergence