Skip to content

'mean' reduction result in CrossEntropyLoss mismatches with manually computing mean #40560

@vincentwen1995

Description

@vincentwen1995

🐛 Bug

Hi, during training, I noticed that when specifiying weights for CrossEntropyLoss, using the 'mean' reduction produces a different loss output, compared to using the 'none' reduction and computing the mean manually.

To Reproduce

Steps to reproduce the behavior:

import torch
import torch.nn as nn

logits = torch.randn((16, 5))
targets = torch.empty(16, dtype=torch.long).random_(5)

weights = [1, 2, 3, 4, 5]
cross_ent_mean = nn.CrossEntropyLoss(weight=torch.FloatTensor(weights), reduction='mean')
loss_a = cross_ent_mean(logits, targets)
print(loss_a)

cross_ent = nn.CrossEntropyLoss(weight=torch.FloatTensor(weights), reduction='none')
loss_b = cross_ent(logits, targets).mean()
print(loss_b)

assert torch.equal(loss_a, loss_b)

Expected behavior

The loss computed by the two methods should be equal torch.equal(loss_a, loss_b) = True

Environment

Please copy and paste the output from our
environment collection script
(or fill out the checklist below manually).

You can get the script and run it with:

wget https://raw.githubusercontent.com/pytorch/pytorch/master/torch/utils/collect_env.py
# For security purposes, please check the contents of collect_env.py before running it.
python collect_env.py
  • PyTorch Version (e.g., 1.0): 1.6.0.dev20200505+cu101
  • OS (e.g., Linux): Windows 10
  • How you installed PyTorch (conda, pip, source): pip
  • Build command you used (if compiling from source):
  • Python version: 3.7.6
  • CUDA/cuDNN version: CUDA 10.1 (irrelevant)
  • GPU models and configuration: GTX970M (irrelevant)
  • Any other relevant information:

Additional context

It seems to be the problem with the weights as I tried out to assign the same weight to all the classes, however, theorectically, the weighted loss should be computed before applying the mean() operation.

cc @ezyang @gchanan @zou3519 @jlin27 @albanD @mruberry

Metadata

Metadata

Assignees

Labels

high prioritymodule: docsRelated to our documentation, both in docs/ and docblocksmodule: nnRelated to torch.nntriagedThis issue has been looked at a team member, and triaged and prioritized into an appropriate module

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions