-
Notifications
You must be signed in to change notification settings - Fork 27.7k
torch.autograd.gradcheck support for Tensor-like types (__torch_function__). #42942
Copy link
Copy link
Closed
Labels
featureA request for a proper, new feature.A request for a proper, new feature.module: autogradRelated to torch.autograd, and the autograd engine in generalRelated to torch.autograd, and the autograd engine in generalmodule: numpyRelated to numpy support, and also numpy compatibility of our operatorsRelated to numpy support, and also numpy compatibility of our operatorsmodule: testsIssues related to tests (not the torch.testing module)Issues related to tests (not the torch.testing module)triagedThis issue has been looked at a team member, and triaged and prioritized into an appropriate moduleThis issue has been looked at a team member, and triaged and prioritized into an appropriate module
Metadata
Metadata
Assignees
Labels
featureA request for a proper, new feature.A request for a proper, new feature.module: autogradRelated to torch.autograd, and the autograd engine in generalRelated to torch.autograd, and the autograd engine in generalmodule: numpyRelated to numpy support, and also numpy compatibility of our operatorsRelated to numpy support, and also numpy compatibility of our operatorsmodule: testsIssues related to tests (not the torch.testing module)Issues related to tests (not the torch.testing module)triagedThis issue has been looked at a team member, and triaged and prioritized into an appropriate moduleThis issue has been looked at a team member, and triaged and prioritized into an appropriate module
🚀 Feature
gradcheck.py#L264 explicitly requires its inputs to be of type torch.Tensor and thus doesn't work for Tensor-like objects.
Motivation
Tensor-like objects might implement operations with autograd support. gradcheck helps check their correctness. Extending torch.autograd.gradcheck with support for Tensor-like types makes it easier for users to test their code.
Pitch
This feature requests asks to extend gradcheck to also accept Tensor-like types and proposes using the presence of
__torch_function__as a determinant, as we usually do. The rest of the code should work fine then, assuming the user has reasonable implementations for the used features. Potentially we could guard using methods such as "is_sparse" based on whether the given type is Tensor-like and actually has that method defined. This might ease the implementation of Tensor-like objects, because it doesn't require the user to implement those. We could signal this with a warning as well as we do when the user doesn't use float64 as an input dtype, etc.Alternatives
An alternative as a user is to copy-paste and modify this code for their own purposes.
Additional context
I'm using this in context of the NestedTensor project. I labeled this as "module: numpy" based on the labels used in #24015.
cc @ezyang @ssnl @albanD @zou3519 @gqchen @mruberry @rgommers @VitalyFedyunin @hameerabbasi