Update and improve the heuristics for linalg.lu_factor#73878
Update and improve the heuristics for linalg.lu_factor#73878lezcano wants to merge 43 commits intogh/Lezcano/52/basefrom
Conversation
This PR adds getrf_cublas to the functions considered in the heuristics for lu_solve. [ghstack-poisoned]
CI Flow Status⚛️ CI FlowRuleset - Version:
|
🔗 Helpful links
✅ No Failures (0 Pending)As of commit 1bd5dc4 (more details on the Dr. CI page): Expand to see more💚 💚 Looks good so far! There are no failures yet. 💚 💚 This comment was automatically generated by Dr. CI (expand for details).Please report bugs/suggestions to the (internal) Dr. CI Users group. |
This PR adds getrf_cublas to the functions considered in the heuristics
for lu_solve. It also updates the heuristics of the function.
## Benchmark
I'm omitting form the benchmarks the looped versions of the functions as they are much slower than the non-looped ones. The only exception to this is cusolver's looped variant, which is faster when applied to a batch of size one.
<details>
<summary>
Benchmark Results
</summary>
```
[------------------------------------------------- linalg.lu_factor CUDA -------------------------------------------------]
| lu_factor_heuristic | lu_factor_magma_batched | lu_factor_cusolver_batched
1 threads: ----------------------------------------------------------------------------------------------------------------
shape torch.Size([1, 1, 1]) | 26 | 47 | 26
shape torch.Size([2, 1, 1]) | 17 | 38 | 17
shape torch.Size([4, 1, 1]) | 17 | 38 | 17
shape torch.Size([8, 1, 1]) | 20 | 38 | 18
shape torch.Size([16, 1, 1]) | 20 | 38 | 17
shape torch.Size([32, 1, 1]) | 18 | 38 | 17
shape torch.Size([64, 1, 1]) | 18 | 39 | 17
shape torch.Size([128, 1, 1]) | 17 | 38 | 17
shape torch.Size([512, 1, 1]) | 18 | 39 | 18
shape torch.Size([1024, 1, 1]) | 18 | 40 | 18
shape torch.Size([1, 2, 2]) | 18 | 38 | 17
shape torch.Size([2, 2, 2]) | 17 | 37 | 17
shape torch.Size([4, 2, 2]) | 17 | 38 | 17
shape torch.Size([8, 2, 2]) | 17 | 38 | 17
shape torch.Size([16, 2, 2]) | 17 | 38 | 17
shape torch.Size([32, 2, 2]) | 17 | 38 | 17
shape torch.Size([64, 2, 2]) | 17 | 38 | 17
shape torch.Size([128, 2, 2]) | 17 | 38 | 17
shape torch.Size([512, 2, 2]) | 17 | 39 | 17
shape torch.Size([1024, 2, 2]) | 17 | 40 | 17
shape torch.Size([1, 8, 8]) | 17 | 40 | 17
shape torch.Size([2, 8, 8]) | 17 | 40 | 17
shape torch.Size([4, 8, 8]) | 17 | 40 | 17
shape torch.Size([8, 8, 8]) | 17 | 40 | 17
shape torch.Size([16, 8, 8]) | 17 | 41 | 17
shape torch.Size([32, 8, 8]) | 17 | 40 | 17
shape torch.Size([64, 8, 8]) | 17 | 40 | 17
shape torch.Size([128, 8, 8]) | 17 | 40 | 17
shape torch.Size([512, 8, 8]) | 17 | 42 | 17
shape torch.Size([1024, 8, 8]) | 17 | 44 | 17
shape torch.Size([1, 16, 16]) | 24 | 44 | 18
shape torch.Size([2, 16, 16]) | 18 | 44 | 18
shape torch.Size([4, 16, 16]) | 18 | 45 | 18
shape torch.Size([8, 16, 16]) | 19 | 44 | 19
shape torch.Size([16, 16, 16]) | 20 | 44 | 20
shape torch.Size([32, 16, 16]) | 20 | 45 | 20
shape torch.Size([64, 16, 16]) | 20 | 44 | 20
shape torch.Size([128, 16, 16]) | 20 | 45 | 20
shape torch.Size([512, 16, 16]) | 28 | 50 | 28
shape torch.Size([1024, 16, 16]) | 41 | 59 | 41
shape torch.Size([1, 32, 32]) | 58 | 50 | 56
shape torch.Size([2, 32, 32]) | 56 | 50 | 56
shape torch.Size([4, 32, 32]) | 56 | 50 | 57
shape torch.Size([8, 32, 32]) | 60 | 50 | 60
shape torch.Size([16, 32, 32]) | 60 | 51 | 60
shape torch.Size([32, 32, 32]) | 247 | 51 | 61
shape torch.Size([64, 32, 32]) | 233 | 51 | 63
shape torch.Size([128, 32, 32]) | 236 | 53 | 66
shape torch.Size([512, 32, 32]) | 268 | 97 | 193
shape torch.Size([1024, 32, 32]) | 317 | 167 | 333
shape torch.Size([1, 64, 64]) | 131 | 216 | 99
shape torch.Size([2, 64, 64]) | 99 | 220 | 99
shape torch.Size([4, 64, 64]) | 99 | 225 | 101
shape torch.Size([8, 64, 64]) | 101 | 225 | 102
shape torch.Size([16, 64, 64]) | 107 | 230 | 108
shape torch.Size([32, 64, 64]) | 440 | 235 | 126
shape torch.Size([64, 64, 64]) | 447 | 240 | 155
shape torch.Size([128, 64, 64]) | 470 | 289 | 240
shape torch.Size([512, 64, 64]) | 793 | 678 | 1180
shape torch.Size([1024, 64, 64]) | 1000 | 1300 | 2112
shape torch.Size([1, 128, 128]) | 296 | 482 | 309
shape torch.Size([2, 128, 128]) | 308 | 499 | 307
shape torch.Size([4, 128, 128]) | 311 | 510 | 310
shape torch.Size([8, 128, 128]) | 314 | 522 | 314
shape torch.Size([16, 128, 128]) | 334 | 541 | 334
shape torch.Size([32, 128, 128]) | 770 | 591 | 467
shape torch.Size([64, 128, 128]) | 860 | 694 | 733
shape torch.Size([128, 128, 128]) | 1040 | 925 | 1980
shape torch.Size([512, 128, 128]) | 2883 | 2809 | 11000
shape torch.Size([1024, 128, 128]) | 5421 | 5430 | 22360
shape torch.Size([1, 256, 256]) | 1310 | 1109 | 1556
shape torch.Size([2, 256, 256]) | 1360 | 1150 | 1560
shape torch.Size([4, 256, 256]) | 1390 | 1188 | 1569
shape torch.Size([8, 256, 256]) | 1440 | 1250 | 1604
shape torch.Size([16, 256, 256]) | 1550 | 1390 | 1850
shape torch.Size([32, 256, 256]) | 1750 | 1620 | 3332
shape torch.Size([64, 256, 256]) | 2327 | 2246 | 6700
shape torch.Size([128, 256, 256]) | 3697 | 3638 | 19100
shape torch.Size([512, 256, 256]) | 12530 | 12500 | 87300
shape torch.Size([1024, 256, 256]) | 24380 | 24420 | 176000
```
</details>
To generate the results below, I put the backend I wanted to test at the beginning of the function `lu_solve_kernel`, followed by a `return;`. Then I run the following script, changing the variable `name`. For the `lu_solve unpack+solve_triangular`, I also changed the `stmt` variable (uncomenting the commented one)
<details>
<summary>
Benchmarking script
</summary>
```python
import torch
import pickle
import itertools
from functools import partial
from torch.utils.benchmark import Timer, Compare
benchmark_name = "linalg.lu_factor CUDA"
name = "magma_batched"
label = "lu_factor_{}".format(name)
shapes = [1, 2, 8, 16, 32, 64, 128, 256]
batches = [(1,), (2,), (4,), (8,), (16,), (32,), (64,), (128,), (512,), (1024,)]
results = []
make_arg = partial(torch.randn, dtype=torch.float32, device="cuda")
for n, batch in itertools.product(shapes, batches):
A = make_arg(batch + (n, n))
print(A.shape)
stmt = "torch.linalg.lu_factor_ex(A)"
timer = Timer(stmt,
globals=globals(),
label=benchmark_name,
description=label,
sub_label=f"shape {A.shape}",
num_threads=1)
results.append(timer.blocked_autorange())
# Test
LU, pivots = torch.linalg.lu_factor(A)
P, L, U = torch.lu_unpack(LU, pivots)
assert torch.allclose(P @ L @ U, A, rtol=1e-2, atol=1e-3)
compare = Compare(results)
compare.trim_significant_figures()
compare.print()
with open(f"{label}.pickle", 'wb') as f:
pickle.dump(results, f)
```
</details>
See #72935 (comment) for the script to join the results.
[ghstack-poisoned]
This PR adds getrf_cublas to the functions considered in the heuristics
for lu_solve. It also updates the heuristics of the function.
## Benchmark
I'm omitting form the benchmarks the looped versions of the functions as they are much slower than the non-looped ones. The only exception to this is cusolver's looped variant, which is faster when applied to a batch of size one.
<details>
<summary>
Benchmark Results
</summary>
```
[------------------------------------------------- linalg.lu_factor CUDA -------------------------------------------------]
| lu_factor_heuristic | lu_factor_magma_batched | lu_factor_cusolver_batched
1 threads: ----------------------------------------------------------------------------------------------------------------
shape torch.Size([1, 1, 1]) | 26 | 47 | 26
shape torch.Size([2, 1, 1]) | 17 | 38 | 17
shape torch.Size([4, 1, 1]) | 17 | 38 | 17
shape torch.Size([8, 1, 1]) | 20 | 38 | 18
shape torch.Size([16, 1, 1]) | 20 | 38 | 17
shape torch.Size([32, 1, 1]) | 18 | 38 | 17
shape torch.Size([64, 1, 1]) | 18 | 39 | 17
shape torch.Size([128, 1, 1]) | 17 | 38 | 17
shape torch.Size([512, 1, 1]) | 18 | 39 | 18
shape torch.Size([1024, 1, 1]) | 18 | 40 | 18
shape torch.Size([1, 2, 2]) | 18 | 38 | 17
shape torch.Size([2, 2, 2]) | 17 | 37 | 17
shape torch.Size([4, 2, 2]) | 17 | 38 | 17
shape torch.Size([8, 2, 2]) | 17 | 38 | 17
shape torch.Size([16, 2, 2]) | 17 | 38 | 17
shape torch.Size([32, 2, 2]) | 17 | 38 | 17
shape torch.Size([64, 2, 2]) | 17 | 38 | 17
shape torch.Size([128, 2, 2]) | 17 | 38 | 17
shape torch.Size([512, 2, 2]) | 17 | 39 | 17
shape torch.Size([1024, 2, 2]) | 17 | 40 | 17
shape torch.Size([1, 8, 8]) | 17 | 40 | 17
shape torch.Size([2, 8, 8]) | 17 | 40 | 17
shape torch.Size([4, 8, 8]) | 17 | 40 | 17
shape torch.Size([8, 8, 8]) | 17 | 40 | 17
shape torch.Size([16, 8, 8]) | 17 | 41 | 17
shape torch.Size([32, 8, 8]) | 17 | 40 | 17
shape torch.Size([64, 8, 8]) | 17 | 40 | 17
shape torch.Size([128, 8, 8]) | 17 | 40 | 17
shape torch.Size([512, 8, 8]) | 17 | 42 | 17
shape torch.Size([1024, 8, 8]) | 17 | 44 | 17
shape torch.Size([1, 16, 16]) | 24 | 44 | 18
shape torch.Size([2, 16, 16]) | 18 | 44 | 18
shape torch.Size([4, 16, 16]) | 18 | 45 | 18
shape torch.Size([8, 16, 16]) | 19 | 44 | 19
shape torch.Size([16, 16, 16]) | 20 | 44 | 20
shape torch.Size([32, 16, 16]) | 20 | 45 | 20
shape torch.Size([64, 16, 16]) | 20 | 44 | 20
shape torch.Size([128, 16, 16]) | 20 | 45 | 20
shape torch.Size([512, 16, 16]) | 28 | 50 | 28
shape torch.Size([1024, 16, 16]) | 41 | 59 | 41
shape torch.Size([1, 32, 32]) | 58 | 50 | 56
shape torch.Size([2, 32, 32]) | 56 | 50 | 56
shape torch.Size([4, 32, 32]) | 56 | 50 | 57
shape torch.Size([8, 32, 32]) | 60 | 50 | 60
shape torch.Size([16, 32, 32]) | 60 | 51 | 60
shape torch.Size([32, 32, 32]) | 247 | 51 | 61
shape torch.Size([64, 32, 32]) | 233 | 51 | 63
shape torch.Size([128, 32, 32]) | 236 | 53 | 66
shape torch.Size([512, 32, 32]) | 268 | 97 | 193
shape torch.Size([1024, 32, 32]) | 317 | 167 | 333
shape torch.Size([1, 64, 64]) | 131 | 216 | 99
shape torch.Size([2, 64, 64]) | 99 | 220 | 99
shape torch.Size([4, 64, 64]) | 99 | 225 | 101
shape torch.Size([8, 64, 64]) | 101 | 225 | 102
shape torch.Size([16, 64, 64]) | 107 | 230 | 108
shape torch.Size([32, 64, 64]) | 440 | 235 | 126
shape torch.Size([64, 64, 64]) | 447 | 240 | 155
shape torch.Size([128, 64, 64]) | 470 | 289 | 240
shape torch.Size([512, 64, 64]) | 793 | 678 | 1180
shape torch.Size([1024, 64, 64]) | 1000 | 1300 | 2112
shape torch.Size([1, 128, 128]) | 296 | 482 | 309
shape torch.Size([2, 128, 128]) | 308 | 499 | 307
shape torch.Size([4, 128, 128]) | 311 | 510 | 310
shape torch.Size([8, 128, 128]) | 314 | 522 | 314
shape torch.Size([16, 128, 128]) | 334 | 541 | 334
shape torch.Size([32, 128, 128]) | 770 | 591 | 467
shape torch.Size([64, 128, 128]) | 860 | 694 | 733
shape torch.Size([128, 128, 128]) | 1040 | 925 | 1980
shape torch.Size([512, 128, 128]) | 2883 | 2809 | 11000
shape torch.Size([1024, 128, 128]) | 5421 | 5430 | 22360
shape torch.Size([1, 256, 256]) | 1310 | 1109 | 1556
shape torch.Size([2, 256, 256]) | 1360 | 1150 | 1560
shape torch.Size([4, 256, 256]) | 1390 | 1188 | 1569
shape torch.Size([8, 256, 256]) | 1440 | 1250 | 1604
shape torch.Size([16, 256, 256]) | 1550 | 1390 | 1850
shape torch.Size([32, 256, 256]) | 1750 | 1620 | 3332
shape torch.Size([64, 256, 256]) | 2327 | 2246 | 6700
shape torch.Size([128, 256, 256]) | 3697 | 3638 | 19100
shape torch.Size([512, 256, 256]) | 12530 | 12500 | 87300
shape torch.Size([1024, 256, 256]) | 24380 | 24420 | 176000
```
</details>
To generate the results below, I put the backend I wanted to test at the beginning of the function `lu_solve_kernel`, followed by a `return;`. Then I run the following script, changing the variable `name`. For the `lu_solve unpack+solve_triangular`, I also changed the `stmt` variable (uncomenting the commented one)
<details>
<summary>
Benchmarking script
</summary>
```python
import torch
import pickle
import itertools
from functools import partial
from torch.utils.benchmark import Timer, Compare
benchmark_name = "linalg.lu_factor CUDA"
name = "magma_batched"
label = "lu_factor_{}".format(name)
shapes = [1, 2, 8, 16, 32, 64, 128, 256]
batches = [(1,), (2,), (4,), (8,), (16,), (32,), (64,), (128,), (512,), (1024,)]
results = []
make_arg = partial(torch.randn, dtype=torch.float32, device="cuda")
for n, batch in itertools.product(shapes, batches):
A = make_arg(batch + (n, n))
print(A.shape)
stmt = "torch.linalg.lu_factor_ex(A)"
timer = Timer(stmt,
globals=globals(),
label=benchmark_name,
description=label,
sub_label=f"shape {A.shape}",
num_threads=1)
results.append(timer.blocked_autorange())
# Test
LU, pivots = torch.linalg.lu_factor(A)
P, L, U = torch.lu_unpack(LU, pivots)
assert torch.allclose(P @ L @ U, A, rtol=1e-2, atol=1e-3)
compare = Compare(results)
compare.trim_significant_figures()
compare.print()
with open(f"{label}.pickle", 'wb') as f:
pickle.dump(results, f)
```
</details>
See #72935 (comment) for the script to join the results.
[ghstack-poisoned]
|
You're running the benchmark only up to 256x256 matrices. It's important to test larger matrices as well, this is the regime where the "looped" variant should be faster. |
|
How large do you want the matrices to be? Do you reckon adding 512 and 1024 would do it? |
|
A comment here suggests that 512 is the breaking point where the cusolver looped variant is better: pytorch/aten/src/ATen/native/cuda/linalg/BatchLinearAlgebraLib.cpp Lines 293 to 295 in bd13bc6 A few other examples of using 512: |
This PR adds getrf_cublas to the functions considered in the heuristics
for lu_solve. It also updates the heuristics of the function.
## Benchmark
I'm omitting form the benchmarks the looped versions of the functions as they are much slower than the non-looped ones. The only exception to this is cusolver's looped variant, which is faster when applied to a batch of size one.
<details>
<summary>
Benchmark Results
</summary>
```
[------------------------------------------------- linalg.lu_factor CUDA -------------------------------------------------]
| lu_factor_heuristic | lu_factor_magma_batched | lu_factor_cusolver_batched
1 threads: ----------------------------------------------------------------------------------------------------------------
shape torch.Size([1, 1, 1]) | 26 | 47 | 26
shape torch.Size([2, 1, 1]) | 17 | 38 | 17
shape torch.Size([4, 1, 1]) | 17 | 38 | 17
shape torch.Size([8, 1, 1]) | 20 | 38 | 18
shape torch.Size([16, 1, 1]) | 20 | 38 | 17
shape torch.Size([32, 1, 1]) | 18 | 38 | 17
shape torch.Size([64, 1, 1]) | 18 | 39 | 17
shape torch.Size([128, 1, 1]) | 17 | 38 | 17
shape torch.Size([512, 1, 1]) | 18 | 39 | 18
shape torch.Size([1024, 1, 1]) | 18 | 40 | 18
shape torch.Size([1, 2, 2]) | 18 | 38 | 17
shape torch.Size([2, 2, 2]) | 17 | 37 | 17
shape torch.Size([4, 2, 2]) | 17 | 38 | 17
shape torch.Size([8, 2, 2]) | 17 | 38 | 17
shape torch.Size([16, 2, 2]) | 17 | 38 | 17
shape torch.Size([32, 2, 2]) | 17 | 38 | 17
shape torch.Size([64, 2, 2]) | 17 | 38 | 17
shape torch.Size([128, 2, 2]) | 17 | 38 | 17
shape torch.Size([512, 2, 2]) | 17 | 39 | 17
shape torch.Size([1024, 2, 2]) | 17 | 40 | 17
shape torch.Size([1, 8, 8]) | 17 | 40 | 17
shape torch.Size([2, 8, 8]) | 17 | 40 | 17
shape torch.Size([4, 8, 8]) | 17 | 40 | 17
shape torch.Size([8, 8, 8]) | 17 | 40 | 17
shape torch.Size([16, 8, 8]) | 17 | 41 | 17
shape torch.Size([32, 8, 8]) | 17 | 40 | 17
shape torch.Size([64, 8, 8]) | 17 | 40 | 17
shape torch.Size([128, 8, 8]) | 17 | 40 | 17
shape torch.Size([512, 8, 8]) | 17 | 42 | 17
shape torch.Size([1024, 8, 8]) | 17 | 44 | 17
shape torch.Size([1, 16, 16]) | 24 | 44 | 18
shape torch.Size([2, 16, 16]) | 18 | 44 | 18
shape torch.Size([4, 16, 16]) | 18 | 45 | 18
shape torch.Size([8, 16, 16]) | 19 | 44 | 19
shape torch.Size([16, 16, 16]) | 20 | 44 | 20
shape torch.Size([32, 16, 16]) | 20 | 45 | 20
shape torch.Size([64, 16, 16]) | 20 | 44 | 20
shape torch.Size([128, 16, 16]) | 20 | 45 | 20
shape torch.Size([512, 16, 16]) | 28 | 50 | 28
shape torch.Size([1024, 16, 16]) | 41 | 59 | 41
shape torch.Size([1, 32, 32]) | 58 | 50 | 56
shape torch.Size([2, 32, 32]) | 56 | 50 | 56
shape torch.Size([4, 32, 32]) | 56 | 50 | 57
shape torch.Size([8, 32, 32]) | 60 | 50 | 60
shape torch.Size([16, 32, 32]) | 60 | 51 | 60
shape torch.Size([32, 32, 32]) | 247 | 51 | 61
shape torch.Size([64, 32, 32]) | 233 | 51 | 63
shape torch.Size([128, 32, 32]) | 236 | 53 | 66
shape torch.Size([512, 32, 32]) | 268 | 97 | 193
shape torch.Size([1024, 32, 32]) | 317 | 167 | 333
shape torch.Size([1, 64, 64]) | 131 | 216 | 99
shape torch.Size([2, 64, 64]) | 99 | 220 | 99
shape torch.Size([4, 64, 64]) | 99 | 225 | 101
shape torch.Size([8, 64, 64]) | 101 | 225 | 102
shape torch.Size([16, 64, 64]) | 107 | 230 | 108
shape torch.Size([32, 64, 64]) | 440 | 235 | 126
shape torch.Size([64, 64, 64]) | 447 | 240 | 155
shape torch.Size([128, 64, 64]) | 470 | 289 | 240
shape torch.Size([512, 64, 64]) | 793 | 678 | 1180
shape torch.Size([1024, 64, 64]) | 1000 | 1300 | 2112
shape torch.Size([1, 128, 128]) | 296 | 482 | 309
shape torch.Size([2, 128, 128]) | 308 | 499 | 307
shape torch.Size([4, 128, 128]) | 311 | 510 | 310
shape torch.Size([8, 128, 128]) | 314 | 522 | 314
shape torch.Size([16, 128, 128]) | 334 | 541 | 334
shape torch.Size([32, 128, 128]) | 770 | 591 | 467
shape torch.Size([64, 128, 128]) | 860 | 694 | 733
shape torch.Size([128, 128, 128]) | 1040 | 925 | 1980
shape torch.Size([512, 128, 128]) | 2883 | 2809 | 11000
shape torch.Size([1024, 128, 128]) | 5421 | 5430 | 22360
shape torch.Size([1, 256, 256]) | 1310 | 1109 | 1556
shape torch.Size([2, 256, 256]) | 1360 | 1150 | 1560
shape torch.Size([4, 256, 256]) | 1390 | 1188 | 1569
shape torch.Size([8, 256, 256]) | 1440 | 1250 | 1604
shape torch.Size([16, 256, 256]) | 1550 | 1390 | 1850
shape torch.Size([32, 256, 256]) | 1750 | 1620 | 3332
shape torch.Size([64, 256, 256]) | 2327 | 2246 | 6700
shape torch.Size([128, 256, 256]) | 3697 | 3638 | 19100
shape torch.Size([512, 256, 256]) | 12530 | 12500 | 87300
shape torch.Size([1024, 256, 256]) | 24380 | 24420 | 176000
```
</details>
To generate the results below, I put the backend I wanted to test at the beginning of the function `lu_solve_kernel`, followed by a `return;`. Then I run the following script, changing the variable `name`. For the `lu_solve unpack+solve_triangular`, I also changed the `stmt` variable (uncomenting the commented one)
<details>
<summary>
Benchmarking script
</summary>
```python
import torch
import pickle
import itertools
from functools import partial
from torch.utils.benchmark import Timer, Compare
benchmark_name = "linalg.lu_factor CUDA"
name = "magma_batched"
label = "lu_factor_{}".format(name)
shapes = [1, 2, 8, 16, 32, 64, 128, 256]
batches = [(1,), (2,), (4,), (8,), (16,), (32,), (64,), (128,), (512,), (1024,)]
results = []
make_arg = partial(torch.randn, dtype=torch.float32, device="cuda")
for n, batch in itertools.product(shapes, batches):
A = make_arg(batch + (n, n))
print(A.shape)
stmt = "torch.linalg.lu_factor_ex(A)"
timer = Timer(stmt,
globals=globals(),
label=benchmark_name,
description=label,
sub_label=f"shape {A.shape}",
num_threads=1)
results.append(timer.blocked_autorange())
# Test
LU, pivots = torch.linalg.lu_factor(A)
P, L, U = torch.lu_unpack(LU, pivots)
assert torch.allclose(P @ L @ U, A, rtol=1e-2, atol=1e-3)
compare = Compare(results)
compare.trim_significant_figures()
compare.print()
with open(f"{label}.pickle", 'wb') as f:
pickle.dump(results, f)
```
</details>
See #72935 (comment) for the script to join the results.
[ghstack-poisoned]
This PR adds getrf_cublas to the functions considered in the heuristics
for lu_solve. It also updates the heuristics of the function.
## Benchmark
I'm omitting form the benchmarks the looped versions of the functions as they are much slower than the non-looped ones. The only exception to this is cusolver's looped variant, which is faster when applied to a batch of size one.
<details>
<summary>
Benchmark Results
</summary>
```
[------------------------------------------------- linalg.lu_factor CUDA -------------------------------------------------]
| lu_factor_heuristic | lu_factor_magma_batched | lu_factor_cusolver_batched
1 threads: ----------------------------------------------------------------------------------------------------------------
shape torch.Size([1, 1, 1]) | 26 | 47 | 26
shape torch.Size([2, 1, 1]) | 17 | 38 | 17
shape torch.Size([4, 1, 1]) | 17 | 38 | 17
shape torch.Size([8, 1, 1]) | 20 | 38 | 18
shape torch.Size([16, 1, 1]) | 20 | 38 | 17
shape torch.Size([32, 1, 1]) | 18 | 38 | 17
shape torch.Size([64, 1, 1]) | 18 | 39 | 17
shape torch.Size([128, 1, 1]) | 17 | 38 | 17
shape torch.Size([512, 1, 1]) | 18 | 39 | 18
shape torch.Size([1024, 1, 1]) | 18 | 40 | 18
shape torch.Size([1, 2, 2]) | 18 | 38 | 17
shape torch.Size([2, 2, 2]) | 17 | 37 | 17
shape torch.Size([4, 2, 2]) | 17 | 38 | 17
shape torch.Size([8, 2, 2]) | 17 | 38 | 17
shape torch.Size([16, 2, 2]) | 17 | 38 | 17
shape torch.Size([32, 2, 2]) | 17 | 38 | 17
shape torch.Size([64, 2, 2]) | 17 | 38 | 17
shape torch.Size([128, 2, 2]) | 17 | 38 | 17
shape torch.Size([512, 2, 2]) | 17 | 39 | 17
shape torch.Size([1024, 2, 2]) | 17 | 40 | 17
shape torch.Size([1, 8, 8]) | 17 | 40 | 17
shape torch.Size([2, 8, 8]) | 17 | 40 | 17
shape torch.Size([4, 8, 8]) | 17 | 40 | 17
shape torch.Size([8, 8, 8]) | 17 | 40 | 17
shape torch.Size([16, 8, 8]) | 17 | 41 | 17
shape torch.Size([32, 8, 8]) | 17 | 40 | 17
shape torch.Size([64, 8, 8]) | 17 | 40 | 17
shape torch.Size([128, 8, 8]) | 17 | 40 | 17
shape torch.Size([512, 8, 8]) | 17 | 42 | 17
shape torch.Size([1024, 8, 8]) | 17 | 44 | 17
shape torch.Size([1, 16, 16]) | 24 | 44 | 18
shape torch.Size([2, 16, 16]) | 18 | 44 | 18
shape torch.Size([4, 16, 16]) | 18 | 45 | 18
shape torch.Size([8, 16, 16]) | 19 | 44 | 19
shape torch.Size([16, 16, 16]) | 20 | 44 | 20
shape torch.Size([32, 16, 16]) | 20 | 45 | 20
shape torch.Size([64, 16, 16]) | 20 | 44 | 20
shape torch.Size([128, 16, 16]) | 20 | 45 | 20
shape torch.Size([512, 16, 16]) | 28 | 50 | 28
shape torch.Size([1024, 16, 16]) | 41 | 59 | 41
shape torch.Size([1, 32, 32]) | 58 | 50 | 56
shape torch.Size([2, 32, 32]) | 56 | 50 | 56
shape torch.Size([4, 32, 32]) | 56 | 50 | 57
shape torch.Size([8, 32, 32]) | 60 | 50 | 60
shape torch.Size([16, 32, 32]) | 60 | 51 | 60
shape torch.Size([32, 32, 32]) | 247 | 51 | 61
shape torch.Size([64, 32, 32]) | 233 | 51 | 63
shape torch.Size([128, 32, 32]) | 236 | 53 | 66
shape torch.Size([512, 32, 32]) | 268 | 97 | 193
shape torch.Size([1024, 32, 32]) | 317 | 167 | 333
shape torch.Size([1, 64, 64]) | 131 | 216 | 99
shape torch.Size([2, 64, 64]) | 99 | 220 | 99
shape torch.Size([4, 64, 64]) | 99 | 225 | 101
shape torch.Size([8, 64, 64]) | 101 | 225 | 102
shape torch.Size([16, 64, 64]) | 107 | 230 | 108
shape torch.Size([32, 64, 64]) | 440 | 235 | 126
shape torch.Size([64, 64, 64]) | 447 | 240 | 155
shape torch.Size([128, 64, 64]) | 470 | 289 | 240
shape torch.Size([512, 64, 64]) | 793 | 678 | 1180
shape torch.Size([1024, 64, 64]) | 1000 | 1300 | 2112
shape torch.Size([1, 128, 128]) | 296 | 482 | 309
shape torch.Size([2, 128, 128]) | 308 | 499 | 307
shape torch.Size([4, 128, 128]) | 311 | 510 | 310
shape torch.Size([8, 128, 128]) | 314 | 522 | 314
shape torch.Size([16, 128, 128]) | 334 | 541 | 334
shape torch.Size([32, 128, 128]) | 770 | 591 | 467
shape torch.Size([64, 128, 128]) | 860 | 694 | 733
shape torch.Size([128, 128, 128]) | 1040 | 925 | 1980
shape torch.Size([512, 128, 128]) | 2883 | 2809 | 11000
shape torch.Size([1024, 128, 128]) | 5421 | 5430 | 22360
shape torch.Size([1, 256, 256]) | 1310 | 1109 | 1556
shape torch.Size([2, 256, 256]) | 1360 | 1150 | 1560
shape torch.Size([4, 256, 256]) | 1390 | 1188 | 1569
shape torch.Size([8, 256, 256]) | 1440 | 1250 | 1604
shape torch.Size([16, 256, 256]) | 1550 | 1390 | 1850
shape torch.Size([32, 256, 256]) | 1750 | 1620 | 3332
shape torch.Size([64, 256, 256]) | 2327 | 2246 | 6700
shape torch.Size([128, 256, 256]) | 3697 | 3638 | 19100
shape torch.Size([512, 256, 256]) | 12530 | 12500 | 87300
shape torch.Size([1024, 256, 256]) | 24380 | 24420 | 176000
```
</details>
To generate the results below, I put the backend I wanted to test at the beginning of the function `lu_solve_kernel`, followed by a `return;`. Then I run the following script, changing the variable `name`. For the `lu_solve unpack+solve_triangular`, I also changed the `stmt` variable (uncomenting the commented one)
<details>
<summary>
Benchmarking script
</summary>
```python
import torch
import pickle
import itertools
from functools import partial
from torch.utils.benchmark import Timer, Compare
benchmark_name = "linalg.lu_factor CUDA"
name = "magma_batched"
label = "lu_factor_{}".format(name)
shapes = [1, 2, 8, 16, 32, 64, 128, 256]
batches = [(1,), (2,), (4,), (8,), (16,), (32,), (64,), (128,), (512,), (1024,)]
results = []
make_arg = partial(torch.randn, dtype=torch.float32, device="cuda")
for n, batch in itertools.product(shapes, batches):
A = make_arg(batch + (n, n))
print(A.shape)
stmt = "torch.linalg.lu_factor_ex(A)"
timer = Timer(stmt,
globals=globals(),
label=benchmark_name,
description=label,
sub_label=f"shape {A.shape}",
num_threads=1)
results.append(timer.blocked_autorange())
# Test
LU, pivots = torch.linalg.lu_factor(A)
P, L, U = torch.lu_unpack(LU, pivots)
assert torch.allclose(P @ L @ U, A, rtol=1e-2, atol=1e-3)
compare = Compare(results)
compare.trim_significant_figures()
compare.print()
with open(f"{label}.pickle", 'wb') as f:
pickle.dump(results, f)
```
</details>
See #72935 (comment) for the script to join the results.
[ghstack-poisoned]
This PR adds getrf_cublas to the functions considered in the heuristics
for lu_solve. It also updates the heuristics of the function.
## Benchmark
I'm omitting form the benchmarks the looped versions of the functions as they are much slower than the non-looped ones. The only exception to this is cusolver's looped variant, which is faster when applied to a batch of size one.
<details>
<summary>
Benchmark Results
</summary>
```
[------------------------------------------------- linalg.lu_factor CUDA -------------------------------------------------]
| lu_factor_heuristic | lu_factor_magma_batched | lu_factor_cusolver_batched
1 threads: ----------------------------------------------------------------------------------------------------------------
shape torch.Size([1, 1, 1]) | 26 | 47 | 26
shape torch.Size([2, 1, 1]) | 17 | 38 | 17
shape torch.Size([4, 1, 1]) | 17 | 38 | 17
shape torch.Size([8, 1, 1]) | 20 | 38 | 18
shape torch.Size([16, 1, 1]) | 20 | 38 | 17
shape torch.Size([32, 1, 1]) | 18 | 38 | 17
shape torch.Size([64, 1, 1]) | 18 | 39 | 17
shape torch.Size([128, 1, 1]) | 17 | 38 | 17
shape torch.Size([512, 1, 1]) | 18 | 39 | 18
shape torch.Size([1024, 1, 1]) | 18 | 40 | 18
shape torch.Size([1, 2, 2]) | 18 | 38 | 17
shape torch.Size([2, 2, 2]) | 17 | 37 | 17
shape torch.Size([4, 2, 2]) | 17 | 38 | 17
shape torch.Size([8, 2, 2]) | 17 | 38 | 17
shape torch.Size([16, 2, 2]) | 17 | 38 | 17
shape torch.Size([32, 2, 2]) | 17 | 38 | 17
shape torch.Size([64, 2, 2]) | 17 | 38 | 17
shape torch.Size([128, 2, 2]) | 17 | 38 | 17
shape torch.Size([512, 2, 2]) | 17 | 39 | 17
shape torch.Size([1024, 2, 2]) | 17 | 40 | 17
shape torch.Size([1, 8, 8]) | 17 | 40 | 17
shape torch.Size([2, 8, 8]) | 17 | 40 | 17
shape torch.Size([4, 8, 8]) | 17 | 40 | 17
shape torch.Size([8, 8, 8]) | 17 | 40 | 17
shape torch.Size([16, 8, 8]) | 17 | 41 | 17
shape torch.Size([32, 8, 8]) | 17 | 40 | 17
shape torch.Size([64, 8, 8]) | 17 | 40 | 17
shape torch.Size([128, 8, 8]) | 17 | 40 | 17
shape torch.Size([512, 8, 8]) | 17 | 42 | 17
shape torch.Size([1024, 8, 8]) | 17 | 44 | 17
shape torch.Size([1, 16, 16]) | 24 | 44 | 18
shape torch.Size([2, 16, 16]) | 18 | 44 | 18
shape torch.Size([4, 16, 16]) | 18 | 45 | 18
shape torch.Size([8, 16, 16]) | 19 | 44 | 19
shape torch.Size([16, 16, 16]) | 20 | 44 | 20
shape torch.Size([32, 16, 16]) | 20 | 45 | 20
shape torch.Size([64, 16, 16]) | 20 | 44 | 20
shape torch.Size([128, 16, 16]) | 20 | 45 | 20
shape torch.Size([512, 16, 16]) | 28 | 50 | 28
shape torch.Size([1024, 16, 16]) | 41 | 59 | 41
shape torch.Size([1, 32, 32]) | 58 | 50 | 56
shape torch.Size([2, 32, 32]) | 56 | 50 | 56
shape torch.Size([4, 32, 32]) | 56 | 50 | 57
shape torch.Size([8, 32, 32]) | 60 | 50 | 60
shape torch.Size([16, 32, 32]) | 60 | 51 | 60
shape torch.Size([32, 32, 32]) | 247 | 51 | 61
shape torch.Size([64, 32, 32]) | 233 | 51 | 63
shape torch.Size([128, 32, 32]) | 236 | 53 | 66
shape torch.Size([512, 32, 32]) | 268 | 97 | 193
shape torch.Size([1024, 32, 32]) | 317 | 167 | 333
shape torch.Size([1, 64, 64]) | 131 | 216 | 99
shape torch.Size([2, 64, 64]) | 99 | 220 | 99
shape torch.Size([4, 64, 64]) | 99 | 225 | 101
shape torch.Size([8, 64, 64]) | 101 | 225 | 102
shape torch.Size([16, 64, 64]) | 107 | 230 | 108
shape torch.Size([32, 64, 64]) | 440 | 235 | 126
shape torch.Size([64, 64, 64]) | 447 | 240 | 155
shape torch.Size([128, 64, 64]) | 470 | 289 | 240
shape torch.Size([512, 64, 64]) | 793 | 678 | 1180
shape torch.Size([1024, 64, 64]) | 1000 | 1300 | 2112
shape torch.Size([1, 128, 128]) | 296 | 482 | 309
shape torch.Size([2, 128, 128]) | 308 | 499 | 307
shape torch.Size([4, 128, 128]) | 311 | 510 | 310
shape torch.Size([8, 128, 128]) | 314 | 522 | 314
shape torch.Size([16, 128, 128]) | 334 | 541 | 334
shape torch.Size([32, 128, 128]) | 770 | 591 | 467
shape torch.Size([64, 128, 128]) | 860 | 694 | 733
shape torch.Size([128, 128, 128]) | 1040 | 925 | 1980
shape torch.Size([512, 128, 128]) | 2883 | 2809 | 11000
shape torch.Size([1024, 128, 128]) | 5421 | 5430 | 22360
shape torch.Size([1, 256, 256]) | 1310 | 1109 | 1556
shape torch.Size([2, 256, 256]) | 1360 | 1150 | 1560
shape torch.Size([4, 256, 256]) | 1390 | 1188 | 1569
shape torch.Size([8, 256, 256]) | 1440 | 1250 | 1604
shape torch.Size([16, 256, 256]) | 1550 | 1390 | 1850
shape torch.Size([32, 256, 256]) | 1750 | 1620 | 3332
shape torch.Size([64, 256, 256]) | 2327 | 2246 | 6700
shape torch.Size([128, 256, 256]) | 3697 | 3638 | 19100
shape torch.Size([512, 256, 256]) | 12530 | 12500 | 87300
shape torch.Size([1024, 256, 256]) | 24380 | 24420 | 176000
```
</details>
To generate the results below, I put the backend I wanted to test at the beginning of the function `lu_solve_kernel`, followed by a `return;`. Then I run the following script, changing the variable `name`. For the `lu_solve unpack+solve_triangular`, I also changed the `stmt` variable (uncomenting the commented one)
<details>
<summary>
Benchmarking script
</summary>
```python
import torch
import pickle
import itertools
from functools import partial
from torch.utils.benchmark import Timer, Compare
benchmark_name = "linalg.lu_factor CUDA"
name = "magma_batched"
label = "lu_factor_{}".format(name)
shapes = [1, 2, 8, 16, 32, 64, 128, 256]
batches = [(1,), (2,), (4,), (8,), (16,), (32,), (64,), (128,), (512,), (1024,)]
results = []
make_arg = partial(torch.randn, dtype=torch.float32, device="cuda")
for n, batch in itertools.product(shapes, batches):
A = make_arg(batch + (n, n))
print(A.shape)
stmt = "torch.linalg.lu_factor_ex(A)"
timer = Timer(stmt,
globals=globals(),
label=benchmark_name,
description=label,
sub_label=f"shape {A.shape}",
num_threads=1)
results.append(timer.blocked_autorange())
# Test
LU, pivots = torch.linalg.lu_factor(A)
P, L, U = torch.lu_unpack(LU, pivots)
assert torch.allclose(P @ L @ U, A, rtol=1e-2, atol=1e-3)
compare = Compare(results)
compare.trim_significant_figures()
compare.print()
with open(f"{label}.pickle", 'wb') as f:
pickle.dump(results, f)
```
</details>
See #72935 (comment) for the script to join the results.
[ghstack-poisoned]
This PR adds getrf_cublas to the functions considered in the heuristics
for lu_solve. It also updates the heuristics of the function.
## Benchmark
I'm omitting form the benchmarks the looped versions of the functions as they are much slower than the non-looped ones. The only exception to this is cusolver's looped variant, which is faster when applied to a batch of size one.
<details>
<summary>
Benchmark Results
</summary>
```
[------------------------------------------------- linalg.lu_factor CUDA -------------------------------------------------]
| lu_factor_heuristic | lu_factor_magma_batched | lu_factor_cusolver_batched
1 threads: ----------------------------------------------------------------------------------------------------------------
shape torch.Size([1, 1, 1]) | 26 | 47 | 26
shape torch.Size([2, 1, 1]) | 17 | 38 | 17
shape torch.Size([4, 1, 1]) | 17 | 38 | 17
shape torch.Size([8, 1, 1]) | 20 | 38 | 18
shape torch.Size([16, 1, 1]) | 20 | 38 | 17
shape torch.Size([32, 1, 1]) | 18 | 38 | 17
shape torch.Size([64, 1, 1]) | 18 | 39 | 17
shape torch.Size([128, 1, 1]) | 17 | 38 | 17
shape torch.Size([512, 1, 1]) | 18 | 39 | 18
shape torch.Size([1024, 1, 1]) | 18 | 40 | 18
shape torch.Size([1, 2, 2]) | 18 | 38 | 17
shape torch.Size([2, 2, 2]) | 17 | 37 | 17
shape torch.Size([4, 2, 2]) | 17 | 38 | 17
shape torch.Size([8, 2, 2]) | 17 | 38 | 17
shape torch.Size([16, 2, 2]) | 17 | 38 | 17
shape torch.Size([32, 2, 2]) | 17 | 38 | 17
shape torch.Size([64, 2, 2]) | 17 | 38 | 17
shape torch.Size([128, 2, 2]) | 17 | 38 | 17
shape torch.Size([512, 2, 2]) | 17 | 39 | 17
shape torch.Size([1024, 2, 2]) | 17 | 40 | 17
shape torch.Size([1, 8, 8]) | 17 | 40 | 17
shape torch.Size([2, 8, 8]) | 17 | 40 | 17
shape torch.Size([4, 8, 8]) | 17 | 40 | 17
shape torch.Size([8, 8, 8]) | 17 | 40 | 17
shape torch.Size([16, 8, 8]) | 17 | 41 | 17
shape torch.Size([32, 8, 8]) | 17 | 40 | 17
shape torch.Size([64, 8, 8]) | 17 | 40 | 17
shape torch.Size([128, 8, 8]) | 17 | 40 | 17
shape torch.Size([512, 8, 8]) | 17 | 42 | 17
shape torch.Size([1024, 8, 8]) | 17 | 44 | 17
shape torch.Size([1, 16, 16]) | 24 | 44 | 18
shape torch.Size([2, 16, 16]) | 18 | 44 | 18
shape torch.Size([4, 16, 16]) | 18 | 45 | 18
shape torch.Size([8, 16, 16]) | 19 | 44 | 19
shape torch.Size([16, 16, 16]) | 20 | 44 | 20
shape torch.Size([32, 16, 16]) | 20 | 45 | 20
shape torch.Size([64, 16, 16]) | 20 | 44 | 20
shape torch.Size([128, 16, 16]) | 20 | 45 | 20
shape torch.Size([512, 16, 16]) | 28 | 50 | 28
shape torch.Size([1024, 16, 16]) | 41 | 59 | 41
shape torch.Size([1, 32, 32]) | 58 | 50 | 56
shape torch.Size([2, 32, 32]) | 56 | 50 | 56
shape torch.Size([4, 32, 32]) | 56 | 50 | 57
shape torch.Size([8, 32, 32]) | 60 | 50 | 60
shape torch.Size([16, 32, 32]) | 60 | 51 | 60
shape torch.Size([32, 32, 32]) | 247 | 51 | 61
shape torch.Size([64, 32, 32]) | 233 | 51 | 63
shape torch.Size([128, 32, 32]) | 236 | 53 | 66
shape torch.Size([512, 32, 32]) | 268 | 97 | 193
shape torch.Size([1024, 32, 32]) | 317 | 167 | 333
shape torch.Size([1, 64, 64]) | 131 | 216 | 99
shape torch.Size([2, 64, 64]) | 99 | 220 | 99
shape torch.Size([4, 64, 64]) | 99 | 225 | 101
shape torch.Size([8, 64, 64]) | 101 | 225 | 102
shape torch.Size([16, 64, 64]) | 107 | 230 | 108
shape torch.Size([32, 64, 64]) | 440 | 235 | 126
shape torch.Size([64, 64, 64]) | 447 | 240 | 155
shape torch.Size([128, 64, 64]) | 470 | 289 | 240
shape torch.Size([512, 64, 64]) | 793 | 678 | 1180
shape torch.Size([1024, 64, 64]) | 1000 | 1300 | 2112
shape torch.Size([1, 128, 128]) | 296 | 482 | 309
shape torch.Size([2, 128, 128]) | 308 | 499 | 307
shape torch.Size([4, 128, 128]) | 311 | 510 | 310
shape torch.Size([8, 128, 128]) | 314 | 522 | 314
shape torch.Size([16, 128, 128]) | 334 | 541 | 334
shape torch.Size([32, 128, 128]) | 770 | 591 | 467
shape torch.Size([64, 128, 128]) | 860 | 694 | 733
shape torch.Size([128, 128, 128]) | 1040 | 925 | 1980
shape torch.Size([512, 128, 128]) | 2883 | 2809 | 11000
shape torch.Size([1024, 128, 128]) | 5421 | 5430 | 22360
shape torch.Size([1, 256, 256]) | 1310 | 1109 | 1556
shape torch.Size([2, 256, 256]) | 1360 | 1150 | 1560
shape torch.Size([4, 256, 256]) | 1390 | 1188 | 1569
shape torch.Size([8, 256, 256]) | 1440 | 1250 | 1604
shape torch.Size([16, 256, 256]) | 1550 | 1390 | 1850
shape torch.Size([32, 256, 256]) | 1750 | 1620 | 3332
shape torch.Size([64, 256, 256]) | 2327 | 2246 | 6700
shape torch.Size([128, 256, 256]) | 3697 | 3638 | 19100
shape torch.Size([512, 256, 256]) | 12530 | 12500 | 87300
shape torch.Size([1024, 256, 256]) | 24380 | 24420 | 176000
```
</details>
To generate the results below, I put the backend I wanted to test at the beginning of the function `lu_solve_kernel`, followed by a `return;`. Then I run the following script, changing the variable `name`. For the `lu_solve unpack+solve_triangular`, I also changed the `stmt` variable (uncomenting the commented one)
<details>
<summary>
Benchmarking script
</summary>
```python
import torch
import pickle
import itertools
from functools import partial
from torch.utils.benchmark import Timer, Compare
benchmark_name = "linalg.lu_factor CUDA"
name = "magma_batched"
label = "lu_factor_{}".format(name)
shapes = [1, 2, 8, 16, 32, 64, 128, 256]
batches = [(1,), (2,), (4,), (8,), (16,), (32,), (64,), (128,), (512,), (1024,)]
results = []
make_arg = partial(torch.randn, dtype=torch.float32, device="cuda")
for n, batch in itertools.product(shapes, batches):
A = make_arg(batch + (n, n))
print(A.shape)
stmt = "torch.linalg.lu_factor_ex(A)"
timer = Timer(stmt,
globals=globals(),
label=benchmark_name,
description=label,
sub_label=f"shape {A.shape}",
num_threads=1)
results.append(timer.blocked_autorange())
# Test
LU, pivots = torch.linalg.lu_factor(A)
P, L, U = torch.lu_unpack(LU, pivots)
assert torch.allclose(P @ L @ U, A, rtol=1e-2, atol=1e-3)
compare = Compare(results)
compare.trim_significant_figures()
compare.print()
with open(f"{label}.pickle", 'wb') as f:
pickle.dump(results, f)
```
</details>
See #72935 (comment) for the script to join the results.
[ghstack-poisoned]
This PR adds getrf_cublas to the functions considered in the heuristics
for `lu_factor`. It also updates the heuristics of the function.
## Benchmark
I'm omitting form the benchmarks the looped versions of the functions as they are much slower than the non-looped ones. The only exception to this is cusolver's looped variant, which is faster when applied to a batch of size one.
<details>
<summary>
Benchmark Results
</summary>
```
[------------------------------------------------- linalg.lu_factor CUDA -------------------------------------------------]
| lu_factor_heuristic | lu_factor_magma_batched | lu_factor_cusolver_batched
1 threads: ----------------------------------------------------------------------------------------------------------------
shape torch.Size([1, 1, 1]) | 26 | 47 | 26
shape torch.Size([2, 1, 1]) | 17 | 38 | 17
shape torch.Size([4, 1, 1]) | 17 | 38 | 17
shape torch.Size([8, 1, 1]) | 20 | 38 | 18
shape torch.Size([16, 1, 1]) | 20 | 38 | 17
shape torch.Size([32, 1, 1]) | 18 | 38 | 17
shape torch.Size([64, 1, 1]) | 18 | 39 | 17
shape torch.Size([128, 1, 1]) | 17 | 38 | 17
shape torch.Size([512, 1, 1]) | 18 | 39 | 18
shape torch.Size([1024, 1, 1]) | 18 | 40 | 18
shape torch.Size([1, 2, 2]) | 18 | 38 | 17
shape torch.Size([2, 2, 2]) | 17 | 37 | 17
shape torch.Size([4, 2, 2]) | 17 | 38 | 17
shape torch.Size([8, 2, 2]) | 17 | 38 | 17
shape torch.Size([16, 2, 2]) | 17 | 38 | 17
shape torch.Size([32, 2, 2]) | 17 | 38 | 17
shape torch.Size([64, 2, 2]) | 17 | 38 | 17
shape torch.Size([128, 2, 2]) | 17 | 38 | 17
shape torch.Size([512, 2, 2]) | 17 | 39 | 17
shape torch.Size([1024, 2, 2]) | 17 | 40 | 17
shape torch.Size([1, 8, 8]) | 17 | 40 | 17
shape torch.Size([2, 8, 8]) | 17 | 40 | 17
shape torch.Size([4, 8, 8]) | 17 | 40 | 17
shape torch.Size([8, 8, 8]) | 17 | 40 | 17
shape torch.Size([16, 8, 8]) | 17 | 41 | 17
shape torch.Size([32, 8, 8]) | 17 | 40 | 17
shape torch.Size([64, 8, 8]) | 17 | 40 | 17
shape torch.Size([128, 8, 8]) | 17 | 40 | 17
shape torch.Size([512, 8, 8]) | 17 | 42 | 17
shape torch.Size([1024, 8, 8]) | 17 | 44 | 17
shape torch.Size([1, 16, 16]) | 24 | 44 | 18
shape torch.Size([2, 16, 16]) | 18 | 44 | 18
shape torch.Size([4, 16, 16]) | 18 | 45 | 18
shape torch.Size([8, 16, 16]) | 19 | 44 | 19
shape torch.Size([16, 16, 16]) | 20 | 44 | 20
shape torch.Size([32, 16, 16]) | 20 | 45 | 20
shape torch.Size([64, 16, 16]) | 20 | 44 | 20
shape torch.Size([128, 16, 16]) | 20 | 45 | 20
shape torch.Size([512, 16, 16]) | 28 | 50 | 28
shape torch.Size([1024, 16, 16]) | 41 | 59 | 41
shape torch.Size([1, 32, 32]) | 58 | 50 | 56
shape torch.Size([2, 32, 32]) | 56 | 50 | 56
shape torch.Size([4, 32, 32]) | 56 | 50 | 57
shape torch.Size([8, 32, 32]) | 60 | 50 | 60
shape torch.Size([16, 32, 32]) | 60 | 51 | 60
shape torch.Size([32, 32, 32]) | 247 | 51 | 61
shape torch.Size([64, 32, 32]) | 233 | 51 | 63
shape torch.Size([128, 32, 32]) | 236 | 53 | 66
shape torch.Size([512, 32, 32]) | 268 | 97 | 193
shape torch.Size([1024, 32, 32]) | 317 | 167 | 333
shape torch.Size([1, 64, 64]) | 131 | 216 | 99
shape torch.Size([2, 64, 64]) | 99 | 220 | 99
shape torch.Size([4, 64, 64]) | 99 | 225 | 101
shape torch.Size([8, 64, 64]) | 101 | 225 | 102
shape torch.Size([16, 64, 64]) | 107 | 230 | 108
shape torch.Size([32, 64, 64]) | 440 | 235 | 126
shape torch.Size([64, 64, 64]) | 447 | 240 | 155
shape torch.Size([128, 64, 64]) | 470 | 289 | 240
shape torch.Size([512, 64, 64]) | 793 | 678 | 1180
shape torch.Size([1024, 64, 64]) | 1000 | 1300 | 2112
shape torch.Size([1, 128, 128]) | 296 | 482 | 309
shape torch.Size([2, 128, 128]) | 308 | 499 | 307
shape torch.Size([4, 128, 128]) | 311 | 510 | 310
shape torch.Size([8, 128, 128]) | 314 | 522 | 314
shape torch.Size([16, 128, 128]) | 334 | 541 | 334
shape torch.Size([32, 128, 128]) | 770 | 591 | 467
shape torch.Size([64, 128, 128]) | 860 | 694 | 733
shape torch.Size([128, 128, 128]) | 1040 | 925 | 1980
shape torch.Size([512, 128, 128]) | 2883 | 2809 | 11000
shape torch.Size([1024, 128, 128]) | 5421 | 5430 | 22360
shape torch.Size([1, 256, 256]) | 1310 | 1109 | 1556
shape torch.Size([2, 256, 256]) | 1360 | 1150 | 1560
shape torch.Size([4, 256, 256]) | 1390 | 1188 | 1569
shape torch.Size([8, 256, 256]) | 1440 | 1250 | 1604
shape torch.Size([16, 256, 256]) | 1550 | 1390 | 1850
shape torch.Size([32, 256, 256]) | 1750 | 1620 | 3332
shape torch.Size([64, 256, 256]) | 2327 | 2246 | 6700
shape torch.Size([128, 256, 256]) | 3697 | 3638 | 19100
shape torch.Size([512, 256, 256]) | 12530 | 12500 | 87300
shape torch.Size([1024, 256, 256]) | 24380 | 24420 | 176000
```
</details>
<details>
<summary>
Benchmark Results all algorithms up to `n=2048`
</summary>
```
[----------------------------------------------------------------- linalg.lu_factor CUDA ------------------------------------------------------------------]
| lu_factor_magma_batched | lu_factor_cusolver_batched | lu_factor_cusolver_looped | lu_factor_magma_looped
1 threads: -------------------------------------------------------------------------------------------------------------------------------------------------
shape torch.Size([1, 1, 1]) | 51 | 30 | 27 | 1390
shape torch.Size([2, 1, 1]) | 42 | 20 | 26 | 2798
shape torch.Size([4, 1, 1]) | 42 | 20 | 42 | 5589
shape torch.Size([8, 1, 1]) | 42 | 20 | 72 | 11000
shape torch.Size([16, 1, 1]) | 42 | 20 | 132 | 22400
shape torch.Size([32, 1, 1]) | 42 | 20 | 253 | 44620
shape torch.Size([64, 1, 1]) | 42 | 20 | 496 | 89200
shape torch.Size([128, 1, 1]) | 42 | 20 | 980 | 180000
shape torch.Size([512, 1, 1]) | 43 | 20 | 3868 | 714100
shape torch.Size([1024, 1, 1]) | 44 | 20 | 7800 | 1430000
shape torch.Size([1, 2, 2]) | 43 | 21 | 19 | 1400
shape torch.Size([2, 2, 2]) | 42 | 21 | 27 | 2898
shape torch.Size([4, 2, 2]) | 43 | 21 | 42 | 5800
shape torch.Size([8, 2, 2]) | 43 | 21 | 73 | 11600
shape torch.Size([16, 2, 2]) | 43 | 21 | 133 | 23170
shape torch.Size([32, 2, 2]) | 43 | 21 | 254 | 46290
shape torch.Size([64, 2, 2]) | 43 | 21 | 500 | 94000
shape torch.Size([128, 2, 2]) | 43 | 21 | 980 | 190000
shape torch.Size([512, 2, 2]) | 44 | 21 | 3860 | 741900
shape torch.Size([1024, 2, 2]) | 44 | 21 | 7640 | 1484000
shape torch.Size([1, 8, 8]) | 45 | 21 | 19 | 1450
shape torch.Size([2, 8, 8]) | 45 | 21 | 27 | 2917
shape torch.Size([4, 8, 8]) | 45 | 21 | 53 | 5800
shape torch.Size([8, 8, 8]) | 45 | 21 | 105 | 11580
shape torch.Size([16, 8, 8]) | 45 | 21 | 207 | 23160
shape torch.Size([32, 8, 8]) | 46 | 21 | 413 | 46400
shape torch.Size([64, 8, 8]) | 46 | 21 | 824 | 93000
shape torch.Size([128, 8, 8]) | 46 | 21 | 1645 | 185000
shape torch.Size([512, 8, 8]) | 47 | 21 | 6574 | 742000
shape torch.Size([1024, 8, 8]) | 49 | 21 | 13150 | 1481000
shape torch.Size([1, 16, 16]) | 49 | 21 | 24 | 1460
shape torch.Size([2, 16, 16]) | 49 | 21 | 46 | 2902
shape torch.Size([4, 16, 16]) | 49 | 21 | 90 | 5800
shape torch.Size([8, 16, 16]) | 49 | 21 | 177 | 11600
shape torch.Size([16, 16, 16]) | 49 | 21 | 352 | 23150
shape torch.Size([32, 16, 16]) | 49 | 21 | 703 | 46300
shape torch.Size([64, 16, 16]) | 49 | 21 | 1404 | 92700
shape torch.Size([128, 16, 16]) | 50 | 21 | 2807 | 185000
shape torch.Size([512, 16, 16]) | 55 | 29 | 11220 | 741700
shape torch.Size([1024, 16, 16]) | 64 | 42 | 22440 | 1480000
shape torch.Size([1, 32, 32]) | 55 | 56 | 58 | 1460
shape torch.Size([2, 32, 32]) | 55 | 57 | 114 | 2920
shape torch.Size([4, 32, 32]) | 55 | 57 | 225 | 5830
shape torch.Size([8, 32, 32]) | 55 | 61 | 449 | 11700
shape torch.Size([16, 32, 32]) | 56 | 61 | 896 | 23300
shape torch.Size([32, 32, 32]) | 56 | 62 | 1791 | 46600
shape torch.Size([64, 32, 32]) | 56 | 63 | 3581 | 93100
shape torch.Size([128, 32, 32]) | 58 | 66 | 7156 | 186000
shape torch.Size([512, 32, 32]) | 100 | 194 | 28700 | 742400
shape torch.Size([1024, 32, 32]) | 169 | 335 | 57620 | 1485000
shape torch.Size([1, 64, 64]) | 224 | 101 | 132 | 1500
shape torch.Size([2, 64, 64]) | 227 | 100 | 262 | 2951
shape torch.Size([4, 64, 64]) | 229 | 101 | 523 | 5890
shape torch.Size([8, 64, 64]) | 231 | 102 | 1040 | 12000
shape torch.Size([16, 64, 64]) | 237 | 109 | 2088 | 23530
shape torch.Size([32, 64, 64]) | 242 | 127 | 4171 | 46900
shape torch.Size([64, 64, 64]) | 247 | 156 | 8330 | 95000
shape torch.Size([128, 64, 64]) | 293 | 244 | 16710 | 189000
shape torch.Size([512, 64, 64]) | 685 | 1180 | 67000 | 750900
shape torch.Size([1024, 64, 64]) | 1300 | 2076 | 134000 | 1505000
shape torch.Size([1, 128, 128]) | 490 | 309 | 298 | 1560
shape torch.Size([2, 128, 128]) | 503 | 309 | 594 | 3120
shape torch.Size([4, 128, 128]) | 515 | 312 | 1185 | 6230
shape torch.Size([8, 128, 128]) | 523 | 317 | 2370 | 12500
shape torch.Size([16, 128, 128]) | 547 | 336 | 4734 | 24890
shape torch.Size([32, 128, 128]) | 596 | 472 | 9491 | 49800
shape torch.Size([64, 128, 128]) | 700 | 741 | 19000 | 100000
shape torch.Size([128, 128, 128]) | 930 | 1770 | 37990 | 199000
shape torch.Size([512, 128, 128]) | 2810 | 11000 | 152000 | 797100
shape torch.Size([1024, 128, 128]) | 5430 | 22430 | 303900 | 1595000
shape torch.Size([1, 256, 256]) | 1120 | 1580 | 666 | 1890
shape torch.Size([2, 256, 256]) | 1160 | 1574 | 1330 | 3784
shape torch.Size([4, 256, 256]) | 1190 | 1580 | 2658 | 7570
shape torch.Size([8, 256, 256]) | 1250 | 1613 | 5325 | 15100
shape torch.Size([16, 256, 256]) | 1394 | 1880 | 10700 | 30260
shape torch.Size([32, 256, 256]) | 1633 | 3360 | 21300 | 61000
shape torch.Size([64, 256, 256]) | 2258 | 6730 | 42600 | 120000
shape torch.Size([128, 256, 256]) | 3639 | 19200 | 85170 | 242200
shape torch.Size([512, 256, 256]) | 12600 | 87200 | 340600 | 969000
shape torch.Size([1024, 256, 256]) | 24530 | 176000 | 681300 | 1943000
shape torch.Size([1, 512, 512]) | 2557 | 9117 | 1724 | 2577
shape torch.Size([2, 512, 512]) | 2691 | 9209 | 3464 | 5200
shape torch.Size([4, 512, 512]) | 2853 | 9860 | 6940 | 10000
shape torch.Size([8, 512, 512]) | 3153 | 11000 | 13900 | 20570
shape torch.Size([16, 512, 512]) | 3765 | 13000 | 27720 | 41360
shape torch.Size([32, 512, 512]) | 5500 | 21400 | 55420 | 82000
shape torch.Size([64, 512, 512]) | 8790 | 44000 | 111000 | 165000
shape torch.Size([128, 512, 512]) | 15300 | 98000 | 221700 | 329800
shape torch.Size([512, 512, 512]) | 55400 | 424100 | 886600 | 1325000
shape torch.Size([1024, 512, 512]) | 110000 | 856200 | 1773000 | 2691000
shape torch.Size([1, 1024, 1024]) | 10350 | 69290 | 5020 | 5327
shape torch.Size([2, 1024, 1024]) | 11200 | 74860 | 10040 | 11000
shape torch.Size([4, 1024, 1024]) | 12200 | 78030 | 20080 | 21290
shape torch.Size([8, 1024, 1024]) | 14000 | 81200 | 40160 | 42850
shape torch.Size([16, 1024, 1024]) | 17700 | 96000 | 80300 | 85500
shape torch.Size([32, 1024, 1024]) | 27740 | 150000 | 160700 | 171000
shape torch.Size([64, 1024, 1024]) | 45940 | 233400 | 321200 | 344100
shape torch.Size([1, 2048, 2048]) | 29860 | 579800 | 12920 | 13500
shape torch.Size([2, 2048, 2048]) | 34000 | 585000 | 25840 | 26840
shape torch.Size([4, 2048, 2048]) | 39770 | 593900 | 51670 | 54000
shape torch.Size([8, 2048, 2048]) | 51720 | 632100 | 103000 | 109000
shape torch.Size([16, 2048, 2048]) | 76900 | 845500 | 206600 | 218400
shape torch.Size([32, 2048, 2048]) | 130000 | 1058000 | 413900 | 437300
Times are in microseconds (us).
```
</details>
To generate the results below, I put the backend I wanted to test at the beginning of the function `lu_solve_kernel`, followed by a `return;`. Then I run the following script, changing the variable `name`.
<details>
<summary>
Benchmarking script
</summary>
```python
import torch
import pickle
import itertools
from functools import partial
from torch.utils.benchmark import Timer, Compare
benchmark_name = "linalg.lu_factor CUDA"
name = "magma_looped"
label = "lu_factor_{}".format(name)
shapes = [1, 2, 8, 16, 32, 64, 128, 256, 512, 1024, 2048]
batches = [(1,), (2,), (4,), (8,), (16,), (32,), (64,), (128,), (512,), (1024,)]
results = []
make_arg = partial(torch.randn, dtype=torch.float32, device="cuda")
for n, batch in itertools.product(shapes, batches):
if n == 1024 and batch[0] >= 128:
continue
if n == 2048 and batch[0] >= 64:
continue
A = make_arg(batch + (n, n))
print(A.shape)
stmt = "torch.linalg.lu_factor_ex(A)"
timer = Timer(stmt,
globals=globals(),
label=benchmark_name,
description=label,
sub_label=f"shape {A.shape}",
num_threads=1)
results.append(timer.blocked_autorange())
compare = Compare(results)
compare.trim_significant_figures()
compare.print()
with open(f"{label}.pickle", 'wb') as f:
pickle.dump(results, f)
```
</details>
See #72935 (comment) for the script to join the results.
[ghstack-poisoned]
This PR adds getrf_cublas to the functions considered in the heuristics
for `lu_factor`. It also updates the heuristics of the function.
## Benchmark
I'm omitting form the benchmarks the looped versions of the functions as they are much slower than the non-looped ones. The only exception to this is cusolver's looped variant, which is faster when applied to a batch of size one.
<details>
<summary>
Benchmark Results
</summary>
```
[------------------------------------------------- linalg.lu_factor CUDA -------------------------------------------------]
| lu_factor_heuristic | lu_factor_magma_batched | lu_factor_cusolver_batched
1 threads: ----------------------------------------------------------------------------------------------------------------
shape torch.Size([1, 1, 1]) | 26 | 47 | 26
shape torch.Size([2, 1, 1]) | 17 | 38 | 17
shape torch.Size([4, 1, 1]) | 17 | 38 | 17
shape torch.Size([8, 1, 1]) | 20 | 38 | 18
shape torch.Size([16, 1, 1]) | 20 | 38 | 17
shape torch.Size([32, 1, 1]) | 18 | 38 | 17
shape torch.Size([64, 1, 1]) | 18 | 39 | 17
shape torch.Size([128, 1, 1]) | 17 | 38 | 17
shape torch.Size([512, 1, 1]) | 18 | 39 | 18
shape torch.Size([1024, 1, 1]) | 18 | 40 | 18
shape torch.Size([1, 2, 2]) | 18 | 38 | 17
shape torch.Size([2, 2, 2]) | 17 | 37 | 17
shape torch.Size([4, 2, 2]) | 17 | 38 | 17
shape torch.Size([8, 2, 2]) | 17 | 38 | 17
shape torch.Size([16, 2, 2]) | 17 | 38 | 17
shape torch.Size([32, 2, 2]) | 17 | 38 | 17
shape torch.Size([64, 2, 2]) | 17 | 38 | 17
shape torch.Size([128, 2, 2]) | 17 | 38 | 17
shape torch.Size([512, 2, 2]) | 17 | 39 | 17
shape torch.Size([1024, 2, 2]) | 17 | 40 | 17
shape torch.Size([1, 8, 8]) | 17 | 40 | 17
shape torch.Size([2, 8, 8]) | 17 | 40 | 17
shape torch.Size([4, 8, 8]) | 17 | 40 | 17
shape torch.Size([8, 8, 8]) | 17 | 40 | 17
shape torch.Size([16, 8, 8]) | 17 | 41 | 17
shape torch.Size([32, 8, 8]) | 17 | 40 | 17
shape torch.Size([64, 8, 8]) | 17 | 40 | 17
shape torch.Size([128, 8, 8]) | 17 | 40 | 17
shape torch.Size([512, 8, 8]) | 17 | 42 | 17
shape torch.Size([1024, 8, 8]) | 17 | 44 | 17
shape torch.Size([1, 16, 16]) | 24 | 44 | 18
shape torch.Size([2, 16, 16]) | 18 | 44 | 18
shape torch.Size([4, 16, 16]) | 18 | 45 | 18
shape torch.Size([8, 16, 16]) | 19 | 44 | 19
shape torch.Size([16, 16, 16]) | 20 | 44 | 20
shape torch.Size([32, 16, 16]) | 20 | 45 | 20
shape torch.Size([64, 16, 16]) | 20 | 44 | 20
shape torch.Size([128, 16, 16]) | 20 | 45 | 20
shape torch.Size([512, 16, 16]) | 28 | 50 | 28
shape torch.Size([1024, 16, 16]) | 41 | 59 | 41
shape torch.Size([1, 32, 32]) | 58 | 50 | 56
shape torch.Size([2, 32, 32]) | 56 | 50 | 56
shape torch.Size([4, 32, 32]) | 56 | 50 | 57
shape torch.Size([8, 32, 32]) | 60 | 50 | 60
shape torch.Size([16, 32, 32]) | 60 | 51 | 60
shape torch.Size([32, 32, 32]) | 247 | 51 | 61
shape torch.Size([64, 32, 32]) | 233 | 51 | 63
shape torch.Size([128, 32, 32]) | 236 | 53 | 66
shape torch.Size([512, 32, 32]) | 268 | 97 | 193
shape torch.Size([1024, 32, 32]) | 317 | 167 | 333
shape torch.Size([1, 64, 64]) | 131 | 216 | 99
shape torch.Size([2, 64, 64]) | 99 | 220 | 99
shape torch.Size([4, 64, 64]) | 99 | 225 | 101
shape torch.Size([8, 64, 64]) | 101 | 225 | 102
shape torch.Size([16, 64, 64]) | 107 | 230 | 108
shape torch.Size([32, 64, 64]) | 440 | 235 | 126
shape torch.Size([64, 64, 64]) | 447 | 240 | 155
shape torch.Size([128, 64, 64]) | 470 | 289 | 240
shape torch.Size([512, 64, 64]) | 793 | 678 | 1180
shape torch.Size([1024, 64, 64]) | 1000 | 1300 | 2112
shape torch.Size([1, 128, 128]) | 296 | 482 | 309
shape torch.Size([2, 128, 128]) | 308 | 499 | 307
shape torch.Size([4, 128, 128]) | 311 | 510 | 310
shape torch.Size([8, 128, 128]) | 314 | 522 | 314
shape torch.Size([16, 128, 128]) | 334 | 541 | 334
shape torch.Size([32, 128, 128]) | 770 | 591 | 467
shape torch.Size([64, 128, 128]) | 860 | 694 | 733
shape torch.Size([128, 128, 128]) | 1040 | 925 | 1980
shape torch.Size([512, 128, 128]) | 2883 | 2809 | 11000
shape torch.Size([1024, 128, 128]) | 5421 | 5430 | 22360
shape torch.Size([1, 256, 256]) | 1310 | 1109 | 1556
shape torch.Size([2, 256, 256]) | 1360 | 1150 | 1560
shape torch.Size([4, 256, 256]) | 1390 | 1188 | 1569
shape torch.Size([8, 256, 256]) | 1440 | 1250 | 1604
shape torch.Size([16, 256, 256]) | 1550 | 1390 | 1850
shape torch.Size([32, 256, 256]) | 1750 | 1620 | 3332
shape torch.Size([64, 256, 256]) | 2327 | 2246 | 6700
shape torch.Size([128, 256, 256]) | 3697 | 3638 | 19100
shape torch.Size([512, 256, 256]) | 12530 | 12500 | 87300
shape torch.Size([1024, 256, 256]) | 24380 | 24420 | 176000
```
</details>
<details>
<summary>
Benchmark Results all algorithms up to `n=2048`
</summary>
```
[----------------------------------------------------------------- linalg.lu_factor CUDA ------------------------------------------------------------------]
| lu_factor_magma_batched | lu_factor_cusolver_batched | lu_factor_cusolver_looped | lu_factor_magma_looped
1 threads: -------------------------------------------------------------------------------------------------------------------------------------------------
shape torch.Size([1, 1, 1]) | 51 | 30 | 27 | 1390
shape torch.Size([2, 1, 1]) | 42 | 20 | 26 | 2798
shape torch.Size([4, 1, 1]) | 42 | 20 | 42 | 5589
shape torch.Size([8, 1, 1]) | 42 | 20 | 72 | 11000
shape torch.Size([16, 1, 1]) | 42 | 20 | 132 | 22400
shape torch.Size([32, 1, 1]) | 42 | 20 | 253 | 44620
shape torch.Size([64, 1, 1]) | 42 | 20 | 496 | 89200
shape torch.Size([128, 1, 1]) | 42 | 20 | 980 | 180000
shape torch.Size([512, 1, 1]) | 43 | 20 | 3868 | 714100
shape torch.Size([1024, 1, 1]) | 44 | 20 | 7800 | 1430000
shape torch.Size([1, 2, 2]) | 43 | 21 | 19 | 1400
shape torch.Size([2, 2, 2]) | 42 | 21 | 27 | 2898
shape torch.Size([4, 2, 2]) | 43 | 21 | 42 | 5800
shape torch.Size([8, 2, 2]) | 43 | 21 | 73 | 11600
shape torch.Size([16, 2, 2]) | 43 | 21 | 133 | 23170
shape torch.Size([32, 2, 2]) | 43 | 21 | 254 | 46290
shape torch.Size([64, 2, 2]) | 43 | 21 | 500 | 94000
shape torch.Size([128, 2, 2]) | 43 | 21 | 980 | 190000
shape torch.Size([512, 2, 2]) | 44 | 21 | 3860 | 741900
shape torch.Size([1024, 2, 2]) | 44 | 21 | 7640 | 1484000
shape torch.Size([1, 8, 8]) | 45 | 21 | 19 | 1450
shape torch.Size([2, 8, 8]) | 45 | 21 | 27 | 2917
shape torch.Size([4, 8, 8]) | 45 | 21 | 53 | 5800
shape torch.Size([8, 8, 8]) | 45 | 21 | 105 | 11580
shape torch.Size([16, 8, 8]) | 45 | 21 | 207 | 23160
shape torch.Size([32, 8, 8]) | 46 | 21 | 413 | 46400
shape torch.Size([64, 8, 8]) | 46 | 21 | 824 | 93000
shape torch.Size([128, 8, 8]) | 46 | 21 | 1645 | 185000
shape torch.Size([512, 8, 8]) | 47 | 21 | 6574 | 742000
shape torch.Size([1024, 8, 8]) | 49 | 21 | 13150 | 1481000
shape torch.Size([1, 16, 16]) | 49 | 21 | 24 | 1460
shape torch.Size([2, 16, 16]) | 49 | 21 | 46 | 2902
shape torch.Size([4, 16, 16]) | 49 | 21 | 90 | 5800
shape torch.Size([8, 16, 16]) | 49 | 21 | 177 | 11600
shape torch.Size([16, 16, 16]) | 49 | 21 | 352 | 23150
shape torch.Size([32, 16, 16]) | 49 | 21 | 703 | 46300
shape torch.Size([64, 16, 16]) | 49 | 21 | 1404 | 92700
shape torch.Size([128, 16, 16]) | 50 | 21 | 2807 | 185000
shape torch.Size([512, 16, 16]) | 55 | 29 | 11220 | 741700
shape torch.Size([1024, 16, 16]) | 64 | 42 | 22440 | 1480000
shape torch.Size([1, 32, 32]) | 55 | 56 | 58 | 1460
shape torch.Size([2, 32, 32]) | 55 | 57 | 114 | 2920
shape torch.Size([4, 32, 32]) | 55 | 57 | 225 | 5830
shape torch.Size([8, 32, 32]) | 55 | 61 | 449 | 11700
shape torch.Size([16, 32, 32]) | 56 | 61 | 896 | 23300
shape torch.Size([32, 32, 32]) | 56 | 62 | 1791 | 46600
shape torch.Size([64, 32, 32]) | 56 | 63 | 3581 | 93100
shape torch.Size([128, 32, 32]) | 58 | 66 | 7156 | 186000
shape torch.Size([512, 32, 32]) | 100 | 194 | 28700 | 742400
shape torch.Size([1024, 32, 32]) | 169 | 335 | 57620 | 1485000
shape torch.Size([1, 64, 64]) | 224 | 101 | 132 | 1500
shape torch.Size([2, 64, 64]) | 227 | 100 | 262 | 2951
shape torch.Size([4, 64, 64]) | 229 | 101 | 523 | 5890
shape torch.Size([8, 64, 64]) | 231 | 102 | 1040 | 12000
shape torch.Size([16, 64, 64]) | 237 | 109 | 2088 | 23530
shape torch.Size([32, 64, 64]) | 242 | 127 | 4171 | 46900
shape torch.Size([64, 64, 64]) | 247 | 156 | 8330 | 95000
shape torch.Size([128, 64, 64]) | 293 | 244 | 16710 | 189000
shape torch.Size([512, 64, 64]) | 685 | 1180 | 67000 | 750900
shape torch.Size([1024, 64, 64]) | 1300 | 2076 | 134000 | 1505000
shape torch.Size([1, 128, 128]) | 490 | 309 | 298 | 1560
shape torch.Size([2, 128, 128]) | 503 | 309 | 594 | 3120
shape torch.Size([4, 128, 128]) | 515 | 312 | 1185 | 6230
shape torch.Size([8, 128, 128]) | 523 | 317 | 2370 | 12500
shape torch.Size([16, 128, 128]) | 547 | 336 | 4734 | 24890
shape torch.Size([32, 128, 128]) | 596 | 472 | 9491 | 49800
shape torch.Size([64, 128, 128]) | 700 | 741 | 19000 | 100000
shape torch.Size([128, 128, 128]) | 930 | 1770 | 37990 | 199000
shape torch.Size([512, 128, 128]) | 2810 | 11000 | 152000 | 797100
shape torch.Size([1024, 128, 128]) | 5430 | 22430 | 303900 | 1595000
shape torch.Size([1, 256, 256]) | 1120 | 1580 | 666 | 1890
shape torch.Size([2, 256, 256]) | 1160 | 1574 | 1330 | 3784
shape torch.Size([4, 256, 256]) | 1190 | 1580 | 2658 | 7570
shape torch.Size([8, 256, 256]) | 1250 | 1613 | 5325 | 15100
shape torch.Size([16, 256, 256]) | 1394 | 1880 | 10700 | 30260
shape torch.Size([32, 256, 256]) | 1633 | 3360 | 21300 | 61000
shape torch.Size([64, 256, 256]) | 2258 | 6730 | 42600 | 120000
shape torch.Size([128, 256, 256]) | 3639 | 19200 | 85170 | 242200
shape torch.Size([512, 256, 256]) | 12600 | 87200 | 340600 | 969000
shape torch.Size([1024, 256, 256]) | 24530 | 176000 | 681300 | 1943000
shape torch.Size([1, 512, 512]) | 2557 | 9117 | 1724 | 2577
shape torch.Size([2, 512, 512]) | 2691 | 9209 | 3464 | 5200
shape torch.Size([4, 512, 512]) | 2853 | 9860 | 6940 | 10000
shape torch.Size([8, 512, 512]) | 3153 | 11000 | 13900 | 20570
shape torch.Size([16, 512, 512]) | 3765 | 13000 | 27720 | 41360
shape torch.Size([32, 512, 512]) | 5500 | 21400 | 55420 | 82000
shape torch.Size([64, 512, 512]) | 8790 | 44000 | 111000 | 165000
shape torch.Size([128, 512, 512]) | 15300 | 98000 | 221700 | 329800
shape torch.Size([512, 512, 512]) | 55400 | 424100 | 886600 | 1325000
shape torch.Size([1024, 512, 512]) | 110000 | 856200 | 1773000 | 2691000
shape torch.Size([1, 1024, 1024]) | 10350 | 69290 | 5020 | 5327
shape torch.Size([2, 1024, 1024]) | 11200 | 74860 | 10040 | 11000
shape torch.Size([4, 1024, 1024]) | 12200 | 78030 | 20080 | 21290
shape torch.Size([8, 1024, 1024]) | 14000 | 81200 | 40160 | 42850
shape torch.Size([16, 1024, 1024]) | 17700 | 96000 | 80300 | 85500
shape torch.Size([32, 1024, 1024]) | 27740 | 150000 | 160700 | 171000
shape torch.Size([64, 1024, 1024]) | 45940 | 233400 | 321200 | 344100
shape torch.Size([1, 2048, 2048]) | 29860 | 579800 | 12920 | 13500
shape torch.Size([2, 2048, 2048]) | 34000 | 585000 | 25840 | 26840
shape torch.Size([4, 2048, 2048]) | 39770 | 593900 | 51670 | 54000
shape torch.Size([8, 2048, 2048]) | 51720 | 632100 | 103000 | 109000
shape torch.Size([16, 2048, 2048]) | 76900 | 845500 | 206600 | 218400
shape torch.Size([32, 2048, 2048]) | 130000 | 1058000 | 413900 | 437300
Times are in microseconds (us).
```
</details>
To generate the results below, I put the backend I wanted to test at the beginning of the function `lu_solve_kernel`, followed by a `return;`. Then I run the following script, changing the variable `name`.
<details>
<summary>
Benchmarking script
</summary>
```python
import torch
import pickle
import itertools
from functools import partial
from torch.utils.benchmark import Timer, Compare
benchmark_name = "linalg.lu_factor CUDA"
name = "magma_looped"
label = "lu_factor_{}".format(name)
shapes = [1, 2, 8, 16, 32, 64, 128, 256, 512, 1024, 2048]
batches = [(1,), (2,), (4,), (8,), (16,), (32,), (64,), (128,), (512,), (1024,)]
results = []
make_arg = partial(torch.randn, dtype=torch.float32, device="cuda")
for n, batch in itertools.product(shapes, batches):
if n == 1024 and batch[0] >= 128:
continue
if n == 2048 and batch[0] >= 64:
continue
A = make_arg(batch + (n, n))
print(A.shape)
stmt = "torch.linalg.lu_factor_ex(A)"
timer = Timer(stmt,
globals=globals(),
label=benchmark_name,
description=label,
sub_label=f"shape {A.shape}",
num_threads=1)
results.append(timer.blocked_autorange())
compare = Compare(results)
compare.trim_significant_figures()
compare.print()
with open(f"{label}.pickle", 'wb') as f:
pickle.dump(results, f)
```
</details>
See #72935 (comment) for the script to join the results.
[ghstack-poisoned]
This PR adds getrf_cublas to the functions considered in the heuristics
for `lu_factor`. It also updates the heuristics of the function.
## Benchmark
I'm omitting form the benchmarks the looped versions of the functions as they are much slower than the non-looped ones. The only exception to this is cusolver's looped variant, which is faster when applied to a batch of size one.
<details>
<summary>
Benchmark Results
</summary>
```
[------------------------------------------------- linalg.lu_factor CUDA -------------------------------------------------]
| lu_factor_heuristic | lu_factor_magma_batched | lu_factor_cusolver_batched
1 threads: ----------------------------------------------------------------------------------------------------------------
shape torch.Size([1, 1, 1]) | 26 | 47 | 26
shape torch.Size([2, 1, 1]) | 17 | 38 | 17
shape torch.Size([4, 1, 1]) | 17 | 38 | 17
shape torch.Size([8, 1, 1]) | 20 | 38 | 18
shape torch.Size([16, 1, 1]) | 20 | 38 | 17
shape torch.Size([32, 1, 1]) | 18 | 38 | 17
shape torch.Size([64, 1, 1]) | 18 | 39 | 17
shape torch.Size([128, 1, 1]) | 17 | 38 | 17
shape torch.Size([512, 1, 1]) | 18 | 39 | 18
shape torch.Size([1024, 1, 1]) | 18 | 40 | 18
shape torch.Size([1, 2, 2]) | 18 | 38 | 17
shape torch.Size([2, 2, 2]) | 17 | 37 | 17
shape torch.Size([4, 2, 2]) | 17 | 38 | 17
shape torch.Size([8, 2, 2]) | 17 | 38 | 17
shape torch.Size([16, 2, 2]) | 17 | 38 | 17
shape torch.Size([32, 2, 2]) | 17 | 38 | 17
shape torch.Size([64, 2, 2]) | 17 | 38 | 17
shape torch.Size([128, 2, 2]) | 17 | 38 | 17
shape torch.Size([512, 2, 2]) | 17 | 39 | 17
shape torch.Size([1024, 2, 2]) | 17 | 40 | 17
shape torch.Size([1, 8, 8]) | 17 | 40 | 17
shape torch.Size([2, 8, 8]) | 17 | 40 | 17
shape torch.Size([4, 8, 8]) | 17 | 40 | 17
shape torch.Size([8, 8, 8]) | 17 | 40 | 17
shape torch.Size([16, 8, 8]) | 17 | 41 | 17
shape torch.Size([32, 8, 8]) | 17 | 40 | 17
shape torch.Size([64, 8, 8]) | 17 | 40 | 17
shape torch.Size([128, 8, 8]) | 17 | 40 | 17
shape torch.Size([512, 8, 8]) | 17 | 42 | 17
shape torch.Size([1024, 8, 8]) | 17 | 44 | 17
shape torch.Size([1, 16, 16]) | 24 | 44 | 18
shape torch.Size([2, 16, 16]) | 18 | 44 | 18
shape torch.Size([4, 16, 16]) | 18 | 45 | 18
shape torch.Size([8, 16, 16]) | 19 | 44 | 19
shape torch.Size([16, 16, 16]) | 20 | 44 | 20
shape torch.Size([32, 16, 16]) | 20 | 45 | 20
shape torch.Size([64, 16, 16]) | 20 | 44 | 20
shape torch.Size([128, 16, 16]) | 20 | 45 | 20
shape torch.Size([512, 16, 16]) | 28 | 50 | 28
shape torch.Size([1024, 16, 16]) | 41 | 59 | 41
shape torch.Size([1, 32, 32]) | 58 | 50 | 56
shape torch.Size([2, 32, 32]) | 56 | 50 | 56
shape torch.Size([4, 32, 32]) | 56 | 50 | 57
shape torch.Size([8, 32, 32]) | 60 | 50 | 60
shape torch.Size([16, 32, 32]) | 60 | 51 | 60
shape torch.Size([32, 32, 32]) | 247 | 51 | 61
shape torch.Size([64, 32, 32]) | 233 | 51 | 63
shape torch.Size([128, 32, 32]) | 236 | 53 | 66
shape torch.Size([512, 32, 32]) | 268 | 97 | 193
shape torch.Size([1024, 32, 32]) | 317 | 167 | 333
shape torch.Size([1, 64, 64]) | 131 | 216 | 99
shape torch.Size([2, 64, 64]) | 99 | 220 | 99
shape torch.Size([4, 64, 64]) | 99 | 225 | 101
shape torch.Size([8, 64, 64]) | 101 | 225 | 102
shape torch.Size([16, 64, 64]) | 107 | 230 | 108
shape torch.Size([32, 64, 64]) | 440 | 235 | 126
shape torch.Size([64, 64, 64]) | 447 | 240 | 155
shape torch.Size([128, 64, 64]) | 470 | 289 | 240
shape torch.Size([512, 64, 64]) | 793 | 678 | 1180
shape torch.Size([1024, 64, 64]) | 1000 | 1300 | 2112
shape torch.Size([1, 128, 128]) | 296 | 482 | 309
shape torch.Size([2, 128, 128]) | 308 | 499 | 307
shape torch.Size([4, 128, 128]) | 311 | 510 | 310
shape torch.Size([8, 128, 128]) | 314 | 522 | 314
shape torch.Size([16, 128, 128]) | 334 | 541 | 334
shape torch.Size([32, 128, 128]) | 770 | 591 | 467
shape torch.Size([64, 128, 128]) | 860 | 694 | 733
shape torch.Size([128, 128, 128]) | 1040 | 925 | 1980
shape torch.Size([512, 128, 128]) | 2883 | 2809 | 11000
shape torch.Size([1024, 128, 128]) | 5421 | 5430 | 22360
shape torch.Size([1, 256, 256]) | 1310 | 1109 | 1556
shape torch.Size([2, 256, 256]) | 1360 | 1150 | 1560
shape torch.Size([4, 256, 256]) | 1390 | 1188 | 1569
shape torch.Size([8, 256, 256]) | 1440 | 1250 | 1604
shape torch.Size([16, 256, 256]) | 1550 | 1390 | 1850
shape torch.Size([32, 256, 256]) | 1750 | 1620 | 3332
shape torch.Size([64, 256, 256]) | 2327 | 2246 | 6700
shape torch.Size([128, 256, 256]) | 3697 | 3638 | 19100
shape torch.Size([512, 256, 256]) | 12530 | 12500 | 87300
shape torch.Size([1024, 256, 256]) | 24380 | 24420 | 176000
```
</details>
<details>
<summary>
Benchmark Results all algorithms up to `n=2048`
</summary>
```
[----------------------------------------------------------------- linalg.lu_factor CUDA ------------------------------------------------------------------]
| lu_factor_magma_batched | lu_factor_cusolver_batched | lu_factor_cusolver_looped | lu_factor_magma_looped
1 threads: -------------------------------------------------------------------------------------------------------------------------------------------------
shape torch.Size([1, 1, 1]) | 51 | 30 | 27 | 1390
shape torch.Size([2, 1, 1]) | 42 | 20 | 26 | 2798
shape torch.Size([4, 1, 1]) | 42 | 20 | 42 | 5589
shape torch.Size([8, 1, 1]) | 42 | 20 | 72 | 11000
shape torch.Size([16, 1, 1]) | 42 | 20 | 132 | 22400
shape torch.Size([32, 1, 1]) | 42 | 20 | 253 | 44620
shape torch.Size([64, 1, 1]) | 42 | 20 | 496 | 89200
shape torch.Size([128, 1, 1]) | 42 | 20 | 980 | 180000
shape torch.Size([512, 1, 1]) | 43 | 20 | 3868 | 714100
shape torch.Size([1024, 1, 1]) | 44 | 20 | 7800 | 1430000
shape torch.Size([1, 2, 2]) | 43 | 21 | 19 | 1400
shape torch.Size([2, 2, 2]) | 42 | 21 | 27 | 2898
shape torch.Size([4, 2, 2]) | 43 | 21 | 42 | 5800
shape torch.Size([8, 2, 2]) | 43 | 21 | 73 | 11600
shape torch.Size([16, 2, 2]) | 43 | 21 | 133 | 23170
shape torch.Size([32, 2, 2]) | 43 | 21 | 254 | 46290
shape torch.Size([64, 2, 2]) | 43 | 21 | 500 | 94000
shape torch.Size([128, 2, 2]) | 43 | 21 | 980 | 190000
shape torch.Size([512, 2, 2]) | 44 | 21 | 3860 | 741900
shape torch.Size([1024, 2, 2]) | 44 | 21 | 7640 | 1484000
shape torch.Size([1, 8, 8]) | 45 | 21 | 19 | 1450
shape torch.Size([2, 8, 8]) | 45 | 21 | 27 | 2917
shape torch.Size([4, 8, 8]) | 45 | 21 | 53 | 5800
shape torch.Size([8, 8, 8]) | 45 | 21 | 105 | 11580
shape torch.Size([16, 8, 8]) | 45 | 21 | 207 | 23160
shape torch.Size([32, 8, 8]) | 46 | 21 | 413 | 46400
shape torch.Size([64, 8, 8]) | 46 | 21 | 824 | 93000
shape torch.Size([128, 8, 8]) | 46 | 21 | 1645 | 185000
shape torch.Size([512, 8, 8]) | 47 | 21 | 6574 | 742000
shape torch.Size([1024, 8, 8]) | 49 | 21 | 13150 | 1481000
shape torch.Size([1, 16, 16]) | 49 | 21 | 24 | 1460
shape torch.Size([2, 16, 16]) | 49 | 21 | 46 | 2902
shape torch.Size([4, 16, 16]) | 49 | 21 | 90 | 5800
shape torch.Size([8, 16, 16]) | 49 | 21 | 177 | 11600
shape torch.Size([16, 16, 16]) | 49 | 21 | 352 | 23150
shape torch.Size([32, 16, 16]) | 49 | 21 | 703 | 46300
shape torch.Size([64, 16, 16]) | 49 | 21 | 1404 | 92700
shape torch.Size([128, 16, 16]) | 50 | 21 | 2807 | 185000
shape torch.Size([512, 16, 16]) | 55 | 29 | 11220 | 741700
shape torch.Size([1024, 16, 16]) | 64 | 42 | 22440 | 1480000
shape torch.Size([1, 32, 32]) | 55 | 56 | 58 | 1460
shape torch.Size([2, 32, 32]) | 55 | 57 | 114 | 2920
shape torch.Size([4, 32, 32]) | 55 | 57 | 225 | 5830
shape torch.Size([8, 32, 32]) | 55 | 61 | 449 | 11700
shape torch.Size([16, 32, 32]) | 56 | 61 | 896 | 23300
shape torch.Size([32, 32, 32]) | 56 | 62 | 1791 | 46600
shape torch.Size([64, 32, 32]) | 56 | 63 | 3581 | 93100
shape torch.Size([128, 32, 32]) | 58 | 66 | 7156 | 186000
shape torch.Size([512, 32, 32]) | 100 | 194 | 28700 | 742400
shape torch.Size([1024, 32, 32]) | 169 | 335 | 57620 | 1485000
shape torch.Size([1, 64, 64]) | 224 | 101 | 132 | 1500
shape torch.Size([2, 64, 64]) | 227 | 100 | 262 | 2951
shape torch.Size([4, 64, 64]) | 229 | 101 | 523 | 5890
shape torch.Size([8, 64, 64]) | 231 | 102 | 1040 | 12000
shape torch.Size([16, 64, 64]) | 237 | 109 | 2088 | 23530
shape torch.Size([32, 64, 64]) | 242 | 127 | 4171 | 46900
shape torch.Size([64, 64, 64]) | 247 | 156 | 8330 | 95000
shape torch.Size([128, 64, 64]) | 293 | 244 | 16710 | 189000
shape torch.Size([512, 64, 64]) | 685 | 1180 | 67000 | 750900
shape torch.Size([1024, 64, 64]) | 1300 | 2076 | 134000 | 1505000
shape torch.Size([1, 128, 128]) | 490 | 309 | 298 | 1560
shape torch.Size([2, 128, 128]) | 503 | 309 | 594 | 3120
shape torch.Size([4, 128, 128]) | 515 | 312 | 1185 | 6230
shape torch.Size([8, 128, 128]) | 523 | 317 | 2370 | 12500
shape torch.Size([16, 128, 128]) | 547 | 336 | 4734 | 24890
shape torch.Size([32, 128, 128]) | 596 | 472 | 9491 | 49800
shape torch.Size([64, 128, 128]) | 700 | 741 | 19000 | 100000
shape torch.Size([128, 128, 128]) | 930 | 1770 | 37990 | 199000
shape torch.Size([512, 128, 128]) | 2810 | 11000 | 152000 | 797100
shape torch.Size([1024, 128, 128]) | 5430 | 22430 | 303900 | 1595000
shape torch.Size([1, 256, 256]) | 1120 | 1580 | 666 | 1890
shape torch.Size([2, 256, 256]) | 1160 | 1574 | 1330 | 3784
shape torch.Size([4, 256, 256]) | 1190 | 1580 | 2658 | 7570
shape torch.Size([8, 256, 256]) | 1250 | 1613 | 5325 | 15100
shape torch.Size([16, 256, 256]) | 1394 | 1880 | 10700 | 30260
shape torch.Size([32, 256, 256]) | 1633 | 3360 | 21300 | 61000
shape torch.Size([64, 256, 256]) | 2258 | 6730 | 42600 | 120000
shape torch.Size([128, 256, 256]) | 3639 | 19200 | 85170 | 242200
shape torch.Size([512, 256, 256]) | 12600 | 87200 | 340600 | 969000
shape torch.Size([1024, 256, 256]) | 24530 | 176000 | 681300 | 1943000
shape torch.Size([1, 512, 512]) | 2557 | 9117 | 1724 | 2577
shape torch.Size([2, 512, 512]) | 2691 | 9209 | 3464 | 5200
shape torch.Size([4, 512, 512]) | 2853 | 9860 | 6940 | 10000
shape torch.Size([8, 512, 512]) | 3153 | 11000 | 13900 | 20570
shape torch.Size([16, 512, 512]) | 3765 | 13000 | 27720 | 41360
shape torch.Size([32, 512, 512]) | 5500 | 21400 | 55420 | 82000
shape torch.Size([64, 512, 512]) | 8790 | 44000 | 111000 | 165000
shape torch.Size([128, 512, 512]) | 15300 | 98000 | 221700 | 329800
shape torch.Size([512, 512, 512]) | 55400 | 424100 | 886600 | 1325000
shape torch.Size([1024, 512, 512]) | 110000 | 856200 | 1773000 | 2691000
shape torch.Size([1, 1024, 1024]) | 10350 | 69290 | 5020 | 5327
shape torch.Size([2, 1024, 1024]) | 11200 | 74860 | 10040 | 11000
shape torch.Size([4, 1024, 1024]) | 12200 | 78030 | 20080 | 21290
shape torch.Size([8, 1024, 1024]) | 14000 | 81200 | 40160 | 42850
shape torch.Size([16, 1024, 1024]) | 17700 | 96000 | 80300 | 85500
shape torch.Size([32, 1024, 1024]) | 27740 | 150000 | 160700 | 171000
shape torch.Size([64, 1024, 1024]) | 45940 | 233400 | 321200 | 344100
shape torch.Size([1, 2048, 2048]) | 29860 | 579800 | 12920 | 13500
shape torch.Size([2, 2048, 2048]) | 34000 | 585000 | 25840 | 26840
shape torch.Size([4, 2048, 2048]) | 39770 | 593900 | 51670 | 54000
shape torch.Size([8, 2048, 2048]) | 51720 | 632100 | 103000 | 109000
shape torch.Size([16, 2048, 2048]) | 76900 | 845500 | 206600 | 218400
shape torch.Size([32, 2048, 2048]) | 130000 | 1058000 | 413900 | 437300
Times are in microseconds (us).
```
</details>
To generate the results below, I put the backend I wanted to test at the beginning of the function `lu_solve_kernel`, followed by a `return;`. Then I run the following script, changing the variable `name`.
<details>
<summary>
Benchmarking script
</summary>
```python
import torch
import pickle
import itertools
from functools import partial
from torch.utils.benchmark import Timer, Compare
benchmark_name = "linalg.lu_factor CUDA"
name = "magma_looped"
label = "lu_factor_{}".format(name)
shapes = [1, 2, 8, 16, 32, 64, 128, 256, 512, 1024, 2048]
batches = [(1,), (2,), (4,), (8,), (16,), (32,), (64,), (128,), (512,), (1024,)]
results = []
make_arg = partial(torch.randn, dtype=torch.float32, device="cuda")
for n, batch in itertools.product(shapes, batches):
if n == 1024 and batch[0] >= 128:
continue
if n == 2048 and batch[0] >= 64:
continue
A = make_arg(batch + (n, n))
print(A.shape)
stmt = "torch.linalg.lu_factor_ex(A)"
timer = Timer(stmt,
globals=globals(),
label=benchmark_name,
description=label,
sub_label=f"shape {A.shape}",
num_threads=1)
results.append(timer.blocked_autorange())
compare = Compare(results)
compare.trim_significant_figures()
compare.print()
with open(f"{label}.pickle", 'wb') as f:
pickle.dump(results, f)
```
</details>
See #72935 (comment) for the script to join the results.
[ghstack-poisoned]
This PR adds getrf_cublas to the functions considered in the heuristics
for `lu_factor`. It also updates the heuristics of the function.
## Benchmark
I'm omitting form the benchmarks the looped versions of the functions as they are much slower than the non-looped ones. The only exception to this is cusolver's looped variant, which is faster when applied to a batch of size one.
<details>
<summary>
Benchmark Results
</summary>
```
[------------------------------------------------- linalg.lu_factor CUDA -------------------------------------------------]
| lu_factor_heuristic | lu_factor_magma_batched | lu_factor_cusolver_batched
1 threads: ----------------------------------------------------------------------------------------------------------------
shape torch.Size([1, 1, 1]) | 26 | 47 | 26
shape torch.Size([2, 1, 1]) | 17 | 38 | 17
shape torch.Size([4, 1, 1]) | 17 | 38 | 17
shape torch.Size([8, 1, 1]) | 20 | 38 | 18
shape torch.Size([16, 1, 1]) | 20 | 38 | 17
shape torch.Size([32, 1, 1]) | 18 | 38 | 17
shape torch.Size([64, 1, 1]) | 18 | 39 | 17
shape torch.Size([128, 1, 1]) | 17 | 38 | 17
shape torch.Size([512, 1, 1]) | 18 | 39 | 18
shape torch.Size([1024, 1, 1]) | 18 | 40 | 18
shape torch.Size([1, 2, 2]) | 18 | 38 | 17
shape torch.Size([2, 2, 2]) | 17 | 37 | 17
shape torch.Size([4, 2, 2]) | 17 | 38 | 17
shape torch.Size([8, 2, 2]) | 17 | 38 | 17
shape torch.Size([16, 2, 2]) | 17 | 38 | 17
shape torch.Size([32, 2, 2]) | 17 | 38 | 17
shape torch.Size([64, 2, 2]) | 17 | 38 | 17
shape torch.Size([128, 2, 2]) | 17 | 38 | 17
shape torch.Size([512, 2, 2]) | 17 | 39 | 17
shape torch.Size([1024, 2, 2]) | 17 | 40 | 17
shape torch.Size([1, 8, 8]) | 17 | 40 | 17
shape torch.Size([2, 8, 8]) | 17 | 40 | 17
shape torch.Size([4, 8, 8]) | 17 | 40 | 17
shape torch.Size([8, 8, 8]) | 17 | 40 | 17
shape torch.Size([16, 8, 8]) | 17 | 41 | 17
shape torch.Size([32, 8, 8]) | 17 | 40 | 17
shape torch.Size([64, 8, 8]) | 17 | 40 | 17
shape torch.Size([128, 8, 8]) | 17 | 40 | 17
shape torch.Size([512, 8, 8]) | 17 | 42 | 17
shape torch.Size([1024, 8, 8]) | 17 | 44 | 17
shape torch.Size([1, 16, 16]) | 24 | 44 | 18
shape torch.Size([2, 16, 16]) | 18 | 44 | 18
shape torch.Size([4, 16, 16]) | 18 | 45 | 18
shape torch.Size([8, 16, 16]) | 19 | 44 | 19
shape torch.Size([16, 16, 16]) | 20 | 44 | 20
shape torch.Size([32, 16, 16]) | 20 | 45 | 20
shape torch.Size([64, 16, 16]) | 20 | 44 | 20
shape torch.Size([128, 16, 16]) | 20 | 45 | 20
shape torch.Size([512, 16, 16]) | 28 | 50 | 28
shape torch.Size([1024, 16, 16]) | 41 | 59 | 41
shape torch.Size([1, 32, 32]) | 58 | 50 | 56
shape torch.Size([2, 32, 32]) | 56 | 50 | 56
shape torch.Size([4, 32, 32]) | 56 | 50 | 57
shape torch.Size([8, 32, 32]) | 60 | 50 | 60
shape torch.Size([16, 32, 32]) | 60 | 51 | 60
shape torch.Size([32, 32, 32]) | 247 | 51 | 61
shape torch.Size([64, 32, 32]) | 233 | 51 | 63
shape torch.Size([128, 32, 32]) | 236 | 53 | 66
shape torch.Size([512, 32, 32]) | 268 | 97 | 193
shape torch.Size([1024, 32, 32]) | 317 | 167 | 333
shape torch.Size([1, 64, 64]) | 131 | 216 | 99
shape torch.Size([2, 64, 64]) | 99 | 220 | 99
shape torch.Size([4, 64, 64]) | 99 | 225 | 101
shape torch.Size([8, 64, 64]) | 101 | 225 | 102
shape torch.Size([16, 64, 64]) | 107 | 230 | 108
shape torch.Size([32, 64, 64]) | 440 | 235 | 126
shape torch.Size([64, 64, 64]) | 447 | 240 | 155
shape torch.Size([128, 64, 64]) | 470 | 289 | 240
shape torch.Size([512, 64, 64]) | 793 | 678 | 1180
shape torch.Size([1024, 64, 64]) | 1000 | 1300 | 2112
shape torch.Size([1, 128, 128]) | 296 | 482 | 309
shape torch.Size([2, 128, 128]) | 308 | 499 | 307
shape torch.Size([4, 128, 128]) | 311 | 510 | 310
shape torch.Size([8, 128, 128]) | 314 | 522 | 314
shape torch.Size([16, 128, 128]) | 334 | 541 | 334
shape torch.Size([32, 128, 128]) | 770 | 591 | 467
shape torch.Size([64, 128, 128]) | 860 | 694 | 733
shape torch.Size([128, 128, 128]) | 1040 | 925 | 1980
shape torch.Size([512, 128, 128]) | 2883 | 2809 | 11000
shape torch.Size([1024, 128, 128]) | 5421 | 5430 | 22360
shape torch.Size([1, 256, 256]) | 1310 | 1109 | 1556
shape torch.Size([2, 256, 256]) | 1360 | 1150 | 1560
shape torch.Size([4, 256, 256]) | 1390 | 1188 | 1569
shape torch.Size([8, 256, 256]) | 1440 | 1250 | 1604
shape torch.Size([16, 256, 256]) | 1550 | 1390 | 1850
shape torch.Size([32, 256, 256]) | 1750 | 1620 | 3332
shape torch.Size([64, 256, 256]) | 2327 | 2246 | 6700
shape torch.Size([128, 256, 256]) | 3697 | 3638 | 19100
shape torch.Size([512, 256, 256]) | 12530 | 12500 | 87300
shape torch.Size([1024, 256, 256]) | 24380 | 24420 | 176000
```
</details>
<details>
<summary>
Benchmark Results all algorithms up to `n=2048`
</summary>
```
[----------------------------------------------------------------- linalg.lu_factor CUDA ------------------------------------------------------------------]
| lu_factor_magma_batched | lu_factor_cusolver_batched | lu_factor_cusolver_looped | lu_factor_magma_looped
1 threads: -------------------------------------------------------------------------------------------------------------------------------------------------
shape torch.Size([1, 1, 1]) | 51 | 30 | 27 | 1390
shape torch.Size([2, 1, 1]) | 42 | 20 | 26 | 2798
shape torch.Size([4, 1, 1]) | 42 | 20 | 42 | 5589
shape torch.Size([8, 1, 1]) | 42 | 20 | 72 | 11000
shape torch.Size([16, 1, 1]) | 42 | 20 | 132 | 22400
shape torch.Size([32, 1, 1]) | 42 | 20 | 253 | 44620
shape torch.Size([64, 1, 1]) | 42 | 20 | 496 | 89200
shape torch.Size([128, 1, 1]) | 42 | 20 | 980 | 180000
shape torch.Size([512, 1, 1]) | 43 | 20 | 3868 | 714100
shape torch.Size([1024, 1, 1]) | 44 | 20 | 7800 | 1430000
shape torch.Size([1, 2, 2]) | 43 | 21 | 19 | 1400
shape torch.Size([2, 2, 2]) | 42 | 21 | 27 | 2898
shape torch.Size([4, 2, 2]) | 43 | 21 | 42 | 5800
shape torch.Size([8, 2, 2]) | 43 | 21 | 73 | 11600
shape torch.Size([16, 2, 2]) | 43 | 21 | 133 | 23170
shape torch.Size([32, 2, 2]) | 43 | 21 | 254 | 46290
shape torch.Size([64, 2, 2]) | 43 | 21 | 500 | 94000
shape torch.Size([128, 2, 2]) | 43 | 21 | 980 | 190000
shape torch.Size([512, 2, 2]) | 44 | 21 | 3860 | 741900
shape torch.Size([1024, 2, 2]) | 44 | 21 | 7640 | 1484000
shape torch.Size([1, 8, 8]) | 45 | 21 | 19 | 1450
shape torch.Size([2, 8, 8]) | 45 | 21 | 27 | 2917
shape torch.Size([4, 8, 8]) | 45 | 21 | 53 | 5800
shape torch.Size([8, 8, 8]) | 45 | 21 | 105 | 11580
shape torch.Size([16, 8, 8]) | 45 | 21 | 207 | 23160
shape torch.Size([32, 8, 8]) | 46 | 21 | 413 | 46400
shape torch.Size([64, 8, 8]) | 46 | 21 | 824 | 93000
shape torch.Size([128, 8, 8]) | 46 | 21 | 1645 | 185000
shape torch.Size([512, 8, 8]) | 47 | 21 | 6574 | 742000
shape torch.Size([1024, 8, 8]) | 49 | 21 | 13150 | 1481000
shape torch.Size([1, 16, 16]) | 49 | 21 | 24 | 1460
shape torch.Size([2, 16, 16]) | 49 | 21 | 46 | 2902
shape torch.Size([4, 16, 16]) | 49 | 21 | 90 | 5800
shape torch.Size([8, 16, 16]) | 49 | 21 | 177 | 11600
shape torch.Size([16, 16, 16]) | 49 | 21 | 352 | 23150
shape torch.Size([32, 16, 16]) | 49 | 21 | 703 | 46300
shape torch.Size([64, 16, 16]) | 49 | 21 | 1404 | 92700
shape torch.Size([128, 16, 16]) | 50 | 21 | 2807 | 185000
shape torch.Size([512, 16, 16]) | 55 | 29 | 11220 | 741700
shape torch.Size([1024, 16, 16]) | 64 | 42 | 22440 | 1480000
shape torch.Size([1, 32, 32]) | 55 | 56 | 58 | 1460
shape torch.Size([2, 32, 32]) | 55 | 57 | 114 | 2920
shape torch.Size([4, 32, 32]) | 55 | 57 | 225 | 5830
shape torch.Size([8, 32, 32]) | 55 | 61 | 449 | 11700
shape torch.Size([16, 32, 32]) | 56 | 61 | 896 | 23300
shape torch.Size([32, 32, 32]) | 56 | 62 | 1791 | 46600
shape torch.Size([64, 32, 32]) | 56 | 63 | 3581 | 93100
shape torch.Size([128, 32, 32]) | 58 | 66 | 7156 | 186000
shape torch.Size([512, 32, 32]) | 100 | 194 | 28700 | 742400
shape torch.Size([1024, 32, 32]) | 169 | 335 | 57620 | 1485000
shape torch.Size([1, 64, 64]) | 224 | 101 | 132 | 1500
shape torch.Size([2, 64, 64]) | 227 | 100 | 262 | 2951
shape torch.Size([4, 64, 64]) | 229 | 101 | 523 | 5890
shape torch.Size([8, 64, 64]) | 231 | 102 | 1040 | 12000
shape torch.Size([16, 64, 64]) | 237 | 109 | 2088 | 23530
shape torch.Size([32, 64, 64]) | 242 | 127 | 4171 | 46900
shape torch.Size([64, 64, 64]) | 247 | 156 | 8330 | 95000
shape torch.Size([128, 64, 64]) | 293 | 244 | 16710 | 189000
shape torch.Size([512, 64, 64]) | 685 | 1180 | 67000 | 750900
shape torch.Size([1024, 64, 64]) | 1300 | 2076 | 134000 | 1505000
shape torch.Size([1, 128, 128]) | 490 | 309 | 298 | 1560
shape torch.Size([2, 128, 128]) | 503 | 309 | 594 | 3120
shape torch.Size([4, 128, 128]) | 515 | 312 | 1185 | 6230
shape torch.Size([8, 128, 128]) | 523 | 317 | 2370 | 12500
shape torch.Size([16, 128, 128]) | 547 | 336 | 4734 | 24890
shape torch.Size([32, 128, 128]) | 596 | 472 | 9491 | 49800
shape torch.Size([64, 128, 128]) | 700 | 741 | 19000 | 100000
shape torch.Size([128, 128, 128]) | 930 | 1770 | 37990 | 199000
shape torch.Size([512, 128, 128]) | 2810 | 11000 | 152000 | 797100
shape torch.Size([1024, 128, 128]) | 5430 | 22430 | 303900 | 1595000
shape torch.Size([1, 256, 256]) | 1120 | 1580 | 666 | 1890
shape torch.Size([2, 256, 256]) | 1160 | 1574 | 1330 | 3784
shape torch.Size([4, 256, 256]) | 1190 | 1580 | 2658 | 7570
shape torch.Size([8, 256, 256]) | 1250 | 1613 | 5325 | 15100
shape torch.Size([16, 256, 256]) | 1394 | 1880 | 10700 | 30260
shape torch.Size([32, 256, 256]) | 1633 | 3360 | 21300 | 61000
shape torch.Size([64, 256, 256]) | 2258 | 6730 | 42600 | 120000
shape torch.Size([128, 256, 256]) | 3639 | 19200 | 85170 | 242200
shape torch.Size([512, 256, 256]) | 12600 | 87200 | 340600 | 969000
shape torch.Size([1024, 256, 256]) | 24530 | 176000 | 681300 | 1943000
shape torch.Size([1, 512, 512]) | 2557 | 9117 | 1724 | 2577
shape torch.Size([2, 512, 512]) | 2691 | 9209 | 3464 | 5200
shape torch.Size([4, 512, 512]) | 2853 | 9860 | 6940 | 10000
shape torch.Size([8, 512, 512]) | 3153 | 11000 | 13900 | 20570
shape torch.Size([16, 512, 512]) | 3765 | 13000 | 27720 | 41360
shape torch.Size([32, 512, 512]) | 5500 | 21400 | 55420 | 82000
shape torch.Size([64, 512, 512]) | 8790 | 44000 | 111000 | 165000
shape torch.Size([128, 512, 512]) | 15300 | 98000 | 221700 | 329800
shape torch.Size([512, 512, 512]) | 55400 | 424100 | 886600 | 1325000
shape torch.Size([1024, 512, 512]) | 110000 | 856200 | 1773000 | 2691000
shape torch.Size([1, 1024, 1024]) | 10350 | 69290 | 5020 | 5327
shape torch.Size([2, 1024, 1024]) | 11200 | 74860 | 10040 | 11000
shape torch.Size([4, 1024, 1024]) | 12200 | 78030 | 20080 | 21290
shape torch.Size([8, 1024, 1024]) | 14000 | 81200 | 40160 | 42850
shape torch.Size([16, 1024, 1024]) | 17700 | 96000 | 80300 | 85500
shape torch.Size([32, 1024, 1024]) | 27740 | 150000 | 160700 | 171000
shape torch.Size([64, 1024, 1024]) | 45940 | 233400 | 321200 | 344100
shape torch.Size([1, 2048, 2048]) | 29860 | 579800 | 12920 | 13500
shape torch.Size([2, 2048, 2048]) | 34000 | 585000 | 25840 | 26840
shape torch.Size([4, 2048, 2048]) | 39770 | 593900 | 51670 | 54000
shape torch.Size([8, 2048, 2048]) | 51720 | 632100 | 103000 | 109000
shape torch.Size([16, 2048, 2048]) | 76900 | 845500 | 206600 | 218400
shape torch.Size([32, 2048, 2048]) | 130000 | 1058000 | 413900 | 437300
Times are in microseconds (us).
```
</details>
To generate the results below, I put the backend I wanted to test at the beginning of the function `lu_solve_kernel`, followed by a `return;`. Then I run the following script, changing the variable `name`.
<details>
<summary>
Benchmarking script
</summary>
```python
import torch
import pickle
import itertools
from functools import partial
from torch.utils.benchmark import Timer, Compare
benchmark_name = "linalg.lu_factor CUDA"
name = "magma_looped"
label = "lu_factor_{}".format(name)
shapes = [1, 2, 8, 16, 32, 64, 128, 256, 512, 1024, 2048]
batches = [(1,), (2,), (4,), (8,), (16,), (32,), (64,), (128,), (512,), (1024,)]
results = []
make_arg = partial(torch.randn, dtype=torch.float32, device="cuda")
for n, batch in itertools.product(shapes, batches):
if n == 1024 and batch[0] >= 128:
continue
if n == 2048 and batch[0] >= 64:
continue
A = make_arg(batch + (n, n))
print(A.shape)
stmt = "torch.linalg.lu_factor_ex(A)"
timer = Timer(stmt,
globals=globals(),
label=benchmark_name,
description=label,
sub_label=f"shape {A.shape}",
num_threads=1)
results.append(timer.blocked_autorange())
compare = Compare(results)
compare.trim_significant_figures()
compare.print()
with open(f"{label}.pickle", 'wb') as f:
pickle.dump(results, f)
```
</details>
See #72935 (comment) for the script to join the results.
[ghstack-poisoned]
This PR adds getrf_cublas to the functions considered in the heuristics
for `lu_factor`. It also updates the heuristics of the function.
## Benchmark
I'm omitting form the benchmarks the looped versions of the functions as they are much slower than the non-looped ones. The only exception to this is cusolver's looped variant, which is faster when applied to a batch of size one.
<details>
<summary>
Benchmark Results
</summary>
```
[------------------------------------------------- linalg.lu_factor CUDA -------------------------------------------------]
| lu_factor_heuristic | lu_factor_magma_batched | lu_factor_cusolver_batched
1 threads: ----------------------------------------------------------------------------------------------------------------
shape torch.Size([1, 1, 1]) | 26 | 47 | 26
shape torch.Size([2, 1, 1]) | 17 | 38 | 17
shape torch.Size([4, 1, 1]) | 17 | 38 | 17
shape torch.Size([8, 1, 1]) | 20 | 38 | 18
shape torch.Size([16, 1, 1]) | 20 | 38 | 17
shape torch.Size([32, 1, 1]) | 18 | 38 | 17
shape torch.Size([64, 1, 1]) | 18 | 39 | 17
shape torch.Size([128, 1, 1]) | 17 | 38 | 17
shape torch.Size([512, 1, 1]) | 18 | 39 | 18
shape torch.Size([1024, 1, 1]) | 18 | 40 | 18
shape torch.Size([1, 2, 2]) | 18 | 38 | 17
shape torch.Size([2, 2, 2]) | 17 | 37 | 17
shape torch.Size([4, 2, 2]) | 17 | 38 | 17
shape torch.Size([8, 2, 2]) | 17 | 38 | 17
shape torch.Size([16, 2, 2]) | 17 | 38 | 17
shape torch.Size([32, 2, 2]) | 17 | 38 | 17
shape torch.Size([64, 2, 2]) | 17 | 38 | 17
shape torch.Size([128, 2, 2]) | 17 | 38 | 17
shape torch.Size([512, 2, 2]) | 17 | 39 | 17
shape torch.Size([1024, 2, 2]) | 17 | 40 | 17
shape torch.Size([1, 8, 8]) | 17 | 40 | 17
shape torch.Size([2, 8, 8]) | 17 | 40 | 17
shape torch.Size([4, 8, 8]) | 17 | 40 | 17
shape torch.Size([8, 8, 8]) | 17 | 40 | 17
shape torch.Size([16, 8, 8]) | 17 | 41 | 17
shape torch.Size([32, 8, 8]) | 17 | 40 | 17
shape torch.Size([64, 8, 8]) | 17 | 40 | 17
shape torch.Size([128, 8, 8]) | 17 | 40 | 17
shape torch.Size([512, 8, 8]) | 17 | 42 | 17
shape torch.Size([1024, 8, 8]) | 17 | 44 | 17
shape torch.Size([1, 16, 16]) | 24 | 44 | 18
shape torch.Size([2, 16, 16]) | 18 | 44 | 18
shape torch.Size([4, 16, 16]) | 18 | 45 | 18
shape torch.Size([8, 16, 16]) | 19 | 44 | 19
shape torch.Size([16, 16, 16]) | 20 | 44 | 20
shape torch.Size([32, 16, 16]) | 20 | 45 | 20
shape torch.Size([64, 16, 16]) | 20 | 44 | 20
shape torch.Size([128, 16, 16]) | 20 | 45 | 20
shape torch.Size([512, 16, 16]) | 28 | 50 | 28
shape torch.Size([1024, 16, 16]) | 41 | 59 | 41
shape torch.Size([1, 32, 32]) | 58 | 50 | 56
shape torch.Size([2, 32, 32]) | 56 | 50 | 56
shape torch.Size([4, 32, 32]) | 56 | 50 | 57
shape torch.Size([8, 32, 32]) | 60 | 50 | 60
shape torch.Size([16, 32, 32]) | 60 | 51 | 60
shape torch.Size([32, 32, 32]) | 247 | 51 | 61
shape torch.Size([64, 32, 32]) | 233 | 51 | 63
shape torch.Size([128, 32, 32]) | 236 | 53 | 66
shape torch.Size([512, 32, 32]) | 268 | 97 | 193
shape torch.Size([1024, 32, 32]) | 317 | 167 | 333
shape torch.Size([1, 64, 64]) | 131 | 216 | 99
shape torch.Size([2, 64, 64]) | 99 | 220 | 99
shape torch.Size([4, 64, 64]) | 99 | 225 | 101
shape torch.Size([8, 64, 64]) | 101 | 225 | 102
shape torch.Size([16, 64, 64]) | 107 | 230 | 108
shape torch.Size([32, 64, 64]) | 440 | 235 | 126
shape torch.Size([64, 64, 64]) | 447 | 240 | 155
shape torch.Size([128, 64, 64]) | 470 | 289 | 240
shape torch.Size([512, 64, 64]) | 793 | 678 | 1180
shape torch.Size([1024, 64, 64]) | 1000 | 1300 | 2112
shape torch.Size([1, 128, 128]) | 296 | 482 | 309
shape torch.Size([2, 128, 128]) | 308 | 499 | 307
shape torch.Size([4, 128, 128]) | 311 | 510 | 310
shape torch.Size([8, 128, 128]) | 314 | 522 | 314
shape torch.Size([16, 128, 128]) | 334 | 541 | 334
shape torch.Size([32, 128, 128]) | 770 | 591 | 467
shape torch.Size([64, 128, 128]) | 860 | 694 | 733
shape torch.Size([128, 128, 128]) | 1040 | 925 | 1980
shape torch.Size([512, 128, 128]) | 2883 | 2809 | 11000
shape torch.Size([1024, 128, 128]) | 5421 | 5430 | 22360
shape torch.Size([1, 256, 256]) | 1310 | 1109 | 1556
shape torch.Size([2, 256, 256]) | 1360 | 1150 | 1560
shape torch.Size([4, 256, 256]) | 1390 | 1188 | 1569
shape torch.Size([8, 256, 256]) | 1440 | 1250 | 1604
shape torch.Size([16, 256, 256]) | 1550 | 1390 | 1850
shape torch.Size([32, 256, 256]) | 1750 | 1620 | 3332
shape torch.Size([64, 256, 256]) | 2327 | 2246 | 6700
shape torch.Size([128, 256, 256]) | 3697 | 3638 | 19100
shape torch.Size([512, 256, 256]) | 12530 | 12500 | 87300
shape torch.Size([1024, 256, 256]) | 24380 | 24420 | 176000
```
</details>
<details>
<summary>
Benchmark Results all algorithms up to `n=2048`
</summary>
```
[----------------------------------------------------------------- linalg.lu_factor CUDA ------------------------------------------------------------------]
| lu_factor_magma_batched | lu_factor_cusolver_batched | lu_factor_cusolver_looped | lu_factor_magma_looped
1 threads: -------------------------------------------------------------------------------------------------------------------------------------------------
shape torch.Size([1, 1, 1]) | 51 | 30 | 27 | 1390
shape torch.Size([2, 1, 1]) | 42 | 20 | 26 | 2798
shape torch.Size([4, 1, 1]) | 42 | 20 | 42 | 5589
shape torch.Size([8, 1, 1]) | 42 | 20 | 72 | 11000
shape torch.Size([16, 1, 1]) | 42 | 20 | 132 | 22400
shape torch.Size([32, 1, 1]) | 42 | 20 | 253 | 44620
shape torch.Size([64, 1, 1]) | 42 | 20 | 496 | 89200
shape torch.Size([128, 1, 1]) | 42 | 20 | 980 | 180000
shape torch.Size([512, 1, 1]) | 43 | 20 | 3868 | 714100
shape torch.Size([1024, 1, 1]) | 44 | 20 | 7800 | 1430000
shape torch.Size([1, 2, 2]) | 43 | 21 | 19 | 1400
shape torch.Size([2, 2, 2]) | 42 | 21 | 27 | 2898
shape torch.Size([4, 2, 2]) | 43 | 21 | 42 | 5800
shape torch.Size([8, 2, 2]) | 43 | 21 | 73 | 11600
shape torch.Size([16, 2, 2]) | 43 | 21 | 133 | 23170
shape torch.Size([32, 2, 2]) | 43 | 21 | 254 | 46290
shape torch.Size([64, 2, 2]) | 43 | 21 | 500 | 94000
shape torch.Size([128, 2, 2]) | 43 | 21 | 980 | 190000
shape torch.Size([512, 2, 2]) | 44 | 21 | 3860 | 741900
shape torch.Size([1024, 2, 2]) | 44 | 21 | 7640 | 1484000
shape torch.Size([1, 8, 8]) | 45 | 21 | 19 | 1450
shape torch.Size([2, 8, 8]) | 45 | 21 | 27 | 2917
shape torch.Size([4, 8, 8]) | 45 | 21 | 53 | 5800
shape torch.Size([8, 8, 8]) | 45 | 21 | 105 | 11580
shape torch.Size([16, 8, 8]) | 45 | 21 | 207 | 23160
shape torch.Size([32, 8, 8]) | 46 | 21 | 413 | 46400
shape torch.Size([64, 8, 8]) | 46 | 21 | 824 | 93000
shape torch.Size([128, 8, 8]) | 46 | 21 | 1645 | 185000
shape torch.Size([512, 8, 8]) | 47 | 21 | 6574 | 742000
shape torch.Size([1024, 8, 8]) | 49 | 21 | 13150 | 1481000
shape torch.Size([1, 16, 16]) | 49 | 21 | 24 | 1460
shape torch.Size([2, 16, 16]) | 49 | 21 | 46 | 2902
shape torch.Size([4, 16, 16]) | 49 | 21 | 90 | 5800
shape torch.Size([8, 16, 16]) | 49 | 21 | 177 | 11600
shape torch.Size([16, 16, 16]) | 49 | 21 | 352 | 23150
shape torch.Size([32, 16, 16]) | 49 | 21 | 703 | 46300
shape torch.Size([64, 16, 16]) | 49 | 21 | 1404 | 92700
shape torch.Size([128, 16, 16]) | 50 | 21 | 2807 | 185000
shape torch.Size([512, 16, 16]) | 55 | 29 | 11220 | 741700
shape torch.Size([1024, 16, 16]) | 64 | 42 | 22440 | 1480000
shape torch.Size([1, 32, 32]) | 55 | 56 | 58 | 1460
shape torch.Size([2, 32, 32]) | 55 | 57 | 114 | 2920
shape torch.Size([4, 32, 32]) | 55 | 57 | 225 | 5830
shape torch.Size([8, 32, 32]) | 55 | 61 | 449 | 11700
shape torch.Size([16, 32, 32]) | 56 | 61 | 896 | 23300
shape torch.Size([32, 32, 32]) | 56 | 62 | 1791 | 46600
shape torch.Size([64, 32, 32]) | 56 | 63 | 3581 | 93100
shape torch.Size([128, 32, 32]) | 58 | 66 | 7156 | 186000
shape torch.Size([512, 32, 32]) | 100 | 194 | 28700 | 742400
shape torch.Size([1024, 32, 32]) | 169 | 335 | 57620 | 1485000
shape torch.Size([1, 64, 64]) | 224 | 101 | 132 | 1500
shape torch.Size([2, 64, 64]) | 227 | 100 | 262 | 2951
shape torch.Size([4, 64, 64]) | 229 | 101 | 523 | 5890
shape torch.Size([8, 64, 64]) | 231 | 102 | 1040 | 12000
shape torch.Size([16, 64, 64]) | 237 | 109 | 2088 | 23530
shape torch.Size([32, 64, 64]) | 242 | 127 | 4171 | 46900
shape torch.Size([64, 64, 64]) | 247 | 156 | 8330 | 95000
shape torch.Size([128, 64, 64]) | 293 | 244 | 16710 | 189000
shape torch.Size([512, 64, 64]) | 685 | 1180 | 67000 | 750900
shape torch.Size([1024, 64, 64]) | 1300 | 2076 | 134000 | 1505000
shape torch.Size([1, 128, 128]) | 490 | 309 | 298 | 1560
shape torch.Size([2, 128, 128]) | 503 | 309 | 594 | 3120
shape torch.Size([4, 128, 128]) | 515 | 312 | 1185 | 6230
shape torch.Size([8, 128, 128]) | 523 | 317 | 2370 | 12500
shape torch.Size([16, 128, 128]) | 547 | 336 | 4734 | 24890
shape torch.Size([32, 128, 128]) | 596 | 472 | 9491 | 49800
shape torch.Size([64, 128, 128]) | 700 | 741 | 19000 | 100000
shape torch.Size([128, 128, 128]) | 930 | 1770 | 37990 | 199000
shape torch.Size([512, 128, 128]) | 2810 | 11000 | 152000 | 797100
shape torch.Size([1024, 128, 128]) | 5430 | 22430 | 303900 | 1595000
shape torch.Size([1, 256, 256]) | 1120 | 1580 | 666 | 1890
shape torch.Size([2, 256, 256]) | 1160 | 1574 | 1330 | 3784
shape torch.Size([4, 256, 256]) | 1190 | 1580 | 2658 | 7570
shape torch.Size([8, 256, 256]) | 1250 | 1613 | 5325 | 15100
shape torch.Size([16, 256, 256]) | 1394 | 1880 | 10700 | 30260
shape torch.Size([32, 256, 256]) | 1633 | 3360 | 21300 | 61000
shape torch.Size([64, 256, 256]) | 2258 | 6730 | 42600 | 120000
shape torch.Size([128, 256, 256]) | 3639 | 19200 | 85170 | 242200
shape torch.Size([512, 256, 256]) | 12600 | 87200 | 340600 | 969000
shape torch.Size([1024, 256, 256]) | 24530 | 176000 | 681300 | 1943000
shape torch.Size([1, 512, 512]) | 2557 | 9117 | 1724 | 2577
shape torch.Size([2, 512, 512]) | 2691 | 9209 | 3464 | 5200
shape torch.Size([4, 512, 512]) | 2853 | 9860 | 6940 | 10000
shape torch.Size([8, 512, 512]) | 3153 | 11000 | 13900 | 20570
shape torch.Size([16, 512, 512]) | 3765 | 13000 | 27720 | 41360
shape torch.Size([32, 512, 512]) | 5500 | 21400 | 55420 | 82000
shape torch.Size([64, 512, 512]) | 8790 | 44000 | 111000 | 165000
shape torch.Size([128, 512, 512]) | 15300 | 98000 | 221700 | 329800
shape torch.Size([512, 512, 512]) | 55400 | 424100 | 886600 | 1325000
shape torch.Size([1024, 512, 512]) | 110000 | 856200 | 1773000 | 2691000
shape torch.Size([1, 1024, 1024]) | 10350 | 69290 | 5020 | 5327
shape torch.Size([2, 1024, 1024]) | 11200 | 74860 | 10040 | 11000
shape torch.Size([4, 1024, 1024]) | 12200 | 78030 | 20080 | 21290
shape torch.Size([8, 1024, 1024]) | 14000 | 81200 | 40160 | 42850
shape torch.Size([16, 1024, 1024]) | 17700 | 96000 | 80300 | 85500
shape torch.Size([32, 1024, 1024]) | 27740 | 150000 | 160700 | 171000
shape torch.Size([64, 1024, 1024]) | 45940 | 233400 | 321200 | 344100
shape torch.Size([1, 2048, 2048]) | 29860 | 579800 | 12920 | 13500
shape torch.Size([2, 2048, 2048]) | 34000 | 585000 | 25840 | 26840
shape torch.Size([4, 2048, 2048]) | 39770 | 593900 | 51670 | 54000
shape torch.Size([8, 2048, 2048]) | 51720 | 632100 | 103000 | 109000
shape torch.Size([16, 2048, 2048]) | 76900 | 845500 | 206600 | 218400
shape torch.Size([32, 2048, 2048]) | 130000 | 1058000 | 413900 | 437300
Times are in microseconds (us).
```
</details>
To generate the results below, I put the backend I wanted to test at the beginning of the function `lu_solve_kernel`, followed by a `return;`. Then I run the following script, changing the variable `name`.
<details>
<summary>
Benchmarking script
</summary>
```python
import torch
import pickle
import itertools
from functools import partial
from torch.utils.benchmark import Timer, Compare
benchmark_name = "linalg.lu_factor CUDA"
name = "magma_looped"
label = "lu_factor_{}".format(name)
shapes = [1, 2, 8, 16, 32, 64, 128, 256, 512, 1024, 2048]
batches = [(1,), (2,), (4,), (8,), (16,), (32,), (64,), (128,), (512,), (1024,)]
results = []
make_arg = partial(torch.randn, dtype=torch.float32, device="cuda")
for n, batch in itertools.product(shapes, batches):
if n == 1024 and batch[0] >= 128:
continue
if n == 2048 and batch[0] >= 64:
continue
A = make_arg(batch + (n, n))
print(A.shape)
stmt = "torch.linalg.lu_factor_ex(A)"
timer = Timer(stmt,
globals=globals(),
label=benchmark_name,
description=label,
sub_label=f"shape {A.shape}",
num_threads=1)
results.append(timer.blocked_autorange())
compare = Compare(results)
compare.trim_significant_figures()
compare.print()
with open(f"{label}.pickle", 'wb') as f:
pickle.dump(results, f)
```
</details>
See #72935 (comment) for the script to join the results.
[ghstack-poisoned]
This PR adds getrf_cublas to the functions considered in the heuristics
for `lu_factor`. It also updates the heuristics of the function.
## Benchmark
I'm omitting form the benchmarks the looped versions of the functions as they are much slower than the non-looped ones. The only exception to this is cusolver's looped variant, which is faster when applied to a batch of size one.
<details>
<summary>
Benchmark Results
</summary>
```
[------------------------------------------------- linalg.lu_factor CUDA -------------------------------------------------]
| lu_factor_heuristic | lu_factor_magma_batched | lu_factor_cusolver_batched
1 threads: ----------------------------------------------------------------------------------------------------------------
shape torch.Size([1, 1, 1]) | 26 | 47 | 26
shape torch.Size([2, 1, 1]) | 17 | 38 | 17
shape torch.Size([4, 1, 1]) | 17 | 38 | 17
shape torch.Size([8, 1, 1]) | 20 | 38 | 18
shape torch.Size([16, 1, 1]) | 20 | 38 | 17
shape torch.Size([32, 1, 1]) | 18 | 38 | 17
shape torch.Size([64, 1, 1]) | 18 | 39 | 17
shape torch.Size([128, 1, 1]) | 17 | 38 | 17
shape torch.Size([512, 1, 1]) | 18 | 39 | 18
shape torch.Size([1024, 1, 1]) | 18 | 40 | 18
shape torch.Size([1, 2, 2]) | 18 | 38 | 17
shape torch.Size([2, 2, 2]) | 17 | 37 | 17
shape torch.Size([4, 2, 2]) | 17 | 38 | 17
shape torch.Size([8, 2, 2]) | 17 | 38 | 17
shape torch.Size([16, 2, 2]) | 17 | 38 | 17
shape torch.Size([32, 2, 2]) | 17 | 38 | 17
shape torch.Size([64, 2, 2]) | 17 | 38 | 17
shape torch.Size([128, 2, 2]) | 17 | 38 | 17
shape torch.Size([512, 2, 2]) | 17 | 39 | 17
shape torch.Size([1024, 2, 2]) | 17 | 40 | 17
shape torch.Size([1, 8, 8]) | 17 | 40 | 17
shape torch.Size([2, 8, 8]) | 17 | 40 | 17
shape torch.Size([4, 8, 8]) | 17 | 40 | 17
shape torch.Size([8, 8, 8]) | 17 | 40 | 17
shape torch.Size([16, 8, 8]) | 17 | 41 | 17
shape torch.Size([32, 8, 8]) | 17 | 40 | 17
shape torch.Size([64, 8, 8]) | 17 | 40 | 17
shape torch.Size([128, 8, 8]) | 17 | 40 | 17
shape torch.Size([512, 8, 8]) | 17 | 42 | 17
shape torch.Size([1024, 8, 8]) | 17 | 44 | 17
shape torch.Size([1, 16, 16]) | 24 | 44 | 18
shape torch.Size([2, 16, 16]) | 18 | 44 | 18
shape torch.Size([4, 16, 16]) | 18 | 45 | 18
shape torch.Size([8, 16, 16]) | 19 | 44 | 19
shape torch.Size([16, 16, 16]) | 20 | 44 | 20
shape torch.Size([32, 16, 16]) | 20 | 45 | 20
shape torch.Size([64, 16, 16]) | 20 | 44 | 20
shape torch.Size([128, 16, 16]) | 20 | 45 | 20
shape torch.Size([512, 16, 16]) | 28 | 50 | 28
shape torch.Size([1024, 16, 16]) | 41 | 59 | 41
shape torch.Size([1, 32, 32]) | 58 | 50 | 56
shape torch.Size([2, 32, 32]) | 56 | 50 | 56
shape torch.Size([4, 32, 32]) | 56 | 50 | 57
shape torch.Size([8, 32, 32]) | 60 | 50 | 60
shape torch.Size([16, 32, 32]) | 60 | 51 | 60
shape torch.Size([32, 32, 32]) | 247 | 51 | 61
shape torch.Size([64, 32, 32]) | 233 | 51 | 63
shape torch.Size([128, 32, 32]) | 236 | 53 | 66
shape torch.Size([512, 32, 32]) | 268 | 97 | 193
shape torch.Size([1024, 32, 32]) | 317 | 167 | 333
shape torch.Size([1, 64, 64]) | 131 | 216 | 99
shape torch.Size([2, 64, 64]) | 99 | 220 | 99
shape torch.Size([4, 64, 64]) | 99 | 225 | 101
shape torch.Size([8, 64, 64]) | 101 | 225 | 102
shape torch.Size([16, 64, 64]) | 107 | 230 | 108
shape torch.Size([32, 64, 64]) | 440 | 235 | 126
shape torch.Size([64, 64, 64]) | 447 | 240 | 155
shape torch.Size([128, 64, 64]) | 470 | 289 | 240
shape torch.Size([512, 64, 64]) | 793 | 678 | 1180
shape torch.Size([1024, 64, 64]) | 1000 | 1300 | 2112
shape torch.Size([1, 128, 128]) | 296 | 482 | 309
shape torch.Size([2, 128, 128]) | 308 | 499 | 307
shape torch.Size([4, 128, 128]) | 311 | 510 | 310
shape torch.Size([8, 128, 128]) | 314 | 522 | 314
shape torch.Size([16, 128, 128]) | 334 | 541 | 334
shape torch.Size([32, 128, 128]) | 770 | 591 | 467
shape torch.Size([64, 128, 128]) | 860 | 694 | 733
shape torch.Size([128, 128, 128]) | 1040 | 925 | 1980
shape torch.Size([512, 128, 128]) | 2883 | 2809 | 11000
shape torch.Size([1024, 128, 128]) | 5421 | 5430 | 22360
shape torch.Size([1, 256, 256]) | 1310 | 1109 | 1556
shape torch.Size([2, 256, 256]) | 1360 | 1150 | 1560
shape torch.Size([4, 256, 256]) | 1390 | 1188 | 1569
shape torch.Size([8, 256, 256]) | 1440 | 1250 | 1604
shape torch.Size([16, 256, 256]) | 1550 | 1390 | 1850
shape torch.Size([32, 256, 256]) | 1750 | 1620 | 3332
shape torch.Size([64, 256, 256]) | 2327 | 2246 | 6700
shape torch.Size([128, 256, 256]) | 3697 | 3638 | 19100
shape torch.Size([512, 256, 256]) | 12530 | 12500 | 87300
shape torch.Size([1024, 256, 256]) | 24380 | 24420 | 176000
```
</details>
<details>
<summary>
Benchmark Results all algorithms up to `n=2048`
</summary>
```
[----------------------------------------------------------------- linalg.lu_factor CUDA ------------------------------------------------------------------]
| lu_factor_magma_batched | lu_factor_cusolver_batched | lu_factor_cusolver_looped | lu_factor_magma_looped
1 threads: -------------------------------------------------------------------------------------------------------------------------------------------------
shape torch.Size([1, 1, 1]) | 51 | 30 | 27 | 1390
shape torch.Size([2, 1, 1]) | 42 | 20 | 26 | 2798
shape torch.Size([4, 1, 1]) | 42 | 20 | 42 | 5589
shape torch.Size([8, 1, 1]) | 42 | 20 | 72 | 11000
shape torch.Size([16, 1, 1]) | 42 | 20 | 132 | 22400
shape torch.Size([32, 1, 1]) | 42 | 20 | 253 | 44620
shape torch.Size([64, 1, 1]) | 42 | 20 | 496 | 89200
shape torch.Size([128, 1, 1]) | 42 | 20 | 980 | 180000
shape torch.Size([512, 1, 1]) | 43 | 20 | 3868 | 714100
shape torch.Size([1024, 1, 1]) | 44 | 20 | 7800 | 1430000
shape torch.Size([1, 2, 2]) | 43 | 21 | 19 | 1400
shape torch.Size([2, 2, 2]) | 42 | 21 | 27 | 2898
shape torch.Size([4, 2, 2]) | 43 | 21 | 42 | 5800
shape torch.Size([8, 2, 2]) | 43 | 21 | 73 | 11600
shape torch.Size([16, 2, 2]) | 43 | 21 | 133 | 23170
shape torch.Size([32, 2, 2]) | 43 | 21 | 254 | 46290
shape torch.Size([64, 2, 2]) | 43 | 21 | 500 | 94000
shape torch.Size([128, 2, 2]) | 43 | 21 | 980 | 190000
shape torch.Size([512, 2, 2]) | 44 | 21 | 3860 | 741900
shape torch.Size([1024, 2, 2]) | 44 | 21 | 7640 | 1484000
shape torch.Size([1, 8, 8]) | 45 | 21 | 19 | 1450
shape torch.Size([2, 8, 8]) | 45 | 21 | 27 | 2917
shape torch.Size([4, 8, 8]) | 45 | 21 | 53 | 5800
shape torch.Size([8, 8, 8]) | 45 | 21 | 105 | 11580
shape torch.Size([16, 8, 8]) | 45 | 21 | 207 | 23160
shape torch.Size([32, 8, 8]) | 46 | 21 | 413 | 46400
shape torch.Size([64, 8, 8]) | 46 | 21 | 824 | 93000
shape torch.Size([128, 8, 8]) | 46 | 21 | 1645 | 185000
shape torch.Size([512, 8, 8]) | 47 | 21 | 6574 | 742000
shape torch.Size([1024, 8, 8]) | 49 | 21 | 13150 | 1481000
shape torch.Size([1, 16, 16]) | 49 | 21 | 24 | 1460
shape torch.Size([2, 16, 16]) | 49 | 21 | 46 | 2902
shape torch.Size([4, 16, 16]) | 49 | 21 | 90 | 5800
shape torch.Size([8, 16, 16]) | 49 | 21 | 177 | 11600
shape torch.Size([16, 16, 16]) | 49 | 21 | 352 | 23150
shape torch.Size([32, 16, 16]) | 49 | 21 | 703 | 46300
shape torch.Size([64, 16, 16]) | 49 | 21 | 1404 | 92700
shape torch.Size([128, 16, 16]) | 50 | 21 | 2807 | 185000
shape torch.Size([512, 16, 16]) | 55 | 29 | 11220 | 741700
shape torch.Size([1024, 16, 16]) | 64 | 42 | 22440 | 1480000
shape torch.Size([1, 32, 32]) | 55 | 56 | 58 | 1460
shape torch.Size([2, 32, 32]) | 55 | 57 | 114 | 2920
shape torch.Size([4, 32, 32]) | 55 | 57 | 225 | 5830
shape torch.Size([8, 32, 32]) | 55 | 61 | 449 | 11700
shape torch.Size([16, 32, 32]) | 56 | 61 | 896 | 23300
shape torch.Size([32, 32, 32]) | 56 | 62 | 1791 | 46600
shape torch.Size([64, 32, 32]) | 56 | 63 | 3581 | 93100
shape torch.Size([128, 32, 32]) | 58 | 66 | 7156 | 186000
shape torch.Size([512, 32, 32]) | 100 | 194 | 28700 | 742400
shape torch.Size([1024, 32, 32]) | 169 | 335 | 57620 | 1485000
shape torch.Size([1, 64, 64]) | 224 | 101 | 132 | 1500
shape torch.Size([2, 64, 64]) | 227 | 100 | 262 | 2951
shape torch.Size([4, 64, 64]) | 229 | 101 | 523 | 5890
shape torch.Size([8, 64, 64]) | 231 | 102 | 1040 | 12000
shape torch.Size([16, 64, 64]) | 237 | 109 | 2088 | 23530
shape torch.Size([32, 64, 64]) | 242 | 127 | 4171 | 46900
shape torch.Size([64, 64, 64]) | 247 | 156 | 8330 | 95000
shape torch.Size([128, 64, 64]) | 293 | 244 | 16710 | 189000
shape torch.Size([512, 64, 64]) | 685 | 1180 | 67000 | 750900
shape torch.Size([1024, 64, 64]) | 1300 | 2076 | 134000 | 1505000
shape torch.Size([1, 128, 128]) | 490 | 309 | 298 | 1560
shape torch.Size([2, 128, 128]) | 503 | 309 | 594 | 3120
shape torch.Size([4, 128, 128]) | 515 | 312 | 1185 | 6230
shape torch.Size([8, 128, 128]) | 523 | 317 | 2370 | 12500
shape torch.Size([16, 128, 128]) | 547 | 336 | 4734 | 24890
shape torch.Size([32, 128, 128]) | 596 | 472 | 9491 | 49800
shape torch.Size([64, 128, 128]) | 700 | 741 | 19000 | 100000
shape torch.Size([128, 128, 128]) | 930 | 1770 | 37990 | 199000
shape torch.Size([512, 128, 128]) | 2810 | 11000 | 152000 | 797100
shape torch.Size([1024, 128, 128]) | 5430 | 22430 | 303900 | 1595000
shape torch.Size([1, 256, 256]) | 1120 | 1580 | 666 | 1890
shape torch.Size([2, 256, 256]) | 1160 | 1574 | 1330 | 3784
shape torch.Size([4, 256, 256]) | 1190 | 1580 | 2658 | 7570
shape torch.Size([8, 256, 256]) | 1250 | 1613 | 5325 | 15100
shape torch.Size([16, 256, 256]) | 1394 | 1880 | 10700 | 30260
shape torch.Size([32, 256, 256]) | 1633 | 3360 | 21300 | 61000
shape torch.Size([64, 256, 256]) | 2258 | 6730 | 42600 | 120000
shape torch.Size([128, 256, 256]) | 3639 | 19200 | 85170 | 242200
shape torch.Size([512, 256, 256]) | 12600 | 87200 | 340600 | 969000
shape torch.Size([1024, 256, 256]) | 24530 | 176000 | 681300 | 1943000
shape torch.Size([1, 512, 512]) | 2557 | 9117 | 1724 | 2577
shape torch.Size([2, 512, 512]) | 2691 | 9209 | 3464 | 5200
shape torch.Size([4, 512, 512]) | 2853 | 9860 | 6940 | 10000
shape torch.Size([8, 512, 512]) | 3153 | 11000 | 13900 | 20570
shape torch.Size([16, 512, 512]) | 3765 | 13000 | 27720 | 41360
shape torch.Size([32, 512, 512]) | 5500 | 21400 | 55420 | 82000
shape torch.Size([64, 512, 512]) | 8790 | 44000 | 111000 | 165000
shape torch.Size([128, 512, 512]) | 15300 | 98000 | 221700 | 329800
shape torch.Size([512, 512, 512]) | 55400 | 424100 | 886600 | 1325000
shape torch.Size([1024, 512, 512]) | 110000 | 856200 | 1773000 | 2691000
shape torch.Size([1, 1024, 1024]) | 10350 | 69290 | 5020 | 5327
shape torch.Size([2, 1024, 1024]) | 11200 | 74860 | 10040 | 11000
shape torch.Size([4, 1024, 1024]) | 12200 | 78030 | 20080 | 21290
shape torch.Size([8, 1024, 1024]) | 14000 | 81200 | 40160 | 42850
shape torch.Size([16, 1024, 1024]) | 17700 | 96000 | 80300 | 85500
shape torch.Size([32, 1024, 1024]) | 27740 | 150000 | 160700 | 171000
shape torch.Size([64, 1024, 1024]) | 45940 | 233400 | 321200 | 344100
shape torch.Size([1, 2048, 2048]) | 29860 | 579800 | 12920 | 13500
shape torch.Size([2, 2048, 2048]) | 34000 | 585000 | 25840 | 26840
shape torch.Size([4, 2048, 2048]) | 39770 | 593900 | 51670 | 54000
shape torch.Size([8, 2048, 2048]) | 51720 | 632100 | 103000 | 109000
shape torch.Size([16, 2048, 2048]) | 76900 | 845500 | 206600 | 218400
shape torch.Size([32, 2048, 2048]) | 130000 | 1058000 | 413900 | 437300
Times are in microseconds (us).
```
</details>
To generate the results below, I put the backend I wanted to test at the beginning of the function `lu_solve_kernel`, followed by a `return;`. Then I run the following script, changing the variable `name`.
<details>
<summary>
Benchmarking script
</summary>
```python
import torch
import pickle
import itertools
from functools import partial
from torch.utils.benchmark import Timer, Compare
benchmark_name = "linalg.lu_factor CUDA"
name = "magma_looped"
label = "lu_factor_{}".format(name)
shapes = [1, 2, 8, 16, 32, 64, 128, 256, 512, 1024, 2048]
batches = [(1,), (2,), (4,), (8,), (16,), (32,), (64,), (128,), (512,), (1024,)]
results = []
make_arg = partial(torch.randn, dtype=torch.float32, device="cuda")
for n, batch in itertools.product(shapes, batches):
if n == 1024 and batch[0] >= 128:
continue
if n == 2048 and batch[0] >= 64:
continue
A = make_arg(batch + (n, n))
print(A.shape)
stmt = "torch.linalg.lu_factor_ex(A)"
timer = Timer(stmt,
globals=globals(),
label=benchmark_name,
description=label,
sub_label=f"shape {A.shape}",
num_threads=1)
results.append(timer.blocked_autorange())
compare = Compare(results)
compare.trim_significant_figures()
compare.print()
with open(f"{label}.pickle", 'wb') as f:
pickle.dump(results, f)
```
</details>
See #72935 (comment) for the script to join the results.
[ghstack-poisoned]
This PR adds getrf_cublas to the functions considered in the heuristics
for `lu_factor`. It also updates the heuristics of the function.
## Benchmark
I'm omitting form the benchmarks the looped versions of the functions as they are much slower than the non-looped ones. The only exception to this is cusolver's looped variant, which is faster when applied to a batch of size one.
<details>
<summary>
Benchmark Results
</summary>
```
[------------------------------------------------- linalg.lu_factor CUDA -------------------------------------------------]
| lu_factor_heuristic | lu_factor_magma_batched | lu_factor_cusolver_batched
1 threads: ----------------------------------------------------------------------------------------------------------------
shape torch.Size([1, 1, 1]) | 26 | 47 | 26
shape torch.Size([2, 1, 1]) | 17 | 38 | 17
shape torch.Size([4, 1, 1]) | 17 | 38 | 17
shape torch.Size([8, 1, 1]) | 20 | 38 | 18
shape torch.Size([16, 1, 1]) | 20 | 38 | 17
shape torch.Size([32, 1, 1]) | 18 | 38 | 17
shape torch.Size([64, 1, 1]) | 18 | 39 | 17
shape torch.Size([128, 1, 1]) | 17 | 38 | 17
shape torch.Size([512, 1, 1]) | 18 | 39 | 18
shape torch.Size([1024, 1, 1]) | 18 | 40 | 18
shape torch.Size([1, 2, 2]) | 18 | 38 | 17
shape torch.Size([2, 2, 2]) | 17 | 37 | 17
shape torch.Size([4, 2, 2]) | 17 | 38 | 17
shape torch.Size([8, 2, 2]) | 17 | 38 | 17
shape torch.Size([16, 2, 2]) | 17 | 38 | 17
shape torch.Size([32, 2, 2]) | 17 | 38 | 17
shape torch.Size([64, 2, 2]) | 17 | 38 | 17
shape torch.Size([128, 2, 2]) | 17 | 38 | 17
shape torch.Size([512, 2, 2]) | 17 | 39 | 17
shape torch.Size([1024, 2, 2]) | 17 | 40 | 17
shape torch.Size([1, 8, 8]) | 17 | 40 | 17
shape torch.Size([2, 8, 8]) | 17 | 40 | 17
shape torch.Size([4, 8, 8]) | 17 | 40 | 17
shape torch.Size([8, 8, 8]) | 17 | 40 | 17
shape torch.Size([16, 8, 8]) | 17 | 41 | 17
shape torch.Size([32, 8, 8]) | 17 | 40 | 17
shape torch.Size([64, 8, 8]) | 17 | 40 | 17
shape torch.Size([128, 8, 8]) | 17 | 40 | 17
shape torch.Size([512, 8, 8]) | 17 | 42 | 17
shape torch.Size([1024, 8, 8]) | 17 | 44 | 17
shape torch.Size([1, 16, 16]) | 24 | 44 | 18
shape torch.Size([2, 16, 16]) | 18 | 44 | 18
shape torch.Size([4, 16, 16]) | 18 | 45 | 18
shape torch.Size([8, 16, 16]) | 19 | 44 | 19
shape torch.Size([16, 16, 16]) | 20 | 44 | 20
shape torch.Size([32, 16, 16]) | 20 | 45 | 20
shape torch.Size([64, 16, 16]) | 20 | 44 | 20
shape torch.Size([128, 16, 16]) | 20 | 45 | 20
shape torch.Size([512, 16, 16]) | 28 | 50 | 28
shape torch.Size([1024, 16, 16]) | 41 | 59 | 41
shape torch.Size([1, 32, 32]) | 58 | 50 | 56
shape torch.Size([2, 32, 32]) | 56 | 50 | 56
shape torch.Size([4, 32, 32]) | 56 | 50 | 57
shape torch.Size([8, 32, 32]) | 60 | 50 | 60
shape torch.Size([16, 32, 32]) | 60 | 51 | 60
shape torch.Size([32, 32, 32]) | 247 | 51 | 61
shape torch.Size([64, 32, 32]) | 233 | 51 | 63
shape torch.Size([128, 32, 32]) | 236 | 53 | 66
shape torch.Size([512, 32, 32]) | 268 | 97 | 193
shape torch.Size([1024, 32, 32]) | 317 | 167 | 333
shape torch.Size([1, 64, 64]) | 131 | 216 | 99
shape torch.Size([2, 64, 64]) | 99 | 220 | 99
shape torch.Size([4, 64, 64]) | 99 | 225 | 101
shape torch.Size([8, 64, 64]) | 101 | 225 | 102
shape torch.Size([16, 64, 64]) | 107 | 230 | 108
shape torch.Size([32, 64, 64]) | 440 | 235 | 126
shape torch.Size([64, 64, 64]) | 447 | 240 | 155
shape torch.Size([128, 64, 64]) | 470 | 289 | 240
shape torch.Size([512, 64, 64]) | 793 | 678 | 1180
shape torch.Size([1024, 64, 64]) | 1000 | 1300 | 2112
shape torch.Size([1, 128, 128]) | 296 | 482 | 309
shape torch.Size([2, 128, 128]) | 308 | 499 | 307
shape torch.Size([4, 128, 128]) | 311 | 510 | 310
shape torch.Size([8, 128, 128]) | 314 | 522 | 314
shape torch.Size([16, 128, 128]) | 334 | 541 | 334
shape torch.Size([32, 128, 128]) | 770 | 591 | 467
shape torch.Size([64, 128, 128]) | 860 | 694 | 733
shape torch.Size([128, 128, 128]) | 1040 | 925 | 1980
shape torch.Size([512, 128, 128]) | 2883 | 2809 | 11000
shape torch.Size([1024, 128, 128]) | 5421 | 5430 | 22360
shape torch.Size([1, 256, 256]) | 1310 | 1109 | 1556
shape torch.Size([2, 256, 256]) | 1360 | 1150 | 1560
shape torch.Size([4, 256, 256]) | 1390 | 1188 | 1569
shape torch.Size([8, 256, 256]) | 1440 | 1250 | 1604
shape torch.Size([16, 256, 256]) | 1550 | 1390 | 1850
shape torch.Size([32, 256, 256]) | 1750 | 1620 | 3332
shape torch.Size([64, 256, 256]) | 2327 | 2246 | 6700
shape torch.Size([128, 256, 256]) | 3697 | 3638 | 19100
shape torch.Size([512, 256, 256]) | 12530 | 12500 | 87300
shape torch.Size([1024, 256, 256]) | 24380 | 24420 | 176000
```
</details>
<details>
<summary>
Benchmark Results all algorithms up to `n=2048`
</summary>
```
[----------------------------------------------------------------- linalg.lu_factor CUDA ------------------------------------------------------------------]
| lu_factor_magma_batched | lu_factor_cusolver_batched | lu_factor_cusolver_looped | lu_factor_magma_looped
1 threads: -------------------------------------------------------------------------------------------------------------------------------------------------
shape torch.Size([1, 1, 1]) | 51 | 30 | 27 | 1390
shape torch.Size([2, 1, 1]) | 42 | 20 | 26 | 2798
shape torch.Size([4, 1, 1]) | 42 | 20 | 42 | 5589
shape torch.Size([8, 1, 1]) | 42 | 20 | 72 | 11000
shape torch.Size([16, 1, 1]) | 42 | 20 | 132 | 22400
shape torch.Size([32, 1, 1]) | 42 | 20 | 253 | 44620
shape torch.Size([64, 1, 1]) | 42 | 20 | 496 | 89200
shape torch.Size([128, 1, 1]) | 42 | 20 | 980 | 180000
shape torch.Size([512, 1, 1]) | 43 | 20 | 3868 | 714100
shape torch.Size([1024, 1, 1]) | 44 | 20 | 7800 | 1430000
shape torch.Size([1, 2, 2]) | 43 | 21 | 19 | 1400
shape torch.Size([2, 2, 2]) | 42 | 21 | 27 | 2898
shape torch.Size([4, 2, 2]) | 43 | 21 | 42 | 5800
shape torch.Size([8, 2, 2]) | 43 | 21 | 73 | 11600
shape torch.Size([16, 2, 2]) | 43 | 21 | 133 | 23170
shape torch.Size([32, 2, 2]) | 43 | 21 | 254 | 46290
shape torch.Size([64, 2, 2]) | 43 | 21 | 500 | 94000
shape torch.Size([128, 2, 2]) | 43 | 21 | 980 | 190000
shape torch.Size([512, 2, 2]) | 44 | 21 | 3860 | 741900
shape torch.Size([1024, 2, 2]) | 44 | 21 | 7640 | 1484000
shape torch.Size([1, 8, 8]) | 45 | 21 | 19 | 1450
shape torch.Size([2, 8, 8]) | 45 | 21 | 27 | 2917
shape torch.Size([4, 8, 8]) | 45 | 21 | 53 | 5800
shape torch.Size([8, 8, 8]) | 45 | 21 | 105 | 11580
shape torch.Size([16, 8, 8]) | 45 | 21 | 207 | 23160
shape torch.Size([32, 8, 8]) | 46 | 21 | 413 | 46400
shape torch.Size([64, 8, 8]) | 46 | 21 | 824 | 93000
shape torch.Size([128, 8, 8]) | 46 | 21 | 1645 | 185000
shape torch.Size([512, 8, 8]) | 47 | 21 | 6574 | 742000
shape torch.Size([1024, 8, 8]) | 49 | 21 | 13150 | 1481000
shape torch.Size([1, 16, 16]) | 49 | 21 | 24 | 1460
shape torch.Size([2, 16, 16]) | 49 | 21 | 46 | 2902
shape torch.Size([4, 16, 16]) | 49 | 21 | 90 | 5800
shape torch.Size([8, 16, 16]) | 49 | 21 | 177 | 11600
shape torch.Size([16, 16, 16]) | 49 | 21 | 352 | 23150
shape torch.Size([32, 16, 16]) | 49 | 21 | 703 | 46300
shape torch.Size([64, 16, 16]) | 49 | 21 | 1404 | 92700
shape torch.Size([128, 16, 16]) | 50 | 21 | 2807 | 185000
shape torch.Size([512, 16, 16]) | 55 | 29 | 11220 | 741700
shape torch.Size([1024, 16, 16]) | 64 | 42 | 22440 | 1480000
shape torch.Size([1, 32, 32]) | 55 | 56 | 58 | 1460
shape torch.Size([2, 32, 32]) | 55 | 57 | 114 | 2920
shape torch.Size([4, 32, 32]) | 55 | 57 | 225 | 5830
shape torch.Size([8, 32, 32]) | 55 | 61 | 449 | 11700
shape torch.Size([16, 32, 32]) | 56 | 61 | 896 | 23300
shape torch.Size([32, 32, 32]) | 56 | 62 | 1791 | 46600
shape torch.Size([64, 32, 32]) | 56 | 63 | 3581 | 93100
shape torch.Size([128, 32, 32]) | 58 | 66 | 7156 | 186000
shape torch.Size([512, 32, 32]) | 100 | 194 | 28700 | 742400
shape torch.Size([1024, 32, 32]) | 169 | 335 | 57620 | 1485000
shape torch.Size([1, 64, 64]) | 224 | 101 | 132 | 1500
shape torch.Size([2, 64, 64]) | 227 | 100 | 262 | 2951
shape torch.Size([4, 64, 64]) | 229 | 101 | 523 | 5890
shape torch.Size([8, 64, 64]) | 231 | 102 | 1040 | 12000
shape torch.Size([16, 64, 64]) | 237 | 109 | 2088 | 23530
shape torch.Size([32, 64, 64]) | 242 | 127 | 4171 | 46900
shape torch.Size([64, 64, 64]) | 247 | 156 | 8330 | 95000
shape torch.Size([128, 64, 64]) | 293 | 244 | 16710 | 189000
shape torch.Size([512, 64, 64]) | 685 | 1180 | 67000 | 750900
shape torch.Size([1024, 64, 64]) | 1300 | 2076 | 134000 | 1505000
shape torch.Size([1, 128, 128]) | 490 | 309 | 298 | 1560
shape torch.Size([2, 128, 128]) | 503 | 309 | 594 | 3120
shape torch.Size([4, 128, 128]) | 515 | 312 | 1185 | 6230
shape torch.Size([8, 128, 128]) | 523 | 317 | 2370 | 12500
shape torch.Size([16, 128, 128]) | 547 | 336 | 4734 | 24890
shape torch.Size([32, 128, 128]) | 596 | 472 | 9491 | 49800
shape torch.Size([64, 128, 128]) | 700 | 741 | 19000 | 100000
shape torch.Size([128, 128, 128]) | 930 | 1770 | 37990 | 199000
shape torch.Size([512, 128, 128]) | 2810 | 11000 | 152000 | 797100
shape torch.Size([1024, 128, 128]) | 5430 | 22430 | 303900 | 1595000
shape torch.Size([1, 256, 256]) | 1120 | 1580 | 666 | 1890
shape torch.Size([2, 256, 256]) | 1160 | 1574 | 1330 | 3784
shape torch.Size([4, 256, 256]) | 1190 | 1580 | 2658 | 7570
shape torch.Size([8, 256, 256]) | 1250 | 1613 | 5325 | 15100
shape torch.Size([16, 256, 256]) | 1394 | 1880 | 10700 | 30260
shape torch.Size([32, 256, 256]) | 1633 | 3360 | 21300 | 61000
shape torch.Size([64, 256, 256]) | 2258 | 6730 | 42600 | 120000
shape torch.Size([128, 256, 256]) | 3639 | 19200 | 85170 | 242200
shape torch.Size([512, 256, 256]) | 12600 | 87200 | 340600 | 969000
shape torch.Size([1024, 256, 256]) | 24530 | 176000 | 681300 | 1943000
shape torch.Size([1, 512, 512]) | 2557 | 9117 | 1724 | 2577
shape torch.Size([2, 512, 512]) | 2691 | 9209 | 3464 | 5200
shape torch.Size([4, 512, 512]) | 2853 | 9860 | 6940 | 10000
shape torch.Size([8, 512, 512]) | 3153 | 11000 | 13900 | 20570
shape torch.Size([16, 512, 512]) | 3765 | 13000 | 27720 | 41360
shape torch.Size([32, 512, 512]) | 5500 | 21400 | 55420 | 82000
shape torch.Size([64, 512, 512]) | 8790 | 44000 | 111000 | 165000
shape torch.Size([128, 512, 512]) | 15300 | 98000 | 221700 | 329800
shape torch.Size([512, 512, 512]) | 55400 | 424100 | 886600 | 1325000
shape torch.Size([1024, 512, 512]) | 110000 | 856200 | 1773000 | 2691000
shape torch.Size([1, 1024, 1024]) | 10350 | 69290 | 5020 | 5327
shape torch.Size([2, 1024, 1024]) | 11200 | 74860 | 10040 | 11000
shape torch.Size([4, 1024, 1024]) | 12200 | 78030 | 20080 | 21290
shape torch.Size([8, 1024, 1024]) | 14000 | 81200 | 40160 | 42850
shape torch.Size([16, 1024, 1024]) | 17700 | 96000 | 80300 | 85500
shape torch.Size([32, 1024, 1024]) | 27740 | 150000 | 160700 | 171000
shape torch.Size([64, 1024, 1024]) | 45940 | 233400 | 321200 | 344100
shape torch.Size([1, 2048, 2048]) | 29860 | 579800 | 12920 | 13500
shape torch.Size([2, 2048, 2048]) | 34000 | 585000 | 25840 | 26840
shape torch.Size([4, 2048, 2048]) | 39770 | 593900 | 51670 | 54000
shape torch.Size([8, 2048, 2048]) | 51720 | 632100 | 103000 | 109000
shape torch.Size([16, 2048, 2048]) | 76900 | 845500 | 206600 | 218400
shape torch.Size([32, 2048, 2048]) | 130000 | 1058000 | 413900 | 437300
Times are in microseconds (us).
```
</details>
To generate the results below, I put the backend I wanted to test at the beginning of the function `lu_solve_kernel`, followed by a `return;`. Then I run the following script, changing the variable `name`.
<details>
<summary>
Benchmarking script
</summary>
```python
import torch
import pickle
import itertools
from functools import partial
from torch.utils.benchmark import Timer, Compare
benchmark_name = "linalg.lu_factor CUDA"
name = "magma_looped"
label = "lu_factor_{}".format(name)
shapes = [1, 2, 8, 16, 32, 64, 128, 256, 512, 1024, 2048]
batches = [(1,), (2,), (4,), (8,), (16,), (32,), (64,), (128,), (512,), (1024,)]
results = []
make_arg = partial(torch.randn, dtype=torch.float32, device="cuda")
for n, batch in itertools.product(shapes, batches):
if n == 1024 and batch[0] >= 128:
continue
if n == 2048 and batch[0] >= 64:
continue
A = make_arg(batch + (n, n))
print(A.shape)
stmt = "torch.linalg.lu_factor_ex(A)"
timer = Timer(stmt,
globals=globals(),
label=benchmark_name,
description=label,
sub_label=f"shape {A.shape}",
num_threads=1)
results.append(timer.blocked_autorange())
compare = Compare(results)
compare.trim_significant_figures()
compare.print()
with open(f"{label}.pickle", 'wb') as f:
pickle.dump(results, f)
```
</details>
See #72935 (comment) for the script to join the results.
[ghstack-poisoned]
|
@pytorchbot merge |
|
@pytorchbot successfully started a merge job. Check the current status here |
|
Merge failed due to Command Raised by https://github.com/pytorch/pytorch/actions/runs/2476746342 |
This PR adds getrf_cublas to the functions considered in the heuristics
for `lu_factor`. It also updates the heuristics of the function.
## Benchmark
I'm omitting form the benchmarks the looped versions of the functions as they are much slower than the non-looped ones. The only exception to this is cusolver's looped variant, which is faster when applied to a batch of size one.
<details>
<summary>
Benchmark Results
</summary>
```
[------------------------------------------------- linalg.lu_factor CUDA -------------------------------------------------]
| lu_factor_heuristic | lu_factor_magma_batched | lu_factor_cusolver_batched
1 threads: ----------------------------------------------------------------------------------------------------------------
shape torch.Size([1, 1, 1]) | 26 | 47 | 26
shape torch.Size([2, 1, 1]) | 17 | 38 | 17
shape torch.Size([4, 1, 1]) | 17 | 38 | 17
shape torch.Size([8, 1, 1]) | 20 | 38 | 18
shape torch.Size([16, 1, 1]) | 20 | 38 | 17
shape torch.Size([32, 1, 1]) | 18 | 38 | 17
shape torch.Size([64, 1, 1]) | 18 | 39 | 17
shape torch.Size([128, 1, 1]) | 17 | 38 | 17
shape torch.Size([512, 1, 1]) | 18 | 39 | 18
shape torch.Size([1024, 1, 1]) | 18 | 40 | 18
shape torch.Size([1, 2, 2]) | 18 | 38 | 17
shape torch.Size([2, 2, 2]) | 17 | 37 | 17
shape torch.Size([4, 2, 2]) | 17 | 38 | 17
shape torch.Size([8, 2, 2]) | 17 | 38 | 17
shape torch.Size([16, 2, 2]) | 17 | 38 | 17
shape torch.Size([32, 2, 2]) | 17 | 38 | 17
shape torch.Size([64, 2, 2]) | 17 | 38 | 17
shape torch.Size([128, 2, 2]) | 17 | 38 | 17
shape torch.Size([512, 2, 2]) | 17 | 39 | 17
shape torch.Size([1024, 2, 2]) | 17 | 40 | 17
shape torch.Size([1, 8, 8]) | 17 | 40 | 17
shape torch.Size([2, 8, 8]) | 17 | 40 | 17
shape torch.Size([4, 8, 8]) | 17 | 40 | 17
shape torch.Size([8, 8, 8]) | 17 | 40 | 17
shape torch.Size([16, 8, 8]) | 17 | 41 | 17
shape torch.Size([32, 8, 8]) | 17 | 40 | 17
shape torch.Size([64, 8, 8]) | 17 | 40 | 17
shape torch.Size([128, 8, 8]) | 17 | 40 | 17
shape torch.Size([512, 8, 8]) | 17 | 42 | 17
shape torch.Size([1024, 8, 8]) | 17 | 44 | 17
shape torch.Size([1, 16, 16]) | 24 | 44 | 18
shape torch.Size([2, 16, 16]) | 18 | 44 | 18
shape torch.Size([4, 16, 16]) | 18 | 45 | 18
shape torch.Size([8, 16, 16]) | 19 | 44 | 19
shape torch.Size([16, 16, 16]) | 20 | 44 | 20
shape torch.Size([32, 16, 16]) | 20 | 45 | 20
shape torch.Size([64, 16, 16]) | 20 | 44 | 20
shape torch.Size([128, 16, 16]) | 20 | 45 | 20
shape torch.Size([512, 16, 16]) | 28 | 50 | 28
shape torch.Size([1024, 16, 16]) | 41 | 59 | 41
shape torch.Size([1, 32, 32]) | 58 | 50 | 56
shape torch.Size([2, 32, 32]) | 56 | 50 | 56
shape torch.Size([4, 32, 32]) | 56 | 50 | 57
shape torch.Size([8, 32, 32]) | 60 | 50 | 60
shape torch.Size([16, 32, 32]) | 60 | 51 | 60
shape torch.Size([32, 32, 32]) | 247 | 51 | 61
shape torch.Size([64, 32, 32]) | 233 | 51 | 63
shape torch.Size([128, 32, 32]) | 236 | 53 | 66
shape torch.Size([512, 32, 32]) | 268 | 97 | 193
shape torch.Size([1024, 32, 32]) | 317 | 167 | 333
shape torch.Size([1, 64, 64]) | 131 | 216 | 99
shape torch.Size([2, 64, 64]) | 99 | 220 | 99
shape torch.Size([4, 64, 64]) | 99 | 225 | 101
shape torch.Size([8, 64, 64]) | 101 | 225 | 102
shape torch.Size([16, 64, 64]) | 107 | 230 | 108
shape torch.Size([32, 64, 64]) | 440 | 235 | 126
shape torch.Size([64, 64, 64]) | 447 | 240 | 155
shape torch.Size([128, 64, 64]) | 470 | 289 | 240
shape torch.Size([512, 64, 64]) | 793 | 678 | 1180
shape torch.Size([1024, 64, 64]) | 1000 | 1300 | 2112
shape torch.Size([1, 128, 128]) | 296 | 482 | 309
shape torch.Size([2, 128, 128]) | 308 | 499 | 307
shape torch.Size([4, 128, 128]) | 311 | 510 | 310
shape torch.Size([8, 128, 128]) | 314 | 522 | 314
shape torch.Size([16, 128, 128]) | 334 | 541 | 334
shape torch.Size([32, 128, 128]) | 770 | 591 | 467
shape torch.Size([64, 128, 128]) | 860 | 694 | 733
shape torch.Size([128, 128, 128]) | 1040 | 925 | 1980
shape torch.Size([512, 128, 128]) | 2883 | 2809 | 11000
shape torch.Size([1024, 128, 128]) | 5421 | 5430 | 22360
shape torch.Size([1, 256, 256]) | 1310 | 1109 | 1556
shape torch.Size([2, 256, 256]) | 1360 | 1150 | 1560
shape torch.Size([4, 256, 256]) | 1390 | 1188 | 1569
shape torch.Size([8, 256, 256]) | 1440 | 1250 | 1604
shape torch.Size([16, 256, 256]) | 1550 | 1390 | 1850
shape torch.Size([32, 256, 256]) | 1750 | 1620 | 3332
shape torch.Size([64, 256, 256]) | 2327 | 2246 | 6700
shape torch.Size([128, 256, 256]) | 3697 | 3638 | 19100
shape torch.Size([512, 256, 256]) | 12530 | 12500 | 87300
shape torch.Size([1024, 256, 256]) | 24380 | 24420 | 176000
```
</details>
<details>
<summary>
Benchmark Results all algorithms up to `n=2048`
</summary>
```
[----------------------------------------------------------------- linalg.lu_factor CUDA ------------------------------------------------------------------]
| lu_factor_magma_batched | lu_factor_cusolver_batched | lu_factor_cusolver_looped | lu_factor_magma_looped
1 threads: -------------------------------------------------------------------------------------------------------------------------------------------------
shape torch.Size([1, 1, 1]) | 51 | 30 | 27 | 1390
shape torch.Size([2, 1, 1]) | 42 | 20 | 26 | 2798
shape torch.Size([4, 1, 1]) | 42 | 20 | 42 | 5589
shape torch.Size([8, 1, 1]) | 42 | 20 | 72 | 11000
shape torch.Size([16, 1, 1]) | 42 | 20 | 132 | 22400
shape torch.Size([32, 1, 1]) | 42 | 20 | 253 | 44620
shape torch.Size([64, 1, 1]) | 42 | 20 | 496 | 89200
shape torch.Size([128, 1, 1]) | 42 | 20 | 980 | 180000
shape torch.Size([512, 1, 1]) | 43 | 20 | 3868 | 714100
shape torch.Size([1024, 1, 1]) | 44 | 20 | 7800 | 1430000
shape torch.Size([1, 2, 2]) | 43 | 21 | 19 | 1400
shape torch.Size([2, 2, 2]) | 42 | 21 | 27 | 2898
shape torch.Size([4, 2, 2]) | 43 | 21 | 42 | 5800
shape torch.Size([8, 2, 2]) | 43 | 21 | 73 | 11600
shape torch.Size([16, 2, 2]) | 43 | 21 | 133 | 23170
shape torch.Size([32, 2, 2]) | 43 | 21 | 254 | 46290
shape torch.Size([64, 2, 2]) | 43 | 21 | 500 | 94000
shape torch.Size([128, 2, 2]) | 43 | 21 | 980 | 190000
shape torch.Size([512, 2, 2]) | 44 | 21 | 3860 | 741900
shape torch.Size([1024, 2, 2]) | 44 | 21 | 7640 | 1484000
shape torch.Size([1, 8, 8]) | 45 | 21 | 19 | 1450
shape torch.Size([2, 8, 8]) | 45 | 21 | 27 | 2917
shape torch.Size([4, 8, 8]) | 45 | 21 | 53 | 5800
shape torch.Size([8, 8, 8]) | 45 | 21 | 105 | 11580
shape torch.Size([16, 8, 8]) | 45 | 21 | 207 | 23160
shape torch.Size([32, 8, 8]) | 46 | 21 | 413 | 46400
shape torch.Size([64, 8, 8]) | 46 | 21 | 824 | 93000
shape torch.Size([128, 8, 8]) | 46 | 21 | 1645 | 185000
shape torch.Size([512, 8, 8]) | 47 | 21 | 6574 | 742000
shape torch.Size([1024, 8, 8]) | 49 | 21 | 13150 | 1481000
shape torch.Size([1, 16, 16]) | 49 | 21 | 24 | 1460
shape torch.Size([2, 16, 16]) | 49 | 21 | 46 | 2902
shape torch.Size([4, 16, 16]) | 49 | 21 | 90 | 5800
shape torch.Size([8, 16, 16]) | 49 | 21 | 177 | 11600
shape torch.Size([16, 16, 16]) | 49 | 21 | 352 | 23150
shape torch.Size([32, 16, 16]) | 49 | 21 | 703 | 46300
shape torch.Size([64, 16, 16]) | 49 | 21 | 1404 | 92700
shape torch.Size([128, 16, 16]) | 50 | 21 | 2807 | 185000
shape torch.Size([512, 16, 16]) | 55 | 29 | 11220 | 741700
shape torch.Size([1024, 16, 16]) | 64 | 42 | 22440 | 1480000
shape torch.Size([1, 32, 32]) | 55 | 56 | 58 | 1460
shape torch.Size([2, 32, 32]) | 55 | 57 | 114 | 2920
shape torch.Size([4, 32, 32]) | 55 | 57 | 225 | 5830
shape torch.Size([8, 32, 32]) | 55 | 61 | 449 | 11700
shape torch.Size([16, 32, 32]) | 56 | 61 | 896 | 23300
shape torch.Size([32, 32, 32]) | 56 | 62 | 1791 | 46600
shape torch.Size([64, 32, 32]) | 56 | 63 | 3581 | 93100
shape torch.Size([128, 32, 32]) | 58 | 66 | 7156 | 186000
shape torch.Size([512, 32, 32]) | 100 | 194 | 28700 | 742400
shape torch.Size([1024, 32, 32]) | 169 | 335 | 57620 | 1485000
shape torch.Size([1, 64, 64]) | 224 | 101 | 132 | 1500
shape torch.Size([2, 64, 64]) | 227 | 100 | 262 | 2951
shape torch.Size([4, 64, 64]) | 229 | 101 | 523 | 5890
shape torch.Size([8, 64, 64]) | 231 | 102 | 1040 | 12000
shape torch.Size([16, 64, 64]) | 237 | 109 | 2088 | 23530
shape torch.Size([32, 64, 64]) | 242 | 127 | 4171 | 46900
shape torch.Size([64, 64, 64]) | 247 | 156 | 8330 | 95000
shape torch.Size([128, 64, 64]) | 293 | 244 | 16710 | 189000
shape torch.Size([512, 64, 64]) | 685 | 1180 | 67000 | 750900
shape torch.Size([1024, 64, 64]) | 1300 | 2076 | 134000 | 1505000
shape torch.Size([1, 128, 128]) | 490 | 309 | 298 | 1560
shape torch.Size([2, 128, 128]) | 503 | 309 | 594 | 3120
shape torch.Size([4, 128, 128]) | 515 | 312 | 1185 | 6230
shape torch.Size([8, 128, 128]) | 523 | 317 | 2370 | 12500
shape torch.Size([16, 128, 128]) | 547 | 336 | 4734 | 24890
shape torch.Size([32, 128, 128]) | 596 | 472 | 9491 | 49800
shape torch.Size([64, 128, 128]) | 700 | 741 | 19000 | 100000
shape torch.Size([128, 128, 128]) | 930 | 1770 | 37990 | 199000
shape torch.Size([512, 128, 128]) | 2810 | 11000 | 152000 | 797100
shape torch.Size([1024, 128, 128]) | 5430 | 22430 | 303900 | 1595000
shape torch.Size([1, 256, 256]) | 1120 | 1580 | 666 | 1890
shape torch.Size([2, 256, 256]) | 1160 | 1574 | 1330 | 3784
shape torch.Size([4, 256, 256]) | 1190 | 1580 | 2658 | 7570
shape torch.Size([8, 256, 256]) | 1250 | 1613 | 5325 | 15100
shape torch.Size([16, 256, 256]) | 1394 | 1880 | 10700 | 30260
shape torch.Size([32, 256, 256]) | 1633 | 3360 | 21300 | 61000
shape torch.Size([64, 256, 256]) | 2258 | 6730 | 42600 | 120000
shape torch.Size([128, 256, 256]) | 3639 | 19200 | 85170 | 242200
shape torch.Size([512, 256, 256]) | 12600 | 87200 | 340600 | 969000
shape torch.Size([1024, 256, 256]) | 24530 | 176000 | 681300 | 1943000
shape torch.Size([1, 512, 512]) | 2557 | 9117 | 1724 | 2577
shape torch.Size([2, 512, 512]) | 2691 | 9209 | 3464 | 5200
shape torch.Size([4, 512, 512]) | 2853 | 9860 | 6940 | 10000
shape torch.Size([8, 512, 512]) | 3153 | 11000 | 13900 | 20570
shape torch.Size([16, 512, 512]) | 3765 | 13000 | 27720 | 41360
shape torch.Size([32, 512, 512]) | 5500 | 21400 | 55420 | 82000
shape torch.Size([64, 512, 512]) | 8790 | 44000 | 111000 | 165000
shape torch.Size([128, 512, 512]) | 15300 | 98000 | 221700 | 329800
shape torch.Size([512, 512, 512]) | 55400 | 424100 | 886600 | 1325000
shape torch.Size([1024, 512, 512]) | 110000 | 856200 | 1773000 | 2691000
shape torch.Size([1, 1024, 1024]) | 10350 | 69290 | 5020 | 5327
shape torch.Size([2, 1024, 1024]) | 11200 | 74860 | 10040 | 11000
shape torch.Size([4, 1024, 1024]) | 12200 | 78030 | 20080 | 21290
shape torch.Size([8, 1024, 1024]) | 14000 | 81200 | 40160 | 42850
shape torch.Size([16, 1024, 1024]) | 17700 | 96000 | 80300 | 85500
shape torch.Size([32, 1024, 1024]) | 27740 | 150000 | 160700 | 171000
shape torch.Size([64, 1024, 1024]) | 45940 | 233400 | 321200 | 344100
shape torch.Size([1, 2048, 2048]) | 29860 | 579800 | 12920 | 13500
shape torch.Size([2, 2048, 2048]) | 34000 | 585000 | 25840 | 26840
shape torch.Size([4, 2048, 2048]) | 39770 | 593900 | 51670 | 54000
shape torch.Size([8, 2048, 2048]) | 51720 | 632100 | 103000 | 109000
shape torch.Size([16, 2048, 2048]) | 76900 | 845500 | 206600 | 218400
shape torch.Size([32, 2048, 2048]) | 130000 | 1058000 | 413900 | 437300
Times are in microseconds (us).
```
</details>
To generate the results below, I put the backend I wanted to test at the beginning of the function `lu_solve_kernel`, followed by a `return;`. Then I run the following script, changing the variable `name`.
<details>
<summary>
Benchmarking script
</summary>
```python
import torch
import pickle
import itertools
from functools import partial
from torch.utils.benchmark import Timer, Compare
benchmark_name = "linalg.lu_factor CUDA"
name = "magma_looped"
label = "lu_factor_{}".format(name)
shapes = [1, 2, 8, 16, 32, 64, 128, 256, 512, 1024, 2048]
batches = [(1,), (2,), (4,), (8,), (16,), (32,), (64,), (128,), (512,), (1024,)]
results = []
make_arg = partial(torch.randn, dtype=torch.float32, device="cuda")
for n, batch in itertools.product(shapes, batches):
if n == 1024 and batch[0] >= 128:
continue
if n == 2048 and batch[0] >= 64:
continue
A = make_arg(batch + (n, n))
print(A.shape)
stmt = "torch.linalg.lu_factor_ex(A)"
timer = Timer(stmt,
globals=globals(),
label=benchmark_name,
description=label,
sub_label=f"shape {A.shape}",
num_threads=1)
results.append(timer.blocked_autorange())
compare = Compare(results)
compare.trim_significant_figures()
compare.print()
with open(f"{label}.pickle", 'wb') as f:
pickle.dump(results, f)
```
</details>
See #72935 (comment) for the script to join the results.
[ghstack-poisoned]
|
@pytorchbot merge |
|
@pytorchbot successfully started a merge job. Check the current status here |
|
Hey @lezcano. |
Summary: This PR adds getrf_cublas to the functions considered in the heuristics for lu_solve. Pull Request resolved: #73878 Approved by: https://github.com/nikitaved, https://github.com/IvanYashchuk, https://github.com/mruberry Test Plan: contbuild & OSS CI, see https://hud.pytorch.org/commit/pytorch/pytorch/9fc2518a8a24e07fe3a9c049ece6dc5bfc6d8641 Reviewed By: osalpekar Differential Revision: D37089127 Pulled By: osalpekar fbshipit-source-id: b9346418f87bd0c41c325f992e0f5675a85828fc
This PR adds getrf_cublas to the functions considered in the heuristics for lu_solve. Pull Request resolved: pytorch#73878 Approved by: https://github.com/nikitaved, https://github.com/IvanYashchuk, https://github.com/mruberry
Stack from ghstack:
This PR adds getrf_cublas to the functions considered in the heuristics
for
lu_factor. It also updates the heuristics of the function.Benchmark
I'm omitting form the benchmarks the looped versions of the functions as they are much slower than the non-looped ones. The only exception to this is cusolver's looped variant, which is faster when applied to a batch of size one.
Benchmark Results
Benchmark Results all algorithms up to `n=2048`
Benchmarking script
See #72935 (comment) for the script to join the results.