Skip to content

Remove deprecated torch.solve#70986

Closed
IvanYashchuk wants to merge 16 commits intopytorch:masterfrom
IvanYashchuk:remove-solve
Closed

Remove deprecated torch.solve#70986
IvanYashchuk wants to merge 16 commits intopytorch:masterfrom
IvanYashchuk:remove-solve

Conversation

@IvanYashchuk
Copy link
Copy Markdown
Collaborator

@IvanYashchuk IvanYashchuk commented Jan 7, 2022

The time has come to remove deprecated linear algebra related functions. This PR removes torch.solve.

cc @jianyuh @nikitaved @pearu @mruberry @walterddr @IvanYashchuk @xwang233 @lezcano

@IvanYashchuk IvanYashchuk added module: linear algebra Issues related to specialized linear algebra operations in PyTorch; includes matrix multiply matmul module: deprecation labels Jan 7, 2022
@pytorch-probot pytorch-probot Bot assigned pytorchbot and unassigned pytorchbot Jan 7, 2022
@pytorch-probot
Copy link
Copy Markdown

pytorch-probot Bot commented Jan 7, 2022

CI Flow Status

⚛️ CI Flow

Ruleset - Version: v1
Ruleset - File: https://github.com/IvanYashchuk/pytorch/blob/a00e3738b89b7a517c33be31a3380c4f95e43c25/.github/generated-ciflow-ruleset.json
PR ciflow labels: ciflow/cuda,ciflow/all

