-
Notifications
You must be signed in to change notification settings - Fork 27.2k
Description
The CrossEntropyLoss class and function uses inputs (unscaled probabilities), targets and class weights to calculate the loss.
The reason for using class weights is to help with imbalanced datasets.
However, with this setup you are not allowed to handle masking, which is a core issue in time-series (RNN, NLP) training with imbalanced sequence length.
I propose two alternative approaches that would solve this problem, while still allowing weighting class imbalances.
1. have the weights as a function argument at every execution. Such as TensorFlow is doing. Such that the weights is the same size as the target tensor, with a weighting on each of the samples. This would still allow the user to supply increased weight sizes at the indices of their imbalanced classes, while masking would be putting ones at wanted samples and zeros everywhere else.
2. A perhaps more elegant solution would be to have the CrossEntropyLoss exactly the same as tensorflows cross entropy loss function, which seems to be the same as PyTorch's, but without averaging the loss of every sample. This would allow the user to average how they see fit and produce similar functions to the one in proposal (1).