[Fix]Fix bug of classwise weight in bce loss #5776
[Fix]Fix bug of classwise weight in bce loss #5776Ezra-Yu wants to merge 8 commits intoopen-mmlab:masterfrom
Conversation
|
I think here we need @xvjiarui, refers to ElectronicElephant@f526be0 |
xvjiarui
left a comment
There was a problem hiding this comment.
LGTM
Just to make sure there is no BC breaking.
Does Yolov3 need to use |
| weight = weight.float() | ||
| loss = F.binary_cross_entropy_with_logits( | ||
| pred, label.float(), pos_weight=class_weight, reduction='none') | ||
| pred, label.float(), weight=class_weight, reduction='none') |
There was a problem hiding this comment.
Please take a look at the docstring of F.binary_cross_entropy_with_logits.
weight should be a tensor that matches the input tensor shape. It is Not the class-aware weight.
pos_weight should be a vector with a length equal to the number of classes.
There was a problem hiding this comment.
it will broadcasts to (....., C)
Of course, the premise is to convert pred tensor into shape of (..., C).
There was a problem hiding this comment.
pos_weight is a vector with a length equal to the number of classes, but it is different with the classwise weight. Using pos_weight will pos items and neg items loss unblanced.
tests/test_metrics/test_losses.py
Outdated
| assert torch.allclose(loss_cls(fake_pred, fake_label), torch.tensor(200.)) | ||
|
|
||
| # test bce_loss | ||
| cls_score = torch.Tensor([[-200, 100], [500, -1000], [300, -300]]) |
There was a problem hiding this comment.
bce_loss now only supports the input tensor with shape (n, 1).
There was a problem hiding this comment.
yolov3 cls head uses bce loss, it's pred input tensor has shpae (batchsize, number_booxes, number_class)
Codecov Report
@@ Coverage Diff @@
## master #5776 +/- ##
==========================================
- Coverage 66.47% 66.10% -0.38%
==========================================
Files 281 281
Lines 22019 21870 -149
Branches 3659 3659
==========================================
- Hits 14638 14458 -180
- Misses 6631 6651 +20
- Partials 750 761 +11
Flags with carried forward coverage won't be shown. Click here to find out more.
Continue to review full report at Codecov.
|
After discussion, we think it is not a bug and we do not need to change the interface of Please check whether the above items are in reason, I will close this PR if you agree with me. |
|
OK,understand that. |
Thanks for your contribution and we appreciate it a lot. The following instructions would make your pull request more healthy and more easily get feedback. If you do not understand some items, don't worry, just make the pull request and seek help from maintainers.
Motivation
fix bug of classwise_weight in bce loss, and add unit test
Modification
this line has a bug. the loss only scale in the location where the label is 1. The class_weight has no effect on the location where the label is 0.
my exp:
