Enable broadcasting of batch dimensions RHS and LHS tensors for lu_solve#24333
Enable broadcasting of batch dimensions RHS and LHS tensors for lu_solve#24333vishwakftw wants to merge 15 commits intopytorch:masterfrom
Conversation
Changelog: - Enable broadcasting of RHS and LHS tensors for lu_solve. This means that you can now have RHS with size `3 x 2` and LHS with size `4 x 3 x 3` for instance - Remove deprecated behavior of having 2D tensors for RHS. Now all tensors have to have a last dimension which equals the number of right hand sides - Modified docs Test Plan: - Add tests for new behavior in test_torch.py with a port to test_cuda.py
|
@pytorchbot rebase this please |
|
@pytorchbot rebase this please |
zou3519
left a comment
There was a problem hiding this comment.
Code looks correct. I had some comments on style and cleaning up the testing code
…olve-new-version
…into lu_solve-new-version
|
@zou3519 I actually missed out on adding tests for broadcasting behavior earlier. I've added them now, just FYI. |
|
@zou3519 except for test refactoring and specifying sizes in error message (which will be addressed in follow-up PRs), the PR should be good to review again. |
| return torch.stack(all_matrices).reshape(*(batches + (l, l))) | ||
|
|
||
|
|
||
| def random_linalg_solve_processed_inputs(A_dims, b_dims, gen_fn, transform_fn, cast_fn): |
There was a problem hiding this comment.
Once this PR is merged, I will use this function in other places in the test suite for *solve methods. This would reduce duplication of code.
|
@pytorchbot rebase this please |
zou3519
left a comment
There was a problem hiding this comment.
Thank you, the code looks a lot better now :)
I had some last comments about repetitiveness in the testing code
…into lu_solve-new-version
…olve-new-version
|
@pytorchbot rebase this please |
facebook-github-bot
left a comment
There was a problem hiding this comment.
@zou3519 has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.
…lve (#24333) Summary: Changelog: - Enable broadcasting of RHS and LHS tensors for lu_solve. This means that you can now have RHS with size `3 x 2` and LHS with size `4 x 3 x 3` for instance - Remove deprecated behavior of having 2D tensors for RHS. Now all tensors have to have a last dimension which equals the number of right hand sides - Modified docs Pull Request resolved: pytorch/pytorch#24333 Test Plan: - Add tests for new behavior in test_torch.py with a port to test_cuda.py Differential Revision: D17165463 Pulled By: zou3519 fbshipit-source-id: cda5d5496ddb29ed0182bab250b5d90f8f454aa6
Summary: Changelog: - De-duplicate the code in tests for torch.solve, torch.cholesky_solve, torch.triangular_solve - Skip tests explicitly if requirements aren't met for e.g., if NumPy / SciPy aren't available in the environment - Add generic helpers for these tests in test/common_utils.py Pull Request resolved: #25733 Test Plan: - All tests should pass to confirm that the change is not erroneous Clears one point specified in the discussion in #24333. Differential Revision: D17315330 Pulled By: zou3519 fbshipit-source-id: c72a793e89af7e2cdb163521816d56747fd70a0e
…lve (pytorch#24333) Summary: Changelog: - Enable broadcasting of RHS and LHS tensors for lu_solve. This means that you can now have RHS with size `3 x 2` and LHS with size `4 x 3 x 3` for instance - Remove deprecated behavior of having 2D tensors for RHS. Now all tensors have to have a last dimension which equals the number of right hand sides - Modified docs Pull Request resolved: pytorch#24333 Test Plan: - Add tests for new behavior in test_torch.py with a port to test_cuda.py Differential Revision: D17165463 Pulled By: zou3519 fbshipit-source-id: cda5d5496ddb29ed0182bab250b5d90f8f454aa6
Summary: Changelog: - De-duplicate the code in tests for torch.solve, torch.cholesky_solve, torch.triangular_solve - Skip tests explicitly if requirements aren't met for e.g., if NumPy / SciPy aren't available in the environment - Add generic helpers for these tests in test/common_utils.py Pull Request resolved: pytorch#25733 Test Plan: - All tests should pass to confirm that the change is not erroneous Clears one point specified in the discussion in pytorch#24333. Differential Revision: D17315330 Pulled By: zou3519 fbshipit-source-id: c72a793e89af7e2cdb163521816d56747fd70a0e
Changelog:
3 x 2and LHS with size4 x 3 x 3for instanceTest Plan: