Skip to content

Simplify and optimize linalg.solve#74046

Closed
lezcano wants to merge 43 commits intogh/Lezcano/54/basefrom
gh/Lezcano/54/head
Closed

Simplify and optimize linalg.solve#74046
lezcano wants to merge 43 commits intogh/Lezcano/54/basefrom
gh/Lezcano/54/head

Conversation

@lezcano
Copy link
Copy Markdown
Collaborator

@lezcano lezcano commented Mar 10, 2022

Stack from ghstack:

This PR heavily simplifies the code of linalg.solve. At the same time,
this implementation saves quite a some copies of the input data in some
cases (e.g. A is contiguous)

We also implement it in such a way that the derivative goes from
computing two LU decompositions and two LU solves to no LU
decompositions and one LU solves. It also avoids a number of unnecessary
copies the derivative was unnecessarily performing (at least the copy of
two matrices).

On top of this, we add a left kw-only arg that allows the user to
solve XA = B rather concisely.

This PR also makes torch.solve an alias of torch.linalg.solve.

Note: This used to be the last PR of the stack. Now it's here because some tests were not passing in a PR that was before this one in the stack, and reshuffling the stack solved those problems. The benchmarks below are performed wrt the last PR fo this stack.

We compare the performance of linalg.solve against master (before merging #67833, but already with a few PRs of the LU stack merged). We see that we got between x2.5 and x10 speed-ups in linalg.solve.

Benchmark Results
[--------------------- linalg.solve + backward --------------------]
                                    |  master |  This PR
1 threads: ----------------------------------------------
      torch.Size([1, 1, 1])         |     1280  |    267
      torch.Size([2, 1, 1])         |     1300  |    200
      torch.Size([4, 1, 1])         |      500  |    231
      torch.Size([8, 1, 1])         |      600  |    232
      torch.Size([16, 1, 1])        |     1200  |    234
      torch.Size([32, 1, 1])        |     1300  |    239
      torch.Size([64, 1, 1])        |     1300  |    300
      torch.Size([128, 1, 1])       |     1340  |    331
      torch.Size([512, 1, 1])       |     1664  |    380
      torch.Size([1024, 1, 1])      |     2000  |    430
      torch.Size([1, 2, 2])         |     1200  |    300
      torch.Size([2, 2, 2])         |     1250  |    237
      torch.Size([4, 2, 2])         |      479  |    240
      torch.Size([8, 2, 2])         |      600  |    239
      torch.Size([16, 2, 2])        |     1300  |    242
      torch.Size([32, 2, 2])        |     1300  |    245
      torch.Size([64, 2, 2])        |     1300  |    260
      torch.Size([128, 2, 2])       |     1400  |    340
      torch.Size([512, 2, 2])       |     1680  |    380
      torch.Size([1024, 2, 2])      |     2100  |    430
      torch.Size([1, 8, 8])         |     1200  |    250
      torch.Size([2, 8, 8])         |     1240  |    238
      torch.Size([4, 8, 8])         |      480  |    240
      torch.Size([8, 8, 8])         |      600  |    240
      torch.Size([16, 8, 8])        |     1330  |    243
      torch.Size([32, 8, 8])        |     1340  |    250
      torch.Size([64, 8, 8])        |     1370  |    257
      torch.Size([128, 8, 8])       |     1400  |    280
      torch.Size([512, 8, 8])       |     1720  |    346
      torch.Size([1024, 8, 8])      |     2300  |    390
      torch.Size([1, 16, 16])       |     1380  |    245
      torch.Size([2, 16, 16])       |     1000  |    300
      torch.Size([4, 16, 16])       |      610  |    260
      torch.Size([8, 16, 16])       |      862  |    260
      torch.Size([16, 16, 16])      |     1350  |    260
      torch.Size([32, 16, 16])      |     1370  |    260
      torch.Size([64, 16, 16])      |     1440  |    273
      torch.Size([128, 16, 16])     |     1520  |    289
      torch.Size([512, 16, 16])     |     1880  |    350
      torch.Size([1024, 16, 16])    |     2540  |    530
      torch.Size([1, 32, 32])       |     1500  |    290
      torch.Size([2, 32, 32])       |     2100  |    287
      torch.Size([4, 32, 32])       |     1370  |    288
      torch.Size([8, 32, 32])       |     1389  |    290
      torch.Size([16, 32, 32])      |     1400  |    290
      torch.Size([32, 32, 32])      |     1500  |    476
      torch.Size([64, 32, 32])      |     1600  |    468
      torch.Size([128, 32, 32])     |     1700  |    479
      torch.Size([512, 32, 32])     |     2300  |    696
      torch.Size([1024, 32, 32])    |     3200  |   1200
      torch.Size([1, 64, 64])       |     1700  |    340
      torch.Size([2, 64, 64])       |     2800  |    353
      torch.Size([4, 64, 64])       |     1990  |    328
      torch.Size([8, 64, 64])       |     2040  |    330
      torch.Size([16, 64, 64])      |     2100  |    350
      torch.Size([32, 64, 64])      |     2300  |    680
      torch.Size([64, 64, 64])      |     2430  |    725
      torch.Size([128, 64, 64])     |     2600  |    845
      torch.Size([512, 64, 64])     |     4700  |   1900
      torch.Size([1024, 64, 64])    |     9200  |   4280
      torch.Size([1, 128, 128])     |     2300  |    497
      torch.Size([2, 128, 128])     |     4000  |    562
      torch.Size([4, 128, 128])     |     3140  |    669
      torch.Size([8, 128, 128])     |     3200  |    698
      torch.Size([16, 128, 128])    |     3400  |    810
      torch.Size([32, 128, 128])    |     3866  |   1410
      torch.Size([64, 128, 128])    |     4200  |   1670
      torch.Size([128, 128, 128])   |     5050  |   2170
      torch.Size([512, 128, 128])   |    14000  |   6417
      torch.Size([1024, 128, 128])  |    28900  |  14700
      torch.Size([1, 256, 256])     |     4100  |   1559
      torch.Size([2, 256, 256])     |     6800  |   1792
      torch.Size([4, 256, 256])     |     7000  |   2000
      torch.Size([8, 256, 256])     |     7300  |   2200
      torch.Size([16, 256, 256])    |     7730  |   2540
      torch.Size([32, 256, 256])    |     8500  |   3390
      torch.Size([64, 256, 256])    |    11000  |   4470
      torch.Size([128, 256, 256])   |    15900  |   6757
      torch.Size([512, 256, 256])   |    50000  |  30000
      torch.Size([1024, 256, 256])  |   102600  |  56400
      torch.Size([1, 512, 512])     |     8793  |   3230
      torch.Size([2, 512, 512])     |    13000  |   3920
      torch.Size([4, 512, 512])     |    14000  |   4531
      torch.Size([8, 512, 512])     |    15000  |   5114
      torch.Size([16, 512, 512])    |    16700  |   6280
      torch.Size([32, 512, 512])    |    22400  |   9530
      torch.Size([64, 512, 512])    |    33700  |  14260
      torch.Size([128, 512, 512])   |    56500  |  20000

Times are in microseconds (us).
Benchmarking Script
import torch
import pickle
import itertools
from functools import partial
from torch.utils.benchmark import Timer, Compare

benchmark_name = "linalg.solve"
label = "master"
shapes = [1, 2, 8, 16, 32, 64, 128, 256, 512]
batches = [(1,), (2,), (4,), (8,), (16,), (32,), (64,), (128,), (512,), (1024,)]
results = []
make_arg = partial(torch.randn, dtype=torch.float32, device="cuda", requires_grad=True)

for n, batch in itertools.product(shapes, batches):
    if n == 512 and batch[0] >= 512:
        continue
    A = make_arg(batch + (n, n))
    B = make_arg(batch + (n, 16))
    ones = torch.ones(B.shape, device=B.device)
    print(A.shape)
    for adjoint in (True, False):
        timer = Timer("torch.linalg.solve(A, B).backward(gradient=ones, inputs=[A, B])",
                      globals=globals(),
                      label=benchmark_name,
                      description=label,
                      sub_label=f"{A.shape}",
                      num_threads=1)
        results.append(timer.blocked_autorange())


compare = Compare(results)
compare.trim_significant_figures()
compare.print()

with open("{}.pickle".format(label), 'wb') as f:
    pickle.dump(results, f)

See #72935 (comment) for the script to join the results.

This PR heavily simplifies the code of `linalg.solve`. At the same time,
this implementation saves quite a few copies of the input data in some
cases (e.g. A is contiguous)

We also implement it in such a way that the derivative goes from
computing two LU decompositions and two LU solves to no LU
decompositions and one LU solves. It also avoids a number of unnecessary
copies the derivative was unnecessarily performing (at least the copy of
two matrices).

On top of this, we add a `left` kw-only arg that allows the user to
solve `XA = B` rather concisely.

[ghstack-poisoned]
@pytorch-bot
Copy link
Copy Markdown

pytorch-bot Bot commented Mar 10, 2022

CI Flow Status

⚛️ CI Flow

Ruleset - Version: v1
Ruleset - File: https://github.com/pytorch/pytorch/blob/819e94462d5b20dc01d80f61c9487ae8d93057a0/.github/generated-ciflow-ruleset.json
PR ciflow labels: ciflow/default
Add ciflow labels to this PR to trigger more builds:

Workflows Labels (bold enabled) Status
Triggered Workflows
linux-binary-conda ciflow/binaries, ciflow/binaries_conda, ciflow/default ✅ triggered
linux-binary-libtorch-cxx11-abi ciflow/all, ciflow/binaries, ciflow/binaries_libtorch, ciflow/default, ciflow/trunk ✅ triggered
linux-binary-libtorch-pre-cxx11 ciflow/all, ciflow/binaries, ciflow/binaries_libtorch, ciflow/default, ciflow/trunk ✅ triggered
linux-binary-manywheel ciflow/all, ciflow/binaries, ciflow/binaries_wheel, ciflow/default, ciflow/trunk ✅ triggered
linux-bionic-py3.7-clang9 ciflow/all, ciflow/cpu, ciflow/default, ciflow/linux, ciflow/noarch, ciflow/trunk ✅ triggered
linux-bionic-rocm4.5-py3.7 ciflow/all, ciflow/default, ciflow/linux, ciflow/rocm, ciflow/trunk ✅ triggered
linux-docs ciflow/all, ciflow/cpu, ciflow/default, ciflow/docs, ciflow/linux, ciflow/trunk ✅ triggered
linux-vulkan-bionic-py3.7-clang9 ciflow/all, ciflow/cpu, ciflow/default, ciflow/linux, ciflow/trunk, ciflow/vulkan ✅ triggered
linux-xenial-cuda11.3-py3.7-gcc7 ciflow/all, ciflow/cuda, ciflow/default, ciflow/linux, ciflow/trunk ✅ triggered
linux-xenial-cuda11.3-py3.7-gcc7-bazel-test ciflow/all, ciflow/bazel, ciflow/cpu, ciflow/default, ciflow/linux, ciflow/trunk ✅ triggered
linux-xenial-py3-clang5-mobile-build ciflow/all, ciflow/default, ciflow/linux, ciflow/mobile, ciflow/trunk ✅ triggered
linux-xenial-py3-clang5-mobile-custom-build-static ciflow/all, ciflow/default, ciflow/linux, ciflow/mobile, ciflow/trunk ✅ triggered
linux-xenial-py3.7-clang7-asan ciflow/all, ciflow/cpu, ciflow/default, ciflow/linux, ciflow/sanitizers, ciflow/trunk ✅ triggered
linux-xenial-py3.7-clang7-onnx ciflow/all, ciflow/cpu, ciflow/default, ciflow/linux, ciflow/onnx, ciflow/trunk ✅ triggered
linux-xenial-py3.7-gcc5.4 ciflow/all, ciflow/cpu, ciflow/default, ciflow/linux, ciflow/trunk ✅ triggered
linux-xenial-py3.7-gcc5.4-mobile-lightweight-dispatch-build ciflow/all, ciflow/cpu, ciflow/default, ciflow/libtorch, ciflow/linux, ciflow/mobile, ciflow/trunk ✅ triggered
linux-xenial-py3.7-gcc7 ciflow/all, ciflow/cpu, ciflow/default, ciflow/linux, ciflow/trunk ✅ triggered
linux-xenial-py3.7-gcc7-no-ops ciflow/all, ciflow/cpu, ciflow/default, ciflow/linux, ciflow/trunk ✅ triggered
macos-arm64-binary-conda ciflow/binaries, ciflow/binaries_conda, ciflow/default ✅ triggered
macos-arm64-binary-wheel ciflow/binaries, ciflow/binaries_wheel, ciflow/default ✅ triggered
macos-binary-conda ciflow/binaries, ciflow/binaries_conda, ciflow/default ✅ triggered
macos-binary-libtorch-cxx11-abi ciflow/binaries, ciflow/binaries_libtorch, ciflow/default ✅ triggered
macos-binary-libtorch-pre-cxx11 ciflow/binaries, ciflow/binaries_libtorch, ciflow/default ✅ triggered
macos-binary-wheel ciflow/binaries, ciflow/binaries_wheel, ciflow/default ✅ triggered
pytorch-linux-xenial-py3-clang5-android-ndk-r19c-gradle-custom-build-single ciflow/all, ciflow/android, ciflow/cpu, ciflow/default, ciflow/linux, ciflow/trunk ✅ triggered
pytorch-linux-xenial-py3-clang5-android-ndk-r19c-gradle-custom-build-single-full-jit ciflow/all, ciflow/android, ciflow/cpu, ciflow/default, ciflow/linux, ciflow/trunk ✅ triggered
win-vs2019-cpu-py3 ciflow/all, ciflow/cpu, ciflow/default, ciflow/trunk, ciflow/win ✅ triggered
win-vs2019-cuda11.3-py3 ciflow/all, ciflow/cuda, ciflow/default, ciflow/trunk, ciflow/win ✅ triggered
windows-binary-conda ciflow/binaries, ciflow/binaries_conda, ciflow/default ✅ triggered
windows-binary-libtorch-debug ciflow/all, ciflow/binaries, ciflow/binaries_libtorch, ciflow/default, ciflow/trunk ✅ triggered
windows-binary-libtorch-release ciflow/all, ciflow/binaries, ciflow/binaries_libtorch, ciflow/default, ciflow/trunk ✅ triggered
windows-binary-wheel ciflow/all, ciflow/binaries, ciflow/binaries_wheel, ciflow/default, ciflow/trunk ✅ triggered
Skipped Workflows
caffe2-linux-xenial-py3.7-gcc5.4 ciflow/all, ciflow/cpu, ciflow/linux, ciflow/trunk 🚫 skipped
docker-builds ciflow/all, ciflow/trunk 🚫 skipped
ios-12-5-1-arm64 ciflow/all, ciflow/ios, ciflow/macos, ciflow/scheduled 🚫 skipped
ios-12-5-1-arm64-coreml ciflow/all, ciflow/ios, ciflow/macos, ciflow/scheduled 🚫 skipped
ios-12-5-1-arm64-custom-ops ciflow/all, ciflow/ios, ciflow/macos, ciflow/scheduled 🚫 skipped
ios-12-5-1-arm64-metal ciflow/all, ciflow/ios, ciflow/macos, ciflow/scheduled 🚫 skipped
ios-12-5-1-x86-64 ciflow/all, ciflow/ios, ciflow/macos, ciflow/trunk 🚫 skipped
ios-12-5-1-x86-64-coreml ciflow/all, ciflow/ios, ciflow/macos, ciflow/trunk 🚫 skipped
libtorch-linux-xenial-cuda10.2-py3.7-gcc7 ciflow/all, ciflow/cuda, ciflow/libtorch, ciflow/linux, ciflow/trunk 🚫 skipped
libtorch-linux-xenial-cuda11.3-py3.7-gcc7 ciflow/all, ciflow/cuda, ciflow/libtorch, ciflow/linux, ciflow/trunk 🚫 skipped
linux-bionic-cuda10.2-py3.9-gcc7 ciflow/all, ciflow/cuda, ciflow/linux, ciflow/slow, ciflow/trunk 🚫 skipped
linux-docs-push ciflow/all, ciflow/cpu, ciflow/linux, ciflow/scheduled 🚫 skipped
linux-xenial-cuda11.3-py3.7-gcc7-no-ops ciflow/all, ciflow/cuda, ciflow/linux, ciflow/trunk 🚫 skipped
macos-10-15-py3-arm64 ciflow/all, ciflow/macos, ciflow/trunk 🚫 skipped
macos-10-15-py3-lite-interpreter-x86-64 ciflow/all, ciflow/macos, ciflow/trunk 🚫 skipped
macos-11-py3-x86-64 ciflow/all, ciflow/macos, ciflow/trunk 🚫 skipped
parallelnative-linux-xenial-py3.7-gcc5.4 ciflow/all, ciflow/cpu, ciflow/linux, ciflow/trunk 🚫 skipped
periodic-libtorch-linux-bionic-cuda11.5-py3.7-gcc7 ciflow/all, ciflow/cuda, ciflow/libtorch, ciflow/linux, ciflow/scheduled 🚫 skipped
periodic-linux-bionic-cuda11.5-py3.7-gcc7 ciflow/all, ciflow/cuda, ciflow/linux, ciflow/scheduled 🚫 skipped
periodic-linux-xenial-cuda10.2-py3-gcc7-slow-gradcheck ciflow/all, ciflow/cuda, ciflow/linux, ciflow/scheduled, ciflow/slow, ciflow/slow-gradcheck 🚫 skipped
periodic-linux-xenial-cuda11.3-py3.7-gcc7-debug ciflow/all, ciflow/cuda, ciflow/linux, ciflow/scheduled 🚫 skipped
periodic-win-vs2019-cuda11.5-py3 ciflow/all, ciflow/cuda, ciflow/scheduled, ciflow/win 🚫 skipped
pytorch-linux-xenial-py3-clang5-android-ndk-r19c-build ciflow/all, ciflow/android, ciflow/cpu, ciflow/linux, ciflow/trunk 🚫 skipped
pytorch-xla-linux-bionic-py3.7-clang8 ciflow/all, ciflow/cpu, ciflow/linux, ciflow/trunk, ciflow/xla 🚫 skipped

@facebook-github-bot
Copy link
Copy Markdown
Contributor

facebook-github-bot commented Mar 10, 2022

🔗 Helpful links

❌ 1 New Failures

As of commit fd02ba5 (more details on the Dr. CI page):

Expand to see more
  • 1/1 failures introduced in this PR

🕵️ 1 new failure recognized by patterns

The following CI failures do not appear to be due to upstream breakages

See GitHub Actions build periodic / linux-xenial-cuda11.3-py3.7-gcc7-debug / test (default, 1, 2, linux.4xlarge.nvidia.gpu) (1/1)

Step: "Test" (full log | diagnosis details | 🔁 rerun)

2022-06-10T22:49:08.2919287Z test_RNN_change_...as.so.11: undefined symbol: cublasGetSmCountTarget
2022-06-10T22:49:08.1356056Z   test_PoissonNLLLoss_no_full_loss_no_log_input_cuda_double (__main__.TestNN) ... ok (0.004s)
2022-06-10T22:49:08.1394346Z   test_PoissonNLLLoss_no_full_loss_no_log_input_cuda_float (__main__.TestNN) ... ok (0.004s)
2022-06-10T22:49:08.1433225Z   test_PoissonNLLLoss_no_full_loss_no_log_input_cuda_half (__main__.TestNN) ... ok (0.004s)
2022-06-10T22:49:08.1588928Z   test_PoissonNLLLoss_no_reduce (__main__.TestNN) ... ok (0.015s)
2022-06-10T22:49:08.1754958Z   test_PoissonNLLLoss_no_reduce_cuda (__main__.TestNN) ... ok (0.017s)
2022-06-10T22:49:08.1839950Z   test_RNN_cell (__main__.TestNN) ... ok (0.008s)
2022-06-10T22:49:08.1952913Z   test_RNN_cell_forward_hidden_size (__main__.TestNN) ... ok (0.011s)
2022-06-10T22:49:08.2030329Z   test_RNN_cell_forward_input_size (__main__.TestNN) ... ok (0.008s)
2022-06-10T22:49:08.2053743Z   test_RNN_cell_forward_zero_hidden_size (__main__.TestNN) ... ok (0.002s)
2022-06-10T22:49:08.2497748Z   test_RNN_cell_no_broadcasting (__main__.TestNN) ... ok (0.044s)
2022-06-10T22:49:08.2919287Z   test_RNN_change_dropout (__main__.TestNN) ... Could not load symbol cublasGetSmCountTarget from libcublas.so.11. Error: /usr/local/cuda/lib64/libcublas.so.11: undefined symbol: cublasGetSmCountTarget
2022-06-10T22:49:08.3742482Z ok (0.124s)
2022-06-10T22:49:11.7910687Z   test_RNN_cpu_vs_cudnn_no_dropout (__main__.TestNN) ... ok (3.417s)
2022-06-10T22:49:15.2137756Z   test_RNN_cpu_vs_cudnn_with_dropout (__main__.TestNN) ... ok (3.423s)
2022-06-10T22:49:15.2188486Z   test_RNN_cudnn_weight_norm (__main__.TestNN) ... /opt/conda/lib/python3.7/site-packages/torch/nn/modules/rnn.py:770: UserWarning: RNN module weights are not part of single contiguous chunk of memory. This means they need to be compacted at every call, possibly greatly increasing memory usage. To compact weights again call flatten_parameters(). (Triggered internally at  /var/lib/jenkins/workspace/aten/src/ATen/native/cudnn/RNN.cpp:968.)
2022-06-10T22:49:15.2189433Z   self.dropout, self.training, self.bidirectional, self.batch_first)
2022-06-10T22:49:15.2205005Z /opt/conda/lib/python3.7/site-packages/torch/nn/modules/rnn.py:770: UserWarning: RNN module weights are not part of single contiguous chunk of memory. This means they need to be compacted at every call, possibly greatly increasing memory usage. To compact weights again call flatten_parameters(). (Triggered internally at  /var/lib/jenkins/workspace/aten/src/ATen/native/cudnn/RNN.cpp:968.)
2022-06-10T22:49:15.2205837Z   self.dropout, self.training, self.bidirectional, self.batch_first)
2022-06-10T22:49:15.2249483Z /opt/conda/lib/python3.7/site-packages/torch/nn/modules/rnn.py:770: UserWarning: RNN module weights are not part of single contiguous chunk of memory. This means they need to be compacted at every call, possibly greatly increasing memory usage. To compact weights again call flatten_parameters(). (Triggered internally at  /var/lib/jenkins/workspace/aten/src/ATen/native/cudnn/RNN.cpp:968.)
2022-06-10T22:49:15.2250302Z   self.dropout, self.training, self.bidirectional, self.batch_first)
2022-06-10T22:49:15.2267331Z /opt/conda/lib/python3.7/site-packages/torch/nn/modules/rnn.py:770: UserWarning: RNN module weights are not part of single contiguous chunk of memory. This means they need to be compacted at every call, possibly greatly increasing memory usage. To compact weights again call flatten_parameters(). (Triggered internally at  /var/lib/jenkins/workspace/aten/src/ATen/native/cudnn/RNN.cpp:968.)

