Simplify and optimize linalg.solve#74046
Simplify and optimize linalg.solve#74046lezcano wants to merge 43 commits intogh/Lezcano/54/basefrom
Conversation
This PR heavily simplifies the code of `linalg.solve`. At the same time, this implementation saves quite a few copies of the input data in some cases (e.g. A is contiguous) We also implement it in such a way that the derivative goes from computing two LU decompositions and two LU solves to no LU decompositions and one LU solves. It also avoids a number of unnecessary copies the derivative was unnecessarily performing (at least the copy of two matrices). On top of this, we add a `left` kw-only arg that allows the user to solve `XA = B` rather concisely. [ghstack-poisoned]
CI Flow Status⚛️ CI FlowRuleset - Version:
|
🔗 Helpful links
❌ 1 New FailuresAs of commit fd02ba5 (more details on the Dr. CI page): Expand to see more
🕵️ 1 new failure recognized by patternsThe following CI failures do not appear to be due to upstream breakages
|
|
Leaving @albanD to have a look at the derivative, see if it can be implemented any better. In particular, at the moment I need to save both the |
This PR heavily simplifies the code of `linalg.solve`. At the same time, this implementation saves quite a few copies of the input data in some cases (e.g. A is contiguous) We also implement it in such a way that the derivative goes from computing two LU decompositions and two LU solves to no LU decompositions and one LU solves. It also avoids a number of unnecessary copies the derivative was unnecessarily performing (at least the copy of two matrices). On top of this, we add a `left` kw-only arg that allows the user to solve `XA = B` rather concisely. [ghstack-poisoned]
This PR heavily simplifies the code of `linalg.solve`. At the same time, this implementation saves quite a few copies of the input data in some cases (e.g. A is contiguous) We also implement it in such a way that the derivative goes from computing two LU decompositions and two LU solves to no LU decompositions and one LU solves. It also avoids a number of unnecessary copies the derivative was unnecessarily performing (at least the copy of two matrices). On top of this, we add a `left` kw-only arg that allows the user to solve `XA = B` rather concisely. [ghstack-poisoned]
This PR heavily simplifies the code of `linalg.solve`. At the same time, this implementation saves quite a few copies of the input data in some cases (e.g. A is contiguous) We also implement it in such a way that the derivative goes from computing two LU decompositions and two LU solves to no LU decompositions and one LU solves. It also avoids a number of unnecessary copies the derivative was unnecessarily performing (at least the copy of two matrices). On top of this, we add a `left` kw-only arg that allows the user to solve `XA = B` rather concisely. ghstack-source-id: 2ae45d2 Pull Request resolved: #74046
This PR heavily simplifies the code of `linalg.solve`. At the same time, this implementation saves quite a few copies of the input data in some cases (e.g. A is contiguous) We also implement it in such a way that the derivative goes from computing two LU decompositions and two LU solves to no LU decompositions and one LU solves. It also avoids a number of unnecessary copies the derivative was unnecessarily performing (at least the copy of two matrices). On top of this, we add a `left` kw-only arg that allows the user to solve `XA = B` rather concisely. [ghstack-poisoned]
This PR heavily simplifies the code of `linalg.solve`. At the same time, this implementation saves quite a few copies of the input data in some cases (e.g. A is contiguous) We also implement it in such a way that the derivative goes from computing two LU decompositions and two LU solves to no LU decompositions and one LU solves. It also avoids a number of unnecessary copies the derivative was unnecessarily performing (at least the copy of two matrices). On top of this, we add a `left` kw-only arg that allows the user to solve `XA = B` rather concisely. [ghstack-poisoned]
This PR heavily simplifies the code of `linalg.solve`. At the same time, this implementation saves quite a few copies of the input data in some cases (e.g. A is contiguous) We also implement it in such a way that the derivative goes from computing two LU decompositions and two LU solves to no LU decompositions and one LU solves. It also avoids a number of unnecessary copies the derivative was unnecessarily performing (at least the copy of two matrices). On top of this, we add a `left` kw-only arg that allows the user to solve `XA = B` rather concisely. [ghstack-poisoned]
This PR heavily simplifies the code of `linalg.solve`. At the same time, this implementation saves quite a some copies of the input data in some cases (e.g. A is contiguous) We also implement it in such a way that the derivative goes from computing two LU decompositions and two LU solves to no LU decompositions and one LU solves. It also avoids a number of unnecessary copies the derivative was unnecessarily performing (at least the copy of two matrices). On top of this, we add a `left` kw-only arg that allows the user to solve `XA = B` rather concisely. This PR also makes `torch.solve` an alias of `torch.linalg.solve`. **Note:** This used to be the last PR of the stack. Now it's here because some tests were not passing in a PR that was before this one in the stack, and reshuffling the stack solved those problems. The benchmarks below are performed wrt the last PR fo this stack. We compare the performance of `linalg.solve` against master (before merging #67833, but already with a few PRs of the LU stack merged). We see that we got between **x2.5 and x10 speed-ups in `linalg.solve`**. <details> <summary> Benchmark Results </summary> ``` [--------------------- linalg.solve + backward --------------------] | master | This PR 1 threads: ---------------------------------------------- torch.Size([1, 1, 1]) | 1280 | 267 torch.Size([2, 1, 1]) | 1300 | 200 torch.Size([4, 1, 1]) | 500 | 231 torch.Size([8, 1, 1]) | 600 | 232 torch.Size([16, 1, 1]) | 1200 | 234 torch.Size([32, 1, 1]) | 1300 | 239 torch.Size([64, 1, 1]) | 1300 | 300 torch.Size([128, 1, 1]) | 1340 | 331 torch.Size([512, 1, 1]) | 1664 | 380 torch.Size([1024, 1, 1]) | 2000 | 430 torch.Size([1, 2, 2]) | 1200 | 300 torch.Size([2, 2, 2]) | 1250 | 237 torch.Size([4, 2, 2]) | 479 | 240 torch.Size([8, 2, 2]) | 600 | 239 torch.Size([16, 2, 2]) | 1300 | 242 torch.Size([32, 2, 2]) | 1300 | 245 torch.Size([64, 2, 2]) | 1300 | 260 torch.Size([128, 2, 2]) | 1400 | 340 torch.Size([512, 2, 2]) | 1680 | 380 torch.Size([1024, 2, 2]) | 2100 | 430 torch.Size([1, 8, 8]) | 1200 | 250 torch.Size([2, 8, 8]) | 1240 | 238 torch.Size([4, 8, 8]) | 480 | 240 torch.Size([8, 8, 8]) | 600 | 240 torch.Size([16, 8, 8]) | 1330 | 243 torch.Size([32, 8, 8]) | 1340 | 250 torch.Size([64, 8, 8]) | 1370 | 257 torch.Size([128, 8, 8]) | 1400 | 280 torch.Size([512, 8, 8]) | 1720 | 346 torch.Size([1024, 8, 8]) | 2300 | 390 torch.Size([1, 16, 16]) | 1380 | 245 torch.Size([2, 16, 16]) | 1000 | 300 torch.Size([4, 16, 16]) | 610 | 260 torch.Size([8, 16, 16]) | 862 | 260 torch.Size([16, 16, 16]) | 1350 | 260 torch.Size([32, 16, 16]) | 1370 | 260 torch.Size([64, 16, 16]) | 1440 | 273 torch.Size([128, 16, 16]) | 1520 | 289 torch.Size([512, 16, 16]) | 1880 | 350 torch.Size([1024, 16, 16]) | 2540 | 530 torch.Size([1, 32, 32]) | 1500 | 290 torch.Size([2, 32, 32]) | 2100 | 287 torch.Size([4, 32, 32]) | 1370 | 288 torch.Size([8, 32, 32]) | 1389 | 290 torch.Size([16, 32, 32]) | 1400 | 290 torch.Size([32, 32, 32]) | 1500 | 476 torch.Size([64, 32, 32]) | 1600 | 468 torch.Size([128, 32, 32]) | 1700 | 479 torch.Size([512, 32, 32]) | 2300 | 696 torch.Size([1024, 32, 32]) | 3200 | 1200 torch.Size([1, 64, 64]) | 1700 | 340 torch.Size([2, 64, 64]) | 2800 | 353 torch.Size([4, 64, 64]) | 1990 | 328 torch.Size([8, 64, 64]) | 2040 | 330 torch.Size([16, 64, 64]) | 2100 | 350 torch.Size([32, 64, 64]) | 2300 | 680 torch.Size([64, 64, 64]) | 2430 | 725 torch.Size([128, 64, 64]) | 2600 | 845 torch.Size([512, 64, 64]) | 4700 | 1900 torch.Size([1024, 64, 64]) | 9200 | 4280 torch.Size([1, 128, 128]) | 2300 | 497 torch.Size([2, 128, 128]) | 4000 | 562 torch.Size([4, 128, 128]) | 3140 | 669 torch.Size([8, 128, 128]) | 3200 | 698 torch.Size([16, 128, 128]) | 3400 | 810 torch.Size([32, 128, 128]) | 3866 | 1410 torch.Size([64, 128, 128]) | 4200 | 1670 torch.Size([128, 128, 128]) | 5050 | 2170 torch.Size([512, 128, 128]) | 14000 | 6417 torch.Size([1024, 128, 128]) | 28900 | 14700 torch.Size([1, 256, 256]) | 4100 | 1559 torch.Size([2, 256, 256]) | 6800 | 1792 torch.Size([4, 256, 256]) | 7000 | 2000 torch.Size([8, 256, 256]) | 7300 | 2200 torch.Size([16, 256, 256]) | 7730 | 2540 torch.Size([32, 256, 256]) | 8500 | 3390 torch.Size([64, 256, 256]) | 11000 | 4470 torch.Size([128, 256, 256]) | 15900 | 6757 torch.Size([512, 256, 256]) | 50000 | 30000 torch.Size([1024, 256, 256]) | 102600 | 56400 torch.Size([1, 512, 512]) | 8793 | 3230 torch.Size([2, 512, 512]) | 13000 | 3920 torch.Size([4, 512, 512]) | 14000 | 4531 torch.Size([8, 512, 512]) | 15000 | 5114 torch.Size([16, 512, 512]) | 16700 | 6280 torch.Size([32, 512, 512]) | 22400 | 9530 torch.Size([64, 512, 512]) | 33700 | 14260 torch.Size([128, 512, 512]) | 56500 | 20000 Times are in microseconds (us). ``` </details> <details> <summary> Benchmarking Script </summary> ```python import torch import pickle import itertools from functools import partial from torch.utils.benchmark import Timer, Compare benchmark_name = "linalg.solve" label = "master" shapes = [1, 2, 8, 16, 32, 64, 128, 256, 512] batches = [(1,), (2,), (4,), (8,), (16,), (32,), (64,), (128,), (512,), (1024,)] results = [] make_arg = partial(torch.randn, dtype=torch.float32, device="cuda", requires_grad=True) for n, batch in itertools.product(shapes, batches): if n == 512 and batch[0] >= 512: continue A = make_arg(batch + (n, n)) B = make_arg(batch + (n, 16)) ones = torch.ones(B.shape, device=B.device) print(A.shape) for adjoint in (True, False): timer = Timer("torch.linalg.solve(A, B).backward(gradient=ones, inputs=[A, B])", globals=globals(), label=benchmark_name, description=label, sub_label=f"{A.shape}", num_threads=1) results.append(timer.blocked_autorange()) compare = Compare(results) compare.trim_significant_figures() compare.print() with open("{}.pickle".format(label), 'wb') as f: pickle.dump(results, f) ``` </details> See #72935 (comment) for the script to join the results. [ghstack-poisoned]
This PR heavily simplifies the code of `linalg.solve`. At the same time, this implementation saves quite a some copies of the input data in some cases (e.g. A is contiguous) We also implement it in such a way that the derivative goes from computing two LU decompositions and two LU solves to no LU decompositions and one LU solves. It also avoids a number of unnecessary copies the derivative was unnecessarily performing (at least the copy of two matrices). On top of this, we add a `left` kw-only arg that allows the user to solve `XA = B` rather concisely. This PR also makes `torch.solve` an alias of `torch.linalg.solve`. **Note:** This used to be the last PR of the stack. Now it's here because some tests were not passing in a PR that was before this one in the stack, and reshuffling the stack solved those problems. The benchmarks below are performed wrt the last PR fo this stack. We compare the performance of `linalg.solve` against master (before merging #67833, but already with a few PRs of the LU stack merged). We see that we got between **x2.5 and x10 speed-ups in `linalg.solve`**. <details> <summary> Benchmark Results </summary> ``` [--------------------- linalg.solve + backward --------------------] | master | This PR 1 threads: ---------------------------------------------- torch.Size([1, 1, 1]) | 1280 | 267 torch.Size([2, 1, 1]) | 1300 | 200 torch.Size([4, 1, 1]) | 500 | 231 torch.Size([8, 1, 1]) | 600 | 232 torch.Size([16, 1, 1]) | 1200 | 234 torch.Size([32, 1, 1]) | 1300 | 239 torch.Size([64, 1, 1]) | 1300 | 300 torch.Size([128, 1, 1]) | 1340 | 331 torch.Size([512, 1, 1]) | 1664 | 380 torch.Size([1024, 1, 1]) | 2000 | 430 torch.Size([1, 2, 2]) | 1200 | 300 torch.Size([2, 2, 2]) | 1250 | 237 torch.Size([4, 2, 2]) | 479 | 240 torch.Size([8, 2, 2]) | 600 | 239 torch.Size([16, 2, 2]) | 1300 | 242 torch.Size([32, 2, 2]) | 1300 | 245 torch.Size([64, 2, 2]) | 1300 | 260 torch.Size([128, 2, 2]) | 1400 | 340 torch.Size([512, 2, 2]) | 1680 | 380 torch.Size([1024, 2, 2]) | 2100 | 430 torch.Size([1, 8, 8]) | 1200 | 250 torch.Size([2, 8, 8]) | 1240 | 238 torch.Size([4, 8, 8]) | 480 | 240 torch.Size([8, 8, 8]) | 600 | 240 torch.Size([16, 8, 8]) | 1330 | 243 torch.Size([32, 8, 8]) | 1340 | 250 torch.Size([64, 8, 8]) | 1370 | 257 torch.Size([128, 8, 8]) | 1400 | 280 torch.Size([512, 8, 8]) | 1720 | 346 torch.Size([1024, 8, 8]) | 2300 | 390 torch.Size([1, 16, 16]) | 1380 | 245 torch.Size([2, 16, 16]) | 1000 | 300 torch.Size([4, 16, 16]) | 610 | 260 torch.Size([8, 16, 16]) | 862 | 260 torch.Size([16, 16, 16]) | 1350 | 260 torch.Size([32, 16, 16]) | 1370 | 260 torch.Size([64, 16, 16]) | 1440 | 273 torch.Size([128, 16, 16]) | 1520 | 289 torch.Size([512, 16, 16]) | 1880 | 350 torch.Size([1024, 16, 16]) | 2540 | 530 torch.Size([1, 32, 32]) | 1500 | 290 torch.Size([2, 32, 32]) | 2100 | 287 torch.Size([4, 32, 32]) | 1370 | 288 torch.Size([8, 32, 32]) | 1389 | 290 torch.Size([16, 32, 32]) | 1400 | 290 torch.Size([32, 32, 32]) | 1500 | 476 torch.Size([64, 32, 32]) | 1600 | 468 torch.Size([128, 32, 32]) | 1700 | 479 torch.Size([512, 32, 32]) | 2300 | 696 torch.Size([1024, 32, 32]) | 3200 | 1200 torch.Size([1, 64, 64]) | 1700 | 340 torch.Size([2, 64, 64]) | 2800 | 353 torch.Size([4, 64, 64]) | 1990 | 328 torch.Size([8, 64, 64]) | 2040 | 330 torch.Size([16, 64, 64]) | 2100 | 350 torch.Size([32, 64, 64]) | 2300 | 680 torch.Size([64, 64, 64]) | 2430 | 725 torch.Size([128, 64, 64]) | 2600 | 845 torch.Size([512, 64, 64]) | 4700 | 1900 torch.Size([1024, 64, 64]) | 9200 | 4280 torch.Size([1, 128, 128]) | 2300 | 497 torch.Size([2, 128, 128]) | 4000 | 562 torch.Size([4, 128, 128]) | 3140 | 669 torch.Size([8, 128, 128]) | 3200 | 698 torch.Size([16, 128, 128]) | 3400 | 810 torch.Size([32, 128, 128]) | 3866 | 1410 torch.Size([64, 128, 128]) | 4200 | 1670 torch.Size([128, 128, 128]) | 5050 | 2170 torch.Size([512, 128, 128]) | 14000 | 6417 torch.Size([1024, 128, 128]) | 28900 | 14700 torch.Size([1, 256, 256]) | 4100 | 1559 torch.Size([2, 256, 256]) | 6800 | 1792 torch.Size([4, 256, 256]) | 7000 | 2000 torch.Size([8, 256, 256]) | 7300 | 2200 torch.Size([16, 256, 256]) | 7730 | 2540 torch.Size([32, 256, 256]) | 8500 | 3390 torch.Size([64, 256, 256]) | 11000 | 4470 torch.Size([128, 256, 256]) | 15900 | 6757 torch.Size([512, 256, 256]) | 50000 | 30000 torch.Size([1024, 256, 256]) | 102600 | 56400 torch.Size([1, 512, 512]) | 8793 | 3230 torch.Size([2, 512, 512]) | 13000 | 3920 torch.Size([4, 512, 512]) | 14000 | 4531 torch.Size([8, 512, 512]) | 15000 | 5114 torch.Size([16, 512, 512]) | 16700 | 6280 torch.Size([32, 512, 512]) | 22400 | 9530 torch.Size([64, 512, 512]) | 33700 | 14260 torch.Size([128, 512, 512]) | 56500 | 20000 Times are in microseconds (us). ``` </details> <details> <summary> Benchmarking Script </summary> ```python import torch import pickle import itertools from functools import partial from torch.utils.benchmark import Timer, Compare benchmark_name = "linalg.solve" label = "master" shapes = [1, 2, 8, 16, 32, 64, 128, 256, 512] batches = [(1,), (2,), (4,), (8,), (16,), (32,), (64,), (128,), (512,), (1024,)] results = [] make_arg = partial(torch.randn, dtype=torch.float32, device="cuda", requires_grad=True) for n, batch in itertools.product(shapes, batches): if n == 512 and batch[0] >= 512: continue A = make_arg(batch + (n, n)) B = make_arg(batch + (n, 16)) ones = torch.ones(B.shape, device=B.device) print(A.shape) for adjoint in (True, False): timer = Timer("torch.linalg.solve(A, B).backward(gradient=ones, inputs=[A, B])", globals=globals(), label=benchmark_name, description=label, sub_label=f"{A.shape}", num_threads=1) results.append(timer.blocked_autorange()) compare = Compare(results) compare.trim_significant_figures() compare.print() with open("{}.pickle".format(label), 'wb') as f: pickle.dump(results, f) ``` </details> See #72935 (comment) for the script to join the results. [ghstack-poisoned]
This PR heavily simplifies the code of `linalg.solve`. At the same time, this implementation saves quite a some copies of the input data in some cases (e.g. A is contiguous) We also implement it in such a way that the derivative goes from computing two LU decompositions and two LU solves to no LU decompositions and one LU solves. It also avoids a number of unnecessary copies the derivative was unnecessarily performing (at least the copy of two matrices). On top of this, we add a `left` kw-only arg that allows the user to solve `XA = B` rather concisely. This PR also makes `torch.solve` an alias of `torch.linalg.solve`. **Note:** This used to be the last PR of the stack. Now it's here because some tests were not passing in a PR that was before this one in the stack, and reshuffling the stack solved those problems. The benchmarks below are performed wrt the last PR fo this stack. We compare the performance of `linalg.solve` against master (before merging #67833, but already with a few PRs of the LU stack merged). We see that we got between **x2.5 and x10 speed-ups in `linalg.solve`**. <details> <summary> Benchmark Results </summary> ``` [--------------------- linalg.solve + backward --------------------] | master | This PR 1 threads: ---------------------------------------------- torch.Size([1, 1, 1]) | 1280 | 267 torch.Size([2, 1, 1]) | 1300 | 200 torch.Size([4, 1, 1]) | 500 | 231 torch.Size([8, 1, 1]) | 600 | 232 torch.Size([16, 1, 1]) | 1200 | 234 torch.Size([32, 1, 1]) | 1300 | 239 torch.Size([64, 1, 1]) | 1300 | 300 torch.Size([128, 1, 1]) | 1340 | 331 torch.Size([512, 1, 1]) | 1664 | 380 torch.Size([1024, 1, 1]) | 2000 | 430 torch.Size([1, 2, 2]) | 1200 | 300 torch.Size([2, 2, 2]) | 1250 | 237 torch.Size([4, 2, 2]) | 479 | 240 torch.Size([8, 2, 2]) | 600 | 239 torch.Size([16, 2, 2]) | 1300 | 242 torch.Size([32, 2, 2]) | 1300 | 245 torch.Size([64, 2, 2]) | 1300 | 260 torch.Size([128, 2, 2]) | 1400 | 340 torch.Size([512, 2, 2]) | 1680 | 380 torch.Size([1024, 2, 2]) | 2100 | 430 torch.Size([1, 8, 8]) | 1200 | 250 torch.Size([2, 8, 8]) | 1240 | 238 torch.Size([4, 8, 8]) | 480 | 240 torch.Size([8, 8, 8]) | 600 | 240 torch.Size([16, 8, 8]) | 1330 | 243 torch.Size([32, 8, 8]) | 1340 | 250 torch.Size([64, 8, 8]) | 1370 | 257 torch.Size([128, 8, 8]) | 1400 | 280 torch.Size([512, 8, 8]) | 1720 | 346 torch.Size([1024, 8, 8]) | 2300 | 390 torch.Size([1, 16, 16]) | 1380 | 245 torch.Size([2, 16, 16]) | 1000 | 300 torch.Size([4, 16, 16]) | 610 | 260 torch.Size([8, 16, 16]) | 862 | 260 torch.Size([16, 16, 16]) | 1350 | 260 torch.Size([32, 16, 16]) | 1370 | 260 torch.Size([64, 16, 16]) | 1440 | 273 torch.Size([128, 16, 16]) | 1520 | 289 torch.Size([512, 16, 16]) | 1880 | 350 torch.Size([1024, 16, 16]) | 2540 | 530 torch.Size([1, 32, 32]) | 1500 | 290 torch.Size([2, 32, 32]) | 2100 | 287 torch.Size([4, 32, 32]) | 1370 | 288 torch.Size([8, 32, 32]) | 1389 | 290 torch.Size([16, 32, 32]) | 1400 | 290 torch.Size([32, 32, 32]) | 1500 | 476 torch.Size([64, 32, 32]) | 1600 | 468 torch.Size([128, 32, 32]) | 1700 | 479 torch.Size([512, 32, 32]) | 2300 | 696 torch.Size([1024, 32, 32]) | 3200 | 1200 torch.Size([1, 64, 64]) | 1700 | 340 torch.Size([2, 64, 64]) | 2800 | 353 torch.Size([4, 64, 64]) | 1990 | 328 torch.Size([8, 64, 64]) | 2040 | 330 torch.Size([16, 64, 64]) | 2100 | 350 torch.Size([32, 64, 64]) | 2300 | 680 torch.Size([64, 64, 64]) | 2430 | 725 torch.Size([128, 64, 64]) | 2600 | 845 torch.Size([512, 64, 64]) | 4700 | 1900 torch.Size([1024, 64, 64]) | 9200 | 4280 torch.Size([1, 128, 128]) | 2300 | 497 torch.Size([2, 128, 128]) | 4000 | 562 torch.Size([4, 128, 128]) | 3140 | 669 torch.Size([8, 128, 128]) | 3200 | 698 torch.Size([16, 128, 128]) | 3400 | 810 torch.Size([32, 128, 128]) | 3866 | 1410 torch.Size([64, 128, 128]) | 4200 | 1670 torch.Size([128, 128, 128]) | 5050 | 2170 torch.Size([512, 128, 128]) | 14000 | 6417 torch.Size([1024, 128, 128]) | 28900 | 14700 torch.Size([1, 256, 256]) | 4100 | 1559 torch.Size([2, 256, 256]) | 6800 | 1792 torch.Size([4, 256, 256]) | 7000 | 2000 torch.Size([8, 256, 256]) | 7300 | 2200 torch.Size([16, 256, 256]) | 7730 | 2540 torch.Size([32, 256, 256]) | 8500 | 3390 torch.Size([64, 256, 256]) | 11000 | 4470 torch.Size([128, 256, 256]) | 15900 | 6757 torch.Size([512, 256, 256]) | 50000 | 30000 torch.Size([1024, 256, 256]) | 102600 | 56400 torch.Size([1, 512, 512]) | 8793 | 3230 torch.Size([2, 512, 512]) | 13000 | 3920 torch.Size([4, 512, 512]) | 14000 | 4531 torch.Size([8, 512, 512]) | 15000 | 5114 torch.Size([16, 512, 512]) | 16700 | 6280 torch.Size([32, 512, 512]) | 22400 | 9530 torch.Size([64, 512, 512]) | 33700 | 14260 torch.Size([128, 512, 512]) | 56500 | 20000 Times are in microseconds (us). ``` </details> <details> <summary> Benchmarking Script </summary> ```python import torch import pickle import itertools from functools import partial from torch.utils.benchmark import Timer, Compare benchmark_name = "linalg.solve" label = "master" shapes = [1, 2, 8, 16, 32, 64, 128, 256, 512] batches = [(1,), (2,), (4,), (8,), (16,), (32,), (64,), (128,), (512,), (1024,)] results = [] make_arg = partial(torch.randn, dtype=torch.float32, device="cuda", requires_grad=True) for n, batch in itertools.product(shapes, batches): if n == 512 and batch[0] >= 512: continue A = make_arg(batch + (n, n)) B = make_arg(batch + (n, 16)) ones = torch.ones(B.shape, device=B.device) print(A.shape) for adjoint in (True, False): timer = Timer("torch.linalg.solve(A, B).backward(gradient=ones, inputs=[A, B])", globals=globals(), label=benchmark_name, description=label, sub_label=f"{A.shape}", num_threads=1) results.append(timer.blocked_autorange()) compare = Compare(results) compare.trim_significant_figures() compare.print() with open("{}.pickle".format(label), 'wb') as f: pickle.dump(results, f) ``` </details> See #72935 (comment) for the script to join the results. [ghstack-poisoned]
This PR heavily simplifies the code of `linalg.solve`. At the same time, this implementation saves quite a some copies of the input data in some cases (e.g. A is contiguous) We also implement it in such a way that the derivative goes from computing two LU decompositions and two LU solves to no LU decompositions and one LU solves. It also avoids a number of unnecessary copies the derivative was unnecessarily performing (at least the copy of two matrices). On top of this, we add a `left` kw-only arg that allows the user to solve `XA = B` rather concisely. This PR also makes `torch.solve` an alias of `torch.linalg.solve`. **Note:** This used to be the last PR of the stack. Now it's here because some tests were not passing in a PR that was before this one in the stack, and reshuffling the stack solved those problems. The benchmarks below are performed wrt the last PR fo this stack. We compare the performance of `linalg.solve` against master (before merging #67833, but already with a few PRs of the LU stack merged). We see that we got between **x2.5 and x10 speed-ups in `linalg.solve`**. <details> <summary> Benchmark Results </summary> ``` [--------------------- linalg.solve + backward --------------------] | master | This PR 1 threads: ---------------------------------------------- torch.Size([1, 1, 1]) | 1280 | 267 torch.Size([2, 1, 1]) | 1300 | 200 torch.Size([4, 1, 1]) | 500 | 231 torch.Size([8, 1, 1]) | 600 | 232 torch.Size([16, 1, 1]) | 1200 | 234 torch.Size([32, 1, 1]) | 1300 | 239 torch.Size([64, 1, 1]) | 1300 | 300 torch.Size([128, 1, 1]) | 1340 | 331 torch.Size([512, 1, 1]) | 1664 | 380 torch.Size([1024, 1, 1]) | 2000 | 430 torch.Size([1, 2, 2]) | 1200 | 300 torch.Size([2, 2, 2]) | 1250 | 237 torch.Size([4, 2, 2]) | 479 | 240 torch.Size([8, 2, 2]) | 600 | 239 torch.Size([16, 2, 2]) | 1300 | 242 torch.Size([32, 2, 2]) | 1300 | 245 torch.Size([64, 2, 2]) | 1300 | 260 torch.Size([128, 2, 2]) | 1400 | 340 torch.Size([512, 2, 2]) | 1680 | 380 torch.Size([1024, 2, 2]) | 2100 | 430 torch.Size([1, 8, 8]) | 1200 | 250 torch.Size([2, 8, 8]) | 1240 | 238 torch.Size([4, 8, 8]) | 480 | 240 torch.Size([8, 8, 8]) | 600 | 240 torch.Size([16, 8, 8]) | 1330 | 243 torch.Size([32, 8, 8]) | 1340 | 250 torch.Size([64, 8, 8]) | 1370 | 257 torch.Size([128, 8, 8]) | 1400 | 280 torch.Size([512, 8, 8]) | 1720 | 346 torch.Size([1024, 8, 8]) | 2300 | 390 torch.Size([1, 16, 16]) | 1380 | 245 torch.Size([2, 16, 16]) | 1000 | 300 torch.Size([4, 16, 16]) | 610 | 260 torch.Size([8, 16, 16]) | 862 | 260 torch.Size([16, 16, 16]) | 1350 | 260 torch.Size([32, 16, 16]) | 1370 | 260 torch.Size([64, 16, 16]) | 1440 | 273 torch.Size([128, 16, 16]) | 1520 | 289 torch.Size([512, 16, 16]) | 1880 | 350 torch.Size([1024, 16, 16]) | 2540 | 530 torch.Size([1, 32, 32]) | 1500 | 290 torch.Size([2, 32, 32]) | 2100 | 287 torch.Size([4, 32, 32]) | 1370 | 288 torch.Size([8, 32, 32]) | 1389 | 290 torch.Size([16, 32, 32]) | 1400 | 290 torch.Size([32, 32, 32]) | 1500 | 476 torch.Size([64, 32, 32]) | 1600 | 468 torch.Size([128, 32, 32]) | 1700 | 479 torch.Size([512, 32, 32]) | 2300 | 696 torch.Size([1024, 32, 32]) | 3200 | 1200 torch.Size([1, 64, 64]) | 1700 | 340 torch.Size([2, 64, 64]) | 2800 | 353 torch.Size([4, 64, 64]) | 1990 | 328 torch.Size([8, 64, 64]) | 2040 | 330 torch.Size([16, 64, 64]) | 2100 | 350 torch.Size([32, 64, 64]) | 2300 | 680 torch.Size([64, 64, 64]) | 2430 | 725 torch.Size([128, 64, 64]) | 2600 | 845 torch.Size([512, 64, 64]) | 4700 | 1900 torch.Size([1024, 64, 64]) | 9200 | 4280 torch.Size([1, 128, 128]) | 2300 | 497 torch.Size([2, 128, 128]) | 4000 | 562 torch.Size([4, 128, 128]) | 3140 | 669 torch.Size([8, 128, 128]) | 3200 | 698 torch.Size([16, 128, 128]) | 3400 | 810 torch.Size([32, 128, 128]) | 3866 | 1410 torch.Size([64, 128, 128]) | 4200 | 1670 torch.Size([128, 128, 128]) | 5050 | 2170 torch.Size([512, 128, 128]) | 14000 | 6417 torch.Size([1024, 128, 128]) | 28900 | 14700 torch.Size([1, 256, 256]) | 4100 | 1559 torch.Size([2, 256, 256]) | 6800 | 1792 torch.Size([4, 256, 256]) | 7000 | 2000 torch.Size([8, 256, 256]) | 7300 | 2200 torch.Size([16, 256, 256]) | 7730 | 2540 torch.Size([32, 256, 256]) | 8500 | 3390 torch.Size([64, 256, 256]) | 11000 | 4470 torch.Size([128, 256, 256]) | 15900 | 6757 torch.Size([512, 256, 256]) | 50000 | 30000 torch.Size([1024, 256, 256]) | 102600 | 56400 torch.Size([1, 512, 512]) | 8793 | 3230 torch.Size([2, 512, 512]) | 13000 | 3920 torch.Size([4, 512, 512]) | 14000 | 4531 torch.Size([8, 512, 512]) | 15000 | 5114 torch.Size([16, 512, 512]) | 16700 | 6280 torch.Size([32, 512, 512]) | 22400 | 9530 torch.Size([64, 512, 512]) | 33700 | 14260 torch.Size([128, 512, 512]) | 56500 | 20000 Times are in microseconds (us). ``` </details> <details> <summary> Benchmarking Script </summary> ```python import torch import pickle import itertools from functools import partial from torch.utils.benchmark import Timer, Compare benchmark_name = "linalg.solve" label = "master" shapes = [1, 2, 8, 16, 32, 64, 128, 256, 512] batches = [(1,), (2,), (4,), (8,), (16,), (32,), (64,), (128,), (512,), (1024,)] results = [] make_arg = partial(torch.randn, dtype=torch.float32, device="cuda", requires_grad=True) for n, batch in itertools.product(shapes, batches): if n == 512 and batch[0] >= 512: continue A = make_arg(batch + (n, n)) B = make_arg(batch + (n, 16)) ones = torch.ones(B.shape, device=B.device) print(A.shape) for adjoint in (True, False): timer = Timer("torch.linalg.solve(A, B).backward(gradient=ones, inputs=[A, B])", globals=globals(), label=benchmark_name, description=label, sub_label=f"{A.shape}", num_threads=1) results.append(timer.blocked_autorange()) compare = Compare(results) compare.trim_significant_figures() compare.print() with open("{}.pickle".format(label), 'wb') as f: pickle.dump(results, f) ``` </details> See #72935 (comment) for the script to join the results. [ghstack-poisoned]
IvanYashchuk
left a comment
There was a problem hiding this comment.
There's an unintended update of the third_party/cudnn_frontend submodule. This change must be reverted.
Other than that it's good to be merged. Thank you for making linalg_solve structured and removing the torch.solve changes!
This PR heavily simplifies the code of `linalg.solve`. At the same time, this implementation saves quite a some copies of the input data in some cases (e.g. A is contiguous) We also implement it in such a way that the derivative goes from computing two LU decompositions and two LU solves to no LU decompositions and one LU solves. It also avoids a number of unnecessary copies the derivative was unnecessarily performing (at least the copy of two matrices). On top of this, we add a `left` kw-only arg that allows the user to solve `XA = B` rather concisely. This PR also makes `torch.solve` an alias of `torch.linalg.solve`. **Note:** This used to be the last PR of the stack. Now it's here because some tests were not passing in a PR that was before this one in the stack, and reshuffling the stack solved those problems. The benchmarks below are performed wrt the last PR fo this stack. We compare the performance of `linalg.solve` against master (before merging #67833, but already with a few PRs of the LU stack merged). We see that we got between **x2.5 and x10 speed-ups in `linalg.solve`**. <details> <summary> Benchmark Results </summary> ``` [--------------------- linalg.solve + backward --------------------] | master | This PR 1 threads: ---------------------------------------------- torch.Size([1, 1, 1]) | 1280 | 267 torch.Size([2, 1, 1]) | 1300 | 200 torch.Size([4, 1, 1]) | 500 | 231 torch.Size([8, 1, 1]) | 600 | 232 torch.Size([16, 1, 1]) | 1200 | 234 torch.Size([32, 1, 1]) | 1300 | 239 torch.Size([64, 1, 1]) | 1300 | 300 torch.Size([128, 1, 1]) | 1340 | 331 torch.Size([512, 1, 1]) | 1664 | 380 torch.Size([1024, 1, 1]) | 2000 | 430 torch.Size([1, 2, 2]) | 1200 | 300 torch.Size([2, 2, 2]) | 1250 | 237 torch.Size([4, 2, 2]) | 479 | 240 torch.Size([8, 2, 2]) | 600 | 239 torch.Size([16, 2, 2]) | 1300 | 242 torch.Size([32, 2, 2]) | 1300 | 245 torch.Size([64, 2, 2]) | 1300 | 260 torch.Size([128, 2, 2]) | 1400 | 340 torch.Size([512, 2, 2]) | 1680 | 380 torch.Size([1024, 2, 2]) | 2100 | 430 torch.Size([1, 8, 8]) | 1200 | 250 torch.Size([2, 8, 8]) | 1240 | 238 torch.Size([4, 8, 8]) | 480 | 240 torch.Size([8, 8, 8]) | 600 | 240 torch.Size([16, 8, 8]) | 1330 | 243 torch.Size([32, 8, 8]) | 1340 | 250 torch.Size([64, 8, 8]) | 1370 | 257 torch.Size([128, 8, 8]) | 1400 | 280 torch.Size([512, 8, 8]) | 1720 | 346 torch.Size([1024, 8, 8]) | 2300 | 390 torch.Size([1, 16, 16]) | 1380 | 245 torch.Size([2, 16, 16]) | 1000 | 300 torch.Size([4, 16, 16]) | 610 | 260 torch.Size([8, 16, 16]) | 862 | 260 torch.Size([16, 16, 16]) | 1350 | 260 torch.Size([32, 16, 16]) | 1370 | 260 torch.Size([64, 16, 16]) | 1440 | 273 torch.Size([128, 16, 16]) | 1520 | 289 torch.Size([512, 16, 16]) | 1880 | 350 torch.Size([1024, 16, 16]) | 2540 | 530 torch.Size([1, 32, 32]) | 1500 | 290 torch.Size([2, 32, 32]) | 2100 | 287 torch.Size([4, 32, 32]) | 1370 | 288 torch.Size([8, 32, 32]) | 1389 | 290 torch.Size([16, 32, 32]) | 1400 | 290 torch.Size([32, 32, 32]) | 1500 | 476 torch.Size([64, 32, 32]) | 1600 | 468 torch.Size([128, 32, 32]) | 1700 | 479 torch.Size([512, 32, 32]) | 2300 | 696 torch.Size([1024, 32, 32]) | 3200 | 1200 torch.Size([1, 64, 64]) | 1700 | 340 torch.Size([2, 64, 64]) | 2800 | 353 torch.Size([4, 64, 64]) | 1990 | 328 torch.Size([8, 64, 64]) | 2040 | 330 torch.Size([16, 64, 64]) | 2100 | 350 torch.Size([32, 64, 64]) | 2300 | 680 torch.Size([64, 64, 64]) | 2430 | 725 torch.Size([128, 64, 64]) | 2600 | 845 torch.Size([512, 64, 64]) | 4700 | 1900 torch.Size([1024, 64, 64]) | 9200 | 4280 torch.Size([1, 128, 128]) | 2300 | 497 torch.Size([2, 128, 128]) | 4000 | 562 torch.Size([4, 128, 128]) | 3140 | 669 torch.Size([8, 128, 128]) | 3200 | 698 torch.Size([16, 128, 128]) | 3400 | 810 torch.Size([32, 128, 128]) | 3866 | 1410 torch.Size([64, 128, 128]) | 4200 | 1670 torch.Size([128, 128, 128]) | 5050 | 2170 torch.Size([512, 128, 128]) | 14000 | 6417 torch.Size([1024, 128, 128]) | 28900 | 14700 torch.Size([1, 256, 256]) | 4100 | 1559 torch.Size([2, 256, 256]) | 6800 | 1792 torch.Size([4, 256, 256]) | 7000 | 2000 torch.Size([8, 256, 256]) | 7300 | 2200 torch.Size([16, 256, 256]) | 7730 | 2540 torch.Size([32, 256, 256]) | 8500 | 3390 torch.Size([64, 256, 256]) | 11000 | 4470 torch.Size([128, 256, 256]) | 15900 | 6757 torch.Size([512, 256, 256]) | 50000 | 30000 torch.Size([1024, 256, 256]) | 102600 | 56400 torch.Size([1, 512, 512]) | 8793 | 3230 torch.Size([2, 512, 512]) | 13000 | 3920 torch.Size([4, 512, 512]) | 14000 | 4531 torch.Size([8, 512, 512]) | 15000 | 5114 torch.Size([16, 512, 512]) | 16700 | 6280 torch.Size([32, 512, 512]) | 22400 | 9530 torch.Size([64, 512, 512]) | 33700 | 14260 torch.Size([128, 512, 512]) | 56500 | 20000 Times are in microseconds (us). ``` </details> <details> <summary> Benchmarking Script </summary> ```python import torch import pickle import itertools from functools import partial from torch.utils.benchmark import Timer, Compare benchmark_name = "linalg.solve" label = "master" shapes = [1, 2, 8, 16, 32, 64, 128, 256, 512] batches = [(1,), (2,), (4,), (8,), (16,), (32,), (64,), (128,), (512,), (1024,)] results = [] make_arg = partial(torch.randn, dtype=torch.float32, device="cuda", requires_grad=True) for n, batch in itertools.product(shapes, batches): if n == 512 and batch[0] >= 512: continue A = make_arg(batch + (n, n)) B = make_arg(batch + (n, 16)) ones = torch.ones(B.shape, device=B.device) print(A.shape) for adjoint in (True, False): timer = Timer("torch.linalg.solve(A, B).backward(gradient=ones, inputs=[A, B])", globals=globals(), label=benchmark_name, description=label, sub_label=f"{A.shape}", num_threads=1) results.append(timer.blocked_autorange()) compare = Compare(results) compare.trim_significant_figures() compare.print() with open("{}.pickle".format(label), 'wb') as f: pickle.dump(results, f) ``` </details> See #72935 (comment) for the script to join the results. [ghstack-poisoned]
|
Thank you for the catch @IvanYashchuk! |
This PR heavily simplifies the code of `linalg.solve`. At the same time, this implementation saves quite a some copies of the input data in some cases (e.g. A is contiguous) We also implement it in such a way that the derivative goes from computing two LU decompositions and two LU solves to no LU decompositions and one LU solves. It also avoids a number of unnecessary copies the derivative was unnecessarily performing (at least the copy of two matrices). On top of this, we add a `left` kw-only arg that allows the user to solve `XA = B` rather concisely. This PR also makes `torch.solve` an alias of `torch.linalg.solve`. **Note:** This used to be the last PR of the stack. Now it's here because some tests were not passing in a PR that was before this one in the stack, and reshuffling the stack solved those problems. The benchmarks below are performed wrt the last PR fo this stack. We compare the performance of `linalg.solve` against master (before merging #67833, but already with a few PRs of the LU stack merged). We see that we got between **x2.5 and x10 speed-ups in `linalg.solve`**. <details> <summary> Benchmark Results </summary> ``` [--------------------- linalg.solve + backward --------------------] | master | This PR 1 threads: ---------------------------------------------- torch.Size([1, 1, 1]) | 1280 | 267 torch.Size([2, 1, 1]) | 1300 | 200 torch.Size([4, 1, 1]) | 500 | 231 torch.Size([8, 1, 1]) | 600 | 232 torch.Size([16, 1, 1]) | 1200 | 234 torch.Size([32, 1, 1]) | 1300 | 239 torch.Size([64, 1, 1]) | 1300 | 300 torch.Size([128, 1, 1]) | 1340 | 331 torch.Size([512, 1, 1]) | 1664 | 380 torch.Size([1024, 1, 1]) | 2000 | 430 torch.Size([1, 2, 2]) | 1200 | 300 torch.Size([2, 2, 2]) | 1250 | 237 torch.Size([4, 2, 2]) | 479 | 240 torch.Size([8, 2, 2]) | 600 | 239 torch.Size([16, 2, 2]) | 1300 | 242 torch.Size([32, 2, 2]) | 1300 | 245 torch.Size([64, 2, 2]) | 1300 | 260 torch.Size([128, 2, 2]) | 1400 | 340 torch.Size([512, 2, 2]) | 1680 | 380 torch.Size([1024, 2, 2]) | 2100 | 430 torch.Size([1, 8, 8]) | 1200 | 250 torch.Size([2, 8, 8]) | 1240 | 238 torch.Size([4, 8, 8]) | 480 | 240 torch.Size([8, 8, 8]) | 600 | 240 torch.Size([16, 8, 8]) | 1330 | 243 torch.Size([32, 8, 8]) | 1340 | 250 torch.Size([64, 8, 8]) | 1370 | 257 torch.Size([128, 8, 8]) | 1400 | 280 torch.Size([512, 8, 8]) | 1720 | 346 torch.Size([1024, 8, 8]) | 2300 | 390 torch.Size([1, 16, 16]) | 1380 | 245 torch.Size([2, 16, 16]) | 1000 | 300 torch.Size([4, 16, 16]) | 610 | 260 torch.Size([8, 16, 16]) | 862 | 260 torch.Size([16, 16, 16]) | 1350 | 260 torch.Size([32, 16, 16]) | 1370 | 260 torch.Size([64, 16, 16]) | 1440 | 273 torch.Size([128, 16, 16]) | 1520 | 289 torch.Size([512, 16, 16]) | 1880 | 350 torch.Size([1024, 16, 16]) | 2540 | 530 torch.Size([1, 32, 32]) | 1500 | 290 torch.Size([2, 32, 32]) | 2100 | 287 torch.Size([4, 32, 32]) | 1370 | 288 torch.Size([8, 32, 32]) | 1389 | 290 torch.Size([16, 32, 32]) | 1400 | 290 torch.Size([32, 32, 32]) | 1500 | 476 torch.Size([64, 32, 32]) | 1600 | 468 torch.Size([128, 32, 32]) | 1700 | 479 torch.Size([512, 32, 32]) | 2300 | 696 torch.Size([1024, 32, 32]) | 3200 | 1200 torch.Size([1, 64, 64]) | 1700 | 340 torch.Size([2, 64, 64]) | 2800 | 353 torch.Size([4, 64, 64]) | 1990 | 328 torch.Size([8, 64, 64]) | 2040 | 330 torch.Size([16, 64, 64]) | 2100 | 350 torch.Size([32, 64, 64]) | 2300 | 680 torch.Size([64, 64, 64]) | 2430 | 725 torch.Size([128, 64, 64]) | 2600 | 845 torch.Size([512, 64, 64]) | 4700 | 1900 torch.Size([1024, 64, 64]) | 9200 | 4280 torch.Size([1, 128, 128]) | 2300 | 497 torch.Size([2, 128, 128]) | 4000 | 562 torch.Size([4, 128, 128]) | 3140 | 669 torch.Size([8, 128, 128]) | 3200 | 698 torch.Size([16, 128, 128]) | 3400 | 810 torch.Size([32, 128, 128]) | 3866 | 1410 torch.Size([64, 128, 128]) | 4200 | 1670 torch.Size([128, 128, 128]) | 5050 | 2170 torch.Size([512, 128, 128]) | 14000 | 6417 torch.Size([1024, 128, 128]) | 28900 | 14700 torch.Size([1, 256, 256]) | 4100 | 1559 torch.Size([2, 256, 256]) | 6800 | 1792 torch.Size([4, 256, 256]) | 7000 | 2000 torch.Size([8, 256, 256]) | 7300 | 2200 torch.Size([16, 256, 256]) | 7730 | 2540 torch.Size([32, 256, 256]) | 8500 | 3390 torch.Size([64, 256, 256]) | 11000 | 4470 torch.Size([128, 256, 256]) | 15900 | 6757 torch.Size([512, 256, 256]) | 50000 | 30000 torch.Size([1024, 256, 256]) | 102600 | 56400 torch.Size([1, 512, 512]) | 8793 | 3230 torch.Size([2, 512, 512]) | 13000 | 3920 torch.Size([4, 512, 512]) | 14000 | 4531 torch.Size([8, 512, 512]) | 15000 | 5114 torch.Size([16, 512, 512]) | 16700 | 6280 torch.Size([32, 512, 512]) | 22400 | 9530 torch.Size([64, 512, 512]) | 33700 | 14260 torch.Size([128, 512, 512]) | 56500 | 20000 Times are in microseconds (us). ``` </details> <details> <summary> Benchmarking Script </summary> ```python import torch import pickle import itertools from functools import partial from torch.utils.benchmark import Timer, Compare benchmark_name = "linalg.solve" label = "master" shapes = [1, 2, 8, 16, 32, 64, 128, 256, 512] batches = [(1,), (2,), (4,), (8,), (16,), (32,), (64,), (128,), (512,), (1024,)] results = [] make_arg = partial(torch.randn, dtype=torch.float32, device="cuda", requires_grad=True) for n, batch in itertools.product(shapes, batches): if n == 512 and batch[0] >= 512: continue A = make_arg(batch + (n, n)) B = make_arg(batch + (n, 16)) ones = torch.ones(B.shape, device=B.device) print(A.shape) for adjoint in (True, False): timer = Timer("torch.linalg.solve(A, B).backward(gradient=ones, inputs=[A, B])", globals=globals(), label=benchmark_name, description=label, sub_label=f"{A.shape}", num_threads=1) results.append(timer.blocked_autorange()) compare = Compare(results) compare.trim_significant_figures() compare.print() with open("{}.pickle".format(label), 'wb') as f: pickle.dump(results, f) ``` </details> See #72935 (comment) for the script to join the results. [ghstack-poisoned]
|
Really hope that this could be merged for the 1.12 release 😃 |
Hey Tianyi, I'm curious what kind of applications are you working on? |
This PR heavily simplifies the code of `linalg.solve`. At the same time, this implementation saves quite a some copies of the input data in some cases (e.g. A is contiguous) We also implement it in such a way that the derivative goes from computing two LU decompositions and two LU solves to no LU decompositions and one LU solves. It also avoids a number of unnecessary copies the derivative was unnecessarily performing (at least the copy of two matrices). On top of this, we add a `left` kw-only arg that allows the user to solve `XA = B` rather concisely. This PR also makes `torch.solve` an alias of `torch.linalg.solve`. **Note:** This used to be the last PR of the stack. Now it's here because some tests were not passing in a PR that was before this one in the stack, and reshuffling the stack solved those problems. The benchmarks below are performed wrt the last PR fo this stack. We compare the performance of `linalg.solve` against master (before merging #67833, but already with a few PRs of the LU stack merged). We see that we got between **x2.5 and x10 speed-ups in `linalg.solve`**. <details> <summary> Benchmark Results </summary> ``` [--------------------- linalg.solve + backward --------------------] | master | This PR 1 threads: ---------------------------------------------- torch.Size([1, 1, 1]) | 1280 | 267 torch.Size([2, 1, 1]) | 1300 | 200 torch.Size([4, 1, 1]) | 500 | 231 torch.Size([8, 1, 1]) | 600 | 232 torch.Size([16, 1, 1]) | 1200 | 234 torch.Size([32, 1, 1]) | 1300 | 239 torch.Size([64, 1, 1]) | 1300 | 300 torch.Size([128, 1, 1]) | 1340 | 331 torch.Size([512, 1, 1]) | 1664 | 380 torch.Size([1024, 1, 1]) | 2000 | 430 torch.Size([1, 2, 2]) | 1200 | 300 torch.Size([2, 2, 2]) | 1250 | 237 torch.Size([4, 2, 2]) | 479 | 240 torch.Size([8, 2, 2]) | 600 | 239 torch.Size([16, 2, 2]) | 1300 | 242 torch.Size([32, 2, 2]) | 1300 | 245 torch.Size([64, 2, 2]) | 1300 | 260 torch.Size([128, 2, 2]) | 1400 | 340 torch.Size([512, 2, 2]) | 1680 | 380 torch.Size([1024, 2, 2]) | 2100 | 430 torch.Size([1, 8, 8]) | 1200 | 250 torch.Size([2, 8, 8]) | 1240 | 238 torch.Size([4, 8, 8]) | 480 | 240 torch.Size([8, 8, 8]) | 600 | 240 torch.Size([16, 8, 8]) | 1330 | 243 torch.Size([32, 8, 8]) | 1340 | 250 torch.Size([64, 8, 8]) | 1370 | 257 torch.Size([128, 8, 8]) | 1400 | 280 torch.Size([512, 8, 8]) | 1720 | 346 torch.Size([1024, 8, 8]) | 2300 | 390 torch.Size([1, 16, 16]) | 1380 | 245 torch.Size([2, 16, 16]) | 1000 | 300 torch.Size([4, 16, 16]) | 610 | 260 torch.Size([8, 16, 16]) | 862 | 260 torch.Size([16, 16, 16]) | 1350 | 260 torch.Size([32, 16, 16]) | 1370 | 260 torch.Size([64, 16, 16]) | 1440 | 273 torch.Size([128, 16, 16]) | 1520 | 289 torch.Size([512, 16, 16]) | 1880 | 350 torch.Size([1024, 16, 16]) | 2540 | 530 torch.Size([1, 32, 32]) | 1500 | 290 torch.Size([2, 32, 32]) | 2100 | 287 torch.Size([4, 32, 32]) | 1370 | 288 torch.Size([8, 32, 32]) | 1389 | 290 torch.Size([16, 32, 32]) | 1400 | 290 torch.Size([32, 32, 32]) | 1500 | 476 torch.Size([64, 32, 32]) | 1600 | 468 torch.Size([128, 32, 32]) | 1700 | 479 torch.Size([512, 32, 32]) | 2300 | 696 torch.Size([1024, 32, 32]) | 3200 | 1200 torch.Size([1, 64, 64]) | 1700 | 340 torch.Size([2, 64, 64]) | 2800 | 353 torch.Size([4, 64, 64]) | 1990 | 328 torch.Size([8, 64, 64]) | 2040 | 330 torch.Size([16, 64, 64]) | 2100 | 350 torch.Size([32, 64, 64]) | 2300 | 680 torch.Size([64, 64, 64]) | 2430 | 725 torch.Size([128, 64, 64]) | 2600 | 845 torch.Size([512, 64, 64]) | 4700 | 1900 torch.Size([1024, 64, 64]) | 9200 | 4280 torch.Size([1, 128, 128]) | 2300 | 497 torch.Size([2, 128, 128]) | 4000 | 562 torch.Size([4, 128, 128]) | 3140 | 669 torch.Size([8, 128, 128]) | 3200 | 698 torch.Size([16, 128, 128]) | 3400 | 810 torch.Size([32, 128, 128]) | 3866 | 1410 torch.Size([64, 128, 128]) | 4200 | 1670 torch.Size([128, 128, 128]) | 5050 | 2170 torch.Size([512, 128, 128]) | 14000 | 6417 torch.Size([1024, 128, 128]) | 28900 | 14700 torch.Size([1, 256, 256]) | 4100 | 1559 torch.Size([2, 256, 256]) | 6800 | 1792 torch.Size([4, 256, 256]) | 7000 | 2000 torch.Size([8, 256, 256]) | 7300 | 2200 torch.Size([16, 256, 256]) | 7730 | 2540 torch.Size([32, 256, 256]) | 8500 | 3390 torch.Size([64, 256, 256]) | 11000 | 4470 torch.Size([128, 256, 256]) | 15900 | 6757 torch.Size([512, 256, 256]) | 50000 | 30000 torch.Size([1024, 256, 256]) | 102600 | 56400 torch.Size([1, 512, 512]) | 8793 | 3230 torch.Size([2, 512, 512]) | 13000 | 3920 torch.Size([4, 512, 512]) | 14000 | 4531 torch.Size([8, 512, 512]) | 15000 | 5114 torch.Size([16, 512, 512]) | 16700 | 6280 torch.Size([32, 512, 512]) | 22400 | 9530 torch.Size([64, 512, 512]) | 33700 | 14260 torch.Size([128, 512, 512]) | 56500 | 20000 Times are in microseconds (us). ``` </details> <details> <summary> Benchmarking Script </summary> ```python import torch import pickle import itertools from functools import partial from torch.utils.benchmark import Timer, Compare benchmark_name = "linalg.solve" label = "master" shapes = [1, 2, 8, 16, 32, 64, 128, 256, 512] batches = [(1,), (2,), (4,), (8,), (16,), (32,), (64,), (128,), (512,), (1024,)] results = [] make_arg = partial(torch.randn, dtype=torch.float32, device="cuda", requires_grad=True) for n, batch in itertools.product(shapes, batches): if n == 512 and batch[0] >= 512: continue A = make_arg(batch + (n, n)) B = make_arg(batch + (n, 16)) ones = torch.ones(B.shape, device=B.device) print(A.shape) for adjoint in (True, False): timer = Timer("torch.linalg.solve(A, B).backward(gradient=ones, inputs=[A, B])", globals=globals(), label=benchmark_name, description=label, sub_label=f"{A.shape}", num_threads=1) results.append(timer.blocked_autorange()) compare = Compare(results) compare.trim_significant_figures() compare.print() with open("{}.pickle".format(label), 'wb') as f: pickle.dump(results, f) ``` </details> See #72935 (comment) for the script to join the results. [ghstack-poisoned]
This PR heavily simplifies the code of `linalg.solve`. At the same time, this implementation saves quite a some copies of the input data in some cases (e.g. A is contiguous) We also implement it in such a way that the derivative goes from computing two LU decompositions and two LU solves to no LU decompositions and one LU solves. It also avoids a number of unnecessary copies the derivative was unnecessarily performing (at least the copy of two matrices). On top of this, we add a `left` kw-only arg that allows the user to solve `XA = B` rather concisely. This PR also makes `torch.solve` an alias of `torch.linalg.solve`. **Note:** This used to be the last PR of the stack. Now it's here because some tests were not passing in a PR that was before this one in the stack, and reshuffling the stack solved those problems. The benchmarks below are performed wrt the last PR fo this stack. We compare the performance of `linalg.solve` against master (before merging #67833, but already with a few PRs of the LU stack merged). We see that we got between **x2.5 and x10 speed-ups in `linalg.solve`**. <details> <summary> Benchmark Results </summary> ``` [--------------------- linalg.solve + backward --------------------] | master | This PR 1 threads: ---------------------------------------------- torch.Size([1, 1, 1]) | 1280 | 267 torch.Size([2, 1, 1]) | 1300 | 200 torch.Size([4, 1, 1]) | 500 | 231 torch.Size([8, 1, 1]) | 600 | 232 torch.Size([16, 1, 1]) | 1200 | 234 torch.Size([32, 1, 1]) | 1300 | 239 torch.Size([64, 1, 1]) | 1300 | 300 torch.Size([128, 1, 1]) | 1340 | 331 torch.Size([512, 1, 1]) | 1664 | 380 torch.Size([1024, 1, 1]) | 2000 | 430 torch.Size([1, 2, 2]) | 1200 | 300 torch.Size([2, 2, 2]) | 1250 | 237 torch.Size([4, 2, 2]) | 479 | 240 torch.Size([8, 2, 2]) | 600 | 239 torch.Size([16, 2, 2]) | 1300 | 242 torch.Size([32, 2, 2]) | 1300 | 245 torch.Size([64, 2, 2]) | 1300 | 260 torch.Size([128, 2, 2]) | 1400 | 340 torch.Size([512, 2, 2]) | 1680 | 380 torch.Size([1024, 2, 2]) | 2100 | 430 torch.Size([1, 8, 8]) | 1200 | 250 torch.Size([2, 8, 8]) | 1240 | 238 torch.Size([4, 8, 8]) | 480 | 240 torch.Size([8, 8, 8]) | 600 | 240 torch.Size([16, 8, 8]) | 1330 | 243 torch.Size([32, 8, 8]) | 1340 | 250 torch.Size([64, 8, 8]) | 1370 | 257 torch.Size([128, 8, 8]) | 1400 | 280 torch.Size([512, 8, 8]) | 1720 | 346 torch.Size([1024, 8, 8]) | 2300 | 390 torch.Size([1, 16, 16]) | 1380 | 245 torch.Size([2, 16, 16]) | 1000 | 300 torch.Size([4, 16, 16]) | 610 | 260 torch.Size([8, 16, 16]) | 862 | 260 torch.Size([16, 16, 16]) | 1350 | 260 torch.Size([32, 16, 16]) | 1370 | 260 torch.Size([64, 16, 16]) | 1440 | 273 torch.Size([128, 16, 16]) | 1520 | 289 torch.Size([512, 16, 16]) | 1880 | 350 torch.Size([1024, 16, 16]) | 2540 | 530 torch.Size([1, 32, 32]) | 1500 | 290 torch.Size([2, 32, 32]) | 2100 | 287 torch.Size([4, 32, 32]) | 1370 | 288 torch.Size([8, 32, 32]) | 1389 | 290 torch.Size([16, 32, 32]) | 1400 | 290 torch.Size([32, 32, 32]) | 1500 | 476 torch.Size([64, 32, 32]) | 1600 | 468 torch.Size([128, 32, 32]) | 1700 | 479 torch.Size([512, 32, 32]) | 2300 | 696 torch.Size([1024, 32, 32]) | 3200 | 1200 torch.Size([1, 64, 64]) | 1700 | 340 torch.Size([2, 64, 64]) | 2800 | 353 torch.Size([4, 64, 64]) | 1990 | 328 torch.Size([8, 64, 64]) | 2040 | 330 torch.Size([16, 64, 64]) | 2100 | 350 torch.Size([32, 64, 64]) | 2300 | 680 torch.Size([64, 64, 64]) | 2430 | 725 torch.Size([128, 64, 64]) | 2600 | 845 torch.Size([512, 64, 64]) | 4700 | 1900 torch.Size([1024, 64, 64]) | 9200 | 4280 torch.Size([1, 128, 128]) | 2300 | 497 torch.Size([2, 128, 128]) | 4000 | 562 torch.Size([4, 128, 128]) | 3140 | 669 torch.Size([8, 128, 128]) | 3200 | 698 torch.Size([16, 128, 128]) | 3400 | 810 torch.Size([32, 128, 128]) | 3866 | 1410 torch.Size([64, 128, 128]) | 4200 | 1670 torch.Size([128, 128, 128]) | 5050 | 2170 torch.Size([512, 128, 128]) | 14000 | 6417 torch.Size([1024, 128, 128]) | 28900 | 14700 torch.Size([1, 256, 256]) | 4100 | 1559 torch.Size([2, 256, 256]) | 6800 | 1792 torch.Size([4, 256, 256]) | 7000 | 2000 torch.Size([8, 256, 256]) | 7300 | 2200 torch.Size([16, 256, 256]) | 7730 | 2540 torch.Size([32, 256, 256]) | 8500 | 3390 torch.Size([64, 256, 256]) | 11000 | 4470 torch.Size([128, 256, 256]) | 15900 | 6757 torch.Size([512, 256, 256]) | 50000 | 30000 torch.Size([1024, 256, 256]) | 102600 | 56400 torch.Size([1, 512, 512]) | 8793 | 3230 torch.Size([2, 512, 512]) | 13000 | 3920 torch.Size([4, 512, 512]) | 14000 | 4531 torch.Size([8, 512, 512]) | 15000 | 5114 torch.Size([16, 512, 512]) | 16700 | 6280 torch.Size([32, 512, 512]) | 22400 | 9530 torch.Size([64, 512, 512]) | 33700 | 14260 torch.Size([128, 512, 512]) | 56500 | 20000 Times are in microseconds (us). ``` </details> <details> <summary> Benchmarking Script </summary> ```python import torch import pickle import itertools from functools import partial from torch.utils.benchmark import Timer, Compare benchmark_name = "linalg.solve" label = "master" shapes = [1, 2, 8, 16, 32, 64, 128, 256, 512] batches = [(1,), (2,), (4,), (8,), (16,), (32,), (64,), (128,), (512,), (1024,)] results = [] make_arg = partial(torch.randn, dtype=torch.float32, device="cuda", requires_grad=True) for n, batch in itertools.product(shapes, batches): if n == 512 and batch[0] >= 512: continue A = make_arg(batch + (n, n)) B = make_arg(batch + (n, 16)) ones = torch.ones(B.shape, device=B.device) print(A.shape) for adjoint in (True, False): timer = Timer("torch.linalg.solve(A, B).backward(gradient=ones, inputs=[A, B])", globals=globals(), label=benchmark_name, description=label, sub_label=f"{A.shape}", num_threads=1) results.append(timer.blocked_autorange()) compare = Compare(results) compare.trim_significant_figures() compare.print() with open("{}.pickle".format(label), 'wb') as f: pickle.dump(results, f) ``` </details> See #72935 (comment) for the script to join the results. [ghstack-poisoned]
This PR heavily simplifies the code of `linalg.solve`. At the same time, this implementation saves quite a some copies of the input data in some cases (e.g. A is contiguous) We also implement it in such a way that the derivative goes from computing two LU decompositions and two LU solves to no LU decompositions and one LU solves. It also avoids a number of unnecessary copies the derivative was unnecessarily performing (at least the copy of two matrices). On top of this, we add a `left` kw-only arg that allows the user to solve `XA = B` rather concisely. This PR also makes `torch.solve` an alias of `torch.linalg.solve`. **Note:** This used to be the last PR of the stack. Now it's here because some tests were not passing in a PR that was before this one in the stack, and reshuffling the stack solved those problems. The benchmarks below are performed wrt the last PR fo this stack. We compare the performance of `linalg.solve` against master (before merging #67833, but already with a few PRs of the LU stack merged). We see that we got between **x2.5 and x10 speed-ups in `linalg.solve`**. <details> <summary> Benchmark Results </summary> ``` [--------------------- linalg.solve + backward --------------------] | master | This PR 1 threads: ---------------------------------------------- torch.Size([1, 1, 1]) | 1280 | 267 torch.Size([2, 1, 1]) | 1300 | 200 torch.Size([4, 1, 1]) | 500 | 231 torch.Size([8, 1, 1]) | 600 | 232 torch.Size([16, 1, 1]) | 1200 | 234 torch.Size([32, 1, 1]) | 1300 | 239 torch.Size([64, 1, 1]) | 1300 | 300 torch.Size([128, 1, 1]) | 1340 | 331 torch.Size([512, 1, 1]) | 1664 | 380 torch.Size([1024, 1, 1]) | 2000 | 430 torch.Size([1, 2, 2]) | 1200 | 300 torch.Size([2, 2, 2]) | 1250 | 237 torch.Size([4, 2, 2]) | 479 | 240 torch.Size([8, 2, 2]) | 600 | 239 torch.Size([16, 2, 2]) | 1300 | 242 torch.Size([32, 2, 2]) | 1300 | 245 torch.Size([64, 2, 2]) | 1300 | 260 torch.Size([128, 2, 2]) | 1400 | 340 torch.Size([512, 2, 2]) | 1680 | 380 torch.Size([1024, 2, 2]) | 2100 | 430 torch.Size([1, 8, 8]) | 1200 | 250 torch.Size([2, 8, 8]) | 1240 | 238 torch.Size([4, 8, 8]) | 480 | 240 torch.Size([8, 8, 8]) | 600 | 240 torch.Size([16, 8, 8]) | 1330 | 243 torch.Size([32, 8, 8]) | 1340 | 250 torch.Size([64, 8, 8]) | 1370 | 257 torch.Size([128, 8, 8]) | 1400 | 280 torch.Size([512, 8, 8]) | 1720 | 346 torch.Size([1024, 8, 8]) | 2300 | 390 torch.Size([1, 16, 16]) | 1380 | 245 torch.Size([2, 16, 16]) | 1000 | 300 torch.Size([4, 16, 16]) | 610 | 260 torch.Size([8, 16, 16]) | 862 | 260 torch.Size([16, 16, 16]) | 1350 | 260 torch.Size([32, 16, 16]) | 1370 | 260 torch.Size([64, 16, 16]) | 1440 | 273 torch.Size([128, 16, 16]) | 1520 | 289 torch.Size([512, 16, 16]) | 1880 | 350 torch.Size([1024, 16, 16]) | 2540 | 530 torch.Size([1, 32, 32]) | 1500 | 290 torch.Size([2, 32, 32]) | 2100 | 287 torch.Size([4, 32, 32]) | 1370 | 288 torch.Size([8, 32, 32]) | 1389 | 290 torch.Size([16, 32, 32]) | 1400 | 290 torch.Size([32, 32, 32]) | 1500 | 476 torch.Size([64, 32, 32]) | 1600 | 468 torch.Size([128, 32, 32]) | 1700 | 479 torch.Size([512, 32, 32]) | 2300 | 696 torch.Size([1024, 32, 32]) | 3200 | 1200 torch.Size([1, 64, 64]) | 1700 | 340 torch.Size([2, 64, 64]) | 2800 | 353 torch.Size([4, 64, 64]) | 1990 | 328 torch.Size([8, 64, 64]) | 2040 | 330 torch.Size([16, 64, 64]) | 2100 | 350 torch.Size([32, 64, 64]) | 2300 | 680 torch.Size([64, 64, 64]) | 2430 | 725 torch.Size([128, 64, 64]) | 2600 | 845 torch.Size([512, 64, 64]) | 4700 | 1900 torch.Size([1024, 64, 64]) | 9200 | 4280 torch.Size([1, 128, 128]) | 2300 | 497 torch.Size([2, 128, 128]) | 4000 | 562 torch.Size([4, 128, 128]) | 3140 | 669 torch.Size([8, 128, 128]) | 3200 | 698 torch.Size([16, 128, 128]) | 3400 | 810 torch.Size([32, 128, 128]) | 3866 | 1410 torch.Size([64, 128, 128]) | 4200 | 1670 torch.Size([128, 128, 128]) | 5050 | 2170 torch.Size([512, 128, 128]) | 14000 | 6417 torch.Size([1024, 128, 128]) | 28900 | 14700 torch.Size([1, 256, 256]) | 4100 | 1559 torch.Size([2, 256, 256]) | 6800 | 1792 torch.Size([4, 256, 256]) | 7000 | 2000 torch.Size([8, 256, 256]) | 7300 | 2200 torch.Size([16, 256, 256]) | 7730 | 2540 torch.Size([32, 256, 256]) | 8500 | 3390 torch.Size([64, 256, 256]) | 11000 | 4470 torch.Size([128, 256, 256]) | 15900 | 6757 torch.Size([512, 256, 256]) | 50000 | 30000 torch.Size([1024, 256, 256]) | 102600 | 56400 torch.Size([1, 512, 512]) | 8793 | 3230 torch.Size([2, 512, 512]) | 13000 | 3920 torch.Size([4, 512, 512]) | 14000 | 4531 torch.Size([8, 512, 512]) | 15000 | 5114 torch.Size([16, 512, 512]) | 16700 | 6280 torch.Size([32, 512, 512]) | 22400 | 9530 torch.Size([64, 512, 512]) | 33700 | 14260 torch.Size([128, 512, 512]) | 56500 | 20000 Times are in microseconds (us). ``` </details> <details> <summary> Benchmarking Script </summary> ```python import torch import pickle import itertools from functools import partial from torch.utils.benchmark import Timer, Compare benchmark_name = "linalg.solve" label = "master" shapes = [1, 2, 8, 16, 32, 64, 128, 256, 512] batches = [(1,), (2,), (4,), (8,), (16,), (32,), (64,), (128,), (512,), (1024,)] results = [] make_arg = partial(torch.randn, dtype=torch.float32, device="cuda", requires_grad=True) for n, batch in itertools.product(shapes, batches): if n == 512 and batch[0] >= 512: continue A = make_arg(batch + (n, n)) B = make_arg(batch + (n, 16)) ones = torch.ones(B.shape, device=B.device) print(A.shape) for adjoint in (True, False): timer = Timer("torch.linalg.solve(A, B).backward(gradient=ones, inputs=[A, B])", globals=globals(), label=benchmark_name, description=label, sub_label=f"{A.shape}", num_threads=1) results.append(timer.blocked_autorange()) compare = Compare(results) compare.trim_significant_figures() compare.print() with open("{}.pickle".format(label), 'wb') as f: pickle.dump(results, f) ``` </details> See #72935 (comment) for the script to join the results. [ghstack-poisoned]
This PR heavily simplifies the code of `linalg.solve`. At the same time, this implementation saves quite a some copies of the input data in some cases (e.g. A is contiguous) We also implement it in such a way that the derivative goes from computing two LU decompositions and two LU solves to no LU decompositions and one LU solves. It also avoids a number of unnecessary copies the derivative was unnecessarily performing (at least the copy of two matrices). On top of this, we add a `left` kw-only arg that allows the user to solve `XA = B` rather concisely. This PR also makes `torch.solve` an alias of `torch.linalg.solve`. **Note:** This used to be the last PR of the stack. Now it's here because some tests were not passing in a PR that was before this one in the stack, and reshuffling the stack solved those problems. The benchmarks below are performed wrt the last PR fo this stack. We compare the performance of `linalg.solve` against master (before merging #67833, but already with a few PRs of the LU stack merged). We see that we got between **x2.5 and x10 speed-ups in `linalg.solve`**. <details> <summary> Benchmark Results </summary> ``` [--------------------- linalg.solve + backward --------------------] | master | This PR 1 threads: ---------------------------------------------- torch.Size([1, 1, 1]) | 1280 | 267 torch.Size([2, 1, 1]) | 1300 | 200 torch.Size([4, 1, 1]) | 500 | 231 torch.Size([8, 1, 1]) | 600 | 232 torch.Size([16, 1, 1]) | 1200 | 234 torch.Size([32, 1, 1]) | 1300 | 239 torch.Size([64, 1, 1]) | 1300 | 300 torch.Size([128, 1, 1]) | 1340 | 331 torch.Size([512, 1, 1]) | 1664 | 380 torch.Size([1024, 1, 1]) | 2000 | 430 torch.Size([1, 2, 2]) | 1200 | 300 torch.Size([2, 2, 2]) | 1250 | 237 torch.Size([4, 2, 2]) | 479 | 240 torch.Size([8, 2, 2]) | 600 | 239 torch.Size([16, 2, 2]) | 1300 | 242 torch.Size([32, 2, 2]) | 1300 | 245 torch.Size([64, 2, 2]) | 1300 | 260 torch.Size([128, 2, 2]) | 1400 | 340 torch.Size([512, 2, 2]) | 1680 | 380 torch.Size([1024, 2, 2]) | 2100 | 430 torch.Size([1, 8, 8]) | 1200 | 250 torch.Size([2, 8, 8]) | 1240 | 238 torch.Size([4, 8, 8]) | 480 | 240 torch.Size([8, 8, 8]) | 600 | 240 torch.Size([16, 8, 8]) | 1330 | 243 torch.Size([32, 8, 8]) | 1340 | 250 torch.Size([64, 8, 8]) | 1370 | 257 torch.Size([128, 8, 8]) | 1400 | 280 torch.Size([512, 8, 8]) | 1720 | 346 torch.Size([1024, 8, 8]) | 2300 | 390 torch.Size([1, 16, 16]) | 1380 | 245 torch.Size([2, 16, 16]) | 1000 | 300 torch.Size([4, 16, 16]) | 610 | 260 torch.Size([8, 16, 16]) | 862 | 260 torch.Size([16, 16, 16]) | 1350 | 260 torch.Size([32, 16, 16]) | 1370 | 260 torch.Size([64, 16, 16]) | 1440 | 273 torch.Size([128, 16, 16]) | 1520 | 289 torch.Size([512, 16, 16]) | 1880 | 350 torch.Size([1024, 16, 16]) | 2540 | 530 torch.Size([1, 32, 32]) | 1500 | 290 torch.Size([2, 32, 32]) | 2100 | 287 torch.Size([4, 32, 32]) | 1370 | 288 torch.Size([8, 32, 32]) | 1389 | 290 torch.Size([16, 32, 32]) | 1400 | 290 torch.Size([32, 32, 32]) | 1500 | 476 torch.Size([64, 32, 32]) | 1600 | 468 torch.Size([128, 32, 32]) | 1700 | 479 torch.Size([512, 32, 32]) | 2300 | 696 torch.Size([1024, 32, 32]) | 3200 | 1200 torch.Size([1, 64, 64]) | 1700 | 340 torch.Size([2, 64, 64]) | 2800 | 353 torch.Size([4, 64, 64]) | 1990 | 328 torch.Size([8, 64, 64]) | 2040 | 330 torch.Size([16, 64, 64]) | 2100 | 350 torch.Size([32, 64, 64]) | 2300 | 680 torch.Size([64, 64, 64]) | 2430 | 725 torch.Size([128, 64, 64]) | 2600 | 845 torch.Size([512, 64, 64]) | 4700 | 1900 torch.Size([1024, 64, 64]) | 9200 | 4280 torch.Size([1, 128, 128]) | 2300 | 497 torch.Size([2, 128, 128]) | 4000 | 562 torch.Size([4, 128, 128]) | 3140 | 669 torch.Size([8, 128, 128]) | 3200 | 698 torch.Size([16, 128, 128]) | 3400 | 810 torch.Size([32, 128, 128]) | 3866 | 1410 torch.Size([64, 128, 128]) | 4200 | 1670 torch.Size([128, 128, 128]) | 5050 | 2170 torch.Size([512, 128, 128]) | 14000 | 6417 torch.Size([1024, 128, 128]) | 28900 | 14700 torch.Size([1, 256, 256]) | 4100 | 1559 torch.Size([2, 256, 256]) | 6800 | 1792 torch.Size([4, 256, 256]) | 7000 | 2000 torch.Size([8, 256, 256]) | 7300 | 2200 torch.Size([16, 256, 256]) | 7730 | 2540 torch.Size([32, 256, 256]) | 8500 | 3390 torch.Size([64, 256, 256]) | 11000 | 4470 torch.Size([128, 256, 256]) | 15900 | 6757 torch.Size([512, 256, 256]) | 50000 | 30000 torch.Size([1024, 256, 256]) | 102600 | 56400 torch.Size([1, 512, 512]) | 8793 | 3230 torch.Size([2, 512, 512]) | 13000 | 3920 torch.Size([4, 512, 512]) | 14000 | 4531 torch.Size([8, 512, 512]) | 15000 | 5114 torch.Size([16, 512, 512]) | 16700 | 6280 torch.Size([32, 512, 512]) | 22400 | 9530 torch.Size([64, 512, 512]) | 33700 | 14260 torch.Size([128, 512, 512]) | 56500 | 20000 Times are in microseconds (us). ``` </details> <details> <summary> Benchmarking Script </summary> ```python import torch import pickle import itertools from functools import partial from torch.utils.benchmark import Timer, Compare benchmark_name = "linalg.solve" label = "master" shapes = [1, 2, 8, 16, 32, 64, 128, 256, 512] batches = [(1,), (2,), (4,), (8,), (16,), (32,), (64,), (128,), (512,), (1024,)] results = [] make_arg = partial(torch.randn, dtype=torch.float32, device="cuda", requires_grad=True) for n, batch in itertools.product(shapes, batches): if n == 512 and batch[0] >= 512: continue A = make_arg(batch + (n, n)) B = make_arg(batch + (n, 16)) ones = torch.ones(B.shape, device=B.device) print(A.shape) for adjoint in (True, False): timer = Timer("torch.linalg.solve(A, B).backward(gradient=ones, inputs=[A, B])", globals=globals(), label=benchmark_name, description=label, sub_label=f"{A.shape}", num_threads=1) results.append(timer.blocked_autorange()) compare = Compare(results) compare.trim_significant_figures() compare.print() with open("{}.pickle".format(label), 'wb') as f: pickle.dump(results, f) ``` </details> See #72935 (comment) for the script to join the results. [ghstack-poisoned]
This PR heavily simplifies the code of `linalg.solve`. At the same time, this implementation saves quite a some copies of the input data in some cases (e.g. A is contiguous) We also implement it in such a way that the derivative goes from computing two LU decompositions and two LU solves to no LU decompositions and one LU solves. It also avoids a number of unnecessary copies the derivative was unnecessarily performing (at least the copy of two matrices). On top of this, we add a `left` kw-only arg that allows the user to solve `XA = B` rather concisely. This PR also makes `torch.solve` an alias of `torch.linalg.solve`. **Note:** This used to be the last PR of the stack. Now it's here because some tests were not passing in a PR that was before this one in the stack, and reshuffling the stack solved those problems. The benchmarks below are performed wrt the last PR fo this stack. We compare the performance of `linalg.solve` against master (before merging #67833, but already with a few PRs of the LU stack merged). We see that we got between **x2.5 and x10 speed-ups in `linalg.solve`**. <details> <summary> Benchmark Results </summary> ``` [--------------------- linalg.solve + backward --------------------] | master | This PR 1 threads: ---------------------------------------------- torch.Size([1, 1, 1]) | 1280 | 267 torch.Size([2, 1, 1]) | 1300 | 200 torch.Size([4, 1, 1]) | 500 | 231 torch.Size([8, 1, 1]) | 600 | 232 torch.Size([16, 1, 1]) | 1200 | 234 torch.Size([32, 1, 1]) | 1300 | 239 torch.Size([64, 1, 1]) | 1300 | 300 torch.Size([128, 1, 1]) | 1340 | 331 torch.Size([512, 1, 1]) | 1664 | 380 torch.Size([1024, 1, 1]) | 2000 | 430 torch.Size([1, 2, 2]) | 1200 | 300 torch.Size([2, 2, 2]) | 1250 | 237 torch.Size([4, 2, 2]) | 479 | 240 torch.Size([8, 2, 2]) | 600 | 239 torch.Size([16, 2, 2]) | 1300 | 242 torch.Size([32, 2, 2]) | 1300 | 245 torch.Size([64, 2, 2]) | 1300 | 260 torch.Size([128, 2, 2]) | 1400 | 340 torch.Size([512, 2, 2]) | 1680 | 380 torch.Size([1024, 2, 2]) | 2100 | 430 torch.Size([1, 8, 8]) | 1200 | 250 torch.Size([2, 8, 8]) | 1240 | 238 torch.Size([4, 8, 8]) | 480 | 240 torch.Size([8, 8, 8]) | 600 | 240 torch.Size([16, 8, 8]) | 1330 | 243 torch.Size([32, 8, 8]) | 1340 | 250 torch.Size([64, 8, 8]) | 1370 | 257 torch.Size([128, 8, 8]) | 1400 | 280 torch.Size([512, 8, 8]) | 1720 | 346 torch.Size([1024, 8, 8]) | 2300 | 390 torch.Size([1, 16, 16]) | 1380 | 245 torch.Size([2, 16, 16]) | 1000 | 300 torch.Size([4, 16, 16]) | 610 | 260 torch.Size([8, 16, 16]) | 862 | 260 torch.Size([16, 16, 16]) | 1350 | 260 torch.Size([32, 16, 16]) | 1370 | 260 torch.Size([64, 16, 16]) | 1440 | 273 torch.Size([128, 16, 16]) | 1520 | 289 torch.Size([512, 16, 16]) | 1880 | 350 torch.Size([1024, 16, 16]) | 2540 | 530 torch.Size([1, 32, 32]) | 1500 | 290 torch.Size([2, 32, 32]) | 2100 | 287 torch.Size([4, 32, 32]) | 1370 | 288 torch.Size([8, 32, 32]) | 1389 | 290 torch.Size([16, 32, 32]) | 1400 | 290 torch.Size([32, 32, 32]) | 1500 | 476 torch.Size([64, 32, 32]) | 1600 | 468 torch.Size([128, 32, 32]) | 1700 | 479 torch.Size([512, 32, 32]) | 2300 | 696 torch.Size([1024, 32, 32]) | 3200 | 1200 torch.Size([1, 64, 64]) | 1700 | 340 torch.Size([2, 64, 64]) | 2800 | 353 torch.Size([4, 64, 64]) | 1990 | 328 torch.Size([8, 64, 64]) | 2040 | 330 torch.Size([16, 64, 64]) | 2100 | 350 torch.Size([32, 64, 64]) | 2300 | 680 torch.Size([64, 64, 64]) | 2430 | 725 torch.Size([128, 64, 64]) | 2600 | 845 torch.Size([512, 64, 64]) | 4700 | 1900 torch.Size([1024, 64, 64]) | 9200 | 4280 torch.Size([1, 128, 128]) | 2300 | 497 torch.Size([2, 128, 128]) | 4000 | 562 torch.Size([4, 128, 128]) | 3140 | 669 torch.Size([8, 128, 128]) | 3200 | 698 torch.Size([16, 128, 128]) | 3400 | 810 torch.Size([32, 128, 128]) | 3866 | 1410 torch.Size([64, 128, 128]) | 4200 | 1670 torch.Size([128, 128, 128]) | 5050 | 2170 torch.Size([512, 128, 128]) | 14000 | 6417 torch.Size([1024, 128, 128]) | 28900 | 14700 torch.Size([1, 256, 256]) | 4100 | 1559 torch.Size([2, 256, 256]) | 6800 | 1792 torch.Size([4, 256, 256]) | 7000 | 2000 torch.Size([8, 256, 256]) | 7300 | 2200 torch.Size([16, 256, 256]) | 7730 | 2540 torch.Size([32, 256, 256]) | 8500 | 3390 torch.Size([64, 256, 256]) | 11000 | 4470 torch.Size([128, 256, 256]) | 15900 | 6757 torch.Size([512, 256, 256]) | 50000 | 30000 torch.Size([1024, 256, 256]) | 102600 | 56400 torch.Size([1, 512, 512]) | 8793 | 3230 torch.Size([2, 512, 512]) | 13000 | 3920 torch.Size([4, 512, 512]) | 14000 | 4531 torch.Size([8, 512, 512]) | 15000 | 5114 torch.Size([16, 512, 512]) | 16700 | 6280 torch.Size([32, 512, 512]) | 22400 | 9530 torch.Size([64, 512, 512]) | 33700 | 14260 torch.Size([128, 512, 512]) | 56500 | 20000 Times are in microseconds (us). ``` </details> <details> <summary> Benchmarking Script </summary> ```python import torch import pickle import itertools from functools import partial from torch.utils.benchmark import Timer, Compare benchmark_name = "linalg.solve" label = "master" shapes = [1, 2, 8, 16, 32, 64, 128, 256, 512] batches = [(1,), (2,), (4,), (8,), (16,), (32,), (64,), (128,), (512,), (1024,)] results = [] make_arg = partial(torch.randn, dtype=torch.float32, device="cuda", requires_grad=True) for n, batch in itertools.product(shapes, batches): if n == 512 and batch[0] >= 512: continue A = make_arg(batch + (n, n)) B = make_arg(batch + (n, 16)) ones = torch.ones(B.shape, device=B.device) print(A.shape) for adjoint in (True, False): timer = Timer("torch.linalg.solve(A, B).backward(gradient=ones, inputs=[A, B])", globals=globals(), label=benchmark_name, description=label, sub_label=f"{A.shape}", num_threads=1) results.append(timer.blocked_autorange()) compare = Compare(results) compare.trim_significant_figures() compare.print() with open("{}.pickle".format(label), 'wb') as f: pickle.dump(results, f) ``` </details> See #72935 (comment) for the script to join the results. [ghstack-poisoned]
This PR heavily simplifies the code of `linalg.solve`. At the same time, this implementation saves quite a some copies of the input data in some cases (e.g. A is contiguous) We also implement it in such a way that the derivative goes from computing two LU decompositions and two LU solves to no LU decompositions and one LU solves. It also avoids a number of unnecessary copies the derivative was unnecessarily performing (at least the copy of two matrices). On top of this, we add a `left` kw-only arg that allows the user to solve `XA = B` rather concisely. This PR also makes `torch.solve` an alias of `torch.linalg.solve`. **Note:** This used to be the last PR of the stack. Now it's here because some tests were not passing in a PR that was before this one in the stack, and reshuffling the stack solved those problems. The benchmarks below are performed wrt the last PR fo this stack. We compare the performance of `linalg.solve` against master (before merging #67833, but already with a few PRs of the LU stack merged). We see that we got between **x2.5 and x10 speed-ups in `linalg.solve`**. <details> <summary> Benchmark Results </summary> ``` [--------------------- linalg.solve + backward --------------------] | master | This PR 1 threads: ---------------------------------------------- torch.Size([1, 1, 1]) | 1280 | 267 torch.Size([2, 1, 1]) | 1300 | 200 torch.Size([4, 1, 1]) | 500 | 231 torch.Size([8, 1, 1]) | 600 | 232 torch.Size([16, 1, 1]) | 1200 | 234 torch.Size([32, 1, 1]) | 1300 | 239 torch.Size([64, 1, 1]) | 1300 | 300 torch.Size([128, 1, 1]) | 1340 | 331 torch.Size([512, 1, 1]) | 1664 | 380 torch.Size([1024, 1, 1]) | 2000 | 430 torch.Size([1, 2, 2]) | 1200 | 300 torch.Size([2, 2, 2]) | 1250 | 237 torch.Size([4, 2, 2]) | 479 | 240 torch.Size([8, 2, 2]) | 600 | 239 torch.Size([16, 2, 2]) | 1300 | 242 torch.Size([32, 2, 2]) | 1300 | 245 torch.Size([64, 2, 2]) | 1300 | 260 torch.Size([128, 2, 2]) | 1400 | 340 torch.Size([512, 2, 2]) | 1680 | 380 torch.Size([1024, 2, 2]) | 2100 | 430 torch.Size([1, 8, 8]) | 1200 | 250 torch.Size([2, 8, 8]) | 1240 | 238 torch.Size([4, 8, 8]) | 480 | 240 torch.Size([8, 8, 8]) | 600 | 240 torch.Size([16, 8, 8]) | 1330 | 243 torch.Size([32, 8, 8]) | 1340 | 250 torch.Size([64, 8, 8]) | 1370 | 257 torch.Size([128, 8, 8]) | 1400 | 280 torch.Size([512, 8, 8]) | 1720 | 346 torch.Size([1024, 8, 8]) | 2300 | 390 torch.Size([1, 16, 16]) | 1380 | 245 torch.Size([2, 16, 16]) | 1000 | 300 torch.Size([4, 16, 16]) | 610 | 260 torch.Size([8, 16, 16]) | 862 | 260 torch.Size([16, 16, 16]) | 1350 | 260 torch.Size([32, 16, 16]) | 1370 | 260 torch.Size([64, 16, 16]) | 1440 | 273 torch.Size([128, 16, 16]) | 1520 | 289 torch.Size([512, 16, 16]) | 1880 | 350 torch.Size([1024, 16, 16]) | 2540 | 530 torch.Size([1, 32, 32]) | 1500 | 290 torch.Size([2, 32, 32]) | 2100 | 287 torch.Size([4, 32, 32]) | 1370 | 288 torch.Size([8, 32, 32]) | 1389 | 290 torch.Size([16, 32, 32]) | 1400 | 290 torch.Size([32, 32, 32]) | 1500 | 476 torch.Size([64, 32, 32]) | 1600 | 468 torch.Size([128, 32, 32]) | 1700 | 479 torch.Size([512, 32, 32]) | 2300 | 696 torch.Size([1024, 32, 32]) | 3200 | 1200 torch.Size([1, 64, 64]) | 1700 | 340 torch.Size([2, 64, 64]) | 2800 | 353 torch.Size([4, 64, 64]) | 1990 | 328 torch.Size([8, 64, 64]) | 2040 | 330 torch.Size([16, 64, 64]) | 2100 | 350 torch.Size([32, 64, 64]) | 2300 | 680 torch.Size([64, 64, 64]) | 2430 | 725 torch.Size([128, 64, 64]) | 2600 | 845 torch.Size([512, 64, 64]) | 4700 | 1900 torch.Size([1024, 64, 64]) | 9200 | 4280 torch.Size([1, 128, 128]) | 2300 | 497 torch.Size([2, 128, 128]) | 4000 | 562 torch.Size([4, 128, 128]) | 3140 | 669 torch.Size([8, 128, 128]) | 3200 | 698 torch.Size([16, 128, 128]) | 3400 | 810 torch.Size([32, 128, 128]) | 3866 | 1410 torch.Size([64, 128, 128]) | 4200 | 1670 torch.Size([128, 128, 128]) | 5050 | 2170 torch.Size([512, 128, 128]) | 14000 | 6417 torch.Size([1024, 128, 128]) | 28900 | 14700 torch.Size([1, 256, 256]) | 4100 | 1559 torch.Size([2, 256, 256]) | 6800 | 1792 torch.Size([4, 256, 256]) | 7000 | 2000 torch.Size([8, 256, 256]) | 7300 | 2200 torch.Size([16, 256, 256]) | 7730 | 2540 torch.Size([32, 256, 256]) | 8500 | 3390 torch.Size([64, 256, 256]) | 11000 | 4470 torch.Size([128, 256, 256]) | 15900 | 6757 torch.Size([512, 256, 256]) | 50000 | 30000 torch.Size([1024, 256, 256]) | 102600 | 56400 torch.Size([1, 512, 512]) | 8793 | 3230 torch.Size([2, 512, 512]) | 13000 | 3920 torch.Size([4, 512, 512]) | 14000 | 4531 torch.Size([8, 512, 512]) | 15000 | 5114 torch.Size([16, 512, 512]) | 16700 | 6280 torch.Size([32, 512, 512]) | 22400 | 9530 torch.Size([64, 512, 512]) | 33700 | 14260 torch.Size([128, 512, 512]) | 56500 | 20000 Times are in microseconds (us). ``` </details> <details> <summary> Benchmarking Script </summary> ```python import torch import pickle import itertools from functools import partial from torch.utils.benchmark import Timer, Compare benchmark_name = "linalg.solve" label = "master" shapes = [1, 2, 8, 16, 32, 64, 128, 256, 512] batches = [(1,), (2,), (4,), (8,), (16,), (32,), (64,), (128,), (512,), (1024,)] results = [] make_arg = partial(torch.randn, dtype=torch.float32, device="cuda", requires_grad=True) for n, batch in itertools.product(shapes, batches): if n == 512 and batch[0] >= 512: continue A = make_arg(batch + (n, n)) B = make_arg(batch + (n, 16)) ones = torch.ones(B.shape, device=B.device) print(A.shape) for adjoint in (True, False): timer = Timer("torch.linalg.solve(A, B).backward(gradient=ones, inputs=[A, B])", globals=globals(), label=benchmark_name, description=label, sub_label=f"{A.shape}", num_threads=1) results.append(timer.blocked_autorange()) compare = Compare(results) compare.trim_significant_figures() compare.print() with open("{}.pickle".format(label), 'wb') as f: pickle.dump(results, f) ``` </details> See #72935 (comment) for the script to join the results. [ghstack-poisoned]
This PR heavily simplifies the code of `linalg.solve`. At the same time, this implementation saves quite a some copies of the input data in some cases (e.g. A is contiguous) We also implement it in such a way that the derivative goes from computing two LU decompositions and two LU solves to no LU decompositions and one LU solves. It also avoids a number of unnecessary copies the derivative was unnecessarily performing (at least the copy of two matrices). On top of this, we add a `left` kw-only arg that allows the user to solve `XA = B` rather concisely. This PR also makes `torch.solve` an alias of `torch.linalg.solve`. **Note:** This used to be the last PR of the stack. Now it's here because some tests were not passing in a PR that was before this one in the stack, and reshuffling the stack solved those problems. The benchmarks below are performed wrt the last PR fo this stack. We compare the performance of `linalg.solve` against master (before merging #67833, but already with a few PRs of the LU stack merged). We see that we got between **x2.5 and x10 speed-ups in `linalg.solve`**. <details> <summary> Benchmark Results </summary> ``` [--------------------- linalg.solve + backward --------------------] | master | This PR 1 threads: ---------------------------------------------- torch.Size([1, 1, 1]) | 1280 | 267 torch.Size([2, 1, 1]) | 1300 | 200 torch.Size([4, 1, 1]) | 500 | 231 torch.Size([8, 1, 1]) | 600 | 232 torch.Size([16, 1, 1]) | 1200 | 234 torch.Size([32, 1, 1]) | 1300 | 239 torch.Size([64, 1, 1]) | 1300 | 300 torch.Size([128, 1, 1]) | 1340 | 331 torch.Size([512, 1, 1]) | 1664 | 380 torch.Size([1024, 1, 1]) | 2000 | 430 torch.Size([1, 2, 2]) | 1200 | 300 torch.Size([2, 2, 2]) | 1250 | 237 torch.Size([4, 2, 2]) | 479 | 240 torch.Size([8, 2, 2]) | 600 | 239 torch.Size([16, 2, 2]) | 1300 | 242 torch.Size([32, 2, 2]) | 1300 | 245 torch.Size([64, 2, 2]) | 1300 | 260 torch.Size([128, 2, 2]) | 1400 | 340 torch.Size([512, 2, 2]) | 1680 | 380 torch.Size([1024, 2, 2]) | 2100 | 430 torch.Size([1, 8, 8]) | 1200 | 250 torch.Size([2, 8, 8]) | 1240 | 238 torch.Size([4, 8, 8]) | 480 | 240 torch.Size([8, 8, 8]) | 600 | 240 torch.Size([16, 8, 8]) | 1330 | 243 torch.Size([32, 8, 8]) | 1340 | 250 torch.Size([64, 8, 8]) | 1370 | 257 torch.Size([128, 8, 8]) | 1400 | 280 torch.Size([512, 8, 8]) | 1720 | 346 torch.Size([1024, 8, 8]) | 2300 | 390 torch.Size([1, 16, 16]) | 1380 | 245 torch.Size([2, 16, 16]) | 1000 | 300 torch.Size([4, 16, 16]) | 610 | 260 torch.Size([8, 16, 16]) | 862 | 260 torch.Size([16, 16, 16]) | 1350 | 260 torch.Size([32, 16, 16]) | 1370 | 260 torch.Size([64, 16, 16]) | 1440 | 273 torch.Size([128, 16, 16]) | 1520 | 289 torch.Size([512, 16, 16]) | 1880 | 350 torch.Size([1024, 16, 16]) | 2540 | 530 torch.Size([1, 32, 32]) | 1500 | 290 torch.Size([2, 32, 32]) | 2100 | 287 torch.Size([4, 32, 32]) | 1370 | 288 torch.Size([8, 32, 32]) | 1389 | 290 torch.Size([16, 32, 32]) | 1400 | 290 torch.Size([32, 32, 32]) | 1500 | 476 torch.Size([64, 32, 32]) | 1600 | 468 torch.Size([128, 32, 32]) | 1700 | 479 torch.Size([512, 32, 32]) | 2300 | 696 torch.Size([1024, 32, 32]) | 3200 | 1200 torch.Size([1, 64, 64]) | 1700 | 340 torch.Size([2, 64, 64]) | 2800 | 353 torch.Size([4, 64, 64]) | 1990 | 328 torch.Size([8, 64, 64]) | 2040 | 330 torch.Size([16, 64, 64]) | 2100 | 350 torch.Size([32, 64, 64]) | 2300 | 680 torch.Size([64, 64, 64]) | 2430 | 725 torch.Size([128, 64, 64]) | 2600 | 845 torch.Size([512, 64, 64]) | 4700 | 1900 torch.Size([1024, 64, 64]) | 9200 | 4280 torch.Size([1, 128, 128]) | 2300 | 497 torch.Size([2, 128, 128]) | 4000 | 562 torch.Size([4, 128, 128]) | 3140 | 669 torch.Size([8, 128, 128]) | 3200 | 698 torch.Size([16, 128, 128]) | 3400 | 810 torch.Size([32, 128, 128]) | 3866 | 1410 torch.Size([64, 128, 128]) | 4200 | 1670 torch.Size([128, 128, 128]) | 5050 | 2170 torch.Size([512, 128, 128]) | 14000 | 6417 torch.Size([1024, 128, 128]) | 28900 | 14700 torch.Size([1, 256, 256]) | 4100 | 1559 torch.Size([2, 256, 256]) | 6800 | 1792 torch.Size([4, 256, 256]) | 7000 | 2000 torch.Size([8, 256, 256]) | 7300 | 2200 torch.Size([16, 256, 256]) | 7730 | 2540 torch.Size([32, 256, 256]) | 8500 | 3390 torch.Size([64, 256, 256]) | 11000 | 4470 torch.Size([128, 256, 256]) | 15900 | 6757 torch.Size([512, 256, 256]) | 50000 | 30000 torch.Size([1024, 256, 256]) | 102600 | 56400 torch.Size([1, 512, 512]) | 8793 | 3230 torch.Size([2, 512, 512]) | 13000 | 3920 torch.Size([4, 512, 512]) | 14000 | 4531 torch.Size([8, 512, 512]) | 15000 | 5114 torch.Size([16, 512, 512]) | 16700 | 6280 torch.Size([32, 512, 512]) | 22400 | 9530 torch.Size([64, 512, 512]) | 33700 | 14260 torch.Size([128, 512, 512]) | 56500 | 20000 Times are in microseconds (us). ``` </details> <details> <summary> Benchmarking Script </summary> ```python import torch import pickle import itertools from functools import partial from torch.utils.benchmark import Timer, Compare benchmark_name = "linalg.solve" label = "master" shapes = [1, 2, 8, 16, 32, 64, 128, 256, 512] batches = [(1,), (2,), (4,), (8,), (16,), (32,), (64,), (128,), (512,), (1024,)] results = [] make_arg = partial(torch.randn, dtype=torch.float32, device="cuda", requires_grad=True) for n, batch in itertools.product(shapes, batches): if n == 512 and batch[0] >= 512: continue A = make_arg(batch + (n, n)) B = make_arg(batch + (n, 16)) ones = torch.ones(B.shape, device=B.device) print(A.shape) for adjoint in (True, False): timer = Timer("torch.linalg.solve(A, B).backward(gradient=ones, inputs=[A, B])", globals=globals(), label=benchmark_name, description=label, sub_label=f"{A.shape}", num_threads=1) results.append(timer.blocked_autorange()) compare = Compare(results) compare.trim_significant_figures() compare.print() with open("{}.pickle".format(label), 'wb') as f: pickle.dump(results, f) ``` </details> See #72935 (comment) for the script to join the results. [ghstack-poisoned]
Summary: This PR heavily simplifies the code of `linalg.solve`. At the same time, this implementation saves quite a few copies of the input data in some cases (e.g. A is contiguous) We also implement it in such a way that the derivative goes from computing two LU decompositions and two LU solves to no LU decompositions and one LU solves. It also avoids a number of unnecessary copies the derivative was unnecessarily performing (at least the copy of two matrices). On top of this, we add a `left` kw-only arg that allows the user to solve `XA = B` rather concisely. Pull Request resolved: #74046 Approved by: https://github.com/nikitaved, https://github.com/IvanYashchuk, https://github.com/mruberry Test Plan: contbuild & OSS CI, see https://hud.pytorch.org/commit/pytorch/pytorch/54949a5abc9890143de4b5dd2f13ff98446376a3 Reviewed By: osalpekar Differential Revision: D37089130 Pulled By: osalpekar fbshipit-source-id: ca444fe7127bb3faf1de717100a4ad21e4d0f681
|
I think this PR introduced some maybe-unintended errors to extreme cases (see example and discussions in #90453). I understand errors in the correctness of the solution for ill-conditioned matrix are expected. However, one still expect the numerical error, e.g., |
This PR heavily simplifies the code of `linalg.solve`. At the same time, this implementation saves quite a few copies of the input data in some cases (e.g. A is contiguous) We also implement it in such a way that the derivative goes from computing two LU decompositions and two LU solves to no LU decompositions and one LU solves. It also avoids a number of unnecessary copies the derivative was unnecessarily performing (at least the copy of two matrices). On top of this, we add a `left` kw-only arg that allows the user to solve `XA = B` rather concisely. Pull Request resolved: pytorch#74046 Approved by: https://github.com/nikitaved, https://github.com/IvanYashchuk, https://github.com/mruberry
Stack from ghstack:
This PR heavily simplifies the code of
linalg.solve. At the same time,this implementation saves quite a some copies of the input data in some
cases (e.g. A is contiguous)
We also implement it in such a way that the derivative goes from
computing two LU decompositions and two LU solves to no LU
decompositions and one LU solves. It also avoids a number of unnecessary
copies the derivative was unnecessarily performing (at least the copy of
two matrices).
On top of this, we add a
leftkw-only arg that allows the user tosolve
XA = Brather concisely.This PR also makes
torch.solvean alias oftorch.linalg.solve.Note: This used to be the last PR of the stack. Now it's here because some tests were not passing in a PR that was before this one in the stack, and reshuffling the stack solved those problems. The benchmarks below are performed wrt the last PR fo this stack.
We compare the performance of
linalg.solveagainst master (before merging #67833, but already with a few PRs of the LU stack merged). We see that we got between x2.5 and x10 speed-ups inlinalg.solve.Benchmark Results
Benchmarking Script
See #72935 (comment) for the script to join the results.