Skip to content

torch.autograd.gradcheck support for Tensor-like types (__torch_function__). #42942

@cpuhrsch

Description

@cpuhrsch

🚀 Feature

gradcheck.py#L264 explicitly requires its inputs to be of type torch.Tensor and thus doesn't work for Tensor-like objects.

Motivation

Tensor-like objects might implement operations with autograd support. gradcheck helps check their correctness. Extending torch.autograd.gradcheck with support for Tensor-like types makes it easier for users to test their code.

Pitch

This feature requests asks to extend gradcheck to also accept Tensor-like types and proposes using the presence of __torch_function__ as a determinant, as we usually do. The rest of the code should work fine then, assuming the user has reasonable implementations for the used features. Potentially we could guard using methods such as "is_sparse" based on whether the given type is Tensor-like and actually has that method defined. This might ease the implementation of Tensor-like objects, because it doesn't require the user to implement those. We could signal this with a warning as well as we do when the user doesn't use float64 as an input dtype, etc.

Alternatives

An alternative as a user is to copy-paste and modify this code for their own purposes.

Additional context

I'm using this in context of the NestedTensor project. I labeled this as "module: numpy" based on the labels used in #24015.

cc @ezyang @ssnl @albanD @zou3519 @gqchen @mruberry @rgommers @VitalyFedyunin @hameerabbasi

Metadata

Metadata

Assignees

Labels

featureA request for a proper, new feature.module: autogradRelated to torch.autograd, and the autograd engine in generalmodule: numpyRelated to numpy support, and also numpy compatibility of our operatorsmodule: testsIssues related to tests (not the torch.testing module)triagedThis issue has been looked at a team member, and triaged and prioritized into an appropriate module

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions