-
Notifications
You must be signed in to change notification settings - Fork 27.7k
Support CUDA and backpropagation in torch.orgqr #50104
Copy link
Copy link
Closed
Labels
module: cudaRelated to torch.cuda, and CUDA support in generalRelated to torch.cuda, and CUDA support in generalmodule: linear algebraIssues related to specialized linear algebra operations in PyTorch; includes matrix multiply matmulIssues related to specialized linear algebra operations in PyTorch; includes matrix multiply matmultriagedThis issue has been looked at a team member, and triaged and prioritized into an appropriate moduleThis issue has been looked at a team member, and triaged and prioritized into an appropriate module
Metadata
Metadata
Assignees
Labels
module: cudaRelated to torch.cuda, and CUDA support in generalRelated to torch.cuda, and CUDA support in generalmodule: linear algebraIssues related to specialized linear algebra operations in PyTorch; includes matrix multiply matmulIssues related to specialized linear algebra operations in PyTorch; includes matrix multiply matmultriagedThis issue has been looked at a team member, and triaged and prioritized into an appropriate moduleThis issue has been looked at a team member, and triaged and prioritized into an appropriate module
🚀 Feature
As of PyTorch 1.7,
torch.orgqronly supports CPU tensors and does not implement gradients.Feature proposal: add support for CUDA tensors and gradients wrt inputs.
Motivation
This function is required to perform orthogonal parameterization of matrices (not to be confused with orthogonal initialization). Given an input matrix of shape
d x rwithd >= r,orgqrproduces an orthogonal matrix of the same shaped x r. The output spans the manifold of Stiefel frames of sized x r, and hence, provides precise orthogonal constraints by construction. This can be used further, e.g. as a weight matrix in an RNN transition layer or any other place requiring strict orthogonality. More motivation can be found here and references therein: #42243Without GPU, the function is slow. Without gradients, parameterization becomes impossible. Once the requested feature is implemented, and once this PR #33344 is merged, it will become possible to perform precise orthogonal parameterization of weight matrices in Linear, all flavors of Convolutional, and Embedding layers.
Pitch
torch.orgqrto support the two requested properties.As hinted by @lezcano, MAGMA already implements a GPU side of the function: https://www.icl.utk.edu/~mgates3/docs/magma-proto.html
The gradients can be taken from my package https://github.com/toshas/torch-householder/blob/master/torch_householder/householder.cpp#L58
Alternatives
The code from my package https://github.com/toshas/torch-householder implements an efficient drop-in replacement for
torch.orgqrsatisfying the requested properties, and I don't mind donating this code. The forward pass might be slower than that of MAGMA -- no comparison was done.Additional context
None so far
cc @ngimel @jianyuh @nikitaved @pearu @mruberry @heitorschueroff @walterddr @IvanYashchuk