[Ready] Make potrs batched#13453
Conversation
00d213d to
b4ff163
Compare
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
b4ff163 to
314f875
Compare
- This is straightforward PR, building up on the batch inverse PR, except for one change:
- The GENERATE_LINALG_HELPER_n_ARGS macro has been removed, since it is not very general
and the resulting code is actually not very copy-pasty.
314f875 to
551360a
Compare
|
@zou3519 This is ready for review, just for your information. |
|
@vishwakftw I'll take a look later today or tomorrow |
|
Thank you, appreciate it. :-) |
|
No, thank you for your contribution :) |
|
|
||
| template<class scalar_t> | ||
| void lapackGetri(int n, scalar_t *a, int lda, int *ipiv, scalar_t *work, int lwork, int *info) { | ||
| void lapackGetri(int n, scalar_t* a, int lda, int* ipiv, scalar_t* work, int lwork, int* info) { |
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
facebook-github-bot
left a comment
There was a problem hiding this comment.
@zou3519 has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.
| '_cumsum.*', '_cumprod.*', '_sum.*', '_prod.*', '_th_.*', | ||
| 'arange.*', 'range.*', '_gesv.*', '_getri.*', '_inverse.*', 'slice', | ||
| 'randint(_out)?', | ||
| 'arange.*', 'range.*', '_gesv.*', '_getri.*', '_inverse.*', '_potrs.*', |
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
| b = cast(torch.randn(2, 1, 3, 4, 6)) | ||
| L = get_cholesky(A, upper) | ||
| x = torch.potrs(b, L, upper=upper) | ||
| x_exp = torch.Tensor(solve(A.cpu().numpy(), b.cpu().numpy())) |
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
| self.assertEqual(x.data, cast(x_exp)) | ||
|
|
||
| # broadcasting A | ||
| A = cast(random_symmetric_pd_matrix(4)) |
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
zou3519
left a comment
There was a problem hiding this comment.
lgtm, thank you @vishwakftw!
facebook-github-bot
left a comment
There was a problem hiding this comment.
@zou3519 has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.
|
@zou3519 is there anything that I need to do? |
facebook-github-bot
left a comment
There was a problem hiding this comment.
@zou3519 has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.
|
@vishwakftw I think it should be good, I'll let you know if any action is required |
|
@zou3519 just a notification: there were merge conflicts after the recent changes to |
facebook-github-bot
left a comment
There was a problem hiding this comment.
@zou3519 has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.
Summary: - This is a straightforward PR, building up on the batch inverse PR, except for one change: - The GENERATE_LINALG_HELPER_n_ARGS macro has been removed, since it is not very general and the resulting code is actually not very copy-pasty. Billing of changes: - Add batching for `potrs` - Add relevant tests - Modify doc string Minor changes: - Remove `_gesv_single`, `_getri_single` from `aten_interned_strings.h`. - Add test for CUDA `potrs` (2D Tensor op) - Move the batched shape checking to `LinearAlgebraUtils.h` Pull Request resolved: pytorch/pytorch#13453 Reviewed By: soumith Differential Revision: D12942039 Pulled By: zou3519 fbshipit-source-id: 1b8007f00218e61593fc415865b51c1dac0b6a35
Summary: - This is a straightforward PR, building up on the batch inverse PR, except for one change: - The GENERATE_LINALG_HELPER_n_ARGS macro has been removed, since it is not very general and the resulting code is actually not very copy-pasty. Billing of changes: - Add batching for `potrs` - Add relevant tests - Modify doc string Minor changes: - Remove `_gesv_single`, `_getri_single` from `aten_interned_strings.h`. - Add test for CUDA `potrs` (2D Tensor op) - Move the batched shape checking to `LinearAlgebraUtils.h` Pull Request resolved: pytorch#13453 Reviewed By: soumith Differential Revision: D12942039 Pulled By: zou3519 fbshipit-source-id: 1b8007f00218e61593fc415865b51c1dac0b6a35
Billing of changes:
potrsMinor changes:
_gesv_single,_getri_singlefromaten_interned_strings.h.potrs(2D Tensor op)LinearAlgebraUtils.h