-
Notifications
You must be signed in to change notification settings - Fork 27.7k
Gaussian NLL loss #48520
Copy link
Copy link
Closed
Labels
featureA request for a proper, new feature.A request for a proper, new feature.module: lossProblem is related to loss functionProblem is related to loss functionmodule: nnRelated to torch.nnRelated to torch.nntriagedThis issue has been looked at a team member, and triaged and prioritized into an appropriate moduleThis issue has been looked at a team member, and triaged and prioritized into an appropriate module
Metadata
Metadata
Assignees
Labels
featureA request for a proper, new feature.A request for a proper, new feature.module: lossProblem is related to loss functionProblem is related to loss functionmodule: nnRelated to torch.nnRelated to torch.nntriagedThis issue has been looked at a team member, and triaged and prioritized into an appropriate moduleThis issue has been looked at a team member, and triaged and prioritized into an appropriate module
🚀 Feature
Gaussian negative log-likelihood loss, similar to issue #1774 (and solution pull #1779)
Motivation
The homoscedastic Gaussian loss is described in Equation 1 of this paper. The heteroscedastic version in Equation 2 here (ignoring the final anchoring loss term). These are both key to the uncertainty quantification techniques described.
Pitch
I'm happy to implement this, using the template of pull #1779. The implementation will allow for both homoscedastic and heteroscedastic losses.
Alternatives
An alternative would be to instantiate a Gaussian (https://pytorch.org/docs/stable/distributions.html#normal) and evaluate the log of this. However, this seems wasteful given a new Gaussian would be instantiated for every new function call for the best case (homoscedastic), and for every element of the output-target pairs in the worst case (heteroscedastic).
Additional context
Definitions: homo/heteroscedasticity (wiki)
cc @albanD @mruberry