This comment was automatically generated by Dr. CI (expand for details).

Please report bugs/suggestions to the (internal) Dr. CI Users group.

Click here to manually regenerate this comment.

@lezcano
Copy link
Copy Markdown
Collaborator Author

lezcano commented Mar 10, 2022

Leaving @albanD to have a look at the derivative, see if it can be implemented any better. In particular, at the moment I need to save both the A and the LU decomposition in case the user wants to do compute the Hessian.

This PR heavily simplifies the code of `linalg.solve`. At the same time,
this implementation saves quite a few copies of the input data in some
cases (e.g. A is contiguous)

We also implement it in such a way that the derivative goes from
computing two LU decompositions and two LU solves to no LU
decompositions and one LU solves. It also avoids a number of unnecessary
copies the derivative was unnecessarily performing (at least the copy of
two matrices).

On top of this, we add a `left` kw-only arg that allows the user to
solve `XA = B` rather concisely.

[ghstack-poisoned]
This PR heavily simplifies the code of `linalg.solve`. At the same time,
this implementation saves quite a few copies of the input data in some
cases (e.g. A is contiguous)

We also implement it in such a way that the derivative goes from
computing two LU decompositions and two LU solves to no LU
decompositions and one LU solves. It also avoids a number of unnecessary
copies the derivative was unnecessarily performing (at least the copy of
two matrices).

On top of this, we add a `left` kw-only arg that allows the user to
solve `XA = B` rather concisely.

[ghstack-poisoned]
lezcano added a commit that referenced this pull request Mar 10, 2022
This PR heavily simplifies the code of `linalg.solve`. At the same time,
this implementation saves quite a few copies of the input data in some
cases (e.g. A is contiguous)

We also implement it in such a way that the derivative goes from
computing two LU decompositions and two LU solves to no LU
decompositions and one LU solves. It also avoids a number of unnecessary
copies the derivative was unnecessarily performing (at least the copy of
two matrices).

On top of this, we add a `left` kw-only arg that allows the user to
solve `XA = B` rather concisely.

ghstack-source-id: 2ae45d2
Pull Request resolved: #74046
This PR heavily simplifies the code of `linalg.solve`. At the same time,
this implementation saves quite a few copies of the input data in some
cases (e.g. A is contiguous)

We also implement it in such a way that the derivative goes from
computing two LU decompositions and two LU solves to no LU
decompositions and one LU solves. It also avoids a number of unnecessary
copies the derivative was unnecessarily performing (at least the copy of
two matrices).

On top of this, we add a `left` kw-only arg that allows the user to
solve `XA = B` rather concisely.

[ghstack-poisoned]
This PR heavily simplifies the code of `linalg.solve`. At the same time,
this implementation saves quite a few copies of the input data in some
cases (e.g. A is contiguous)

We also implement it in such a way that the derivative goes from
computing two LU decompositions and two LU solves to no LU
decompositions and one LU solves. It also avoids a number of unnecessary
copies the derivative was unnecessarily performing (at least the copy of
two matrices).

On top of this, we add a `left` kw-only arg that allows the user to
solve `XA = B` rather concisely.

[ghstack-poisoned]
This PR heavily simplifies the code of `linalg.solve`. At the same time,
this implementation saves quite a few copies of the input data in some
cases (e.g. A is contiguous)

We also implement it in such a way that the derivative goes from
computing two LU decompositions and two LU solves to no LU
decompositions and one LU solves. It also avoids a number of unnecessary
copies the derivative was unnecessarily performing (at least the copy of
two matrices).

On top of this, we add a `left` kw-only arg that allows the user to
solve `XA = B` rather concisely.

[ghstack-poisoned]
Copy link
Copy Markdown
Collaborator

@nikitaved nikitaved left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM! Thanks!

This PR heavily simplifies the code of `linalg.solve`. At the same time,
this implementation saves quite a some copies of the input data in some
cases (e.g. A is contiguous)

We also implement it in such a way that the derivative goes from
computing two LU decompositions and two LU solves to no LU
decompositions and one LU solves. It also avoids a number of unnecessary
copies the derivative was unnecessarily performing (at least the copy of
two matrices).

On top of this, we add a `left` kw-only arg that allows the user to
solve `XA = B` rather concisely.

This PR also makes `torch.solve` an alias of `torch.linalg.solve`.

**Note:** This used to be the last PR of the stack. Now it's here because some tests were not passing in a PR that was before this one in the stack, and reshuffling the stack solved those problems. The benchmarks below are performed wrt the last PR fo this stack.

