-
Notifications
You must be signed in to change notification settings - Fork 27.7k
Loss functions for complex tensors #46642
Copy link
Copy link
Open
Labels
complex_autogradmodule: complexRelated to complex number support in PyTorchRelated to complex number support in PyTorchmodule: nnRelated to torch.nnRelated to torch.nntriagedThis issue has been looked at a team member, and triaged and prioritized into an appropriate moduleThis issue has been looked at a team member, and triaged and prioritized into an appropriate module
Metadata
Metadata
Assignees
Labels
complex_autogradmodule: complexRelated to complex number support in PyTorchRelated to complex number support in PyTorchmodule: nnRelated to torch.nnRelated to torch.nntriagedThis issue has been looked at a team member, and triaged and prioritized into an appropriate moduleThis issue has been looked at a team member, and triaged and prioritized into an appropriate module
🚀 Feature
Loss functions in
torch.nnmodule should support complex tensors whenever the operations make sense for complex numbers.Motivation
Complex Neural Nets are an active area of research and there are a few issues on GitHub (for example, #46546 (comment)) which suggests that we should add complex number support for loss functions.
Pitch
NOTE: As of now, we have decided to add complex support for only real valued loss functions, so please make sure to check that property for your chosen loss function before you start working on a PR to add complex support.
These loss functions should be updated to add support for complex numbers (both forward and backward operations). If a loss function doesn't make sense for complex numbers, it should throw an error clearly stating that. I.e. this is a list of loss functions as of the time this issue was written, we still need to figure out which we want to support and which should throw errors.
nn.L1Loss : PR Add complex support for torch.nn.L1Loss #49912
nn.MSELoss
nn.CrossEntropyLoss
nn.CTCLoss
nn.NLLLoss
nn.PoissonNLLLoss
nn.KLDivLoss
nn.BCELoss
nn.BCEWithLogitsLoss
nn.MarginRankingLoss
nn.HingeEmbeddingLoss
nn.MultiLabelMarginLoss
nn.SmoothL1Loss
nn.SoftMarginLoss
nn.MultiLabelSoftMarginLoss
nn.CosineEmbeddingLoss
nn.MultiMarginLoss
nn.TripletMarginLoss
If a loss function, uses an operation feasible but not supported for complex numbers right now, we should prioritize adding it.
cc @ezyang @anjali411 @dylanbespalko @mruberry @albanD