Skip to content

[Fix]Fix bug of classwise weight in bce loss #5776

Closed
Ezra-Yu wants to merge 8 commits intoopen-mmlab:masterfrom
Ezra-Yu:classwise-weight-bce
Closed

[Fix]Fix bug of classwise weight in bce loss #5776
Ezra-Yu wants to merge 8 commits intoopen-mmlab:masterfrom
Ezra-Yu:classwise-weight-bce

Conversation

@Ezra-Yu
Copy link
Copy Markdown

@Ezra-Yu Ezra-Yu commented Aug 3, 2021

Thanks for your contribution and we appreciate it a lot. The following instructions would make your pull request more healthy and more easily get feedback. If you do not understand some items, don't worry, just make the pull request and seek help from maintainers.

Motivation

fix bug of classwise_weight in bce loss, and add unit test

Modification

this line has a bug. the loss only scale in the location where the label is 1. The class_weight has no effect on the location where the label is 0.

my exp:
image

@CLAassistant
Copy link
Copy Markdown

CLAassistant commented Aug 3, 2021

CLA assistant check
All committers have signed the CLA.

@jshilong jshilong changed the title [bug] fix bug of classwise weight in bce loss [Bug] fix bug of classwise weight in bce loss Aug 3, 2021
@jshilong jshilong changed the title [Bug] fix bug of classwise weight in bce loss [Bug] Fix bug of classwise weight in bce loss Aug 3, 2021
@ZwwWayne ZwwWayne requested review from AronLin and jshilong August 4, 2021 02:17
@mzr1996
Copy link
Copy Markdown
Member

mzr1996 commented Aug 4, 2021

I think here we need @xvjiarui, refers to ElectronicElephant@f526be0

Copy link
Copy Markdown
Collaborator

@xvjiarui xvjiarui left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM
Just to make sure there is no BC breaking.

@mzr1996
Copy link
Copy Markdown
Member

mzr1996 commented Aug 5, 2021

LGTM
Just to make sure there is no BC breaking.

Does Yolov3 need to use pos_weight? I notice it's modified in YOLOv3-mmdetection

@Ezra-Yu Ezra-Yu changed the title [Bug] Fix bug of classwise weight in bce loss [Bug] Fix bug of classwise weight in bce loss [WIP] Aug 5, 2021
@Ezra-Yu Ezra-Yu changed the title [Bug] Fix bug of classwise weight in bce loss [WIP] [WIP]Fix bug of classwise weight in bce loss Aug 5, 2021
weight = weight.float()
loss = F.binary_cross_entropy_with_logits(
pred, label.float(), pos_weight=class_weight, reduction='none')
pred, label.float(), weight=class_weight, reduction='none')
Copy link
Copy Markdown
Contributor

@AronLin AronLin Aug 5, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please take a look at the docstring of F.binary_cross_entropy_with_logits.
weight should be a tensor that matches the input tensor shape. It is Not the class-aware weight.
pos_weight should be a vector with a length equal to the number of classes.

Copy link
Copy Markdown
Author

@Ezra-Yu Ezra-Yu Aug 5, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

it will broadcasts to (....., C)
Of course, the premise is to convert pred tensor into shape of (..., C).

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

pos_weight is a vector with a length equal to the number of classes, but it is different with the classwise weight. Using pos_weight will pos items and neg items loss unblanced.

assert torch.allclose(loss_cls(fake_pred, fake_label), torch.tensor(200.))

# test bce_loss
cls_score = torch.Tensor([[-200, 100], [500, -1000], [300, -300]])
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

bce_loss now only supports the input tensor with shape (n, 1).

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yolov3 cls head uses bce loss, it's pred input tensor has shpae (batchsize, number_booxes, number_class)

@Ezra-Yu Ezra-Yu changed the title [WIP]Fix bug of classwise weight in bce loss [Fix]Fix bug of classwise weight in bce loss Aug 6, 2021
@codecov
Copy link
Copy Markdown

codecov bot commented Aug 10, 2021

Codecov Report

Merging #5776 (4d326d0) into master (87eda06) will decrease coverage by 0.37%.
The diff coverage is n/a.

Impacted file tree graph

