Add pos_weight argument to nn.BCEWithLogitsLoss (#5660)#6856
Add pos_weight argument to nn.BCEWithLogitsLoss (#5660)#6856ssnl merged 2 commits intopytorch:masterfrom
Conversation
…tropy_with_logits (pytorch#5660) - Add an option to control precision/recall in imbalanced datasets - Add tests (but new_criterion_tests)
torch/nn/functional.py
Outdated
| target: Tensor of the same shape as input | ||
| weight (Tensor, optional): a manual rescaling weight | ||
| if provided it's repeated to match input tensor shape | ||
| pos_weight (Tensor, optional): a weight of positive examples. |
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
torch/nn/modules/loss.py
Outdated
| weight (Tensor, optional): a manual rescaling weight given to the loss | ||
| of each batch element. If given, has to be a Tensor of size | ||
| "nbatch". | ||
| pos_weight (Tensor, optional): a weight of positive examples. |
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
ezyang
left a comment
There was a problem hiding this comment.
Accept assuming the doc issues are fixed
`pos_weight` was moved to the end because it is the last argument in both `nn.BCEWithLogitsLoss` and `binary_cross_entropy_with_logits`
| is ``False``, returns a loss per input/target element instead and ignores | ||
| :attr:`size_average`. Default: ``True`` | ||
| pos_weight (Tensor, optional): a weight of positive examples. | ||
| Must be a vector with length equal to the number of classes. |
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
| def __init__(self, weight=None, size_average=True, reduce=True, pos_weight=None): | ||
| super(BCEWithLogitsLoss, self).__init__(size_average, reduce) | ||
| self.register_buffer('weight', weight) | ||
| self.register_buffer('pos_weight', pos_weight) |
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
|
@pytorchbot retest this please |
4 similar comments
|
@pytorchbot retest this please |
|
@pytorchbot retest this please |
|
@pytorchbot retest this please |
|
@pytorchbot retest this please |
* upstream/master: (42 commits) [c10d] No default device for ProcessGroupGloo (pytorch#8888) Fix default values for affine= in the docstrings of InstanceNormXd (pytorch#8895) Stop making dynamic allocations of PinnedMemoryAllocator. (pytorch#8896) [C++ API] Rework optimization package (pytorch#8815) Mention MPICH_MAX_THREAD_SAFETY=multiple. (pytorch#8580) Unify isViewable, handle n-dimensional empty tensors. (pytorch#8883) Add pos_weight argument to nn.BCEWithLogitsLoss (pytorch#5660) (pytorch#6856) [build] Enable clang-specific warnings only when using clang (pytorch#8869) Fix cmake cudnn autodetection (pytorch#8891) [c10d] Fix link order for building C++ tests (pytorch#8889) directly add_subdirectory(nanopb) from torch CMakeLists (pytorch#8870) [C++ API] Bag of fixes (pytorch#8843) [build] Raise in cmake when seeing NVCC{9/9.1} + GCC6 combo (pytorch#8863) Create avg_pool1d in ATen (pytorch#8880) throw error when grid_sample is passed unsupported mode (pytorch#8884) Allow autograd to work even when the shape of values cannot be determined (pytorch#8641) Make at::Tensor::to() const (pytorch#8839) [auto] Update onnx to 458c521 - Fix typo (onnx/onnx#1143) onnx/onnx@458c521 [Caffe2] Fix gradient_check on in-place ops (pytorch#8828) Fix as_strided_backward (pytorch#8721) ...
Multiplier
1 + (pos_weight - 1) * targetis the only significant difference.Notes:
F.binary_cross_entropy/nn.BCELoss. (It uses implementation from torch._C.) I can add a straightforward and numerically unstable implementation to this function but I'm not sure if it is really needed.pos_weightas last argument to prevent errors in the code that doesn't use names for keyword arguments. But it looks quite ugly.P.S. I proposed these changes in #5660