Add JSD kernel#264
Conversation
|
@yundai424 btw, should I change jsd to js_div? |
jsd is fne~ |
|
Added jsd benchmark script |
|
GPU CI failing on irrelevant tests |
|
pushed an implementation without inplace operations on input. see #262 (comment). we can decide which one to use in future. |
There was a problem hiding this comment.
Thanks for the efforts! Looks good to me in general and I'm assuming the log scale input is for considering log_softmax as input for common cases which is more numerically stable? (One minor useful feature we can add but could be a future pr is the ignore index part (suppose we can provide a extra input label tensor for it but might be too specific since general JSD does not have hard label as input). )
yes, and torch.KLDivLoss also takes input in the log-space. I think it would be better to have similar arguments for users. |
yundai424
left a comment
There was a problem hiding this comment.
LGTM, thanks for the contribution! @qingquansong could you help to create two follow up issues for 1) adding ignore index support for divergence losses, and 2) add general JSD (w/ beta) support? Thanks a lot!!
qingquansong
left a comment
There was a problem hiding this comment.
LGTM! Let me know when the CI test got fixed and I can give a quick shipping stamp. Thanks!
Summary
Resolve #252
Details
JSD
We expect input$X$ and target $Y$ are distributions in log-space, i.e., $X = log Q$ and $Y = log P$ .$P$ and $Q$ is defined as:
Jenson-Shannon Divergence between two distributions
where$M = \frac{1}{2}(P + Q)$ is the average distribution and $KL$ is the Kullback-Leibler divergence.$X = log Q$ and $Y = log P$ , we can simplify JSD expression to:
Given that
We define the point-wise JSD as:
With point-wise JSD, it's easier to implement JSDs with respect to different reduction methods in future.
The only downside is that it creates a torch.float32 tensor with the same shape as input's.
Current implementation is hardcoded to batchmean which is the original JSD definition.
Gradients
Given:
where$Q = e^X$ , $P = e^Y$ , and $M = \frac{1}{2}(e^X + e^Y)$ .
Gradients of$KL(P\ \Vert\ M)$ with respect to $X_i$ :
Gradients of$KL(Q\ \Vert\ M)$ with respect to $X_i$ :
Final gradients of JSD:
Combine the results from two KL divergence terms:
Simplify this to:
We store gradients at X_ptr in forward pass to save memory, then retrieve it through ctx in backward function as cross_entropy does. (inplace)
note: inplace operations on inputs might cause an issue with gradient computation.
Testing Done
With inplace (Storing gradients to inputs)
reduce memory usage by 61.54%


increase speed by 53.64%
Without inplace
reduce memory usage by 53%


increase speed by 61%
make testto ensure correctnessmake checkstyleto ensure code stylemake test-convergenceto ensure convergence