@@            Coverage Diff             @@
##           master    #5776      +/-   ##
==========================================
- Coverage   66.47%   66.10%   -0.38%     
==========================================
  Files         281      281              
  Lines       22019    21870     -149     
  Branches     3659     3659              
==========================================
- Hits        14638    14458     -180     
- Misses       6631     6651      +20     
- Partials      750      761      +11     
Flag Coverage Δ
unittests 66.10% <ø> (-0.35%) ⬇️

Flags with carried forward coverage won't be shown. Click here to find out more.

Impacted Files Coverage Δ
mmdet/models/losses/cross_entropy_loss.py 100.00% <ø> (ø)
mmdet/utils/contextmanagers.py 0.00% <0.00%> (-18.58%) ⬇️
mmdet/models/dense_heads/dense_test_mixins.py 38.46% <0.00%> (-5.13%) ⬇️
mmdet/models/utils/normed_predictor.py 83.33% <0.00%> (-4.77%) ⬇️
mmdet/models/roi_heads/test_mixins.py 70.21% <0.00%> (-4.26%) ⬇️
mmdet/models/detectors/base.py 53.65% <0.00%> (-2.50%) ⬇️
mmdet/models/roi_heads/base_roi_head.py 85.29% <0.00%> (-2.21%) ⬇️
mmdet/datasets/pipelines/formating.py 64.22% <0.00%> (-2.16%) ⬇️
mmdet/core/bbox/coder/yolo_bbox_coder.py 58.97% <0.00%> (-2.01%) ⬇️
mmdet/core/anchor/point_generator.py 51.57% <0.00%> (-1.96%) ⬇️
... and 49 more

Continue to review full report at Codecov.

Legend - Click here to learn more
Δ = absolute <relative> (impact), ø = not affected, ? = missing data
Powered by Codecov. Last update 87eda06...4d326d0. Read the comment docs.

@AronLin
Copy link
Copy Markdown
Contributor

AronLin commented Aug 12, 2021

pred = torch.Tensor([[2, 4],
                     [-2, 1], 
                     [-1, 3]])  # (N, C)
label = torch.Tensor([[0, 1], 
                     [1, 0], 
                     [1, 1]])  # (N, C)
class_weight = torch.tensor([0.1, 0.9])  # (C, )

# loss without class_weight
loss_cls_cfg = dict(
    type='CrossEntropyLoss',
    use_sigmoid=True,
    loss_weight=1.0,
    reduction='none'
)
loss_cls = build_loss(loss_cls_cfg)
origin_loss = loss_cls(pred, label)
weight_between_classes = loss_cls(pred, label, weight=class_weight)

# loss with class_weight
loss_cls_cfg = dict(
    type='CrossEntropyLoss',
    use_sigmoid=True,
    loss_weight=1.0,
    class_weight = torch.tensor([0.1, 0.9]),
    reduction='none'
)
loss_cls = build_loss(loss_cls_cfg)
weight_between_samples = loss_cls(pred, label)

# output
origin_loss: tensor([[2.1269, 0.0181],
                     [2.1269, 1.3133],
                     [1.3133, 0.0486]])
weight_between_classes: tensor([[0.2127, 0.0163],
                                [0.2127, 1.1819],
                                [0.1313, 0.0437]])
weight_between_samples: tensor([[2.1269, 0.0163],
                                [0.2127, 1.3133],
                                [0.1313, 0.0437]])

After discussion, we think it is not a bug and we do not need to change the interface of class_weight, here are the reasons:
(1). In the original design, class_weight is used to adjust the weight between the positive and negative samples of each class. So the class_weight is applied to positive samples.
(2). If we want to adjust the importance between classes, we can use the interface 'weight' in the forward function.
(3). If we approve this PR, it will be hard to implement the function in (1). But the function in (2) is easy to implement without this PR.

Please check whether the above items are in reason, I will close this PR if you agree with me.
Kindly ping @ZwwWayne @xvjiarui @mzr1996 @Ezra-Yu .

@Ezra-Yu
Copy link
Copy Markdown
Author

Ezra-Yu commented Aug 12, 2021

OK,understand that.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

6 participants