We compare the performance of `linalg.solve` against master (before merging #67833, but already with a few PRs of the LU stack merged). We see that we got between  **x2.5 and x10 speed-ups in `linalg.solve`**.


<details>
<summary>
Benchmark Results
</summary>

```
[--------------------- linalg.solve + backward --------------------]
                                    |  master |  This PR
1 threads: ----------------------------------------------
      torch.Size([1, 1, 1])         |     1280  |    267
      torch.Size([2, 1, 1])         |     1300  |    200
      torch.Size([4, 1, 1])         |      500  |    231
      torch.Size([8, 1, 1])         |      600  |    232
      torch.Size([16, 1, 1])        |     1200  |    234
      torch.Size([32, 1, 1])        |     1300  |    239
      torch.Size([64, 1, 1])        |     1300  |    300
      torch.Size([128, 1, 1])       |     1340  |    331
      torch.Size([512, 1, 1])       |     1664  |    380
      torch.Size([1024, 1, 1])      |     2000  |    430
      torch.Size([1, 2, 2])         |     1200  |    300
      torch.Size([2, 2, 2])         |     1250  |    237
      torch.Size([4, 2, 2])         |      479  |    240
      torch.Size([8, 2, 2])         |      600  |    239
      torch.Size([16, 2, 2])        |     1300  |    242
      torch.Size([32, 2, 2])        |     1300  |    245
      torch.Size([64, 2, 2])        |     1300  |    260
      torch.Size([128, 2, 2])       |     1400  |    340
      torch.Size([512, 2, 2])       |     1680  |    380
      torch.Size([1024, 2, 2])      |     2100  |    430
      torch.Size([1, 8, 8])         |     1200  |    250
      torch.Size([2, 8, 8])         |     1240  |    238
      torch.Size([4, 8, 8])         |      480  |    240
      torch.Size([8, 8, 8])         |      600  |    240
      torch.Size([16, 8, 8])        |     1330  |    243
      torch.Size([32, 8, 8])        |     1340  |    250
      torch.Size([64, 8, 8])        |     1370  |    257
      torch.Size([128, 8, 8])       |     1400  |    280
      torch.Size([512, 8, 8])       |     1720  |    346
      torch.Size([1024, 8, 8])      |     2300  |    390
      torch.Size([1, 16, 16])       |     1380  |    245
      torch.Size([2, 16, 16])       |     1000  |    300
      torch.Size([4, 16, 16])       |      610  |    260
      torch.Size([8, 16, 16])       |      862  |    260
      torch.Size([16, 16, 16])      |     1350  |    260
      torch.Size([32, 16, 16])      |     1370  |    260
      torch.Size([64, 16, 16])      |     1440  |    273
      torch.Size([128, 16, 16])     |     1520  |    289
      torch.Size([512, 16, 16])     |     1880  |    350
      torch.Size([1024, 16, 16])    |     2540  |    530
      torch.Size([1, 32, 32])       |     1500  |    290
      torch.Size([2, 32, 32])       |     2100  |    287
      torch.Size([4, 32, 32])       |     1370  |    288
      torch.Size([8, 32, 32])       |     1389  |    290
      torch.Size([16, 32, 32])      |     1400  |    290
      torch.Size([32, 32, 32])      |     1500  |    476
      torch.Size([64, 32, 32])      |     1600  |    468
      torch.Size([128, 32, 32])     |     1700  |    479
      torch.Size([512, 32, 32])     |     2300  |    696
      torch.Size([1024, 32, 32])    |     3200  |   1200
      torch.Size([1, 64, 64])       |     1700  |    340
      torch.Size([2, 64, 64])       |     2800  |    353
      torch.Size([4, 64, 64])       |     1990  |    328
      torch.Size([8, 64, 64])       |     2040  |    330
      torch.Size([16, 64, 64])      |     2100  |    350
      torch.Size([32, 64, 64])      |     2300  |    680
      torch.Size([64, 64, 64])      |     2430  |    725
      torch.Size([128, 64, 64])     |     2600  |    845
      torch.Size([512, 64, 64])     |     4700  |   1900
      torch.Size([1024, 64, 64])    |     9200  |   4280
      torch.Size([1, 128, 128])     |     2300  |    497
      torch.Size([2, 128, 128])     |     4000  |    562
      torch.Size([4, 128, 128])     |     3140  |    669
      torch.Size([8, 128, 128])     |     3200  |    698
      torch.Size([16, 128, 128])    |     3400  |    810
      torch.Size([32, 128, 128])    |     3866  |   1410
      torch.Size([64, 128, 128])    |     4200  |   1670
      torch.Size([128, 128, 128])   |     5050  |   2170
      torch.Size([512, 128, 128])   |    14000  |   6417
      torch.Size([1024, 128, 128])  |    28900  |  14700
      torch.Size([1, 256, 256])     |     4100  |   1559
      torch.Size([2, 256, 256])     |     6800  |   1792
      torch.Size([4, 256, 256])     |     7000  |   2000
      torch.Size([8, 256, 256])     |     7300  |   2200
      torch.Size([16, 256, 256])    |     7730  |   2540
      torch.Size([32, 256, 256])    |     8500  |   3390
      torch.Size([64, 256, 256])    |    11000  |   4470
      torch.Size([128, 256, 256])   |    15900  |   6757
      torch.Size([512, 256, 256])   |    50000  |  30000
      torch.Size([1024, 256, 256])  |   102600  |  56400
      torch.Size([1, 512, 512])     |     8793  |   3230
      torch.Size([2, 512, 512])     |    13000  |   3920
      torch.Size([4, 512, 512])     |    14000  |   4531
      torch.Size([8, 512, 512])     |    15000  |   5114
      torch.Size([16, 512, 512])    |    16700  |   6280
      torch.Size([32, 512, 512])    |    22400  |   9530
      torch.Size([64, 512, 512])    |    33700  |  14260
      torch.Size([128, 512, 512])   |    56500  |  20000

Times are in microseconds (us).
```

</details>

<details>
<summary>
Benchmarking Script
</summary>

```python
import torch
import pickle
import itertools
from functools import partial
from torch.utils.benchmark import Timer, Compare

benchmark_name = "linalg.solve"
label = "master"
shapes = [1, 2, 8, 16, 32, 64, 128, 256, 512]
batches = [(1,), (2,), (4,), (8,), (16,), (32,), (64,), (128,), (512,), (1024,)]
results = []
make_arg = partial(torch.randn, dtype=torch.float32, device="cuda", requires_grad=True)

for n, batch in itertools.product(shapes, batches):
    if n == 512 and batch[0] >= 512:
        continue
    A = make_arg(batch + (n, n))
    B = make_arg(batch + (n, 16))
    ones = torch.ones(B.shape, device=B.device)
    print(A.shape)
    for adjoint in (True, False):
        timer = Timer("torch.linalg.solve(A, B).backward(gradient=ones, inputs=[A, B])",
                      globals=globals(),
                      label=benchmark_name,
                      description=label,
                      sub_label=f"{A.shape}",
                      num_threads=1)
        results.append(timer.blocked_autorange())


compare = Compare(results)
compare.trim_significant_figures()
compare.print()

with open("{}.pickle".format(label), 'wb') as f:
    pickle.dump(results, f)
```
</details>

See #72935 (comment) for the script to join the results.


[ghstack-poisoned]
This PR heavily simplifies the code of `linalg.solve`. At the same time,
this implementation saves quite a some copies of the input data in some
cases (e.g. A is contiguous)

We also implement it in such a way that the derivative goes from
computing two LU decompositions and two LU solves to no LU
decompositions and one LU solves. It also avoids a number of unnecessary
copies the derivative was unnecessarily performing (at least the copy of
two matrices).

On top of this, we add a `left` kw-only arg that allows the user to
solve `XA = B` rather concisely.

This PR also makes `torch.solve` an alias of `torch.linalg.solve`.

**Note:** This used to be the last PR of the stack. Now it's here because some tests were not passing in a PR that was before this one in the stack, and reshuffling the stack solved those problems. The benchmarks below are performed wrt the last PR fo this stack.

We compare the performance of `linalg.solve` against master (before merging #67833, but already with a few PRs of the LU stack merged). We see that we got between  **x2.5 and x10 speed-ups in `linalg.solve`**.


<details>
<summary>
Benchmark Results
</summary>

```
[--------------------- linalg.solve + backward --------------------]
                                    |  master |  This PR
1 threads: ----------------------------------------------
      torch.Size([1, 1, 1])         |     1280  |    267
      torch.Size([2, 1, 1])         |     1300  |    200
      torch.Size([4, 1, 1])         |      500  |    231
      torch.Size([8, 1, 1])         |      600  |    232
      torch.Size([16, 1, 1])        |     1200  |    234
      torch.Size([32, 1, 1])        |     1300  |    239
      torch.Size([64, 1, 1])        |     1300  |    300
      torch.Size([128, 1, 1])       |     1340  |    331
      torch.Size([512, 1, 1])       |     1664  |    380
      torch.Size([1024, 1, 1])      |     2000  |    430
      torch.Size([1, 2, 2])         |     1200  |    300
      torch.Size([2, 2, 2])         |     1250  |    237
      torch.Size([4, 2, 2])         |      479  |    240
      torch.Size([8, 2, 2])         |      600  |    239
      torch.Size([16, 2, 2])        |     1300  |    242
      torch.Size([32, 2, 2])        |     1300  |    245
      torch.Size([64, 2, 2])        |     1300  |    260
      torch.Size([128, 2, 2])       |     1400  |    340
      torch.Size([512, 2, 2])       |     1680  |    380
      torch.Size([1024, 2, 2])      |     2100  |    430
      torch.Size([1, 8, 8])         |     1200  |    250
      torch.Size([2, 8, 8])         |     1240  |    238
      torch.Size([4, 8, 8])         |      480  |    240
      torch.Size([8, 8, 8])         |      600  |    240
      torch.Size([16, 8, 8])        |     1330  |    243
      torch.Size([32, 8, 8])        |     1340  |    250
      torch.Size([64, 8, 8])        |     1370  |    257
      torch.Size([128, 8, 8])       |     1400  |    280
      torch.Size([512, 8, 8])       |     1720  |    346
      torch.Size([1024, 8, 8])      |     2300  |    390
      torch.Size([1, 16, 16])       |     1380  |    245
      torch.Size([2, 16, 16])       |     1000  |    300
      torch.Size([4, 16, 16])       |      610  |    260
      torch.Size([8, 16, 16])       |      862  |    260
      torch.Size([16, 16, 16])      |     1350  |    260
      torch.Size([32, 16, 16])      |     1370  |    260
      torch.Size([64, 16, 16])      |     1440  |    273
      torch.Size([128, 16, 16])     |     1520  |    289
      torch.Size([512, 16, 16])     |     1880  |    350
      torch.Size([1024, 16, 16])    |     2540  |    530
      torch.Size([1, 32, 32])       |     1500  |    290
      torch.Size([2, 32, 32])       |     2100  |    287
      torch.Size([4, 32, 32])       |     1370  |    288
      torch.Size([8, 32, 32])       |     1389  |    290
      torch.Size([16, 32, 32])      |     1400  |    290
      torch.Size([32, 32, 32])      |     1500  |    476
      torch.Size([64, 32, 32])      |     1600  |    468
      torch.Size([128, 32, 32])     |     1700  |    479
      torch.Size([512, 32, 32])     |     2300  |    696
      torch.Size([1024, 32, 32])    |     3200  |   1200
      torch.Size([1, 64, 64])       |     1700  |    340
      torch.Size([2, 64, 64])       |     2800  |    353
      torch.Size([4, 64, 64])       |     1990  |    328
      torch.Size([8, 64, 64])       |     2040  |    330
      torch.Size([16, 64, 64])      |     2100  |    350
      torch.Size([32, 64, 64])      |     2300  |    680
      torch.Size([64, 64, 64])      |     2430  |    725
      torch.Size([128, 64, 64])     |     2600  |    845
      torch.Size([512, 64, 64])     |     4700  |   1900
      torch.Size([1024, 64, 64])    |     9200  |   4280
      torch.Size([1, 128, 128])     |     2300  |    497
      torch.Size([2, 128, 128])     |     4000  |    562
      torch.Size([4, 128, 128])     |     3140  |    669
      torch.Size([8, 128, 128])     |     3200  |    698
      torch.Size([16, 128, 128])    |     3400  |    810
      torch.Size([32, 128, 128])    |     3866  |   1410
      torch.Size([64, 128, 128])    |     4200  |   1670
      torch.Size([128, 128, 128])   |     5050  |   2170
      torch.Size([512, 128, 128])   |    14000  |   6417
      torch.Size([1024, 128, 128])  |    28900  |  14700
      torch.Size([1, 256, 256])     |     4100  |   1559
      torch.Size([2, 256, 256])     |     6800  |   1792
      torch.Size([4, 256, 256])     |     7000  |   2000
      torch.Size([8, 256, 256])     |     7300  |   2200
      torch.Size([16, 256, 256])    |     7730  |   2540
      torch.Size([32, 256, 256])    |     8500  |   3390
      torch.Size([64, 256, 256])    |    11000  |   4470
      torch.Size([128, 256, 256])   |    15900  |   6757
      torch.Size([512, 256, 256])   |    50000  |  30000
      torch.Size([1024, 256, 256])  |   102600  |  56400
      torch.Size([1, 512, 512])     |     8793  |   3230
      torch.Size([2, 512, 512])     |    13000  |   3920
      torch.Size([4, 512, 512])     |    14000  |   4531
      torch.Size([8, 512, 512])     |    15000  |   5114
      torch.Size([16, 512, 512])    |    16700  |   6280
      torch.Size([32, 512, 512])    |    22400  |   9530
      torch.Size([64, 512, 512])    |    33700  |  14260
      torch.Size([128, 512, 512])   |    56500  |  20000

Times are in microseconds (us).
```

</details>

<details>
<summary>
Benchmarking Script
</summary>

```python
import torch
import pickle
import itertools
from functools import partial
from torch.utils.benchmark import Timer, Compare

benchmark_name = "linalg.solve"
label = "master"
shapes = [1, 2, 8, 16, 32, 64, 128, 256, 512]
batches = [(1,), (2,), (4,), (8,), (16,), (32,), (64,), (128,), (512,), (1024,)]
results = []
make_arg = partial(torch.randn, dtype=torch.float32, device="cuda", requires_grad=True)

for n, batch in itertools.product(shapes, batches):
    if n == 512 and batch[0] >= 512:
        continue
    A = make_arg(batch + (n, n))
    B = make_arg(batch + (n, 16))
    ones = torch.ones(B.shape, device=B.device)
    print(A.shape)
    for adjoint in (True, False):
        timer = Timer("torch.linalg.solve(A, B).backward(gradient=ones, inputs=[A, B])",
                      globals=globals(),
                      label=benchmark_name,
                      description=label,
                      sub_label=f"{A.shape}",
                      num_threads=1)
        results.append(timer.blocked_autorange())


compare = Compare(results)
compare.trim_significant_figures()
compare.print()

with open("{}.pickle".format(label), 'wb') as f:
    pickle.dump(results, f)
```
</details>

See #72935 (comment) for the script to join the results.


[ghstack-poisoned]
lezcano added 2 commits May 17, 2022 15:55
This PR heavily simplifies the code of `linalg.solve`. At the same time,
this implementation saves quite a some copies of the input data in some
cases (e.g. A is contiguous)

We also implement it in such a way that the derivative goes from
computing two LU decompositions and two LU solves to no LU
decompositions and one LU solves. It also avoids a number of unnecessary
copies the derivative was unnecessarily performing (at least the copy of
two matrices).

On top of this, we add a `left` kw-only arg that allows the user to
solve `XA = B` rather concisely.

This PR also makes `torch.solve` an alias of `torch.linalg.solve`.

**Note:** This used to be the last PR of the stack. Now it's here because some tests were not passing in a PR that was before this one in the stack, and reshuffling the stack solved those problems. The benchmarks below are performed wrt the last PR fo this stack.

We compare the performance of `linalg.solve` against master (before merging #67833, but already with a few PRs of the LU stack merged). We see that we got between  **x2.5 and x10 speed-ups in `linalg.solve`**.


<details>
<summary>
Benchmark Results
</summary>

```
[--------------------- linalg.solve + backward --------------------]
                                    |  master |  This PR
1 threads: ----------------------------------------------
      torch.Size([1, 1, 1])         |     1280  |    267
      torch.Size([2, 1, 1])         |     1300  |    200
      torch.Size([4, 1, 1])         |      500  |    231
      torch.Size([8, 1, 1])         |      600  |    232
      torch.Size([16, 1, 1])        |     1200  |    234
      torch.Size([32, 1, 1])        |     1300  |    239
      torch.Size([64, 1, 1])        |     1300  |    300
      torch.Size([128, 1, 1])       |     1340  |    331
      torch.Size([512, 1, 1])       |     1664  |    380
      torch.Size([1024, 1, 1])      |     2000  |    430
      torch.Size([1, 2, 2])         |     1200  |    300
      torch.Size([2, 2, 2])         |     1250  |    237
      torch.Size([4, 2, 2])         |      479  |    240
      torch.Size([8, 2, 2])         |      600  |    239
      torch.Size([16, 2, 2])        |     1300  |    242
      torch.Size([32, 2, 2])        |     1300  |    245
      torch.Size([64, 2, 2])        |     1300  |    260
      torch.Size([128, 2, 2])       |     1400  |    340
      torch.Size([512, 2, 2])       |     1680  |    380
      torch.Size([1024, 2, 2])      |     2100  |    430
      torch.Size([1, 8, 8])         |     1200  |    250
      torch.Size([2, 8, 8])         |     1240  |    238
      torch.Size([4, 8, 8])         |      480  |    240
      torch.Size([8, 8, 8])         |      600  |    240
      torch.Size([16, 8, 8])        |     1330  |    243
      torch.Size([32, 8, 8])        |     1340  |    250
      torch.Size([64, 8, 8])        |     1370  |    257
      torch.Size([128, 8, 8])       |     1400  |    280
      torch.Size([512, 8, 8])       |     1720  |    346
      torch.Size([1024, 8, 8])      |     2300  |    390
      torch.Size([1, 16, 16])       |     1380  |    245
      torch.Size([2, 16, 16])       |     1000  |    300
      torch.Size([4, 16, 16])       |      610  |    260
      torch.Size([8, 16, 16])       |      862  |    260
      torch.Size([16, 16, 16])      |     1350  |    260
      torch.Size([32, 16, 16])      |     1370  |    260
      torch.Size([64, 16, 16])      |     1440  |    273
      torch.Size([128, 16, 16])     |     1520  |    289
      torch.Size([512, 16, 16])     |     1880  |    350
      torch.Size([1024, 16, 16])    |     2540  |    530
      torch.Size([1, 32, 32])       |     1500  |    290
      torch.Size([2, 32, 32])       |     2100  |    287
      torch.Size([4, 32, 32])       |     1370  |    288
      torch.Size([8, 32, 32])       |     1389  |    290
      torch.Size([16, 32, 32])      |     1400  |    290
      torch.Size([32, 32, 32])      |     1500  |    476
      torch.Size([64, 32, 32])      |     1600  |    468
      torch.Size([128, 32, 32])     |     1700  |    479
      torch.Size([512, 32, 32])     |     2300  |    696
      torch.Size([1024, 32, 32])    |     3200  |   1200
      torch.Size([1, 64, 64])       |     1700  |    340
      torch.Size([2, 64, 64])       |     2800  |    353
      torch.Size([4, 64, 64])       |     1990  |    328
      torch.Size([8, 64, 64])       |     2040  |    330
      torch.Size([16, 64, 64])      |     2100  |    350
      torch.Size([32, 64, 64])      |     2300  |    680
      torch.Size([64, 64, 64])      |     2430  |    725
      torch.Size([128, 64, 64])     |     2600  |    845
      torch.Size([512, 64, 64])     |     4700  |   1900
      torch.Size([1024, 64, 64])    |     9200  |   4280
      torch.Size([1, 128, 128])     |     2300  |    497
      torch.Size([2, 128, 128])     |     4000  |    562
      torch.Size([4, 128, 128])     |     3140  |    669
      torch.Size([8, 128, 128])     |     3200  |    698
      torch.Size([16, 128, 128])    |     3400  |    810
      torch.Size([32, 128, 128])    |     3866  |   1410
      torch.Size([64, 128, 128])    |     4200  |   1670
      torch.Size([128, 128, 128])   |     5050  |   2170
      torch.Size([512, 128, 128])   |    14000  |   6417
      torch.Size([1024, 128, 128])  |    28900  |  14700
      torch.Size([1, 256, 256])     |     4100  |   1559
      torch.Size([2, 256, 256])     |     6800  |   1792
      torch.Size([4, 256, 256])     |     7000  |   2000
      torch.Size([8, 256, 256])     |     7300  |   2200
      torch.Size([16, 256, 256])    |     7730  |   2540
      torch.Size([32, 256, 256])    |     8500  |   3390
      torch.Size([64, 256, 256])    |    11000  |   4470
      torch.Size([128, 256, 256])   |    15900  |   6757
      torch.Size([512, 256, 256])   |    50000  |  30000
      torch.Size([1024, 256, 256])  |   102600  |  56400
      torch.Size([1, 512, 512])     |     8793  |   3230
      torch.Size([2, 512, 512])     |    13000  |   3920
      torch.Size([4, 512, 512])     |    14000  |   4531
      torch.Size([8, 512, 512])     |    15000  |   5114
      torch.Size([16, 512, 512])    |    16700  |   6280
      torch.Size([32, 512, 512])    |    22400  |   9530
      torch.Size([64, 512, 512])    |    33700  |  14260
      torch.Size([128, 512, 512])   |    56500  |  20000

Times are in microseconds (us).
```

</details>

<details>
<summary>
Benchmarking Script
</summary>

```python
import torch
import pickle
import itertools
from functools import partial
from torch.utils.benchmark import Timer, Compare

benchmark_name = "linalg.solve"
label = "master"
shapes = [1, 2, 8, 16, 32, 64, 128, 256, 512]
batches = [(1,), (2,), (4,), (8,), (16,), (32,), (64,), (128,), (512,), (1024,)]
results = []
make_arg = partial(torch.randn, dtype=torch.float32, device="cuda", requires_grad=True)

for n, batch in itertools.product(shapes, batches):
    if n == 512 and batch[0] >= 512:
        continue
    A = make_arg(batch + (n, n))
    B = make_arg(batch + (n, 16))
    ones = torch.ones(B.shape, device=B.device)
    print(A.shape)
    for adjoint in (True, False):
        timer = Timer("torch.linalg.solve(A, B).backward(gradient=ones, inputs=[A, B])",
                      globals=globals(),
                      label=benchmark_name,
                      description=label,
                      sub_label=f"{A.shape}",
                      num_threads=1)
        results.append(timer.blocked_autorange())


compare = Compare(results)
compare.trim_significant_figures()
compare.print()

with open("{}.pickle".format(label), 'wb') as f:
    pickle.dump(results, f)
```
</details>

See #72935 (comment) for the script to join the results.


[ghstack-poisoned]
This PR heavily simplifies the code of `linalg.solve`. At the same time,
this implementation saves quite a some copies of the input data in some
cases (e.g. A is contiguous)

We also implement it in such a way that the derivative goes from
computing two LU decompositions and two LU solves to no LU
decompositions and one LU solves. It also avoids a number of unnecessary
copies the derivative was unnecessarily performing (at least the copy of
two matrices).

On top of this, we add a `left` kw-only arg that allows the user to
solve `XA = B` rather concisely.

This PR also makes `torch.solve` an alias of `torch.linalg.solve`.

**Note:** This used to be the last PR of the stack. Now it's here because some tests were not passing in a PR that was before this one in the stack, and reshuffling the stack solved those problems. The benchmarks below are performed wrt the last PR fo this stack.

We compare the performance of `linalg.solve` against master (before merging #67833, but already with a few PRs of the LU stack merged). We see that we got between  **x2.5 and x10 speed-ups in `linalg.solve`**.


<details>
<summary>
Benchmark Results
</summary>

```
[--------------------- linalg.solve + backward --------------------]
                                    |  master |  This PR
1 threads: ----------------------------------------------
      torch.Size([1, 1, 1])         |     1280  |    267
      torch.Size([2, 1, 1])         |     1300  |    200
      torch.Size([4, 1, 1])         |      500  |    231
      torch.Size([8, 1, 1])         |      600  |    232
      torch.Size([16, 1, 1])        |     1200  |    234
      torch.Size([32, 1, 1])        |     1300  |    239
      torch.Size([64, 1, 1])        |     1300  |    300
      torch.Size([128, 1, 1])       |     1340  |    331
      torch.Size([512, 1, 1])       |     1664  |    380
      torch.Size([1024, 1, 1])      |     2000  |    430
      torch.Size([1, 2, 2])         |     1200  |    300
      torch.Size([2, 2, 2])         |     1250  |    237
      torch.Size([4, 2, 2])         |      479  |    240
      torch.Size([8, 2, 2])         |      600  |    239
      torch.Size([16, 2, 2])        |     1300  |    242
      torch.Size([32, 2, 2])        |     1300  |    245
      torch.Size([64, 2, 2])        |     1300  |    260
      torch.Size([128, 2, 2])       |     1400  |    340
      torch.Size([512, 2, 2])       |     1680  |    380
      torch.Size([1024, 2, 2])      |     2100  |    430
      torch.Size([1, 8, 8])         |     1200  |    250
      torch.Size([2, 8, 8])         |     1240  |    238
      torch.Size([4, 8, 8])         |      480  |    240
      torch.Size([8, 8, 8])         |      600  |    240
      torch.Size([16, 8, 8])        |     1330  |    243
      torch.Size([32, 8, 8])        |     1340  |    250
      torch.Size([64, 8, 8])        |     1370  |    257
      torch.Size([128, 8, 8])       |     1400  |    280
      torch.Size([512, 8, 8])       |     1720  |    346
      torch.Size([1024, 8, 8])      |     2300  |    390
      torch.Size([1, 16, 16])       |     1380  |    245
      torch.Size([2, 16, 16])       |     1000  |    300
      torch.Size([4, 16, 16])       |      610  |    260
      torch.Size([8, 16, 16])       |      862  |    260
      torch.Size([16, 16, 16])      |     1350  |    260
      torch.Size([32, 16, 16])      |     1370  |    260
      torch.Size([64, 16, 16])      |     1440  |    273
      torch.Size([128, 16, 16])     |     1520  |    289
      torch.Size([512, 16, 16])     |     1880  |    350
      torch.Size([1024, 16, 16])    |     2540  |    530
      torch.Size([1, 32, 32])       |     1500  |    290
      torch.Size([2, 32, 32])       |     2100  |    287
      torch.Size([4, 32, 32])       |     1370  |    288
      torch.Size([8, 32, 32])       |     1389  |    290
      torch.Size([16, 32, 32])      |     1400  |    290
      torch.Size([32, 32, 32])      |     1500  |    476
      torch.Size([64, 32, 32])      |     1600  |    468
      torch.Size([128, 32, 32])     |     1700  |    479
      torch.Size([512, 32, 32])     |     2300  |    696
      torch.Size([1024, 32, 32])    |     3200  |   1200
      torch.Size([1, 64, 64])       |     1700  |    340
      torch.Size([2, 64, 64])       |     2800  |    353
      torch.Size([4, 64, 64])       |     1990  |    328
      torch.Size([8, 64, 64])       |     2040  |    330
      torch.Size([16, 64, 64])      |     2100  |    350
      torch.Size([32, 64, 64])      |     2300  |    680
      torch.Size([64, 64, 64])      |     2430  |    725
      torch.Size([128, 64, 64])     |     2600  |    845
      torch.Size([512, 64, 64])     |     4700  |   1900
      torch.Size([1024, 64, 64])    |     9200  |   4280
      torch.Size([1, 128, 128])     |     2300  |    497
      torch.Size([2, 128, 128])     |     4000  |    562
      torch.Size([4, 128, 128])     |     3140  |    669
      torch.Size([8, 128, 128])     |     3200  |    698
      torch.Size([16, 128, 128])    |     3400  |    810
      torch.Size([32, 128, 128])    |     3866  |   1410
      torch.Size([64, 128, 128])    |     4200  |   1670
      torch.Size([128, 128, 128])   |     5050  |   2170
      torch.Size([512, 128, 128])   |    14000  |   6417
      torch.Size([1024, 128, 128])  |    28900  |  14700
      torch.Size([1, 256, 256])     |     4100  |   1559
      torch.Size([2, 256, 256])     |     6800  |   1792
      torch.Size([4, 256, 256])     |     7000  |   2000
      torch.Size([8, 256, 256])     |     7300  |   2200
      torch.Size([16, 256, 256])    |     7730  |   2540
      torch.Size([32, 256, 256])    |     8500  |   3390
      torch.Size([64, 256, 256])    |    11000  |   4470
      torch.Size([128, 256, 256])   |    15900  |   6757
      torch.Size([512, 256, 256])   |    50000  |  30000
      torch.Size([1024, 256, 256])  |   102600  |  56400
      torch.Size([1, 512, 512])     |     8793  |   3230
      torch.Size([2, 512, 512])     |    13000  |   3920
      torch.Size([4, 512, 512])     |    14000  |   4531
      torch.Size([8, 512, 512])     |    15000  |   5114
      torch.Size([16, 512, 512])    |    16700  |   6280
      torch.Size([32, 512, 512])    |    22400  |   9530
      torch.Size([64, 512, 512])    |    33700  |  14260
      torch.Size([128, 512, 512])   |    56500  |  20000

Times are in microseconds (us).
```

</details>

<details>
<summary>
Benchmarking Script
</summary>

```python
import torch
import pickle
import itertools
from functools import partial
from torch.utils.benchmark import Timer, Compare

benchmark_name = "linalg.solve"
label = "master"
shapes = [1, 2, 8, 16, 32, 64, 128, 256, 512]
batches = [(1,), (2,), (4,), (8,), (16,), (32,), (64,), (128,), (512,), (1024,)]
results = []
make_arg = partial(torch.randn, dtype=torch.float32, device="cuda", requires_grad=True)

for n, batch in itertools.product(shapes, batches):
    if n == 512 and batch[0] >= 512:
        continue
    A = make_arg(batch + (n, n))
    B = make_arg(batch + (n, 16))
    ones = torch.ones(B.shape, device=B.device)
    print(A.shape)
    for adjoint in (True, False):
        timer = Timer("torch.linalg.solve(A, B).backward(gradient=ones, inputs=[A, B])",
                      globals=globals(),
                      label=benchmark_name,
                      description=label,
                      sub_label=f"{A.shape}",
                      num_threads=1)
        results.append(timer.blocked_autorange())


compare = Compare(results)
compare.trim_significant_figures()
compare.print()

with open("{}.pickle".format(label), 'wb') as f:
    pickle.dump(results, f)
```
</details>

See #72935 (comment) for the script to join the results.


[ghstack-poisoned]
Comment thread torch/return_types.py Outdated
Copy link
Copy Markdown
Collaborator

@IvanYashchuk IvanYashchuk left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There's an unintended update of the third_party/cudnn_frontend submodule. This change must be reverted.

Other than that it's good to be merged. Thank you for making linalg_solve structured and removing the torch.solve changes!

This PR heavily simplifies the code of `linalg.solve`. At the same time,
this implementation saves quite a some copies of the input data in some
cases (e.g. A is contiguous)

We also implement it in such a way that the derivative goes from
computing two LU decompositions and two LU solves to no LU
decompositions and one LU solves. It also avoids a number of unnecessary
copies the derivative was unnecessarily performing (at least the copy of
two matrices).

On top of this, we add a `left` kw-only arg that allows the user to
solve `XA = B` rather concisely.

This PR also makes `torch.solve` an alias of `torch.linalg.solve`.

**Note:** This used to be the last PR of the stack. Now it's here because some tests were not passing in a PR that was before this one in the stack, and reshuffling the stack solved those problems. The benchmarks below are performed wrt the last PR fo this stack.

We compare the performance of `linalg.solve` against master (before merging #67833, but already with a few PRs of the LU stack merged). We see that we got between  **x2.5 and x10 speed-ups in `linalg.solve`**.


<details>
<summary>
Benchmark Results
</summary>

```
[--------------------- linalg.solve + backward --------------------]
                                    |  master |  This PR
1 threads: ----------------------------------------------
      torch.Size([1, 1, 1])         |     1280  |    267
      torch.Size([2, 1, 1])         |     1300  |    200
      torch.Size([4, 1, 1])         |      500  |    231
      torch.Size([8, 1, 1])         |      600  |    232
      torch.Size([16, 1, 1])        |     1200  |    234
      torch.Size([32, 1, 1])        |     1300  |    239
      torch.Size([64, 1, 1])        |     1300  |    300
      torch.Size([128, 1, 1])       |     1340  |    331
      torch.Size([512, 1, 1])       |     1664  |    380
      torch.Size([1024, 1, 1])      |     2000  |    430
      torch.Size([1, 2, 2])         |     1200  |    300
      torch.Size([2, 2, 2])         |     1250  |    237
      torch.Size([4, 2, 2])         |      479  |    240
      torch.Size([8, 2, 2])         |      600  |    239
      torch.Size([16, 2, 2])        |     1300  |    242
      torch.Size([32, 2, 2])        |     1300  |    245
      torch.Size([64, 2, 2])        |     1300  |    260
      torch.Size([128, 2, 2])       |     1400  |    340
      torch.Size([512, 2, 2])       |     1680  |    380
      torch.Size([1024, 2, 2])      |     2100  |    430
      torch.Size([1, 8, 8])         |     1200  |    250
      torch.Size([2, 8, 8])         |     1240  |    238
      torch.Size([4, 8, 8])         |      480  |    240
      torch.Size([8, 8, 8])         |      600  |    240
      torch.Size([16, 8, 8])        |     1330  |    243
      torch.Size([32, 8, 8])        |     1340  |    250
      torch.Size([64, 8, 8])        |     1370  |    257
      torch.Size([128, 8, 8])       |     1400  |    280
      torch.Size([512, 8, 8])       |     1720  |    346
      torch.Size([1024, 8, 8])      |     2300  |    390
      torch.Size([1, 16, 16])       |     1380  |    245
      torch.Size([2, 16, 16])       |     1000  |    300
      torch.Size([4, 16, 16])       |      610  |    260
      torch.Size([8, 16, 16])       |      862  |    260
      torch.Size([16, 16, 16])      |     1350  |    260
      torch.Size([32, 16, 16])      |     1370  |    260
      torch.Size([64, 16, 16])      |     1440  |    273
      torch.Size([128, 16, 16])     |     1520  |    289
      torch.Size([512, 16, 16])     |     1880  |    350
      torch.Size([1024, 16, 16])    |     2540  |    530
      torch.Size([1, 32, 32])       |     1500  |    290
      torch.Size([2, 32, 32])       |     2100  |    287
      torch.Size([4, 32, 32])       |     1370  |    288
      torch.Size([8, 32, 32])       |     1389  |    290
      torch.Size([16, 32, 32])      |     1400  |    290
      torch.Size([32, 32, 32])      |     1500  |    476
      torch.Size([64, 32, 32])      |     1600  |    468
      torch.Size([128, 32, 32])     |     1700  |    479
      torch.Size([512, 32, 32])     |     2300  |    696
      torch.Size([1024, 32, 32])    |     3200  |   1200
      torch.Size([1, 64, 64])       |     1700  |    340
      torch.Size([2, 64, 64])       |     2800  |    353
      torch.Size([4, 64, 64])       |     1990  |    328
      torch.Size([8, 64, 64])       |     2040  |    330
      torch.Size([16, 64, 64])      |     2100  |    350
      torch.Size([32, 64, 64])      |     2300  |    680
      torch.Size([64, 64, 64])      |     2430  |    725
      torch.Size([128, 64, 64])     |     2600  |    845
      torch.Size([512, 64, 64])     |     4700  |   1900
      torch.Size([1024, 64, 64])    |     9200  |   4280
      torch.Size([1, 128, 128])     |     2300  |    497
      torch.Size([2, 128, 128])     |     4000  |    562
      torch.Size([4, 128, 128])     |     3140  |    669
      torch.Size([8, 128, 128])     |     3200  |    698
      torch.Size([16, 128, 128])    |     3400  |    810
      torch.Size([32, 128, 128])    |     3866  |   1410
      torch.Size([64, 128, 128])    |     4200  |   1670
      torch.Size([128, 128, 128])   |     5050  |   2170
      torch.Size([512, 128, 128])   |    14000  |   6417
      torch.Size([1024, 128, 128])  |    28900  |  14700
      torch.Size([1, 256, 256])     |     4100  |   1559
      torch.Size([2, 256, 256])     |     6800  |   1792
      torch.Size([4, 256, 256])     |     7000  |   2000
      torch.Size([8, 256, 256])     |     7300  |   2200
      torch.Size([16, 256, 256])    |     7730  |   2540
      torch.Size([32, 256, 256])    |     8500  |   3390
      torch.Size([64, 256, 256])    |    11000  |   4470
      torch.Size([128, 256, 256])   |    15900  |   6757
      torch.Size([512, 256, 256])   |    50000  |  30000
      torch.Size([1024, 256, 256])  |   102600  |  56400
      torch.Size([1, 512, 512])     |     8793  |   3230
      torch.Size([2, 512, 512])     |    13000  |   3920
      torch.Size([4, 512, 512])     |    14000  |   4531
      torch.Size([8, 512, 512])     |    15000  |   5114
      torch.Size([16, 512, 512])    |    16700  |   6280
      torch.Size([32, 512, 512])    |    22400  |   9530
      torch.Size([64, 512, 512])    |    33700  |  14260
      torch.Size([128, 512, 512])   |    56500  |  20000

Times are in microseconds (us).
```

</details>

<details>
<summary>
Benchmarking Script
</summary>

```python
import torch
import pickle
import itertools
from functools import partial
from torch.utils.benchmark import Timer, Compare

benchmark_name = "linalg.solve"
label = "master"
shapes = [1, 2, 8, 16, 32, 64, 128, 256, 512]
batches = [(1,), (2,), (4,), (8,), (16,), (32,), (64,), (128,), (512,), (1024,)]
results = []
make_arg = partial(torch.randn, dtype=torch.float32, device="cuda", requires_grad=True)

for n, batch in itertools.product(shapes, batches):
    if n == 512 and batch[0] >= 512:
        continue
    A = make_arg(batch + (n, n))
    B = make_arg(batch + (n, 16))
    ones = torch.ones(B.shape, device=B.device)
    print(A.shape)
    for adjoint in (True, False):
        timer = Timer("torch.linalg.solve(A, B).backward(gradient=ones, inputs=[A, B])",
                      globals=globals(),
                      label=benchmark_name,
                      description=label,
                      sub_label=f"{A.shape}",
                      num_threads=1)
        results.append(timer.blocked_autorange())


compare = Compare(results)
compare.trim_significant_figures()
compare.print()

with open("{}.pickle".format(label), 'wb') as f:
    pickle.dump(results, f)
```
</details>

See #72935 (comment) for the script to join the results.


[ghstack-poisoned]
@lezcano lezcano requested a review from anjali411 as a code owner May 18, 2022 19:03
@lezcano
Copy link
Copy Markdown
Collaborator Author

lezcano commented May 18, 2022

Thank you for the catch @IvanYashchuk!

This PR heavily simplifies the code of `linalg.solve`. At the same time,
this implementation saves quite a some copies of the input data in some
cases (e.g. A is contiguous)

We also implement it in such a way that the derivative goes from
computing two LU decompositions and two LU solves to no LU
decompositions and one LU solves. It also avoids a number of unnecessary
copies the derivative was unnecessarily performing (at least the copy of
two matrices).

On top of this, we add a `left` kw-only arg that allows the user to
solve `XA = B` rather concisely.

This PR also makes `torch.solve` an alias of `torch.linalg.solve`.

**Note:** This used to be the last PR of the stack. Now it's here because some tests were not passing in a PR that was before this one in the stack, and reshuffling the stack solved those problems. The benchmarks below are performed wrt the last PR fo this stack.

We compare the performance of `linalg.solve` against master (before merging #67833, but already with a few PRs of the LU stack merged). We see that we got between  **x2.5 and x10 speed-ups in `linalg.solve`**.


<details>
<summary>
Benchmark Results
</summary>

```
[--------------------- linalg.solve + backward --------------------]
                                    |  master |  This PR
1 threads: ----------------------------------------------
      torch.Size([1, 1, 1])         |     1280  |    267
      torch.Size([2, 1, 1])         |     1300  |    200
      torch.Size([4, 1, 1])         |      500  |    231
      torch.Size([8, 1, 1])         |      600  |    232
      torch.Size([16, 1, 1])        |     1200  |    234
      torch.Size([32, 1, 1])        |     1300  |    239
      torch.Size([64, 1, 1])        |     1300  |    300
      torch.Size([128, 1, 1])       |     1340  |    331
      torch.Size([512, 1, 1])       |     1664  |    380
      torch.Size([1024, 1, 1])      |     2000  |    430
      torch.Size([1, 2, 2])         |     1200  |    300
      torch.Size([2, 2, 2])         |     1250  |    237
      torch.Size([4, 2, 2])         |      479  |    240
      torch.Size([8, 2, 2])         |      600  |    239
      torch.Size([16, 2, 2])        |     1300  |    242
      torch.Size([32, 2, 2])        |     1300  |    245
      torch.Size([64, 2, 2])        |     1300  |    260
      torch.Size([128, 2, 2])       |     1400  |    340
      torch.Size([512, 2, 2])       |     1680  |    380
      torch.Size([1024, 2, 2])      |     2100  |    430
      torch.Size([1, 8, 8])         |     1200  |    250
      torch.Size([2, 8, 8])         |     1240  |    238
      torch.Size([4, 8, 8])         |      480  |    240
      torch.Size([8, 8, 8])         |      600  |    240
      torch.Size([16, 8, 8])        |     1330  |    243
      torch.Size([32, 8, 8])        |     1340  |    250
      torch.Size([64, 8, 8])        |     1370  |    257
      torch.Size([128, 8, 8])       |     1400  |    280
      torch.Size([512, 8, 8])       |     1720  |    346
      torch.Size([1024, 8, 8])      |     2300  |    390
      torch.Size([1, 16, 16])       |     1380  |    245
      torch.Size([2, 16, 16])       |     1000  |    300
      torch.Size([4, 16, 16])       |      610  |    260
      torch.Size([8, 16, 16])       |      862  |    260
      torch.Size([16, 16, 16])      |     1350  |    260
      torch.Size([32, 16, 16])      |     1370  |    260
      torch.Size([64, 16, 16])      |     1440  |    273
      torch.Size([128, 16, 16])     |     1520  |    289
      torch.Size([512, 16, 16])     |     1880  |    350
      torch.Size([1024, 16, 16])    |     2540  |    530
      torch.Size([1, 32, 32])       |     1500  |    290
      torch.Size([2, 32, 32])       |     2100  |    287
      torch.Size([4, 32, 32])       |     1370  |    288
      torch.Size([8, 32, 32])       |     1389  |    290
      torch.Size([16, 32, 32])      |     1400  |    290
      torch.Size([32, 32, 32])      |     1500  |    476
      torch.Size([64, 32, 32])      |     1600  |    468
      torch.Size([128, 32, 32])     |     1700  |    479
      torch.Size([512, 32, 32])     |     2300  |    696
      torch.Size([1024, 32, 32])    |     3200  |   1200
      torch.Size([1, 64, 64])       |     1700  |    340
      torch.Size([2, 64, 64])       |     2800  |    353
      torch.Size([4, 64, 64])       |     1990  |    328
      torch.Size([8, 64, 64])       |     2040  |    330
      torch.Size([16, 64, 64])      |     2100  |    350
      torch.Size([32, 64, 64])      |     2300  |    680
      torch.Size([64, 64, 64])      |     2430  |    725
      torch.Size([128, 64, 64])     |     2600  |    845
      torch.Size([512, 64, 64])     |     4700  |   1900
      torch.Size([1024, 64, 64])    |     9200  |   4280
      torch.Size([1, 128, 128])     |     2300  |    497
      torch.Size([2, 128, 128])     |     4000  |    562
      torch.Size([4, 128, 128])     |     3140  |    669
      torch.Size([8, 128, 128])     |     3200  |    698
      torch.Size([16, 128, 128])    |     3400  |    810
      torch.Size([32, 128, 128])    |     3866  |   1410
      torch.Size([64, 128, 128])    |     4200  |   1670
      torch.Size([128, 128, 128])   |     5050  |   2170
      torch.Size([512, 128, 128])   |    14000  |   6417
      torch.Size([1024, 128, 128])  |    28900  |  14700
      torch.Size([1, 256, 256])     |     4100  |   1559
      torch.Size([2, 256, 256])     |     6800  |   1792
      torch.Size([4, 256, 256])     |     7000  |   2000
      torch.Size([8, 256, 256])     |     7300  |   2200
      torch.Size([16, 256, 256])    |     7730  |   2540
      torch.Size([32, 256, 256])    |     8500  |   3390
      torch.Size([64, 256, 256])    |    11000  |   4470
      torch.Size([128, 256, 256])   |    15900  |   6757
      torch.Size([512, 256, 256])   |    50000  |  30000
      torch.Size([1024, 256, 256])  |   102600  |  56400
      torch.Size([1, 512, 512])     |     8793  |   3230
      torch.Size([2, 512, 512])     |    13000  |   3920
      torch.Size([4, 512, 512])     |    14000  |   4531
      torch.Size([8, 512, 512])     |    15000  |   5114
      torch.Size([16, 512, 512])    |    16700  |   6280
      torch.Size([32, 512, 512])    |    22400  |   9530
      torch.Size([64, 512, 512])    |    33700  |  14260
      torch.Size([128, 512, 512])   |    56500  |  20000

Times are in microseconds (us).
```

</details>

<details>
<summary>
Benchmarking Script
</summary>

```python
import torch
import pickle
import itertools
from functools import partial
from torch.utils.benchmark import Timer, Compare

benchmark_name = "linalg.solve"
label = "master"
shapes = [1, 2, 8, 16, 32, 64, 128, 256, 512]
batches = [(1,), (2,), (4,), (8,), (16,), (32,), (64,), (128,), (512,), (1024,)]
results = []
make_arg = partial(torch.randn, dtype=torch.float32, device="cuda", requires_grad=True)

for n, batch in itertools.product(shapes, batches):
    if n == 512 and batch[0] >= 512:
        continue
    A = make_arg(batch + (n, n))
    B = make_arg(batch + (n, 16))
    ones = torch.ones(B.shape, device=B.device)
    print(A.shape)
    for adjoint in (True, False):
        timer = Timer("torch.linalg.solve(A, B).backward(gradient=ones, inputs=[A, B])",
                      globals=globals(),
                      label=benchmark_name,
                      description=label,
                      sub_label=f"{A.shape}",
                      num_threads=1)
        results.append(timer.blocked_autorange())


compare = Compare(results)
compare.trim_significant_figures()
compare.print()

with open("{}.pickle".format(label), 'wb') as f:
    pickle.dump(results, f)
```
</details>

See #72935 (comment) for the script to join the results.


[ghstack-poisoned]
@tianyikillua
Copy link
Copy Markdown

Really hope that this could be merged for the 1.12 release 😃

@IvanYashchuk
Copy link
Copy Markdown
Collaborator

Really hope that this could be merged for the 1.12 release 😃

Hey Tianyi, I'm curious what kind of applications are you working on?
Unfortunately it seems that this PR won't be included in the 1.12 release. Release candidate branch has already been cut https://github.com/pytorch/pytorch/tree/release/1.12. However, if you really need this functionality you'd be able to use it from pytorch nightly releases (see https://pytorch.org/get-started/locally/ for the instructions).

lezcano added 6 commits May 25, 2022 16:30
This PR heavily simplifies the code of `linalg.solve`. At the same time,
this implementation saves quite a some copies of the input data in some
cases (e.g. A is contiguous)

We also implement it in such a way that the derivative goes from
computing two LU decompositions and two LU solves to no LU
decompositions and one LU solves. It also avoids a number of unnecessary
copies the derivative was unnecessarily performing (at least the copy of
two matrices).

On top of this, we add a `left` kw-only arg that allows the user to
solve `XA = B` rather concisely.

This PR also makes `torch.solve` an alias of `torch.linalg.solve`.

**Note:** This used to be the last PR of the stack. Now it's here because some tests were not passing in a PR that was before this one in the stack, and reshuffling the stack solved those problems. The benchmarks below are performed wrt the last PR fo this stack.

We compare the performance of `linalg.solve` against master (before merging #67833, but already with a few PRs of the LU stack merged). We see that we got between  **x2.5 and x10 speed-ups in `linalg.solve`**.


<details>
<summary>
Benchmark Results
</summary>

```
[--------------------- linalg.solve + backward --------------------]
                                    |  master |  This PR
1 threads: ----------------------------------------------
      torch.Size([1, 1, 1])         |     1280  |    267
      torch.Size([2, 1, 1])         |     1300  |    200
      torch.Size([4, 1, 1])         |      500  |    231
      torch.Size([8, 1, 1])         |      600  |    232
      torch.Size([16, 1, 1])        |     1200  |    234
      torch.Size([32, 1, 1])        |     1300  |    239
      torch.Size([64, 1, 1])        |     1300  |    300
      torch.Size([128, 1, 1])       |     1340  |    331
      torch.Size([512, 1, 1])       |     1664  |    380
      torch.Size([1024, 1, 1])      |     2000  |    430
      torch.Size([1, 2, 2])         |     1200  |    300
      torch.Size([2, 2, 2])         |     1250  |    237
      torch.Size([4, 2, 2])         |      479  |    240
      torch.Size([8, 2, 2])         |      600  |    239
      torch.Size([16, 2, 2])        |     1300  |    242
      torch.Size([32, 2, 2])        |     1300  |    245
      torch.Size([64, 2, 2])        |     1300  |    260
      torch.Size([128, 2, 2])       |     1400  |    340
      torch.Size([512, 2, 2])       |     1680  |    380
      torch.Size([1024, 2, 2])      |     2100  |    430
      torch.Size([1, 8, 8])         |     1200  |    250
      torch.Size([2, 8, 8])         |     1240  |    238
      torch.Size([4, 8, 8])         |      480  |    240
      torch.Size([8, 8, 8])         |      600  |    240
      torch.Size([16, 8, 8])        |     1330  |    243
      torch.Size([32, 8, 8])        |     1340  |    250
      torch.Size([64, 8, 8])        |     1370  |    257
      torch.Size([128, 8, 8])       |     1400  |    280
      torch.Size([512, 8, 8])       |     1720  |    346
      torch.Size([1024, 8, 8])      |     2300  |    390
      torch.Size([1, 16, 16])       |     1380  |    245
      torch.Size([2, 16, 16])       |     1000  |    300
      torch.Size([4, 16, 16])       |      610  |    260
      torch.Size([8, 16, 16])       |      862  |    260
      torch.Size([16, 16, 16])      |     1350  |    260
      torch.Size([32, 16, 16])      |     1370  |    260
      torch.Size([64, 16, 16])      |     1440  |    273
      torch.Size([128, 16, 16])     |     1520  |    289
      torch.Size([512, 16, 16])     |     1880  |    350
      torch.Size([1024, 16, 16])    |     2540  |    530
      torch.Size([1, 32, 32])       |     1500  |    290
      torch.Size([2, 32, 32])       |     2100  |    287
      torch.Size([4, 32, 32])       |     1370  |    288
      torch.Size([8, 32, 32])       |     1389  |    290
      torch.Size([16, 32, 32])      |     1400  |    290
      torch.Size([32, 32, 32])      |     1500  |    476
      torch.Size([64, 32, 32])      |     1600  |    468
      torch.Size([128, 32, 32])     |     1700  |    479
      torch.Size([512, 32, 32])     |     2300  |    696
      torch.Size([1024, 32, 32])    |     3200  |   1200
      torch.Size([1, 64, 64])       |     1700  |    340
      torch.Size([2, 64, 64])       |     2800  |    353
      torch.Size([4, 64, 64])       |     1990  |    328
      torch.Size([8, 64, 64])       |     2040  |    330
      torch.Size([16, 64, 64])      |     2100  |    350
      torch.Size([32, 64, 64])      |     2300  |    680
      torch.Size([64, 64, 64])      |     2430  |    725
      torch.Size([128, 64, 64])     |     2600  |    845
      torch.Size([512, 64, 64])     |     4700  |   1900
      torch.Size([1024, 64, 64])    |     9200  |   4280
      torch.Size([1, 128, 128])     |     2300  |    497
      torch.Size([2, 128, 128])     |     4000  |    562
      torch.Size([4, 128, 128])     |     3140  |    669
      torch.Size([8, 128, 128])     |     3200  |    698
      torch.Size([16, 128, 128])    |     3400  |    810
      torch.Size([32, 128, 128])    |     3866  |   1410
      torch.Size([64, 128, 128])    |     4200  |   1670
      torch.Size([128, 128, 128])   |     5050  |   2170
      torch.Size([512, 128, 128])   |    14000  |   6417
      torch.Size([1024, 128, 128])  |    28900  |  14700
      torch.Size([1, 256, 256])     |     4100  |   1559
      torch.Size([2, 256, 256])     |     6800  |   1792
      torch.Size([4, 256, 256])     |     7000  |   2000
      torch.Size([8, 256, 256])     |     7300  |   2200
      torch.Size([16, 256, 256])    |     7730  |   2540
      torch.Size([32, 256, 256])    |     8500  |   3390
      torch.Size([64, 256, 256])    |    11000  |   4470
      torch.Size([128, 256, 256])   |    15900  |   6757
      torch.Size([512, 256, 256])   |    50000  |  30000
      torch.Size([1024, 256, 256])  |   102600  |  56400
      torch.Size([1, 512, 512])     |     8793  |   3230
      torch.Size([2, 512, 512])     |    13000  |   3920
      torch.Size([4, 512, 512])     |    14000  |   4531
      torch.Size([8, 512, 512])     |    15000  |   5114
      torch.Size([16, 512, 512])    |    16700  |   6280
      torch.Size([32, 512, 512])    |    22400  |   9530
      torch.Size([64, 512, 512])    |    33700  |  14260
      torch.Size([128, 512, 512])   |    56500  |  20000

Times are in microseconds (us).
```

</details>

<details>
<summary>
Benchmarking Script
</summary>

```python
import torch
import pickle
import itertools
from functools import partial
from torch.utils.benchmark import Timer, Compare

benchmark_name = "linalg.solve"
label = "master"
shapes = [1, 2, 8, 16, 32, 64, 128, 256, 512]
batches = [(1,), (2,), (4,), (8,), (16,), (32,), (64,), (128,), (512,), (1024,)]
results = []
make_arg = partial(torch.randn, dtype=torch.float32, device="cuda", requires_grad=True)

for n, batch in itertools.product(shapes, batches):
    if n == 512 and batch[0] >= 512:
        continue
    A = make_arg(batch + (n, n))
    B = make_arg(batch + (n, 16))
    ones = torch.ones(B.shape, device=B.device)
    print(A.shape)
    for adjoint in (True, False):
        timer = Timer("torch.linalg.solve(A, B).backward(gradient=ones, inputs=[A, B])",
                      globals=globals(),
                      label=benchmark_name,
                      description=label,
                      sub_label=f"{A.shape}",
                      num_threads=1)
        results.append(timer.blocked_autorange())


compare = Compare(results)
compare.trim_significant_figures()
compare.print()

with open("{}.pickle".format(label), 'wb') as f:
    pickle.dump(results, f)
```
</details>

See #72935 (comment) for the script to join the results.


[ghstack-poisoned]
This PR heavily simplifies the code of `linalg.solve`. At the same time,
this implementation saves quite a some copies of the input data in some
cases (e.g. A is contiguous)

We also implement it in such a way that the derivative goes from
computing two LU decompositions and two LU solves to no LU
decompositions and one LU solves. It also avoids a number of unnecessary
copies the derivative was unnecessarily performing (at least the copy of
two matrices).

On top of this, we add a `left` kw-only arg that allows the user to
solve `XA = B` rather concisely.

This PR also makes `torch.solve` an alias of `torch.linalg.solve`.

**Note:** This used to be the last PR of the stack. Now it's here because some tests were not passing in a PR that was before this one in the stack, and reshuffling the stack solved those problems. The benchmarks below are performed wrt the last PR fo this stack.

We compare the performance of `linalg.solve` against master (before merging #67833, but already with a few PRs of the LU stack merged). We see that we got between  **x2.5 and x10 speed-ups in `linalg.solve`**.


<details>
<summary>
Benchmark Results
</summary>

```
[--------------------- linalg.solve + backward --------------------]
                                    |  master |  This PR
1 threads: ----------------------------------------------
      torch.Size([1, 1, 1])         |     1280  |    267
      torch.Size([2, 1, 1])         |     1300  |    200
      torch.Size([4, 1, 1])         |      500  |    231
      torch.Size([8, 1, 1])         |      600  |    232
      torch.Size([16, 1, 1])        |     1200  |    234
      torch.Size([32, 1, 1])        |     1300  |    239
      torch.Size([64, 1, 1])        |     1300  |    300
      torch.Size([128, 1, 1])       |     1340  |    331
      torch.Size([512, 1, 1])       |     1664  |    380
      torch.Size([1024, 1, 1])      |     2000  |    430
      torch.Size([1, 2, 2])         |     1200  |    300
      torch.Size([2, 2, 2])         |     1250  |    237
      torch.Size([4, 2, 2])         |      479  |    240
      torch.Size([8, 2, 2])         |      600  |    239
      torch.Size([16, 2, 2])        |     1300  |    242
      torch.Size([32, 2, 2])        |     1300  |    245
      torch.Size([64, 2, 2])        |     1300  |    260
      torch.Size([128, 2, 2])       |     1400  |    340
      torch.Size([512, 2, 2])       |     1680  |    380
      torch.Size([1024, 2, 2])      |     2100  |    430
      torch.Size([1, 8, 8])         |     1200  |    250
      torch.Size([2, 8, 8])         |     1240  |    238
      torch.Size([4, 8, 8])         |      480  |    240
      torch.Size([8, 8, 8])         |      600  |    240
      torch.Size([16, 8, 8])        |     1330  |    243
      torch.Size([32, 8, 8])        |     1340  |    250
      torch.Size([64, 8, 8])        |     1370  |    257
      torch.Size([128, 8, 8])       |     1400  |    280
      torch.Size([512, 8, 8])       |     1720  |    346
      torch.Size([1024, 8, 8])      |     2300  |    390
      torch.Size([1, 16, 16])       |     1380  |    245
      torch.Size([2, 16, 16])       |     1000  |    300
      torch.Size([4, 16, 16])       |      610  |    260
      torch.Size([8, 16, 16])       |      862  |    260
      torch.Size([16, 16, 16])      |     1350  |    260
      torch.Size([32, 16, 16])      |     1370  |    260
      torch.Size([64, 16, 16])      |     1440  |    273
      torch.Size([128, 16, 16])     |     1520  |    289
      torch.Size([512, 16, 16])     |     1880  |    350
      torch.Size([1024, 16, 16])    |     2540  |    530
      torch.Size([1, 32, 32])       |     1500  |    290
      torch.Size([2, 32, 32])       |     2100  |    287
      torch.Size([4, 32, 32])       |     1370  |    288
      torch.Size([8, 32, 32])       |     1389  |    290
      torch.Size([16, 32, 32])      |     1400  |    290
      torch.Size([32, 32, 32])      |     1500  |    476
      torch.Size([64, 32, 32])      |     1600  |    468
      torch.Size([128, 32, 32])     |     1700  |    479
      torch.Size([512, 32, 32])     |     2300  |    696
      torch.Size([1024, 32, 32])    |     3200  |   1200
      torch.Size([1, 64, 64])       |     1700  |    340
      torch.Size([2, 64, 64])       |     2800  |    353
      torch.Size([4, 64, 64])       |     1990  |    328
      torch.Size([8, 64, 64])       |     2040  |    330
      torch.Size([16, 64, 64])      |     2100  |    350
      torch.Size([32, 64, 64])      |     2300  |    680
      torch.Size([64, 64, 64])      |     2430  |    725
      torch.Size([128, 64, 64])     |     2600  |    845
      torch.Size([512, 64, 64])     |     4700  |   1900
      torch.Size([1024, 64, 64])    |     9200  |   4280
      torch.Size([1, 128, 128])     |     2300  |    497
      torch.Size([2, 128, 128])     |     4000  |    562
      torch.Size([4, 128, 128])     |     3140  |    669
      torch.Size([8, 128, 128])     |     3200  |    698
      torch.Size([16, 128, 128])    |     3400  |    810
      torch.Size([32, 128, 128])    |     3866  |   1410
      torch.Size([64, 128, 128])    |     4200  |   1670
      torch.Size([128, 128, 128])   |     5050  |   2170
      torch.Size([512, 128, 128])   |    14000  |   6417
      torch.Size([1024, 128, 128])  |    28900  |  14700
      torch.Size([1, 256, 256])     |     4100  |   1559
      torch.Size([2, 256, 256])     |     6800  |   1792
      torch.Size([4, 256, 256])     |     7000  |   2000
      torch.Size([8, 256, 256])     |     7300  |   2200
      torch.Size([16, 256, 256])    |     7730  |   2540
      torch.Size([32, 256, 256])    |     8500  |   3390
      torch.Size([64, 256, 256])    |    11000  |   4470
      torch.Size([128, 256, 256])   |    15900  |   6757
      torch.Size([512, 256, 256])   |    50000  |  30000
      torch.Size([1024, 256, 256])  |   102600  |  56400
      torch.Size([1, 512, 512])     |     8793  |   3230
      torch.Size([2, 512, 512])     |    13000  |   3920
      torch.Size([4, 512, 512])     |    14000  |   4531
      torch.Size([8, 512, 512])     |    15000  |   5114
      torch.Size([16, 512, 512])    |    16700  |   6280
      torch.Size([32, 512, 512])    |    22400  |   9530
      torch.Size([64, 512, 512])    |    33700  |  14260
      torch.Size([128, 512, 512])   |    56500  |  20000

Times are in microseconds (us).
```

</details>

<details>
<summary>
Benchmarking Script
</summary>

```python
import torch
import pickle
import itertools
from functools import partial
from torch.utils.benchmark import Timer, Compare

benchmark_name = "linalg.solve"
label = "master"
shapes = [1, 2, 8, 16, 32, 64, 128, 256, 512]
batches = [(1,), (2,), (4,), (8,), (16,), (32,), (64,), (128,), (512,), (1024,)]
results = []
make_arg = partial(torch.randn, dtype=torch.float32, device="cuda", requires_grad=True)

for n, batch in itertools.product(shapes, batches):
    if n == 512 and batch[0] >= 512:
        continue
    A = make_arg(batch + (n, n))
    B = make_arg(batch + (n, 16))
    ones = torch.ones(B.shape, device=B.device)
    print(A.shape)
    for adjoint in (True, False):
        timer = Timer("torch.linalg.solve(A, B).backward(gradient=ones, inputs=[A, B])",
                      globals=globals(),
                      label=benchmark_name,
                      description=label,
                      sub_label=f"{A.shape}",
                      num_threads=1)
        results.append(timer.blocked_autorange())


compare = Compare(results)
compare.trim_significant_figures()
compare.print()

with open("{}.pickle".format(label), 'wb') as f:
    pickle.dump(results, f)
```
</details>

See #72935 (comment) for the script to join the results.


[ghstack-poisoned]
This PR heavily simplifies the code of `linalg.solve`. At the same time,
this implementation saves quite a some copies of the input data in some
cases (e.g. A is contiguous)

We also implement it in such a way that the derivative goes from
computing two LU decompositions and two LU solves to no LU
decompositions and one LU solves. It also avoids a number of unnecessary
copies the derivative was unnecessarily performing (at least the copy of
two matrices).

On top of this, we add a `left` kw-only arg that allows the user to
solve `XA = B` rather concisely.

This PR also makes `torch.solve` an alias of `torch.linalg.solve`.

**Note:** This used to be the last PR of the stack. Now it's here because some tests were not passing in a PR that was before this one in the stack, and reshuffling the stack solved those problems. The benchmarks below are performed wrt the last PR fo this stack.

We compare the performance of `linalg.solve` against master (before merging #67833, but already with a few PRs of the LU stack merged). We see that we got between  **x2.5 and x10 speed-ups in `linalg.solve`**.


<details>
<summary>
Benchmark Results
</summary>

```
[--------------------- linalg.solve + backward --------------------]
                                    |  master |  This PR
1 threads: ----------------------------------------------
      torch.Size([1, 1, 1])         |     1280  |    267
      torch.Size([2, 1, 1])         |     1300  |    200
      torch.Size([4, 1, 1])         |      500  |    231
      torch.Size([8, 1, 1])         |      600  |    232
      torch.Size([16, 1, 1])        |     1200  |    234
      torch.Size([32, 1, 1])        |     1300  |    239
      torch.Size([64, 1, 1])        |     1300  |    300
      torch.Size([128, 1, 1])       |     1340  |    331
      torch.Size([512, 1, 1])       |     1664  |    380
      torch.Size([1024, 1, 1])      |     2000  |    430
      torch.Size([1, 2, 2])         |     1200  |    300
      torch.Size([2, 2, 2])         |     1250  |    237
      torch.Size([4, 2, 2])         |      479  |    240
      torch.Size([8, 2, 2])         |      600  |    239
      torch.Size([16, 2, 2])        |     1300  |    242
      torch.Size([32, 2, 2])        |     1300  |    245
      torch.Size([64, 2, 2])        |     1300  |    260
      torch.Size([128, 2, 2])       |     1400  |    340
      torch.Size([512, 2, 2])       |     1680  |    380
      torch.Size([1024, 2, 2])      |     2100  |    430
      torch.Size([1, 8, 8])         |     1200  |    250
      torch.Size([2, 8, 8])         |     1240  |    238
      torch.Size([4, 8, 8])         |      480  |    240
      torch.Size([8, 8, 8])         |      600  |    240
      torch.Size([16, 8, 8])        |     1330  |    243
      torch.Size([32, 8, 8])        |     1340  |    250
      torch.Size([64, 8, 8])        |     1370  |    257
      torch.Size([128, 8, 8])       |     1400  |    280
      torch.Size([512, 8, 8])       |     1720  |    346
      torch.Size([1024, 8, 8])      |     2300  |    390
      torch.Size([1, 16, 16])       |     1380  |    245
      torch.Size([2, 16, 16])       |     1000  |    300
      torch.Size([4, 16, 16])       |      610  |    260
      torch.Size([8, 16, 16])       |      862  |    260
      torch.Size([16, 16, 16])      |     1350  |    260
      torch.Size([32, 16, 16])      |     1370  |    260
      torch.Size([64, 16, 16])      |     1440  |    273
      torch.Size([128, 16, 16])     |     1520  |    289
      torch.Size([512, 16, 16])     |     1880  |    350
      torch.Size([1024, 16, 16])    |     2540  |    530
      torch.Size([1, 32, 32])       |     1500  |    290
      torch.Size([2, 32, 32])       |     2100  |    287
      torch.Size([4, 32, 32])       |     1370  |    288
      torch.Size([8, 32, 32])       |     1389  |    290
      torch.Size([16, 32, 32])      |     1400  |    290
      torch.Size([32, 32, 32])      |     1500  |    476
      torch.Size([64, 32, 32])      |     1600  |    468
      torch.Size([128, 32, 32])     |     1700  |    479
      torch.Size([512, 32, 32])     |     2300  |    696
      torch.Size([1024, 32, 32])    |     3200  |   1200
      torch.Size([1, 64, 64])       |     1700  |    340
      torch.Size([2, 64, 64])       |     2800  |    353
      torch.Size([4, 64, 64])       |     1990  |    328
      torch.Size([8, 64, 64])       |     2040  |    330
      torch.Size([16, 64, 64])      |     2100  |    350
      torch.Size([32, 64, 64])      |     2300  |    680
      torch.Size([64, 64, 64])      |     2430  |    725
      torch.Size([128, 64, 64])     |     2600  |    845
      torch.Size([512, 64, 64])     |     4700  |   1900
      torch.Size([1024, 64, 64])    |     9200  |   4280
      torch.Size([1, 128, 128])     |     2300  |    497
      torch.Size([2, 128, 128])     |     4000  |    562
      torch.Size([4, 128, 128])     |     3140  |    669
      torch.Size([8, 128, 128])     |     3200  |    698
      torch.Size([16, 128, 128])    |     3400  |    810
      torch.Size([32, 128, 128])    |     3866  |   1410
      torch.Size([64, 128, 128])    |     4200  |   1670
      torch.Size([128, 128, 128])   |     5050  |   2170
      torch.Size([512, 128, 128])   |    14000  |   6417
      torch.Size([1024, 128, 128])  |    28900  |  14700
      torch.Size([1, 256, 256])     |     4100  |   1559
      torch.Size([2, 256, 256])     |     6800  |   1792
      torch.Size([4, 256, 256])     |     7000  |   2000
      torch.Size([8, 256, 256])     |     7300  |   2200
      torch.Size([16, 256, 256])    |     7730  |   2540
      torch.Size([32, 256, 256])    |     8500  |   3390
      torch.Size([64, 256, 256])    |    11000  |   4470
      torch.Size([128, 256, 256])   |    15900  |   6757
      torch.Size([512, 256, 256])   |    50000  |  30000
      torch.Size([1024, 256, 256])  |   102600  |  56400
      torch.Size([1, 512, 512])     |     8793  |   3230
      torch.Size([2, 512, 512])     |    13000  |   3920
      torch.Size([4, 512, 512])     |    14000  |   4531
      torch.Size([8, 512, 512])     |    15000  |   5114
      torch.Size([16, 512, 512])    |    16700  |   6280
      torch.Size([32, 512, 512])    |    22400  |   9530
      torch.Size([64, 512, 512])    |    33700  |  14260
      torch.Size([128, 512, 512])   |    56500  |  20000

Times are in microseconds (us).
```

</details>

<details>
<summary>
Benchmarking Script
</summary>

```python
import torch
import pickle
import itertools
from functools import partial
from torch.utils.benchmark import Timer, Compare

benchmark_name = "linalg.solve"
label = "master"
shapes = [1, 2, 8, 16, 32, 64, 128, 256, 512]
batches = [(1,), (2,), (4,), (8,), (16,), (32,), (64,), (128,), (512,), (1024,)]
results = []
make_arg = partial(torch.randn, dtype=torch.float32, device="cuda", requires_grad=True)

for n, batch in itertools.product(shapes, batches):
    if n == 512 and batch[0] >= 512:
        continue
    A = make_arg(batch + (n, n))
    B = make_arg(batch + (n, 16))
    ones = torch.ones(B.shape, device=B.device)
    print(A.shape)
    for adjoint in (True, False):
        timer = Timer("torch.linalg.solve(A, B).backward(gradient=ones, inputs=[A, B])",
                      globals=globals(),
                      label=benchmark_name,
                      description=label,
                      sub_label=f"{A.shape}",
                      num_threads=1)
        results.append(timer.blocked_autorange())


compare = Compare(results)
compare.trim_significant_figures()
compare.print()

with open("{}.pickle".format(label), 'wb') as f:
    pickle.dump(results, f)
```
</details>

See #72935 (comment) for the script to join the results.


[ghstack-poisoned]
This PR heavily simplifies the code of `linalg.solve`. At the same time,
this implementation saves quite a some copies of the input data in some
cases (e.g. A is contiguous)

We also implement it in such a way that the derivative goes from
computing two LU decompositions and two LU solves to no LU
decompositions and one LU solves. It also avoids a number of unnecessary
copies the derivative was unnecessarily performing (at least the copy of
two matrices).

On top of this, we add a `left` kw-only arg that allows the user to
solve `XA = B` rather concisely.

This PR also makes `torch.solve` an alias of `torch.linalg.solve`.

**Note:** This used to be the last PR of the stack. Now it's here because some tests were not passing in a PR that was before this one in the stack, and reshuffling the stack solved those problems. The benchmarks below are performed wrt the last PR fo this stack.

We compare the performance of `linalg.solve` against master (before merging #67833, but already with a few PRs of the LU stack merged). We see that we got between  **x2.5 and x10 speed-ups in `linalg.solve`**.


<details>
<summary>
Benchmark Results
</summary>

```
[--------------------- linalg.solve + backward --------------------]
                                    |  master |  This PR
1 threads: ----------------------------------------------
      torch.Size([1, 1, 1])         |     1280  |    267
      torch.Size([2, 1, 1])         |     1300  |    200
      torch.Size([4, 1, 1])         |      500  |    231
      torch.Size([8, 1, 1])         |      600  |    232
      torch.Size([16, 1, 1])        |     1200  |    234
      torch.Size([32, 1, 1])        |     1300  |    239
      torch.Size([64, 1, 1])        |     1300  |    300
      torch.Size([128, 1, 1])       |     1340  |    331
      torch.Size([512, 1, 1])       |     1664  |    380
      torch.Size([1024, 1, 1])      |     2000  |    430
      torch.Size([1, 2, 2])         |     1200  |    300
      torch.Size([2, 2, 2])         |     1250  |    237
      torch.Size([4, 2, 2])         |      479  |    240
      torch.Size([8, 2, 2])         |      600  |    239
      torch.Size([16, 2, 2])        |     1300  |    242
      torch.Size([32, 2, 2])        |     1300  |    245
      torch.Size([64, 2, 2])        |     1300  |    260
      torch.Size([128, 2, 2])       |     1400  |    340
      torch.Size([512, 2, 2])       |     1680  |    380
      torch.Size([1024, 2, 2])      |     2100  |    430
      torch.Size([1, 8, 8])         |     1200  |    250
      torch.Size([2, 8, 8])         |     1240  |    238
      torch.Size([4, 8, 8])         |      480  |    240
      torch.Size([8, 8, 8])         |      600  |    240
      torch.Size([16, 8, 8])        |     1330  |    243
      torch.Size([32, 8, 8])        |     1340  |    250
      torch.Size([64, 8, 8])        |     1370  |    257
      torch.Size([128, 8, 8])       |     1400  |    280
      torch.Size([512, 8, 8])       |     1720  |    346
      torch.Size([1024, 8, 8])      |     2300  |    390
      torch.Size([1, 16, 16])       |     1380  |    245
      torch.Size([2, 16, 16])       |     1000  |    300
      torch.Size([4, 16, 16])       |      610  |    260
      torch.Size([8, 16, 16])       |      862  |    260
      torch.Size([16, 16, 16])      |     1350  |    260
      torch.Size([32, 16, 16])      |     1370  |    260
      torch.Size([64, 16, 16])      |     1440  |    273
      torch.Size([128, 16, 16])     |     1520  |    289
      torch.Size([512, 16, 16])     |     1880  |    350
      torch.Size([1024, 16, 16])    |     2540  |    530
      torch.Size([1, 32, 32])       |     1500  |    290
      torch.Size([2, 32, 32])       |     2100  |    287
      torch.Size([4, 32, 32])       |     1370  |    288
      torch.Size([8, 32, 32])       |     1389  |    290
      torch.Size([16, 32, 32])      |     1400  |    290
      torch.Size([32, 32, 32])      |     1500  |    476
      torch.Size([64, 32, 32])      |     1600  |    468
      torch.Size([128, 32, 32])     |     1700  |    479
      torch.Size([512, 32, 32])     |     2300  |    696
      torch.Size([1024, 32, 32])    |     3200  |   1200
      torch.Size([1, 64, 64])       |     1700  |    340
      torch.Size([2, 64, 64])       |     2800  |    353
      torch.Size([4, 64, 64])       |     1990  |    328
      torch.Size([8, 64, 64])       |     2040  |    330
      torch.Size([16, 64, 64])      |     2100  |    350
      torch.Size([32, 64, 64])      |     2300  |    680
      torch.Size([64, 64, 64])      |     2430  |    725
      torch.Size([128, 64, 64])     |     2600  |    845
      torch.Size([512, 64, 64])     |     4700  |   1900
      torch.Size([1024, 64, 64])    |     9200  |   4280
      torch.Size([1, 128, 128])     |     2300  |    497
      torch.Size([2, 128, 128])     |     4000  |    562
      torch.Size([4, 128, 128])     |     3140  |    669
      torch.Size([8, 128, 128])     |     3200  |    698
      torch.Size([16, 128, 128])    |     3400  |    810
      torch.Size([32, 128, 128])    |     3866  |   1410
      torch.Size([64, 128, 128])    |     4200  |   1670
      torch.Size([128, 128, 128])   |     5050  |   2170
      torch.Size([512, 128, 128])   |    14000  |   6417
      torch.Size([1024, 128, 128])  |    28900  |  14700
      torch.Size([1, 256, 256])     |     4100  |   1559
      torch.Size([2, 256, 256])     |     6800  |   1792
      torch.Size([4, 256, 256])     |     7000  |   2000
      torch.Size([8, 256, 256])     |     7300  |   2200
      torch.Size([16, 256, 256])    |     7730  |   2540
      torch.Size([32, 256, 256])    |     8500  |   3390
      torch.Size([64, 256, 256])    |    11000  |   4470
      torch.Size([128, 256, 256])   |    15900  |   6757
      torch.Size([512, 256, 256])   |    50000  |  30000
      torch.Size([1024, 256, 256])  |   102600  |  56400
      torch.Size([1, 512, 512])     |     8793  |   3230
      torch.Size([2, 512, 512])     |    13000  |   3920
      torch.Size([4, 512, 512])     |    14000  |   4531
      torch.Size([8, 512, 512])     |    15000  |   5114
      torch.Size([16, 512, 512])    |    16700  |   6280
      torch.Size([32, 512, 512])    |    22400  |   9530
      torch.Size([64, 512, 512])    |    33700  |  14260
      torch.Size([128, 512, 512])   |    56500  |  20000

Times are in microseconds (us).
```

</details>

<details>
<summary>
Benchmarking Script
</summary>

```python
import torch
import pickle
import itertools
from functools import partial
from torch.utils.benchmark import Timer, Compare

benchmark_name = "linalg.solve"
label = "master"
shapes = [1, 2, 8, 16, 32, 64, 128, 256, 512]
batches = [(1,), (2,), (4,), (8,), (16,), (32,), (64,), (128,), (512,), (1024,)]
results = []
make_arg = partial(torch.randn, dtype=torch.float32, device="cuda", requires_grad=True)

for n, batch in itertools.product(shapes, batches):
    if n == 512 and batch[0] >= 512:
        continue
    A = make_arg(batch + (n, n))
    B = make_arg(batch + (n, 16))
    ones = torch.ones(B.shape, device=B.device)
    print(A.shape)
    for adjoint in (True, False):
        timer = Timer("torch.linalg.solve(A, B).backward(gradient=ones, inputs=[A, B])",
                      globals=globals(),
                      label=benchmark_name,
                      description=label,
                      sub_label=f"{A.shape}",
                      num_threads=1)
        results.append(timer.blocked_autorange())


compare = Compare(results)
compare.trim_significant_figures()
compare.print()

with open("{}.pickle".format(label), 'wb') as f:
    pickle.dump(results, f)
```
</details>

See #72935 (comment) for the script to join the results.


[ghstack-poisoned]
This PR heavily simplifies the code of `linalg.solve`. At the same time,
this implementation saves quite a some copies of the input data in some
cases (e.g. A is contiguous)

We also implement it in such a way that the derivative goes from
computing two LU decompositions and two LU solves to no LU
decompositions and one LU solves. It also avoids a number of unnecessary
copies the derivative was unnecessarily performing (at least the copy of
two matrices).

On top of this, we add a `left` kw-only arg that allows the user to
solve `XA = B` rather concisely.

This PR also makes `torch.solve` an alias of `torch.linalg.solve`.

**Note:** This used to be the last PR of the stack. Now it's here because some tests were not passing in a PR that was before this one in the stack, and reshuffling the stack solved those problems. The benchmarks below are performed wrt the last PR fo this stack.

We compare the performance of `linalg.solve` against master (before merging #67833, but already with a few PRs of the LU stack merged). We see that we got between  **x2.5 and x10 speed-ups in `linalg.solve`**.


<details>
<summary>
Benchmark Results
</summary>

```
[--------------------- linalg.solve + backward --------------------]
                                    |  master |  This PR
1 threads: ----------------------------------------------
      torch.Size([1, 1, 1])         |     1280  |    267
      torch.Size([2, 1, 1])         |     1300  |    200
      torch.Size([4, 1, 1])         |      500  |    231
      torch.Size([8, 1, 1])         |      600  |    232
      torch.Size([16, 1, 1])        |     1200  |    234
      torch.Size([32, 1, 1])        |     1300  |    239
      torch.Size([64, 1, 1])        |     1300  |    300
      torch.Size([128, 1, 1])       |     1340  |    331
      torch.Size([512, 1, 1])       |     1664  |    380
      torch.Size([1024, 1, 1])      |     2000  |    430
      torch.Size([1, 2, 2])         |     1200  |    300
      torch.Size([2, 2, 2])         |     1250  |    237
      torch.Size([4, 2, 2])         |      479  |    240
      torch.Size([8, 2, 2])         |      600  |    239
      torch.Size([16, 2, 2])        |     1300  |    242
      torch.Size([32, 2, 2])        |     1300  |    245
      torch.Size([64, 2, 2])        |     1300  |    260
      torch.Size([128, 2, 2])       |     1400  |    340
      torch.Size([512, 2, 2])       |     1680  |    380
      torch.Size([1024, 2, 2])      |     2100  |    430
      torch.Size([1, 8, 8])         |     1200  |    250
      torch.Size([2, 8, 8])         |     1240  |    238
      torch.Size([4, 8, 8])         |      480  |    240
      torch.Size([8, 8, 8])         |      600  |    240
      torch.Size([16, 8, 8])        |     1330  |    243
      torch.Size([32, 8, 8])        |     1340  |    250
      torch.Size([64, 8, 8])        |     1370  |    257
      torch.Size([128, 8, 8])       |     1400  |    280
      torch.Size([512, 8, 8])       |     1720  |    346
      torch.Size([1024, 8, 8])      |     2300  |    390
      torch.Size([1, 16, 16])       |     1380  |    245
      torch.Size([2, 16, 16])       |     1000  |    300
      torch.Size([4, 16, 16])       |      610  |    260
      torch.Size([8, 16, 16])       |      862  |    260
      torch.Size([16, 16, 16])      |     1350  |    260
      torch.Size([32, 16, 16])      |     1370  |    260
      torch.Size([64, 16, 16])      |     1440  |    273
      torch.Size([128, 16, 16])     |     1520  |    289
      torch.Size([512, 16, 16])     |     1880  |    350
      torch.Size([1024, 16, 16])    |     2540  |    530
      torch.Size([1, 32, 32])       |     1500  |    290
      torch.Size([2, 32, 32])       |     2100  |    287
      torch.Size([4, 32, 32])       |     1370  |    288
      torch.Size([8, 32, 32])       |     1389  |    290
      torch.Size([16, 32, 32])      |     1400  |    290
      torch.Size([32, 32, 32])      |     1500  |    476
      torch.Size([64, 32, 32])      |     1600  |    468
      torch.Size([128, 32, 32])     |     1700  |    479
      torch.Size([512, 32, 32])     |     2300  |    696
      torch.Size([1024, 32, 32])    |     3200  |   1200
      torch.Size([1, 64, 64])       |     1700  |    340
      torch.Size([2, 64, 64])       |     2800  |    353
      torch.Size([4, 64, 64])       |     1990  |    328
      torch.Size([8, 64, 64])       |     2040  |    330
      torch.Size([16, 64, 64])      |     2100  |    350
      torch.Size([32, 64, 64])      |     2300  |    680
      torch.Size([64, 64, 64])      |     2430  |    725
      torch.Size([128, 64, 64])     |     2600  |    845
      torch.Size([512, 64, 64])     |     4700  |   1900
      torch.Size([1024, 64, 64])    |     9200  |   4280
      torch.Size([1, 128, 128])     |     2300  |    497
      torch.Size([2, 128, 128])     |     4000  |    562
      torch.Size([4, 128, 128])     |     3140  |    669
      torch.Size([8, 128, 128])     |     3200  |    698
      torch.Size([16, 128, 128])    |     3400  |    810
      torch.Size([32, 128, 128])    |     3866  |   1410
      torch.Size([64, 128, 128])    |     4200  |   1670
      torch.Size([128, 128, 128])   |     5050  |   2170
      torch.Size([512, 128, 128])   |    14000  |   6417
      torch.Size([1024, 128, 128])  |    28900  |  14700
      torch.Size([1, 256, 256])     |     4100  |   1559
      torch.Size([2, 256, 256])     |     6800  |   1792
      torch.Size([4, 256, 256])     |     7000  |   2000
      torch.Size([8, 256, 256])     |     7300  |   2200
      torch.Size([16, 256, 256])    |     7730  |   2540
      torch.Size([32, 256, 256])    |     8500  |   3390
      torch.Size([64, 256, 256])    |    11000  |   4470
      torch.Size([128, 256, 256])   |    15900  |   6757
      torch.Size([512, 256, 256])   |    50000  |  30000
      torch.Size([1024, 256, 256])  |   102600  |  56400
      torch.Size([1, 512, 512])     |     8793  |   3230
      torch.Size([2, 512, 512])     |    13000  |   3920
      torch.Size([4, 512, 512])     |    14000  |   4531
      torch.Size([8, 512, 512])     |    15000  |   5114
      torch.Size([16, 512, 512])    |    16700  |   6280
      torch.Size([32, 512, 512])    |    22400  |   9530
      torch.Size([64, 512, 512])    |    33700  |  14260
      torch.Size([128, 512, 512])   |    56500  |  20000

Times are in microseconds (us).
```

</details>

<details>
<summary>
Benchmarking Script
</summary>

```python
import torch
import pickle
import itertools
from functools import partial
from torch.utils.benchmark import Timer, Compare

benchmark_name = "linalg.solve"
label = "master"
shapes = [1, 2, 8, 16, 32, 64, 128, 256, 512]
batches = [(1,), (2,), (4,), (8,), (16,), (32,), (64,), (128,), (512,), (1024,)]
results = []
make_arg = partial(torch.randn, dtype=torch.float32, device="cuda", requires_grad=True)

for n, batch in itertools.product(shapes, batches):
    if n == 512 and batch[0] >= 512:
        continue
    A = make_arg(batch + (n, n))
    B = make_arg(batch + (n, 16))
    ones = torch.ones(B.shape, device=B.device)
    print(A.shape)
    for adjoint in (True, False):
        timer = Timer("torch.linalg.solve(A, B).backward(gradient=ones, inputs=[A, B])",
                      globals=globals(),
                      label=benchmark_name,
                      description=label,
                      sub_label=f"{A.shape}",
                      num_threads=1)
        results.append(timer.blocked_autorange())


compare = Compare(results)
compare.trim_significant_figures()
compare.print()

with open("{}.pickle".format(label), 'wb') as f:
    pickle.dump(results, f)
```
</details>

See #72935 (comment) for the script to join the results.


[ghstack-poisoned]
This PR heavily simplifies the code of `linalg.solve`. At the same time,
this implementation saves quite a some copies of the input data in some
cases (e.g. A is contiguous)

We also implement it in such a way that the derivative goes from
computing two LU decompositions and two LU solves to no LU
decompositions and one LU solves. It also avoids a number of unnecessary
copies the derivative was unnecessarily performing (at least the copy of
two matrices).

On top of this, we add a `left` kw-only arg that allows the user to
solve `XA = B` rather concisely.

This PR also makes `torch.solve` an alias of `torch.linalg.solve`.

**Note:** This used to be the last PR of the stack. Now it's here because some tests were not passing in a PR that was before this one in the stack, and reshuffling the stack solved those problems. The benchmarks below are performed wrt the last PR fo this stack.

We compare the performance of `linalg.solve` against master (before merging #67833, but already with a few PRs of the LU stack merged). We see that we got between  **x2.5 and x10 speed-ups in `linalg.solve`**.


<details>
<summary>
Benchmark Results
</summary>

```
[--------------------- linalg.solve + backward --------------------]
                                    |  master |  This PR
1 threads: ----------------------------------------------
      torch.Size([1, 1, 1])         |     1280  |    267
      torch.Size([2, 1, 1])         |     1300  |    200
      torch.Size([4, 1, 1])         |      500  |    231
      torch.Size([8, 1, 1])         |      600  |    232
      torch.Size([16, 1, 1])        |     1200  |    234
      torch.Size([32, 1, 1])        |     1300  |    239
      torch.Size([64, 1, 1])        |     1300  |    300
      torch.Size([128, 1, 1])       |     1340  |    331
      torch.Size([512, 1, 1])       |     1664  |    380
      torch.Size([1024, 1, 1])      |     2000  |    430
      torch.Size([1, 2, 2])         |     1200  |    300
      torch.Size([2, 2, 2])         |     1250  |    237
      torch.Size([4, 2, 2])         |      479  |    240
      torch.Size([8, 2, 2])         |      600  |    239
      torch.Size([16, 2, 2])        |     1300  |    242
      torch.Size([32, 2, 2])        |     1300  |    245
      torch.Size([64, 2, 2])        |     1300  |    260
      torch.Size([128, 2, 2])       |     1400  |    340
      torch.Size([512, 2, 2])       |     1680  |    380
      torch.Size([1024, 2, 2])      |     2100  |    430
      torch.Size([1, 8, 8])         |     1200  |    250
      torch.Size([2, 8, 8])         |     1240  |    238
      torch.Size([4, 8, 8])         |      480  |    240
      torch.Size([8, 8, 8])         |      600  |    240
      torch.Size([16, 8, 8])        |     1330  |    243
      torch.Size([32, 8, 8])        |     1340  |    250
      torch.Size([64, 8, 8])        |     1370  |    257
      torch.Size([128, 8, 8])       |     1400  |    280
      torch.Size([512, 8, 8])       |     1720  |    346
      torch.Size([1024, 8, 8])      |     2300  |    390
      torch.Size([1, 16, 16])       |     1380  |    245
      torch.Size([2, 16, 16])       |     1000  |    300
      torch.Size([4, 16, 16])       |      610  |    260
      torch.Size([8, 16, 16])       |      862  |    260
      torch.Size([16, 16, 16])      |     1350  |    260
      torch.Size([32, 16, 16])      |     1370  |    260
      torch.Size([64, 16, 16])      |     1440  |    273
      torch.Size([128, 16, 16])     |     1520  |    289
      torch.Size([512, 16, 16])     |     1880  |    350
      torch.Size([1024, 16, 16])    |     2540  |    530
      torch.Size([1, 32, 32])       |     1500  |    290
      torch.Size([2, 32, 32])       |     2100  |    287
      torch.Size([4, 32, 32])       |     1370  |    288
      torch.Size([8, 32, 32])       |     1389  |    290
      torch.Size([16, 32, 32])      |     1400  |    290
      torch.Size([32, 32, 32])      |     1500  |    476
      torch.Size([64, 32, 32])      |     1600  |    468
      torch.Size([128, 32, 32])     |     1700  |    479
      torch.Size([512, 32, 32])     |     2300  |    696
      torch.Size([1024, 32, 32])    |     3200  |   1200
      torch.Size([1, 64, 64])       |     1700  |    340
      torch.Size([2, 64, 64])       |     2800  |    353
      torch.Size([4, 64, 64])       |     1990  |    328
      torch.Size([8, 64, 64])       |     2040  |    330
      torch.Size([16, 64, 64])      |     2100  |    350
      torch.Size([32, 64, 64])      |     2300  |    680
      torch.Size([64, 64, 64])      |     2430  |    725
      torch.Size([128, 64, 64])     |     2600  |    845
      torch.Size([512, 64, 64])     |     4700  |   1900
      torch.Size([1024, 64, 64])    |     9200  |   4280
      torch.Size([1, 128, 128])     |     2300  |    497
      torch.Size([2, 128, 128])     |     4000  |    562
      torch.Size([4, 128, 128])     |     3140  |    669
      torch.Size([8, 128, 128])     |     3200  |    698
      torch.Size([16, 128, 128])    |     3400  |    810
      torch.Size([32, 128, 128])    |     3866  |   1410
      torch.Size([64, 128, 128])    |     4200  |   1670
      torch.Size([128, 128, 128])   |     5050  |   2170
      torch.Size([512, 128, 128])   |    14000  |   6417
      torch.Size([1024, 128, 128])  |    28900  |  14700
      torch.Size([1, 256, 256])     |     4100  |   1559
      torch.Size([2, 256, 256])     |     6800  |   1792
      torch.Size([4, 256, 256])     |     7000  |   2000
      torch.Size([8, 256, 256])     |     7300  |   2200
      torch.Size([16, 256, 256])    |     7730  |   2540
      torch.Size([32, 256, 256])    |     8500  |   3390
      torch.Size([64, 256, 256])    |    11000  |   4470
      torch.Size([128, 256, 256])   |    15900  |   6757
      torch.Size([512, 256, 256])   |    50000  |  30000
      torch.Size([1024, 256, 256])  |   102600  |  56400
      torch.Size([1, 512, 512])     |     8793  |   3230
      torch.Size([2, 512, 512])     |    13000  |   3920
      torch.Size([4, 512, 512])     |    14000  |   4531
      torch.Size([8, 512, 512])     |    15000  |   5114
      torch.Size([16, 512, 512])    |    16700  |   6280
      torch.Size([32, 512, 512])    |    22400  |   9530
      torch.Size([64, 512, 512])    |    33700  |  14260
      torch.Size([128, 512, 512])   |    56500  |  20000

Times are in microseconds (us).
```

</details>

<details>
<summary>
Benchmarking Script
</summary>

```python
import torch
import pickle
import itertools
from functools import partial
from torch.utils.benchmark import Timer, Compare

benchmark_name = "linalg.solve"
label = "master"
shapes = [1, 2, 8, 16, 32, 64, 128, 256, 512]
batches = [(1,), (2,), (4,), (8,), (16,), (32,), (64,), (128,), (512,), (1024,)]
results = []
make_arg = partial(torch.randn, dtype=torch.float32, device="cuda", requires_grad=True)

for n, batch in itertools.product(shapes, batches):
    if n == 512 and batch[0] >= 512:
        continue
    A = make_arg(batch + (n, n))
    B = make_arg(batch + (n, 16))
    ones = torch.ones(B.shape, device=B.device)
    print(A.shape)
    for adjoint in (True, False):
        timer = Timer("torch.linalg.solve(A, B).backward(gradient=ones, inputs=[A, B])",
                      globals=globals(),
                      label=benchmark_name,
                      description=label,
                      sub_label=f"{A.shape}",
                      num_threads=1)
        results.append(timer.blocked_autorange())


compare = Compare(results)
compare.trim_significant_figures()
compare.print()

with open("{}.pickle".format(label), 'wb') as f:
    pickle.dump(results, f)
```
</details>

See #72935 (comment) for the script to join the results.


[ghstack-poisoned]
@lezcano lezcano added ciflow/trunk Trigger trunk jobs on your pull request ciflow/periodic Trigger jobs ran periodically on master (periodic.yml) on the PR labels Jun 10, 2022
Copy link
Copy Markdown
Collaborator

@mruberry mruberry left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Stamped!

This PR heavily simplifies the code of `linalg.solve`. At the same time,
this implementation saves quite a some copies of the input data in some
cases (e.g. A is contiguous)

We also implement it in such a way that the derivative goes from
computing two LU decompositions and two LU solves to no LU
decompositions and one LU solves. It also avoids a number of unnecessary
copies the derivative was unnecessarily performing (at least the copy of
two matrices).

On top of this, we add a `left` kw-only arg that allows the user to
solve `XA = B` rather concisely.

This PR also makes `torch.solve` an alias of `torch.linalg.solve`.

**Note:** This used to be the last PR of the stack. Now it's here because some tests were not passing in a PR that was before this one in the stack, and reshuffling the stack solved those problems. The benchmarks below are performed wrt the last PR fo this stack.

We compare the performance of `linalg.solve` against master (before merging #67833, but already with a few PRs of the LU stack merged). We see that we got between  **x2.5 and x10 speed-ups in `linalg.solve`**.


<details>
<summary>
Benchmark Results
</summary>

```
[--------------------- linalg.solve + backward --------------------]
                                    |  master |  This PR
1 threads: ----------------------------------------------
      torch.Size([1, 1, 1])         |     1280  |    267
      torch.Size([2, 1, 1])         |     1300  |    200
      torch.Size([4, 1, 1])         |      500  |    231
      torch.Size([8, 1, 1])         |      600  |    232
      torch.Size([16, 1, 1])        |     1200  |    234
      torch.Size([32, 1, 1])        |     1300  |    239
      torch.Size([64, 1, 1])        |     1300  |    300
      torch.Size([128, 1, 1])       |     1340  |    331
      torch.Size([512, 1, 1])       |     1664  |    380
      torch.Size([1024, 1, 1])      |     2000  |    430
      torch.Size([1, 2, 2])         |     1200  |    300
      torch.Size([2, 2, 2])         |     1250  |    237
      torch.Size([4, 2, 2])         |      479  |    240
      torch.Size([8, 2, 2])         |      600  |    239
      torch.Size([16, 2, 2])        |     1300  |    242
      torch.Size([32, 2, 2])        |     1300  |    245
      torch.Size([64, 2, 2])        |     1300  |    260
      torch.Size([128, 2, 2])       |     1400  |    340
      torch.Size([512, 2, 2])       |     1680  |    380
      torch.Size([1024, 2, 2])      |     2100  |    430
      torch.Size([1, 8, 8])         |     1200  |    250
      torch.Size([2, 8, 8])         |     1240  |    238
      torch.Size([4, 8, 8])         |      480  |    240
      torch.Size([8, 8, 8])         |      600  |    240
      torch.Size([16, 8, 8])        |     1330  |    243
      torch.Size([32, 8, 8])        |     1340  |    250
      torch.Size([64, 8, 8])        |     1370  |    257
      torch.Size([128, 8, 8])       |     1400  |    280
      torch.Size([512, 8, 8])       |     1720  |    346
      torch.Size([1024, 8, 8])      |     2300  |    390
      torch.Size([1, 16, 16])       |     1380  |    245
      torch.Size([2, 16, 16])       |     1000  |    300
      torch.Size([4, 16, 16])       |      610  |    260
      torch.Size([8, 16, 16])       |      862  |    260
      torch.Size([16, 16, 16])      |     1350  |    260
      torch.Size([32, 16, 16])      |     1370  |    260
      torch.Size([64, 16, 16])      |     1440  |    273
      torch.Size([128, 16, 16])     |     1520  |    289
      torch.Size([512, 16, 16])     |     1880  |    350
      torch.Size([1024, 16, 16])    |     2540  |    530
      torch.Size([1, 32, 32])       |     1500  |    290
      torch.Size([2, 32, 32])       |     2100  |    287
      torch.Size([4, 32, 32])       |     1370  |    288
      torch.Size([8, 32, 32])       |     1389  |    290
      torch.Size([16, 32, 32])      |     1400  |    290
      torch.Size([32, 32, 32])      |     1500  |    476
      torch.Size([64, 32, 32])      |     1600  |    468
      torch.Size([128, 32, 32])     |     1700  |    479
      torch.Size([512, 32, 32])     |     2300  |    696
      torch.Size([1024, 32, 32])    |     3200  |   1200
      torch.Size([1, 64, 64])       |     1700  |    340
      torch.Size([2, 64, 64])       |     2800  |    353
      torch.Size([4, 64, 64])       |     1990  |    328
      torch.Size([8, 64, 64])       |     2040  |    330
      torch.Size([16, 64, 64])      |     2100  |    350
      torch.Size([32, 64, 64])      |     2300  |    680
      torch.Size([64, 64, 64])      |     2430  |    725
      torch.Size([128, 64, 64])     |     2600  |    845
      torch.Size([512, 64, 64])     |     4700  |   1900
      torch.Size([1024, 64, 64])    |     9200  |   4280
      torch.Size([1, 128, 128])     |     2300  |    497
      torch.Size([2, 128, 128])     |     4000  |    562
      torch.Size([4, 128, 128])     |     3140  |    669
      torch.Size([8, 128, 128])     |     3200  |    698
      torch.Size([16, 128, 128])    |     3400  |    810
      torch.Size([32, 128, 128])    |     3866  |   1410
      torch.Size([64, 128, 128])    |     4200  |   1670
      torch.Size([128, 128, 128])   |     5050  |   2170
      torch.Size([512, 128, 128])   |    14000  |   6417
      torch.Size([1024, 128, 128])  |    28900  |  14700
      torch.Size([1, 256, 256])     |     4100  |   1559
      torch.Size([2, 256, 256])     |     6800  |   1792
      torch.Size([4, 256, 256])     |     7000  |   2000
      torch.Size([8, 256, 256])     |     7300  |   2200
      torch.Size([16, 256, 256])    |     7730  |   2540
      torch.Size([32, 256, 256])    |     8500  |   3390
      torch.Size([64, 256, 256])    |    11000  |   4470
      torch.Size([128, 256, 256])   |    15900  |   6757
      torch.Size([512, 256, 256])   |    50000  |  30000
      torch.Size([1024, 256, 256])  |   102600  |  56400
      torch.Size([1, 512, 512])     |     8793  |   3230
      torch.Size([2, 512, 512])     |    13000  |   3920
      torch.Size([4, 512, 512])     |    14000  |   4531
      torch.Size([8, 512, 512])     |    15000  |   5114
      torch.Size([16, 512, 512])    |    16700  |   6280
      torch.Size([32, 512, 512])    |    22400  |   9530
      torch.Size([64, 512, 512])    |    33700  |  14260
      torch.Size([128, 512, 512])   |    56500  |  20000

Times are in microseconds (us).
```

</details>

<details>
<summary>
Benchmarking Script
</summary>

```python
import torch
import pickle
import itertools
from functools import partial
from torch.utils.benchmark import Timer, Compare

benchmark_name = "linalg.solve"
label = "master"
shapes = [1, 2, 8, 16, 32, 64, 128, 256, 512]
batches = [(1,), (2,), (4,), (8,), (16,), (32,), (64,), (128,), (512,), (1024,)]
results = []
make_arg = partial(torch.randn, dtype=torch.float32, device="cuda", requires_grad=True)

for n, batch in itertools.product(shapes, batches):
    if n == 512 and batch[0] >= 512:
        continue
    A = make_arg(batch + (n, n))
    B = make_arg(batch + (n, 16))
    ones = torch.ones(B.shape, device=B.device)
    print(A.shape)
    for adjoint in (True, False):
        timer = Timer("torch.linalg.solve(A, B).backward(gradient=ones, inputs=[A, B])",
                      globals=globals(),
                      label=benchmark_name,
                      description=label,
                      sub_label=f"{A.shape}",
                      num_threads=1)
        results.append(timer.blocked_autorange())


compare = Compare(results)
compare.trim_significant_figures()
compare.print()

with open("{}.pickle".format(label), 'wb') as f:
    pickle.dump(results, f)
```
</details>

See #72935 (comment) for the script to join the results.


[ghstack-poisoned]
@facebook-github-bot facebook-github-bot deleted the gh/Lezcano/54/head branch June 14, 2022 14:16
facebook-github-bot pushed a commit that referenced this pull request Jun 14, 2022
Summary:
This PR heavily simplifies the code of `linalg.solve`. At the same time,
this implementation saves quite a few copies of the input data in some
cases (e.g. A is contiguous)

We also implement it in such a way that the derivative goes from
computing two LU decompositions and two LU solves to no LU
decompositions and one LU solves. It also avoids a number of unnecessary
copies the derivative was unnecessarily performing (at least the copy of
two matrices).

On top of this, we add a `left` kw-only arg that allows the user to
solve `XA = B` rather concisely.

Pull Request resolved: #74046

Approved by: https://github.com/nikitaved, https://github.com/IvanYashchuk, https://github.com/mruberry

Test Plan: contbuild & OSS CI, see https://hud.pytorch.org/commit/pytorch/pytorch/54949a5abc9890143de4b5dd2f13ff98446376a3

Reviewed By: osalpekar

Differential Revision: D37089130

Pulled By: osalpekar

fbshipit-source-id: ca444fe7127bb3faf1de717100a4ad21e4d0f681
@astroboylrx
Copy link
Copy Markdown

I think this PR introduced some maybe-unintended errors to extreme cases (see example and discussions in #90453).

I understand errors in the correctness of the solution for ill-conditioned matrix are expected. However, one still expect the numerical error, e.g., (A @ X - B) / B, to be close to machine precision (as what torch prior to this PR, numpy, and cupy would produce).

laurentdupin pushed a commit to laurentdupin/pytorch that referenced this pull request Apr 25, 2026
This PR heavily simplifies the code of `linalg.solve`. At the same time,
this implementation saves quite a few copies of the input data in some
cases (e.g. A is contiguous)

We also implement it in such a way that the derivative goes from
computing two LU decompositions and two LU solves to no LU
decompositions and one LU solves. It also avoids a number of unnecessary
copies the derivative was unnecessarily performing (at least the copy of
two matrices).

On top of this, we add a `left` kw-only arg that allows the user to
solve `XA = B` rather concisely.

Pull Request resolved: pytorch#74046

Approved by: https://github.com/nikitaved, https://github.com/IvanYashchuk, https://github.com/mruberry
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ciflow/periodic Trigger jobs ran periodically on master (periodic.yml) on the PR ciflow/trunk Trigger trunk jobs on your pull request cla signed module: derivatives Related to derivatives of operators module: linear algebra Issues related to specialized linear algebra operations in PyTorch; includes matrix multiply matmul open source release notes: linalg_frontend release notes category topic: performance topic category

Projects

None yet

Development

Successfully merging this pull request may close these issues.

10 participants