Workflows Labels (bold enabled) Status
Triggered Workflows
caffe2-linux-xenial-py3.7-gcc5.4 ciflow/all, ciflow/cpu, ciflow/linux, ciflow/trunk ✅ triggered
docker-builds ciflow/all, ciflow/trunk ✅ triggered
ios-12-5-1-arm64 ciflow/all, ciflow/ios, ciflow/macos, ciflow/trunk ✅ triggered
ios-12-5-1-arm64-coreml ciflow/all, ciflow/ios, ciflow/macos, ciflow/trunk ✅ triggered
ios-12-5-1-arm64-custom-ops ciflow/all, ciflow/ios, ciflow/macos, ciflow/trunk ✅ triggered
ios-12-5-1-arm64-full-jit ciflow/all, ciflow/ios, ciflow/macos, ciflow/trunk ✅ triggered
ios-12-5-1-arm64-metal ciflow/all, ciflow/ios, ciflow/macos, ciflow/trunk ✅ triggered
ios-12-5-1-x86-64 ciflow/all, ciflow/ios, ciflow/macos, ciflow/trunk ✅ triggered
ios-12-5-1-x86-64-coreml ciflow/all, ciflow/ios, ciflow/macos, ciflow/trunk ✅ triggered
ios-12-5-1-x86-64-full-jit ciflow/all, ciflow/ios, ciflow/macos, ciflow/trunk ✅ triggered
libtorch-linux-xenial-cuda10.2-py3.7-gcc7 ciflow/all, ciflow/cuda, ciflow/libtorch, ciflow/linux, ciflow/trunk ✅ triggered
libtorch-linux-xenial-cuda11.3-py3.7-gcc7 ciflow/all, ciflow/cuda, ciflow/libtorch, ciflow/linux, ciflow/trunk ✅ triggered
linux-bionic-cuda10.2-py3.9-gcc7 ciflow/all, ciflow/cuda, ciflow/linux, ciflow/slow, ciflow/trunk ✅ triggered
linux-bionic-py3.7-clang9 ciflow/all, ciflow/cpu, ciflow/default, ciflow/linux, ciflow/noarch, ciflow/trunk ✅ triggered
linux-docs ciflow/all, ciflow/cpu, ciflow/default, ciflow/docs, ciflow/linux, ciflow/trunk ✅ triggered
linux-docs-push ciflow/all, ciflow/cpu, ciflow/linux, ciflow/scheduled ✅ triggered
linux-vulkan-bionic-py3.7-clang9 ciflow/all, ciflow/cpu, ciflow/default, ciflow/linux, ciflow/trunk, ciflow/vulkan ✅ triggered
linux-xenial-cuda11.3-py3.7-gcc7 ciflow/all, ciflow/cuda, ciflow/default, ciflow/linux, ciflow/trunk ✅ triggered
linux-xenial-cuda11.3-py3.7-gcc7-bazel-test ciflow/all, ciflow/bazel, ciflow/cpu, ciflow/default, ciflow/linux, ciflow/trunk ✅ triggered
linux-xenial-cuda11.3-py3.7-gcc7-no-ops ciflow/all, ciflow/cuda, ciflow/linux, ciflow/trunk ✅ triggered
linux-xenial-py3-clang5-mobile-build ciflow/all, ciflow/default, ciflow/linux, ciflow/mobile, ciflow/trunk ✅ triggered
linux-xenial-py3-clang5-mobile-custom-build-static ciflow/all, ciflow/default, ciflow/linux, ciflow/mobile, ciflow/trunk ✅ triggered
linux-xenial-py3.7-clang7-asan ciflow/all, ciflow/cpu, ciflow/default, ciflow/linux, ciflow/sanitizers, ciflow/trunk ✅ triggered
linux-xenial-py3.7-clang7-onnx ciflow/all, ciflow/cpu, ciflow/default, ciflow/linux, ciflow/onnx, ciflow/trunk ✅ triggered
linux-xenial-py3.7-gcc5.4 ciflow/all, ciflow/cpu, ciflow/default, ciflow/linux, ciflow/trunk ✅ triggered
linux-xenial-py3.7-gcc7 ciflow/all, ciflow/cpu, ciflow/default, ciflow/linux, ciflow/trunk ✅ triggered
linux-xenial-py3.7-gcc7-no-ops ciflow/all, ciflow/cpu, ciflow/default, ciflow/linux, ciflow/trunk ✅ triggered
macos-10-15-py3-arm64 ciflow/all, ciflow/macos, ciflow/trunk ✅ triggered
macos-10-15-py3-lite-interpreter-x86-64 ciflow/all, ciflow/macos, ciflow/trunk ✅ triggered
macos-11-py3-x86-64 ciflow/all, ciflow/macos, ciflow/trunk ✅ triggered
parallelnative-linux-xenial-py3.7-gcc5.4 ciflow/all, ciflow/cpu, ciflow/linux, ciflow/trunk ✅ triggered
periodic-libtorch-linux-bionic-cuda11.5-py3.7-gcc7 ciflow/all, ciflow/cuda, ciflow/libtorch, ciflow/linux, ciflow/scheduled ✅ triggered
periodic-libtorch-linux-xenial-cuda11.1-py3.7-gcc7 ciflow/all, ciflow/cuda, ciflow/libtorch, ciflow/linux, ciflow/scheduled ✅ triggered
periodic-linux-bionic-cuda11.5-py3.7-gcc7 ciflow/all, ciflow/cuda, ciflow/linux, ciflow/scheduled ✅ triggered
periodic-linux-xenial-cuda10.2-py3-gcc7-slow-gradcheck ciflow/all, ciflow/cuda, ciflow/linux, ciflow/scheduled, ciflow/slow, ciflow/slow-gradcheck ✅ triggered
periodic-linux-xenial-cuda11.1-py3.7-gcc7-debug ciflow/all, ciflow/cuda, ciflow/linux, ciflow/scheduled ✅ triggered
periodic-win-vs2019-cuda11.1-py3 ciflow/all, ciflow/cuda, ciflow/scheduled, ciflow/win ✅ triggered
periodic-win-vs2019-cuda11.5-py3 ciflow/all, ciflow/cuda, ciflow/scheduled, ciflow/win ✅ triggered
pytorch-linux-xenial-py3-clang5-android-ndk-r19c-build ciflow/all, ciflow/android, ciflow/cpu, ciflow/linux, ciflow/trunk ✅ triggered
pytorch-linux-xenial-py3-clang5-android-ndk-r19c-gradle-custom-build-single ciflow/all, ciflow/android, ciflow/cpu, ciflow/default, ciflow/linux, ciflow/trunk ✅ triggered
pytorch-linux-xenial-py3-clang5-android-ndk-r19c-gradle-custom-build-single-full-jit ciflow/all, ciflow/android, ciflow/cpu, ciflow/default, ciflow/linux, ciflow/trunk ✅ triggered
win-vs2019-cpu-py3 ciflow/all, ciflow/cpu, ciflow/default, ciflow/trunk, ciflow/win ✅ triggered
win-vs2019-cuda11.3-py3 ciflow/all, ciflow/cuda, ciflow/default, ciflow/trunk, ciflow/win ✅ triggered
Skipped Workflows

You can add a comment to the PR and tag @pytorchbot with the following commands:
# ciflow rerun, "ciflow/default" will always be added automatically
@pytorchbot ciflow rerun

# ciflow rerun with additional labels "-l <ciflow/label_name>", which is equivalent to adding these labels manually and trigger the rerun
@pytorchbot ciflow rerun -l ciflow/scheduled -l ciflow/slow

For more information, please take a look at the CI Flow Wiki.

@facebook-github-bot
Copy link
Copy Markdown
Contributor

facebook-github-bot commented Jan 7, 2022

🔗 Helpful links

