Skip to content

Support CUDA and backpropagation in torch.orgqr #50104

@toshas

Description

@toshas

🚀 Feature

As of PyTorch 1.7, torch.orgqr only supports CPU tensors and does not implement gradients.
Feature proposal: add support for CUDA tensors and gradients wrt inputs.

Motivation

This function is required to perform orthogonal parameterization of matrices (not to be confused with orthogonal initialization). Given an input matrix of shape d x r with d >= r, orgqr produces an orthogonal matrix of the same shape d x r. The output spans the manifold of Stiefel frames of size d x r, and hence, provides precise orthogonal constraints by construction. This can be used further, e.g. as a weight matrix in an RNN transition layer or any other place requiring strict orthogonality. More motivation can be found here and references therein: #42243
Without GPU, the function is slow. Without gradients, parameterization becomes impossible. Once the requested feature is implemented, and once this PR #33344 is merged, it will become possible to perform precise orthogonal parameterization of weight matrices in Linear, all flavors of Convolutional, and Embedding layers.

Pitch

torch.orgqr to support the two requested properties.
As hinted by @lezcano, MAGMA already implements a GPU side of the function: https://www.icl.utk.edu/~mgates3/docs/magma-proto.html
The gradients can be taken from my package https://github.com/toshas/torch-householder/blob/master/torch_householder/householder.cpp#L58

Alternatives

The code from my package https://github.com/toshas/torch-householder implements an efficient drop-in replacement for torch.orgqr satisfying the requested properties, and I don't mind donating this code. The forward pass might be slower than that of MAGMA -- no comparison was done.

Additional context

None so far

cc @ngimel @jianyuh @nikitaved @pearu @mruberry @heitorschueroff @walterddr @IvanYashchuk

Metadata

Metadata

Assignees

Labels

module: cudaRelated to torch.cuda, and CUDA support in generalmodule: linear algebraIssues related to specialized linear algebra operations in PyTorch; includes matrix multiply matmultriagedThis issue has been looked at a team member, and triaged and prioritized into an appropriate module

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions