-
Notifications
You must be signed in to change notification settings - Fork 27.7k
Move lu, lu_solve, and lu_unpack into torch.linalg #61657
Copy link
Copy link
Closed
Labels
enhancementNot as big of a feature, but technically not a bug. Should be easy to fixNot as big of a feature, but technically not a bug. Should be easy to fixmodule: linear algebraIssues related to specialized linear algebra operations in PyTorch; includes matrix multiply matmulIssues related to specialized linear algebra operations in PyTorch; includes matrix multiply matmultriagedThis issue has been looked at a team member, and triaged and prioritized into an appropriate moduleThis issue has been looked at a team member, and triaged and prioritized into an appropriate module
Metadata
Metadata
Assignees
Labels
enhancementNot as big of a feature, but technically not a bug. Should be easy to fixNot as big of a feature, but technically not a bug. Should be easy to fixmodule: linear algebraIssues related to specialized linear algebra operations in PyTorch; includes matrix multiply matmulIssues related to specialized linear algebra operations in PyTorch; includes matrix multiply matmultriagedThis issue has been looked at a team member, and triaged and prioritized into an appropriate moduleThis issue has been looked at a team member, and triaged and prioritized into an appropriate module
The plan to move these is to follow scipy's API
For
torch.lu, we would like to split it intotorch.linalg.lu_factorandtorch.linalg.lu_factor_ex(the former implemented as a call to the latter) with the following signatures:If
pivot=False,lu_pivotsshould be an empty tensor.For
torch.lu_unpack, we want to combine it withtorch.linalg.lu_factorand have it astorch.linalg.luwith signaturewhere
pivotcontrols whether we perform an LU with partial pivoting or not. Ifpivot=False,Pshould be an empty tensor.For
solve_lu, we want to update it to have signatureAdjoint will dispatch to
trans='T'ortrans='C'depending on the dtype. Thesideshould be implemented transposing the inputs and output. Ifleft=Falseandadjoint=True, then this is equivalent to solving forA.conj(). This is fine, because the LU decomposition of a conjugate matrix is given by conjugating thelu_factormatrix.cc @jianyuh @nikitaved @pearu @mruberry @heitorschueroff @walterddr @IvanYashchuk @xwang233 @lezcano