Support for target with class probs in CrossEntropyLoss#61044
Support for target with class probs in CrossEntropyLoss#61044jbschlosser wants to merge 12 commits intopytorch:masterfrom
Conversation
🔗 Helpful links
💊 CI failures summary and remediationsAs of commit 9ca7ae3 (more details on the Dr. CI page): ✅ None of the CI failures appear to be your fault 💚
1 job timed out:
🚧 1 fixed upstream failure:These were probably caused by upstream breakages that were already fixed.
Please rebase on the
|
e95f9ae to
abd8c30
Compare
|
@jbschlosser has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator. |
zou3519
left a comment
There was a problem hiding this comment.
this looks pretty good, some suggestions and comments
aten/src/ATen/native/LossNLL.cpp
Outdated
There was a problem hiding this comment.
This... seems like a good argument for why a user would expect reduction=mean to return (-(input * target * weight_).mean()), I'm having a hard time coming up with a use case where someone wants probabilities and wants to do a weighted mean over the probabilities and weights.
At any rate we should probably be consistent with our hard targets cross_entropy function...
There was a problem hiding this comment.
Yeah, I agree it doesn't make sense to do a weighted mean over probabilities and weights. I did it this way here to maintain consistency with the hard target cross-entropy loss- with one-hot targets, the results are equivalent between soft and hard if done like this :/
Also to be fully precise: I think a mean computation that fits user intuitions would be -(input * target * weight_).sum(1).mean(). As in the non-weighted calculation, sum(1) should be taken first before the mean to be correct.
0ba16ea to
7e7cace
Compare
zou3519
left a comment
There was a problem hiding this comment.
Some minor comments but otherwise this LGTM!
aten/src/ATen/native/LossNLL.cpp
Outdated
There was a problem hiding this comment.
NRVO yes, but I'd expect the compiler to do RVO. Not sure how to test for this though; feel free to leave the code as-is.
torch/nn/modules/loss.py
Outdated
There was a problem hiding this comment.
The mean case is only true if the input and target are of size (N, C). Otherwise, we divide by a factor that isn't the batch size -- for a tensor of shape (N, C, d1, d2, ..., dk) we end up dividing by a factor of tensor.numel() / C, right?
Maybe this is OK because we can view data of (N, C, d1, d2, ..., dk) as being a "batch" of (N, d1, d2, ..., dk) distributions.
There was a problem hiding this comment.
Yeah, N is doing a lot of work implicitly here. I do think that data of shape (N, C, d1, d2, ..., dk) is conceptually a batch of (N, d1, d2, ..., dk) distributions (and as mentioned before, I think d1, ..., dk should have been added to the left of C for this to be clearer, but that ship has sailed).
While this was carried over to some extent from the old docs, each item in the formula is now more explicitly defined, so I think it needs to be more precise. Specifically, "N is the batch size" should change. Borrowing some terminology from KLDivLoss, we could do something like:
N spans the minibatch dimension as well as dimensions d1, ..., dk in the case of K-dimensional loss
wdyt?
There was a problem hiding this comment.
N spans the minibatch dimension as well as dimensions d1, ..., dk in the case of K-dimensional loss
That sounds good
5137b89 to
8dab918
Compare
|
@jbschlosser has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator. |
|
@jbschlosser has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator. |
|
@jbschlosser merged this pull request in a42345a. |
|
the commit message in a42345a (from the first post of this PR) says |
@VitamintK That's right; it was only added to |

Fixes #11959
Alternative approach to creating a new
CrossEntropyLossWithSoftLabelsclass. This PR simply adds support for "soft targets" AKA class probabilities to the existingCrossEntropyLossandNLLLossclasses.Implementation is dumb and simple right now, but future work can add higher performance kernels for this case.