Add beta support for jsd #290
Conversation
|
|
||
| def forward(self, p, q): | ||
| return LigerJSDFunction.apply(p, q) | ||
| def forward(self, log_q, log_p): |
There was a problem hiding this comment.
This is the correct order of input and target (student and teacher) respectively. would it be too confusing?
There was a problem hiding this comment.
yeah, the name is a bit confusing, or we can add some descriptions here to clarify
|
@qingquansong @yundai424 ready for review! |
qingquansong
left a comment
There was a problem hiding this comment.
LGTM in general! In case you're interested, I think one good future work is to make those KL or JSD losses similar to the fuse CE loss: feed teacher and student model last projection layer to the kernel and fuse it with the losses. Here teacher weight does not need grad and student will need grad.
|
|
||
| def forward(self, p, q): | ||
| return LigerJSDFunction.apply(p, q) | ||
| def forward(self, log_q, log_p): |
There was a problem hiding this comment.
yeah, the name is a bit confusing, or we can add some descriptions here to clarify
|
awesome work! waiting for the final nit review |
@qingquansong sure, I'm in. |
|
Forgot to add jsd in readme and liger_kernel.transformer |
Head branch was pushed to by a user without write access
Summary
Resolve #278 .
Details
Forward:
where$X=logQ$ , $Y=logP$ and $M=\beta P + (1-\beta)Q$ .
Gradients:
Testing Done
make testto ensure correctnessmake checkstyleto ensure code stylemake test-convergenceto ensure convergence