-
Notifications
You must be signed in to change notification settings - Fork 9.9k
Inconsistency with loss function of PyTorch when using ingore_index arg #7367
Description
Hello,I'm a user of mmdet3d, as this toolbox import a lot of funcs from mmdet, which is the reason why I report here.
For example the commly used module CrossEntropyLoss with ignore_index
Here is the definition from PyTorch
ignore_index (int, optional): Specifies a target value that is ignored
and does not contribute to the input gradient. When :attr:`size_average` is
``True``, the loss is averaged over non-ignored targets.
And the following is from MMDetection
ignore_index (int | None): The label index to be ignored.
If None, it will be set to default value. Default: -100.
This is the code segment from mmdet
def cross_entropy(pred,
label,
weight=None,
reduction='mean',
avg_factor=None,
class_weight=None,
ignore_index=-100):
"""Calculate the CrossEntropy loss.
Args:
pred (torch.Tensor): The prediction with shape (N, C), C is the number
of classes.
label (torch.Tensor): The learning label of the prediction.
weight (torch.Tensor, optional): Sample-wise loss weight.
reduction (str, optional): The method used to reduce the loss.
avg_factor (int, optional): Average factor that is used to average
the loss. Defaults to None.
class_weight (list[float], optional): The weight for each class.
ignore_index (int | None): The label index to be ignored.
If None, it will be set to default value. Default: -100.
Returns:
torch.Tensor: The calculated loss
"""
# The default value of ignore_index is the same as F.cross_entropy
ignore_index = -100 if ignore_index is None else ignore_index
# element-wise losses
loss = F.cross_entropy(
pred,
label,
weight=class_weight,
reduction='none',
ignore_index=ignore_index)
# apply weights and do the reduction
if weight is not None:
weight = weight.float()
loss = weight_reduce_loss(
loss, weight=weight, reduction=reduction, avg_factor=avg_factor)
return lossThe key is that, the results are inconsistent. If someone uses mmdets CrossEntropyLoss by default and give a param of ignore_index such as 255, torch version will ignores the ignored elements, while mmdet version will first caclulate element-wise loss, then apply mean reduction for all elements by default.
In my scenario of point cloud segmentation, the ground truth is a voxelized tensor with most of the labels ignored. By torch version, mean reduction will be performed over only no-ingored voxels while mmdet version over all elements.
It troubles me for days because of this tiny, sort of, 'bugs'?
Pls figure it out and keep consistency with common funcs of torch
Bests
Singal