❌ 1 New Failures

As of commit 155b7f5 (more details on the Dr. CI page):

Expand to see more
  • 1/1 failures introduced in this PR

🕵️ 1 new failure recognized by patterns

The following CI failures do not appear to be due to upstream breakages

See GitHub Actions build pull / linux-xenial-cuda11.3-py3.7-gcc7 / test (default, 1, 4, linux.4xlarge.nvidia.gpu) (1/1)

Step: "Test" (full log | diagnosis details | 🔁 rerun)

2022-05-10T10:36:11.0295683Z RuntimeError: test_ops failed! Received signal: SIGIOT
2022-05-10T10:36:09.1721708Z   test_variant_consistency_eager_dsplit_cuda_complex64 (__main__.TestCommonCUDA) ... ok (0.012s)
2022-05-10T10:36:09.1803889Z   test_variant_consistency_eager_dsplit_cuda_float32 (__main__.TestCommonCUDA) ... ok (0.008s)
2022-05-10T10:36:09.1880596Z   test_variant_consistency_eager_dstack_cuda_complex64 (__main__.TestCommonCUDA) ... ok (0.008s)
2022-05-10T10:36:09.1940639Z   test_variant_consistency_eager_dstack_cuda_float32 (__main__.TestCommonCUDA) ... ok (0.006s)
2022-05-10T10:36:09.2143477Z   test_variant_consistency_eager_eig_cuda_complex64 (__main__.TestCommonCUDA) ... python: /opt/conda/conda-bld/magma-cuda113_1619629459349/work/interface_cuda/interface.cpp:806: void magma_queue_create_internal(magma_device_t, magma_queue**, const char*, const char*, int): Assertion `queue->dAarray__ != __null' failed.
2022-05-10T10:36:11.0285750Z Traceback (most recent call last):
2022-05-10T10:36:11.0286437Z   File "test/run_test.py", line 1072, in <module>
2022-05-10T10:36:11.0289835Z     main()
2022-05-10T10:36:11.0290355Z   File "test/run_test.py", line 1050, in main
2022-05-10T10:36:11.0295021Z     raise RuntimeError(err_message)
2022-05-10T10:36:11.0295683Z RuntimeError: test_ops failed! Received signal: SIGIOT
2022-05-10T10:36:12.2458102Z + cleanup
2022-05-10T10:36:12.2458396Z + retcode=1
2022-05-10T10:36:12.2458640Z + set +x
2022-05-10T10:36:12.2502864Z ##[error]Process completed with exit code 1.
2022-05-10T10:36:12.2551165Z ##[group]Run pytorch/pytorch/.github/actions/get-workflow-job-id@master
2022-05-10T10:36:12.2551540Z with:
2022-05-10T10:36:12.2552107Z   github-token: ***
2022-05-10T10:36:12.2552335Z env:
2022-05-10T10:36:12.2552555Z   IN_CI: 1
2022-05-10T10:36:12.2552777Z   IS_GHA: 1

This comment was automatically generated by Dr. CI (expand for details).

Please report bugs/suggestions to the (internal) Dr. CI Users group.

Click here to manually regenerate this comment.

Copy link
Copy Markdown
Collaborator

@lezcano lezcano left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I won't cry when this one's gone either. Now, I left a few comments regarding the MAGMA / LAPACK functions that we're removing.

}

template<> void lapackSolve<float>(int n, int nrhs, float *a, int lda, int *ipiv, float *b, int ldb, int *info) {
sgesv_(&n, &nrhs, a, &lda, ipiv, b, &ldb, info);
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why don't we want to keep these kernels? Shouldn't they be more efficient than lu + lu_solve?

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No, they are not more efficient. gesv = getrf + getrs but done in an external library. We should leave only the code that is used.

Comment on lines -166 to -173
template<>
void magmaSolve<double>(
magma_int_t n, magma_int_t nrhs, double* dA, magma_int_t ldda,
magma_int_t* ipiv, double* dB, magma_int_t lddb, magma_int_t* info) {
MagmaStreamSyncGuard guard;
magma_dgesv_gpu(n, nrhs, dA, ldda, ipiv, dB, lddb, info);
AT_CUDA_CHECK(cudaGetLastError());
}
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same question for these.

@anjali411 anjali411 added the triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module label Jan 25, 2022
Copy link
Copy Markdown
Collaborator

@albanD albanD left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the PR Ivan.
Do you think it would be possible to add to torch/__init__.py something like:

# This function was depreacted and removed.
# This nice error message can be removed in version 1.3+
def solve(input, A, out=None):
    raise RuntimeError("This function was deprecated since version 1.9 and is now removed. Please use the `torch.linalg.solve` function instead")

To make sure that users get a nice error message for a couple versions?

@IvanYashchuk
Copy link
Copy Markdown
Collaborator Author

Sure, I will add that.

@IvanYashchuk
Copy link
Copy Markdown
Collaborator Author

Alright, I added a solve function and a method that raises an error:

In [1]: import torch
In [2]: a = torch.randn(3, 3)
In [3]: b = torch.randn(3, 3)
In [4]: torch.solve(b, a)
---------------------------------------------------------------------------
RuntimeError                              Traceback (most recent call last)
<ipython-input-4-6f9b9974b5b1> in <module>
----> 1 torch.solve(b, a)

RuntimeError: This function was deprecated since version 1.9 and is now removed. Please use the `torch.linalg.solve` function instead.

In [5]: b.solve(a)
---------------------------------------------------------------------------
RuntimeError                              Traceback (most recent call last)
<ipython-input-5-fd669f2275c3> in <module>
----> 1 b.solve(a)

RuntimeError: This function was deprecated since version 1.9 and is now removed. Please use the `torch.linalg.solve` function instead.

@lezcano
Copy link
Copy Markdown
Collaborator

lezcano commented May 5, 2022

If you ever manage to land this, remind me to remove in #74046 the backward functions for torch.solve. Big if though.

@albanD albanD added the ciflow/trunk Trigger trunk jobs on your pull request label May 5, 2022
Copy link
Copy Markdown
Collaborator

@albanD albanD left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sounds good!

Comment thread test/test_linalg.py
# triangular_solve goes through specific backend dispatch (CPU/CUDA) and hit auto-generated device check first
generic_backend_dispatch_err_str = "Expected b and A to be on the same device"
specific_backend_dispatch_err_str = "Expected all tensors to be on the same device"
with self.assertRaisesRegex(RuntimeError, generic_backend_dispatch_err_str):
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: could you leave a small test here (or in test_torch.py) that the error you added is working as expected?

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sure, I added the test in 68928e7.

@IvanYashchuk
Copy link
Copy Markdown
Collaborator Author

I added a test checking that the error is raised as expected and resolved accumulated merge conflicts.

Copy link
Copy Markdown
Collaborator

@albanD albanD left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sounds good!

@albanD
Copy link
Copy Markdown
Collaborator

albanD commented May 10, 2022

@pytorchbot merge this please

@github-actions
Copy link
Copy Markdown
Contributor

Hey @IvanYashchuk.
You've committed this PR, but it does not have both a 'release notes: ...' and 'topics: ...' label. Please add one of each to the PR. The 'release notes: ...' label should represent the part of PyTorch that this PR changes (fx, autograd, distributed, etc) and the 'topics: ...' label should represent the kind of PR it is (not user facing, new feature, bug fix, perf improvement, etc). The list of valid labels can be found here for the 'release notes: ...' and here for the 'topics: ...'.
For changes that are 'topic: not user facing' there is no need for a release notes label.

@IvanYashchuk IvanYashchuk added topic: deprecation topic category release notes: linalg_frontend release notes category labels May 10, 2022
facebook-github-bot pushed a commit that referenced this pull request May 13, 2022
Summary:
The time has come to remove deprecated linear algebra related functions. This PR removes `torch.solve`.

cc jianyuh nikitaved pearu mruberry walterddr IvanYashchuk xwang233 Lezcano

Pull Request resolved: #70986
Approved by: https://github.com/Lezcano, https://github.com/albanD

Test Plan: contbuild & OSS CI, see https://hud.pytorch.org/commit/pytorch/pytorch/890bdf13e17adcd0ede255afbf5abf29b7d8d6ee

Reviewed By: malfet

Differential Revision: D36299674

Pulled By: malfet

fbshipit-source-id: fd01095c72787103d16c442595f67f5020f393fb
laurentdupin pushed a commit to laurentdupin/pytorch that referenced this pull request Apr 25, 2026
The time has come to remove deprecated linear algebra related functions. This PR removes `torch.solve`.

cc @jianyuh @nikitaved @pearu @mruberry @walterddr @IvanYashchuk @xwang233 @lezcano
Pull Request resolved: pytorch#70986
Approved by: https://github.com/Lezcano, https://github.com/albanD
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ciflow/trunk Trigger trunk jobs on your pull request cla signed Merged module: deprecation module: linear algebra Issues related to specialized linear algebra operations in PyTorch; includes matrix multiply matmul oncall: jit Add this issue/PR to JIT oncall triage queue open source release notes: linalg_frontend release notes category topic: deprecation topic category triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module

Projects

None yet

Development

Successfully merging this pull request may close these issues.

8 participants