Add linalg.lu_solve#72935
Conversation
This PR adds `linalg.lu_solve`. While doing so, I found a bug in MAGMA when calling the batched MAGMA backend with trans=True. We work around that by solving the system solving two triangular systems. We also update the heuristics for this function, as they were fairly updated. We found that cuSolver is king, so luckily we do not need to rely on the buggy backend from magma for this function. We added tests testing this function left and right. We also added tests for the different backends. We also activated the tests for AMD, as those should work as well. Fixes #61657 [ghstack-poisoned]
CI Flow Status⚛️ CI FlowRuleset - Version:
|
🔗 Helpful links
❌ 6 New FailuresAs of commit f927ccc (more details on the Dr. CI page): Expand to see more
🕵️ 6 new failures recognized by patternsThe following CI failures do not appear to be due to upstream breakages
|
This PR adds `linalg.lu_solve`. While doing so, I found a bug in MAGMA when calling the batched MAGMA backend with trans=True. We work around that by solving the system solving two triangular systems. We also update the heuristics for this function, as they were fairly updated. We found that cuSolver is king, so luckily we do not need to rely on the buggy backend from magma for this function. We added tests testing this function left and right. We also added tests for the different backends. We also activated the tests for AMD, as those should work as well. Fixes #61657 ghstack-source-id: 86a91f0 Pull Request resolved: #72935
This PR adds `linalg.lu_solve`. While doing so, I found a bug in MAGMA
when calling the batched MAGMA backend with trans=True. We work around
that by solving the system solving two triangular systems.
We also update the heuristics for this function, as they were fairly
outdated. We found that cuSolver is king, so luckily we do not need to
rely on the buggy backend from magma for this function.
We added tests testing this function left and right. We also added tests
for the different backends. We also activated the tests for AMD, as
those should work as well.
### Benchmarking
<details>
<summary>
Benchmark Results (adjoint=False)
</summary>
```
--------------------------------------------------------------------------------------------- linalg.lu_solve CUDA ---------------------------------------------------------------------------------------------]
| lu_solve looped_magma | lu_solve looped cusolver | lu_solve batched cublas | lu_solve batched magma | lu_solve unpack+solve_triangular | lu_solve heuristic
1 threads: -------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
shape torch.Size([1, 1, 1]) | 750 | 34 | 28 | 252 | 86 | 27
shape torch.Size([2, 1, 1]) | 1500 | 50 | 28 | 239 | 94 | 27
shape torch.Size([4, 1, 1]) | 2995 | 83 | 28 | 239 | 100 | 27
shape torch.Size([8, 1, 1]) | 6000 | 146 | 28 | 239 | 94 | 27
shape torch.Size([16, 1, 1]) | 11900 | 272 | 28 | 241 | 95 | 27
shape torch.Size([32, 1, 1]) | 23880 | 524 | 28 | 244 | 94 | 27
shape torch.Size([64, 1, 1]) | 48000 | 1000 | 28 | 245 | 99 | 27
shape torch.Size([128, 1, 1]) | 96000 | 2054 | 28 | 242 | 96 | 27
shape torch.Size([512, 1, 1]) | 381900 | 8100 | 28 | 250 | 94 | 27
shape torch.Size([1024, 1, 1]) | 763800 | 16200 | 28 | 257 | 95 | 27
shape torch.Size([1, 2, 2]) | 750 | 33 | 28 | 240 | 88 | 27
shape torch.Size([2, 2, 2]) | 1500 | 51 | 28 | 240 | 96 | 27
shape torch.Size([4, 2, 2]) | 2991 | 82 | 28 | 241 | 96 | 28
shape torch.Size([8, 2, 2]) | 6000 | 150 | 28 | 241 | 96 | 27
shape torch.Size([16, 2, 2]) | 12000 | 275 | 28 | 242 | 96 | 27
shape torch.Size([32, 2, 2]) | 23980 | 530 | 28 | 246 | 97 | 28
shape torch.Size([64, 2, 2]) | 48000 | 1000 | 28 | 244 | 96 | 27
shape torch.Size([128, 2, 2]) | 96000 | 2063 | 28 | 245 | 96 | 28
shape torch.Size([512, 2, 2]) | 382000 | 8300 | 28 | 257 | 97 | 28
shape torch.Size([1024, 2, 2]) | 764000 | 20000 | 28 | 271 | 97 | 28
shape torch.Size([1, 8, 8]) | 749 | 34 | 28 | 243 | 88 | 28
shape torch.Size([2, 8, 8]) | 1500 | 50 | 28 | 244 | 97 | 28
shape torch.Size([4, 8, 8]) | 2988 | 83 | 28 | 244 | 100 | 28
shape torch.Size([8, 8, 8]) | 5980 | 150 | 28 | 245 | 97 | 28
shape torch.Size([16, 8, 8]) | 12000 | 278 | 28 | 246 | 96 | 28
shape torch.Size([32, 8, 8]) | 23910 | 536 | 28 | 249 | 98 | 28
shape torch.Size([64, 8, 8]) | 47800 | 1100 | 28 | 247 | 96 | 28
shape torch.Size([128, 8, 8]) | 96000 | 2075 | 28 | 248 | 96 | 28
shape torch.Size([512, 8, 8]) | 382100 | 8300 | 28 | 270 | 97 | 28
shape torch.Size([1024, 8, 8]) | 764100 | 16400 | 28 | 291 | 100 | 28
shape torch.Size([1, 16, 16]) | 750 | 33 | 28 | 248 | 88 | 28
shape torch.Size([2, 16, 16]) | 1500 | 50 | 28 | 250 | 97 | 28
shape torch.Size([4, 16, 16]) | 2996 | 83 | 28 | 250 | 100 | 28
shape torch.Size([8, 16, 16]) | 5980 | 147 | 28 | 251 | 97 | 28
shape torch.Size([16, 16, 16]) | 11900 | 274 | 28 | 251 | 97 | 28
shape torch.Size([32, 16, 16]) | 24040 | 527 | 28 | 252 | 97 | 28
shape torch.Size([64, 16, 16]) | 47800 | 1037 | 28 | 251 | 100 | 28
shape torch.Size([128, 16, 16]) | 95600 | 2044 | 28 | 252 | 98 | 28
shape torch.Size([512, 16, 16]) | 388200 | 8100 | 28 | 280 | 100 | 28
shape torch.Size([1024, 16, 16]) | 769700 | 16000 | 28 | 322 | 117 | 28
shape torch.Size([1, 32, 32]) | 760 | 33 | 28 | 255 | 88 | 28
shape torch.Size([2, 32, 32]) | 1510 | 50 | 28 | 256 | 97 | 28
shape torch.Size([4, 32, 32]) | 3022 | 82 | 31 | 256 | 97 | 30
shape torch.Size([8, 32, 32]) | 6000 | 140 | 31 | 257 | 100 | 31
shape torch.Size([16, 32, 32]) | 12000 | 281 | 31 | 258 | 96 | 31
shape torch.Size([32, 32, 32]) | 24150 | 563 | 35 | 258 | 96 | 35
shape torch.Size([64, 32, 32]) | 48300 | 1119 | 36 | 258 | 97 | 36
shape torch.Size([128, 32, 32]) | 96500 | 2235 | 43 | 261 | 96 | 43
shape torch.Size([512, 32, 32]) | 383100 | 8930 | 82 | 317 | 191 | 82
shape torch.Size([1024, 32, 32]) | 766300 | 19200 | 122 | 400 | 312 | 122
shape torch.Size([1, 64, 64]) | 760 | 33 | 55 | 272 | 68 | 34
shape torch.Size([2, 64, 64]) | 1500 | 52 | 58 | 273 | 85 | 52
shape torch.Size([4, 64, 64]) | 3127 | 102 | 65 | 273 | 150 | 65
shape torch.Size([8, 64, 64]) | 6070 | 201 | 65 | 275 | 278 | 65
shape torch.Size([16, 64, 64]) | 12000 | 399 | 66 | 274 | 95 | 67
shape torch.Size([32, 64, 64]) | 23900 | 796 | 73 | 275 | 97 | 73
shape torch.Size([64, 64, 64]) | 48000 | 1594 | 75 | 283 | 123 | 76
shape torch.Size([128, 64, 64]) | 95000 | 3177 | 96 | 292 | 176 | 96
shape torch.Size([512, 64, 64]) | 379300 | 13520 | 208 | 426 | 551 | 208
shape torch.Size([1024, 64, 64]) | 758700 | 27100 | 306 | 570 | 919 | 306
shape torch.Size([1, 128, 128]) | 750 | 42 | 115 | 306 | 90 | 42
shape torch.Size([2, 128, 128]) | 1500 | 82 | 122 | 307 | 164 | 83
shape torch.Size([4, 128, 128]) | 2966 | 162 | 136 | 307 | 301 | 136
shape torch.Size([8, 128, 128]) | 5930 | 317 | 137 | 308 | 578 | 138
shape torch.Size([16, 128, 128]) | 12000 | 635 | 143 | 316 | 199 | 143
shape torch.Size([32, 128, 128]) | 23700 | 1266 | 152 | 322 | 241 | 152
shape torch.Size([64, 128, 128]) | 48000 | 2668 | 177 | 337 | 322 | 177
shape torch.Size([128, 128, 128]) | 96000 | 5366 | 228 | 365 | 514 | 228
shape torch.Size([512, 128, 128]) | 379400 | 21490 | 502 | 620 | 1697 | 502
shape torch.Size([1024, 128, 128]) | 755700 | 43040 | 764 | 903 | 3040 | 770
shape torch.Size([1, 256, 256]) | 750 | 70 | 235 | 383 | 178 | 72
shape torch.Size([2, 256, 256]) | 2000 | 138 | 250 | 384 | 329 | 139
shape torch.Size([4, 256, 256]) | 2988 | 277 | 279 | 404 | 655 | 278
shape torch.Size([8, 256, 256]) | 6100 | 546 | 283 | 420 | 1321 | 286
shape torch.Size([16, 256, 256]) | 12100 | 1149 | 330 | 441 | 472 | 330
shape torch.Size([32, 256, 256]) | 24040 | 2303 | 359 | 453 | 634 | 360
shape torch.Size([64, 256, 256]) | 48000 | 4626 | 408 | 472 | 925 | 408
shape torch.Size([128, 256, 256]) | 94700 | 9247 | 543 | 543 | 1582 | 543
shape torch.Size([512, 256, 256]) | 372000 | 37030 | 1310 | 1185 | 5711 | 1310
shape torch.Size([1024, 256, 256]) | 747200 | 74100 | 2116 | 1910 | 10660 | 2122
```
</details>
<details>
<summary>
Benchmark Results (adjoint=True)
</summary>
```
[----------------------------------------------------------------------------------------- linalg.lu_solve CUDA Adjoint -----------------------------------------------------------------------------------------]
| lu_solve looped_magma | lu_solve looped cusolver | lu_solve batched cublas | lu_solve batched magma | lu_solve unpack+solve_triangular | lu_solve heuristic
1 threads: -------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
shape torch.Size([1, 1, 1]) | 749 | 34 | 28 | 33 | 98 | 27
shape torch.Size([2, 1, 1]) | 1500 | 50 | 28 | 50 | 110 | 27
shape torch.Size([4, 1, 1]) | 3005 | 82 | 28 | 81 | 110 | 27
shape torch.Size([8, 1, 1]) | 5999 | 145 | 28 | 140 | 110 | 27
shape torch.Size([16, 1, 1]) | 12000 | 273 | 28 | 77 | 110 | 27
shape torch.Size([32, 1, 1]) | 24000 | 522 | 28 | 78 | 110 | 27
shape torch.Size([64, 1, 1]) | 48000 | 1000 | 28 | 77 | 100 | 27
shape torch.Size([128, 1, 1]) | 96000 | 2029 | 28 | 78 | 110 | 27
shape torch.Size([512, 1, 1]) | 383300 | 8100 | 28 | 78 | 110 | 28
shape torch.Size([1024, 1, 1]) | 767500 | 16100 | 28 | 77 | 100 | 27
shape torch.Size([1, 2, 2]) | 753 | 33 | 28 | 33 | 99 | 28
shape torch.Size([2, 2, 2]) | 1500 | 50 | 28 | 50 | 110 | 28
shape torch.Size([4, 2, 2]) | 3002 | 82 | 28 | 80 | 100 | 27
shape torch.Size([8, 2, 2]) | 6000 | 145 | 28 | 144 | 107 | 27
shape torch.Size([16, 2, 2]) | 12000 | 271 | 28 | 78 | 110 | 27
shape torch.Size([32, 2, 2]) | 24120 | 524 | 28 | 78 | 110 | 28
shape torch.Size([64, 2, 2]) | 48300 | 1030 | 28 | 78 | 111 | 27
shape torch.Size([128, 2, 2]) | 96100 | 2041 | 28 | 78 | 107 | 28
shape torch.Size([512, 2, 2]) | 383000 | 8100 | 28 | 79 | 108 | 28
shape torch.Size([1024, 2, 2]) | 766100 | 16000 | 28 | 78 | 110 | 28
shape torch.Size([1, 8, 8]) | 750 | 34 | 28 | 34 | 99 | 28
shape torch.Size([2, 8, 8]) | 1500 | 50 | 28 | 50 | 107 | 28
shape torch.Size([4, 8, 8]) | 2998 | 82 | 28 | 82 | 110 | 28
shape torch.Size([8, 8, 8]) | 5990 | 146 | 28 | 150 | 107 | 28
shape torch.Size([16, 8, 8]) | 11980 | 272 | 28 | 79 | 107 | 28
shape torch.Size([32, 8, 8]) | 23970 | 530 | 28 | 79 | 110 | 28
shape torch.Size([64, 8, 8]) | 47900 | 1040 | 28 | 79 | 108 | 28
shape torch.Size([128, 8, 8]) | 96000 | 2048 | 28 | 78 | 108 | 28
shape torch.Size([512, 8, 8]) | 383700 | 8100 | 28 | 80 | 108 | 28
shape torch.Size([1024, 8, 8]) | 766200 | 16300 | 28 | 80 | 108 | 28
shape torch.Size([1, 16, 16]) | 760 | 33 | 28 | 34 | 99 | 28
shape torch.Size([2, 16, 16]) | 1500 | 50 | 28 | 50 | 110 | 28 [85/469]
shape torch.Size([4, 16, 16]) | 3001 | 81 | 28 | 82 | 108 | 28
shape torch.Size([8, 16, 16]) | 6000 | 145 | 28 | 140 | 110 | 28
shape torch.Size([16, 16, 16]) | 12000 | 276 | 28 | 79 | 110 | 28
shape torch.Size([32, 16, 16]) | 23870 | 549 | 28 | 79 | 110 | 28
shape torch.Size([64, 16, 16]) | 47900 | 1098 | 29 | 80 | 100 | 28
shape torch.Size([128, 16, 16]) | 95800 | 2184 | 28 | 79 | 108 | 28
shape torch.Size([512, 16, 16]) | 386900 | 8769 | 28 | 80 | 108 | 28
shape torch.Size([1024, 16, 16]) | 769800 | 17460 | 37 | 80 | 107 | 37
shape torch.Size([1, 32, 32]) | 760 | 33 | 28 | 34 | 99 | 28
shape torch.Size([2, 32, 32]) | 1500 | 50 | 28 | 50 | 110 | 29
shape torch.Size([4, 32, 32]) | 3021 | 86 | 31 | 84 | 110 | 32
shape torch.Size([8, 32, 32]) | 6040 | 167 | 32 | 167 | 108 | 32
shape torch.Size([16, 32, 32]) | 12100 | 330 | 33 | 78 | 107 | 33
shape torch.Size([32, 32, 32]) | 24150 | 662 | 35 | 78 | 110 | 35
shape torch.Size([64, 32, 32]) | 48200 | 1323 | 36 | 79 | 110 | 36
shape torch.Size([128, 32, 32]) | 97000 | 2637 | 44 | 79 | 110 | 43
shape torch.Size([512, 32, 32]) | 382500 | 10580 | 83 | 180 | 198 | 83
shape torch.Size([1024, 32, 32]) | 766600 | 22670 | 123 | 260 | 318 | 120
shape torch.Size([1, 64, 64]) | 760 | 33 | 58 | 33 | 88 | 34
shape torch.Size([2, 64, 64]) | 1520 | 60 | 60 | 60 | 109 | 59
shape torch.Size([4, 64, 64]) | 3016 | 115 | 67 | 119 | 132 | 66
shape torch.Size([8, 64, 64]) | 6120 | 230 | 67 | 233 | 180 | 68
shape torch.Size([16, 64, 64]) | 12100 | 457 | 69 | 86 | 107 | 69
shape torch.Size([32, 64, 64]) | 24000 | 912 | 74 | 95 | 100 | 74
shape torch.Size([64, 64, 64]) | 48000 | 1833 | 76 | 106 | 123 | 76
shape torch.Size([128, 64, 64]) | 95000 | 3636 | 97 | 163 | 183 | 97
shape torch.Size([512, 64, 64]) | 380600 | 15600 | 210 | 464 | 549 | 210
shape torch.Size([1024, 64, 64]) | 761200 | 31140 | 308 | 741 | 918 | 308
shape torch.Size([1, 128, 128]) | 756 | 46 | 120 | 47 | 89 | 46
shape torch.Size([2, 128, 128]) | 1500 | 91 | 123 | 89 | 110 | 92
shape torch.Size([4, 128, 128]) | 2994 | 178 | 139 | 180 | 131 | 139
shape torch.Size([8, 128, 128]) | 5960 | 350 | 140 | 354 | 208 | 142
shape torch.Size([16, 128, 128]) | 12000 | 701 | 144 | 177 | 198 | 143
shape torch.Size([32, 128, 128]) | 23870 | 1401 | 155 | 225 | 246 | 155
shape torch.Size([64, 128, 128]) | 47600 | 2948 | 179 | 288 | 323 | 180
shape torch.Size([128, 128, 128]) | 96000 | 5910 | 231 | 442 | 512 | 231
shape torch.Size([512, 128, 128]) | 381200 | 23640 | 519 | 1400 | 1700 | 519
shape torch.Size([1024, 128, 128]) | 755800 | 47340 | 794 | 2436 | 3018 | 794
shape torch.Size([1, 256, 256]) | 760 | 74 | 246 | 77 | 88 | 78
shape torch.Size([2, 256, 256]) | 1510 | 150 | 256 | 150 | 117 | 150
shape torch.Size([4, 256, 256]) | 3030 | 296 | 284 | 296 | 209 | 284
shape torch.Size([8, 256, 256]) | 6100 | 588 | 286 | 592 | 394 | 288
shape torch.Size([16, 256, 256]) | 12200 | 1238 | 330 | 445 | 480 | 330
shape torch.Size([32, 256, 256]) | 24430 | 2476 | 368 | 568 | 629 | 367
shape torch.Size([64, 256, 256]) | 49000 | 4950 | 415 | 800 | 921 | 414
shape torch.Size([128, 256, 256]) | 96000 | 9900 | 552 | 1330 | 1579 | 553
shape torch.Size([512, 256, 256]) | 369400 | 39580 | 1410 | 4614 | 5616 | 1410
shape torch.Size([1024, 256, 256]) | 716200 | 79200 | 2270 | 8472 | 10500 | 2277
Times are in microseconds (us).
```
</details>
To generate the results below, I put the backend I wanted to test at the beginning of the function `lu_solve_kernel`, followed by a `break;`. Then I run the following script, changing the variable `name`. For the `lu_solve unpack+solve_triangular`, I also changed the `stmt` variable (uncomenting the commented one)
<details>
<summary>
Benchmarking script
</summary>
```python
import torch
import pickle
import itertools
from functools import partial
from torch.utils.benchmark import Timer, Compare
name = "heuristic"
label = "lu_solve {}".format(name)
shapes = [1, 2, 8, 16, 32, 64, 128, 256]
batches = [(1,), (2,), (4,), (8,), (16,), (32,), (64,), (128,), (512,), (1024,)]
results = []
make_arg = partial(torch.randn, dtype=torch.float32, device="cuda")
def f(LU, pivots, B, adjoint):
P, L, U = torch.lu_unpack(LU, pivots)
if adjoint:
X = torch.linalg.solve_triangular(U.mH, B, upper=False)
return P @ torch.linalg.solve_triangular(L.mH, X, upper=True, unitriangular=True, out=X)
else:
X = P.mT @ B
X = torch.linalg.solve_triangular(L, X, upper=False, unitriangular=True, out=X)
return torch.linalg.solve_triangular(U, X, upper=True, out=X)
for n, batch in itertools.product(shapes, batches):
LU, pivots = torch.linalg.lu_factor(make_arg(batch + (n, n)))
B = make_arg(batch + (n, 1))
print(LU.shape)
stmt = "torch.linalg.lu_solve(LU, pivots, B, adjoint=adjoint)"
#stmt = "f(LU, pivots, B, adjoint=adjoint)"
for adjoint in (True, False):
timer = Timer(stmt,
globals=globals(),
label="linalg.lu_solve CUDA{}".format(" Adjoint" if adjoint else ""),
description=label,
sub_label=f"shape {LU.shape}",
num_threads=1)
results.append(timer.blocked_autorange())
compare = Compare(results)
compare.trim_significant_figures()
compare.print()
with open("{}_lu_solve.pickle".format(name), 'wb') as f:
pickle.dump(results, f)
```
</details>
Finally, I joined all the results with the following script:
<details>
<summary>
Script to join the results
</summary>
```python
import pickle
from torch.utils.benchmark import Timer, Compare
files = [
"looped_magma",
"looped cusolver",
"batched cublas",
"batched magma",
"unpack+solve_triangular",
"heuristic",
]
timers = []
for name in files:
with open("{}_lu_solve.pickle".format(name), 'rb') as f:
timers += pickle.load(f)
compare = Compare(timers)
compare.trim_significant_figures()
compare.print()
```
</details>
### Fix for Magma's batched lu_solve when `adjoint=True`
I also developed the following fix around MAGMA's bug, but I ended up not using it, and preferring the triangular solves over it, as they were faster. I'm leaving it here in case it's useful in the future.
<details>
<summary>
Fix for MAGMA's issue with `adjoint=True`
</summary>
```cpp
auto lu_solve_batched_magma_fn = [m](const Tensor& LU, const Tensor& pivots, const Tensor& B, TransposeType trans) {
if (trans == TransposeType::NoTranspose) {
lu_solve_batched_magma(LU, pivots, B, trans);
return;
}
// There's a bug in magma for the other cases, so we need to properly perform mT or mH on LU
// The LU of the transpose is not the transpose of the LU
// We need to do LU = LDU' = L'U' where L' = LD, U' = D^{-1}U and D = diag(U)
auto diag = LU.diagonal(0, -2, -1);
auto LU_f = LU.tril(-1).mul_(diag.unsqueeze(-2)) +
LU.triu(1).div_(diag.unsqueeze(-1));
LU_f.diagonal(0, -2, -1).copy_(diag);
if (trans == TransposeType::ConjTranspose) {
LU_f = LU_f.conj_physical();
}
LU_f.transpose(-2, -1);
// At this point LU_f is F-contiguous, because triu / tril / conj_phisical return contiguous tensors
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(LU_f.mT().is_contiguous());
// Trivial permutation
auto pivots_aux = at::arange(1, m + 1, pivots.options()).expand_as(pivots).contiguous();
lu_solve_batched_magma(LU_f, pivots_aux, B, TransposeType::NoTranspose);
// We then need to multiply B by P on the right as (PLU)^T = B iff U^TL^T = BP
// Fill `perm` with the identity permutation (perhaps batched)
// This is faster than torch.lu_unpack + matmul, as this logic is borrowed from lu_unpack
const auto perm = at::arange(m, pivots.options().dtype(kLong)).expand(pivots.sizes()).contiguous();
auto iter = TensorIteratorConfig()
.set_check_mem_overlap(false)
.check_all_same_dtype(false)
.resize_outputs(false)
.declare_static_shape(pivots.sizes(), /*squash_dim=*/pivots.dim() - 1)
.add_output(perm)
.add_input(pivots)
.build();
unpack_pivots_stub(pivots.device().type(), iter, m);
B.scatter_(-2, perm.unsqueeze(-1).expand_as(B), B.clone());
};
```
</details>
Fixes #61657
[ghstack-poisoned]
This PR adds `linalg.lu_solve`. While doing so, I found a bug in MAGMA
when calling the batched MAGMA backend with trans=True. We work around
that by solving the system solving two triangular systems.
We also update the heuristics for this function, as they were fairly
outdated. We found that cuSolver is king, so luckily we do not need to
rely on the buggy backend from magma for this function.
We added tests testing this function left and right. We also added tests
for the different backends. We also activated the tests for AMD, as
those should work as well.
### Benchmarking
<details>
<summary>
Benchmark Results (adjoint=False)
</summary>
```
--------------------------------------------------------------------------------------------- linalg.lu_solve CUDA ---------------------------------------------------------------------------------------------]
| lu_solve looped_magma | lu_solve looped cusolver | lu_solve batched cublas | lu_solve batched magma | lu_solve unpack+solve_triangular | lu_solve heuristic
1 threads: -------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
shape torch.Size([1, 1, 1]) | 750 | 34 | 28 | 252 | 86 | 27
shape torch.Size([2, 1, 1]) | 1500 | 50 | 28 | 239 | 94 | 27
shape torch.Size([4, 1, 1]) | 2995 | 83 | 28 | 239 | 100 | 27
shape torch.Size([8, 1, 1]) | 6000 | 146 | 28 | 239 | 94 | 27
shape torch.Size([16, 1, 1]) | 11900 | 272 | 28 | 241 | 95 | 27
shape torch.Size([32, 1, 1]) | 23880 | 524 | 28 | 244 | 94 | 27
shape torch.Size([64, 1, 1]) | 48000 | 1000 | 28 | 245 | 99 | 27
shape torch.Size([128, 1, 1]) | 96000 | 2054 | 28 | 242 | 96 | 27
shape torch.Size([512, 1, 1]) | 381900 | 8100 | 28 | 250 | 94 | 27
shape torch.Size([1024, 1, 1]) | 763800 | 16200 | 28 | 257 | 95 | 27
shape torch.Size([1, 2, 2]) | 750 | 33 | 28 | 240 | 88 | 27
shape torch.Size([2, 2, 2]) | 1500 | 51 | 28 | 240 | 96 | 27
shape torch.Size([4, 2, 2]) | 2991 | 82 | 28 | 241 | 96 | 28
shape torch.Size([8, 2, 2]) | 6000 | 150 | 28 | 241 | 96 | 27
shape torch.Size([16, 2, 2]) | 12000 | 275 | 28 | 242 | 96 | 27
shape torch.Size([32, 2, 2]) | 23980 | 530 | 28 | 246 | 97 | 28
shape torch.Size([64, 2, 2]) | 48000 | 1000 | 28 | 244 | 96 | 27
shape torch.Size([128, 2, 2]) | 96000 | 2063 | 28 | 245 | 96 | 28
shape torch.Size([512, 2, 2]) | 382000 | 8300 | 28 | 257 | 97 | 28
shape torch.Size([1024, 2, 2]) | 764000 | 20000 | 28 | 271 | 97 | 28
shape torch.Size([1, 8, 8]) | 749 | 34 | 28 | 243 | 88 | 28
shape torch.Size([2, 8, 8]) | 1500 | 50 | 28 | 244 | 97 | 28
shape torch.Size([4, 8, 8]) | 2988 | 83 | 28 | 244 | 100 | 28
shape torch.Size([8, 8, 8]) | 5980 | 150 | 28 | 245 | 97 | 28
shape torch.Size([16, 8, 8]) | 12000 | 278 | 28 | 246 | 96 | 28
shape torch.Size([32, 8, 8]) | 23910 | 536 | 28 | 249 | 98 | 28
shape torch.Size([64, 8, 8]) | 47800 | 1100 | 28 | 247 | 96 | 28
shape torch.Size([128, 8, 8]) | 96000 | 2075 | 28 | 248 | 96 | 28
shape torch.Size([512, 8, 8]) | 382100 | 8300 | 28 | 270 | 97 | 28
shape torch.Size([1024, 8, 8]) | 764100 | 16400 | 28 | 291 | 100 | 28
shape torch.Size([1, 16, 16]) | 750 | 33 | 28 | 248 | 88 | 28
shape torch.Size([2, 16, 16]) | 1500 | 50 | 28 | 250 | 97 | 28
shape torch.Size([4, 16, 16]) | 2996 | 83 | 28 | 250 | 100 | 28
shape torch.Size([8, 16, 16]) | 5980 | 147 | 28 | 251 | 97 | 28
shape torch.Size([16, 16, 16]) | 11900 | 274 | 28 | 251 | 97 | 28
shape torch.Size([32, 16, 16]) | 24040 | 527 | 28 | 252 | 97 | 28
shape torch.Size([64, 16, 16]) | 47800 | 1037 | 28 | 251 | 100 | 28
shape torch.Size([128, 16, 16]) | 95600 | 2044 | 28 | 252 | 98 | 28
shape torch.Size([512, 16, 16]) | 388200 | 8100 | 28 | 280 | 100 | 28
shape torch.Size([1024, 16, 16]) | 769700 | 16000 | 28 | 322 | 117 | 28
shape torch.Size([1, 32, 32]) | 760 | 33 | 28 | 255 | 88 | 28
shape torch.Size([2, 32, 32]) | 1510 | 50 | 28 | 256 | 97 | 28
shape torch.Size([4, 32, 32]) | 3022 | 82 | 31 | 256 | 97 | 30
shape torch.Size([8, 32, 32]) | 6000 | 140 | 31 | 257 | 100 | 31
shape torch.Size([16, 32, 32]) | 12000 | 281 | 31 | 258 | 96 | 31
shape torch.Size([32, 32, 32]) | 24150 | 563 | 35 | 258 | 96 | 35
shape torch.Size([64, 32, 32]) | 48300 | 1119 | 36 | 258 | 97 | 36
shape torch.Size([128, 32, 32]) | 96500 | 2235 | 43 | 261 | 96 | 43
shape torch.Size([512, 32, 32]) | 383100 | 8930 | 82 | 317 | 191 | 82
shape torch.Size([1024, 32, 32]) | 766300 | 19200 | 122 | 400 | 312 | 122
shape torch.Size([1, 64, 64]) | 760 | 33 | 55 | 272 | 68 | 34
shape torch.Size([2, 64, 64]) | 1500 | 52 | 58 | 273 | 85 | 52
shape torch.Size([4, 64, 64]) | 3127 | 102 | 65 | 273 | 150 | 65
shape torch.Size([8, 64, 64]) | 6070 | 201 | 65 | 275 | 278 | 65
shape torch.Size([16, 64, 64]) | 12000 | 399 | 66 | 274 | 95 | 67
shape torch.Size([32, 64, 64]) | 23900 | 796 | 73 | 275 | 97 | 73
shape torch.Size([64, 64, 64]) | 48000 | 1594 | 75 | 283 | 123 | 76
shape torch.Size([128, 64, 64]) | 95000 | 3177 | 96 | 292 | 176 | 96
shape torch.Size([512, 64, 64]) | 379300 | 13520 | 208 | 426 | 551 | 208
shape torch.Size([1024, 64, 64]) | 758700 | 27100 | 306 | 570 | 919 | 306
shape torch.Size([1, 128, 128]) | 750 | 42 | 115 | 306 | 90 | 42
shape torch.Size([2, 128, 128]) | 1500 | 82 | 122 | 307 | 164 | 83
shape torch.Size([4, 128, 128]) | 2966 | 162 | 136 | 307 | 301 | 136
shape torch.Size([8, 128, 128]) | 5930 | 317 | 137 | 308 | 578 | 138
shape torch.Size([16, 128, 128]) | 12000 | 635 | 143 | 316 | 199 | 143
shape torch.Size([32, 128, 128]) | 23700 | 1266 | 152 | 322 | 241 | 152
shape torch.Size([64, 128, 128]) | 48000 | 2668 | 177 | 337 | 322 | 177
shape torch.Size([128, 128, 128]) | 96000 | 5366 | 228 | 365 | 514 | 228
shape torch.Size([512, 128, 128]) | 379400 | 21490 | 502 | 620 | 1697 | 502
shape torch.Size([1024, 128, 128]) | 755700 | 43040 | 764 | 903 | 3040 | 770
shape torch.Size([1, 256, 256]) | 750 | 70 | 235 | 383 | 178 | 72
shape torch.Size([2, 256, 256]) | 2000 | 138 | 250 | 384 | 329 | 139
shape torch.Size([4, 256, 256]) | 2988 | 277 | 279 | 404 | 655 | 278
shape torch.Size([8, 256, 256]) | 6100 | 546 | 283 | 420 | 1321 | 286
shape torch.Size([16, 256, 256]) | 12100 | 1149 | 330 | 441 | 472 | 330
shape torch.Size([32, 256, 256]) | 24040 | 2303 | 359 | 453 | 634 | 360
shape torch.Size([64, 256, 256]) | 48000 | 4626 | 408 | 472 | 925 | 408
shape torch.Size([128, 256, 256]) | 94700 | 9247 | 543 | 543 | 1582 | 543
shape torch.Size([512, 256, 256]) | 372000 | 37030 | 1310 | 1185 | 5711 | 1310
shape torch.Size([1024, 256, 256]) | 747200 | 74100 | 2116 | 1910 | 10660 | 2122
```
</details>
<details>
<summary>
Benchmark Results (adjoint=True)
</summary>
```
[----------------------------------------------------------------------------------------- linalg.lu_solve CUDA Adjoint -----------------------------------------------------------------------------------------]
| lu_solve looped_magma | lu_solve looped cusolver | lu_solve batched cublas | lu_solve batched magma | lu_solve unpack+solve_triangular | lu_solve heuristic
1 threads: -------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
shape torch.Size([1, 1, 1]) | 749 | 34 | 28 | 33 | 98 | 27
shape torch.Size([2, 1, 1]) | 1500 | 50 | 28 | 50 | 110 | 27
shape torch.Size([4, 1, 1]) | 3005 | 82 | 28 | 81 | 110 | 27
shape torch.Size([8, 1, 1]) | 5999 | 145 | 28 | 140 | 110 | 27
shape torch.Size([16, 1, 1]) | 12000 | 273 | 28 | 77 | 110 | 27
shape torch.Size([32, 1, 1]) | 24000 | 522 | 28 | 78 | 110 | 27
shape torch.Size([64, 1, 1]) | 48000 | 1000 | 28 | 77 | 100 | 27
shape torch.Size([128, 1, 1]) | 96000 | 2029 | 28 | 78 | 110 | 27
shape torch.Size([512, 1, 1]) | 383300 | 8100 | 28 | 78 | 110 | 28
shape torch.Size([1024, 1, 1]) | 767500 | 16100 | 28 | 77 | 100 | 27
shape torch.Size([1, 2, 2]) | 753 | 33 | 28 | 33 | 99 | 28
shape torch.Size([2, 2, 2]) | 1500 | 50 | 28 | 50 | 110 | 28
shape torch.Size([4, 2, 2]) | 3002 | 82 | 28 | 80 | 100 | 27
shape torch.Size([8, 2, 2]) | 6000 | 145 | 28 | 144 | 107 | 27
shape torch.Size([16, 2, 2]) | 12000 | 271 | 28 | 78 | 110 | 27
shape torch.Size([32, 2, 2]) | 24120 | 524 | 28 | 78 | 110 | 28
shape torch.Size([64, 2, 2]) | 48300 | 1030 | 28 | 78 | 111 | 27
shape torch.Size([128, 2, 2]) | 96100 | 2041 | 28 | 78 | 107 | 28
shape torch.Size([512, 2, 2]) | 383000 | 8100 | 28 | 79 | 108 | 28
shape torch.Size([1024, 2, 2]) | 766100 | 16000 | 28 | 78 | 110 | 28
shape torch.Size([1, 8, 8]) | 750 | 34 | 28 | 34 | 99 | 28
shape torch.Size([2, 8, 8]) | 1500 | 50 | 28 | 50 | 107 | 28
shape torch.Size([4, 8, 8]) | 2998 | 82 | 28 | 82 | 110 | 28
shape torch.Size([8, 8, 8]) | 5990 | 146 | 28 | 150 | 107 | 28
shape torch.Size([16, 8, 8]) | 11980 | 272 | 28 | 79 | 107 | 28
shape torch.Size([32, 8, 8]) | 23970 | 530 | 28 | 79 | 110 | 28
shape torch.Size([64, 8, 8]) | 47900 | 1040 | 28 | 79 | 108 | 28
shape torch.Size([128, 8, 8]) | 96000 | 2048 | 28 | 78 | 108 | 28
shape torch.Size([512, 8, 8]) | 383700 | 8100 | 28 | 80 | 108 | 28
shape torch.Size([1024, 8, 8]) | 766200 | 16300 | 28 | 80 | 108 | 28
shape torch.Size([1, 16, 16]) | 760 | 33 | 28 | 34 | 99 | 28
shape torch.Size([2, 16, 16]) | 1500 | 50 | 28 | 50 | 110 | 28 [85/469]
shape torch.Size([4, 16, 16]) | 3001 | 81 | 28 | 82 | 108 | 28
shape torch.Size([8, 16, 16]) | 6000 | 145 | 28 | 140 | 110 | 28
shape torch.Size([16, 16, 16]) | 12000 | 276 | 28 | 79 | 110 | 28
shape torch.Size([32, 16, 16]) | 23870 | 549 | 28 | 79 | 110 | 28
shape torch.Size([64, 16, 16]) | 47900 | 1098 | 29 | 80 | 100 | 28
shape torch.Size([128, 16, 16]) | 95800 | 2184 | 28 | 79 | 108 | 28
shape torch.Size([512, 16, 16]) | 386900 | 8769 | 28 | 80 | 108 | 28
shape torch.Size([1024, 16, 16]) | 769800 | 17460 | 37 | 80 | 107 | 37
shape torch.Size([1, 32, 32]) | 760 | 33 | 28 | 34 | 99 | 28
shape torch.Size([2, 32, 32]) | 1500 | 50 | 28 | 50 | 110 | 29
shape torch.Size([4, 32, 32]) | 3021 | 86 | 31 | 84 | 110 | 32
shape torch.Size([8, 32, 32]) | 6040 | 167 | 32 | 167 | 108 | 32
shape torch.Size([16, 32, 32]) | 12100 | 330 | 33 | 78 | 107 | 33
shape torch.Size([32, 32, 32]) | 24150 | 662 | 35 | 78 | 110 | 35
shape torch.Size([64, 32, 32]) | 48200 | 1323 | 36 | 79 | 110 | 36
shape torch.Size([128, 32, 32]) | 97000 | 2637 | 44 | 79 | 110 | 43
shape torch.Size([512, 32, 32]) | 382500 | 10580 | 83 | 180 | 198 | 83
shape torch.Size([1024, 32, 32]) | 766600 | 22670 | 123 | 260 | 318 | 120
shape torch.Size([1, 64, 64]) | 760 | 33 | 58 | 33 | 88 | 34
shape torch.Size([2, 64, 64]) | 1520 | 60 | 60 | 60 | 109 | 59
shape torch.Size([4, 64, 64]) | 3016 | 115 | 67 | 119 | 132 | 66
shape torch.Size([8, 64, 64]) | 6120 | 230 | 67 | 233 | 180 | 68
shape torch.Size([16, 64, 64]) | 12100 | 457 | 69 | 86 | 107 | 69
shape torch.Size([32, 64, 64]) | 24000 | 912 | 74 | 95 | 100 | 74
shape torch.Size([64, 64, 64]) | 48000 | 1833 | 76 | 106 | 123 | 76
shape torch.Size([128, 64, 64]) | 95000 | 3636 | 97 | 163 | 183 | 97
shape torch.Size([512, 64, 64]) | 380600 | 15600 | 210 | 464 | 549 | 210
shape torch.Size([1024, 64, 64]) | 761200 | 31140 | 308 | 741 | 918 | 308
shape torch.Size([1, 128, 128]) | 756 | 46 | 120 | 47 | 89 | 46
shape torch.Size([2, 128, 128]) | 1500 | 91 | 123 | 89 | 110 | 92
shape torch.Size([4, 128, 128]) | 2994 | 178 | 139 | 180 | 131 | 139
shape torch.Size([8, 128, 128]) | 5960 | 350 | 140 | 354 | 208 | 142
shape torch.Size([16, 128, 128]) | 12000 | 701 | 144 | 177 | 198 | 143
shape torch.Size([32, 128, 128]) | 23870 | 1401 | 155 | 225 | 246 | 155
shape torch.Size([64, 128, 128]) | 47600 | 2948 | 179 | 288 | 323 | 180
shape torch.Size([128, 128, 128]) | 96000 | 5910 | 231 | 442 | 512 | 231
shape torch.Size([512, 128, 128]) | 381200 | 23640 | 519 | 1400 | 1700 | 519
shape torch.Size([1024, 128, 128]) | 755800 | 47340 | 794 | 2436 | 3018 | 794
shape torch.Size([1, 256, 256]) | 760 | 74 | 246 | 77 | 88 | 78
shape torch.Size([2, 256, 256]) | 1510 | 150 | 256 | 150 | 117 | 150
shape torch.Size([4, 256, 256]) | 3030 | 296 | 284 | 296 | 209 | 284
shape torch.Size([8, 256, 256]) | 6100 | 588 | 286 | 592 | 394 | 288
shape torch.Size([16, 256, 256]) | 12200 | 1238 | 330 | 445 | 480 | 330
shape torch.Size([32, 256, 256]) | 24430 | 2476 | 368 | 568 | 629 | 367
shape torch.Size([64, 256, 256]) | 49000 | 4950 | 415 | 800 | 921 | 414
shape torch.Size([128, 256, 256]) | 96000 | 9900 | 552 | 1330 | 1579 | 553
shape torch.Size([512, 256, 256]) | 369400 | 39580 | 1410 | 4614 | 5616 | 1410
shape torch.Size([1024, 256, 256]) | 716200 | 79200 | 2270 | 8472 | 10500 | 2277
Times are in microseconds (us).
```
</details>
To generate the results below, I put the backend I wanted to test at the beginning of the function `lu_solve_kernel`, followed by a `break;`. Then I run the following script, changing the variable `name`. For the `lu_solve unpack+solve_triangular`, I also changed the `stmt` variable (uncomenting the commented one)
<details>
<summary>
Benchmarking script
</summary>
```python
import torch
import pickle
import itertools
from functools import partial
from torch.utils.benchmark import Timer, Compare
name = "heuristic"
label = "lu_solve {}".format(name)
shapes = [1, 2, 8, 16, 32, 64, 128, 256]
batches = [(1,), (2,), (4,), (8,), (16,), (32,), (64,), (128,), (512,), (1024,)]
results = []
make_arg = partial(torch.randn, dtype=torch.float32, device="cuda")
def f(LU, pivots, B, adjoint):
P, L, U = torch.lu_unpack(LU, pivots)
if adjoint:
X = torch.linalg.solve_triangular(U.mH, B, upper=False)
return P @ torch.linalg.solve_triangular(L.mH, X, upper=True, unitriangular=True, out=X)
else:
X = P.mT @ B
X = torch.linalg.solve_triangular(L, X, upper=False, unitriangular=True, out=X)
return torch.linalg.solve_triangular(U, X, upper=True, out=X)
for n, batch in itertools.product(shapes, batches):
LU, pivots = torch.linalg.lu_factor(make_arg(batch + (n, n)))
B = make_arg(batch + (n, 1))
print(LU.shape)
stmt = "torch.linalg.lu_solve(LU, pivots, B, adjoint=adjoint)"
#stmt = "f(LU, pivots, B, adjoint=adjoint)"
for adjoint in (True, False):
timer = Timer(stmt,
globals=globals(),
label="linalg.lu_solve CUDA{}".format(" Adjoint" if adjoint else ""),
description=label,
sub_label=f"shape {LU.shape}",
num_threads=1)
results.append(timer.blocked_autorange())
compare = Compare(results)
compare.trim_significant_figures()
compare.print()
with open("{}_lu_solve.pickle".format(name), 'wb') as f:
pickle.dump(results, f)
```
</details>
Finally, I joined all the results with the following script:
<details>
<summary>
Script to join the results
</summary>
```python
import pickle
from torch.utils.benchmark import Timer, Compare
files = [
"looped_magma",
"looped cusolver",
"batched cublas",
"batched magma",
"unpack+solve_triangular",
"heuristic",
]
timers = []
for name in files:
with open("{}_lu_solve.pickle".format(name), 'rb') as f:
timers += pickle.load(f)
compare = Compare(timers)
compare.trim_significant_figures()
compare.print()
```
</details>
### Fix for Magma's batched lu_solve when `adjoint=True`
I also developed the following fix around MAGMA's bug, but I ended up not using it, and preferring the triangular solves over it, as they were faster. I'm leaving it here in case it's useful in the future.
<details>
<summary>
Fix for MAGMA's issue with `adjoint=True`
</summary>
```cpp
auto lu_solve_batched_magma_fn = [m](const Tensor& LU, const Tensor& pivots, const Tensor& B, TransposeType trans) {
if (trans == TransposeType::NoTranspose) {
lu_solve_batched_magma(LU, pivots, B, trans);
return;
}
// There's a bug in magma for the other cases, so we need to properly perform mT or mH on LU
// The LU of the transpose is not the transpose of the LU
// We need to do LU = LDU' = L'U' where L' = LD, U' = D^{-1}U and D = diag(U)
auto diag = LU.diagonal(0, -2, -1);
auto LU_f = LU.tril(-1).mul_(diag.unsqueeze(-2)) +
LU.triu(1).div_(diag.unsqueeze(-1));
LU_f.diagonal(0, -2, -1).copy_(diag);
if (trans == TransposeType::ConjTranspose) {
LU_f = LU_f.conj_physical();
}
LU_f.transpose(-2, -1);
// At this point LU_f is F-contiguous, because triu / tril / conj_phisical return contiguous tensors
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(LU_f.mT().is_contiguous());
// Trivial permutation
auto pivots_aux = at::arange(1, m + 1, pivots.options()).expand_as(pivots).contiguous();
lu_solve_batched_magma(LU_f, pivots_aux, B, TransposeType::NoTranspose);
// We then need to multiply B by P on the right as (PLU)^T = B iff U^TL^T = BP
// Fill `perm` with the identity permutation (perhaps batched)
// This is faster than torch.lu_unpack + matmul, as this logic is borrowed from lu_unpack
const auto perm = at::arange(m, pivots.options().dtype(kLong)).expand(pivots.sizes()).contiguous();
auto iter = TensorIteratorConfig()
.set_check_mem_overlap(false)
.check_all_same_dtype(false)
.resize_outputs(false)
.declare_static_shape(pivots.sizes(), /*squash_dim=*/pivots.dim() - 1)
.add_output(perm)
.add_input(pivots)
.build();
unpack_pivots_stub(pivots.device().type(), iter, m);
B.scatter_(-2, perm.unsqueeze(-1).expand_as(B), B.clone());
};
```
</details>
Fixes #61657
[ghstack-poisoned]
This PR adds `linalg.lu_solve`. While doing so, I found a bug in MAGMA when calling the batched MAGMA backend with trans=True. We work around that by solving the system solving two triangular systems. We also update the heuristics for this function, as they were fairly updated. We found that cuSolver is king, so luckily we do not need to rely on the buggy backend from magma for this function. We added tests testing this function left and right. We also added tests for the different backends. We also activated the tests for AMD, as those should work as well. Fixes #61657 ghstack-source-id: 7076c75 Pull Request resolved: #72935
This PR adds `linalg.lu_solve`. While doing so, I found a bug in MAGMA
when calling the batched MAGMA backend with trans=True. We work around
that by solving the system solving two triangular systems.
We also update the heuristics for this function, as they were fairly
outdated. We found that cuSolver is king, so luckily we do not need to
rely on the buggy backend from magma for this function.
We added tests testing this function left and right. We also added tests
for the different backends. We also activated the tests for AMD, as
those should work as well.
### Benchmarking
<details>
<summary>
Benchmark Results (adjoint=False)
</summary>
```
--------------------------------------------------------------------------------------------- linalg.lu_solve CUDA ---------------------------------------------------------------------------------------------]
| lu_solve looped_magma | lu_solve looped cusolver | lu_solve batched cublas | lu_solve batched magma | lu_solve unpack+solve_triangular | lu_solve heuristic
1 threads: -------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
shape torch.Size([1, 1, 1]) | 750 | 34 | 28 | 252 | 86 | 27
shape torch.Size([2, 1, 1]) | 1500 | 50 | 28 | 239 | 94 | 27
shape torch.Size([4, 1, 1]) | 2995 | 83 | 28 | 239 | 100 | 27
shape torch.Size([8, 1, 1]) | 6000 | 146 | 28 | 239 | 94 | 27
shape torch.Size([16, 1, 1]) | 11900 | 272 | 28 | 241 | 95 | 27
shape torch.Size([32, 1, 1]) | 23880 | 524 | 28 | 244 | 94 | 27
shape torch.Size([64, 1, 1]) | 48000 | 1000 | 28 | 245 | 99 | 27
shape torch.Size([128, 1, 1]) | 96000 | 2054 | 28 | 242 | 96 | 27
shape torch.Size([512, 1, 1]) | 381900 | 8100 | 28 | 250 | 94 | 27
shape torch.Size([1024, 1, 1]) | 763800 | 16200 | 28 | 257 | 95 | 27
shape torch.Size([1, 2, 2]) | 750 | 33 | 28 | 240 | 88 | 27
shape torch.Size([2, 2, 2]) | 1500 | 51 | 28 | 240 | 96 | 27
shape torch.Size([4, 2, 2]) | 2991 | 82 | 28 | 241 | 96 | 28
shape torch.Size([8, 2, 2]) | 6000 | 150 | 28 | 241 | 96 | 27
shape torch.Size([16, 2, 2]) | 12000 | 275 | 28 | 242 | 96 | 27
shape torch.Size([32, 2, 2]) | 23980 | 530 | 28 | 246 | 97 | 28
shape torch.Size([64, 2, 2]) | 48000 | 1000 | 28 | 244 | 96 | 27
shape torch.Size([128, 2, 2]) | 96000 | 2063 | 28 | 245 | 96 | 28
shape torch.Size([512, 2, 2]) | 382000 | 8300 | 28 | 257 | 97 | 28
shape torch.Size([1024, 2, 2]) | 764000 | 20000 | 28 | 271 | 97 | 28
shape torch.Size([1, 8, 8]) | 749 | 34 | 28 | 243 | 88 | 28
shape torch.Size([2, 8, 8]) | 1500 | 50 | 28 | 244 | 97 | 28
shape torch.Size([4, 8, 8]) | 2988 | 83 | 28 | 244 | 100 | 28
shape torch.Size([8, 8, 8]) | 5980 | 150 | 28 | 245 | 97 | 28
shape torch.Size([16, 8, 8]) | 12000 | 278 | 28 | 246 | 96 | 28
shape torch.Size([32, 8, 8]) | 23910 | 536 | 28 | 249 | 98 | 28
shape torch.Size([64, 8, 8]) | 47800 | 1100 | 28 | 247 | 96 | 28
shape torch.Size([128, 8, 8]) | 96000 | 2075 | 28 | 248 | 96 | 28
shape torch.Size([512, 8, 8]) | 382100 | 8300 | 28 | 270 | 97 | 28
shape torch.Size([1024, 8, 8]) | 764100 | 16400 | 28 | 291 | 100 | 28
shape torch.Size([1, 16, 16]) | 750 | 33 | 28 | 248 | 88 | 28
shape torch.Size([2, 16, 16]) | 1500 | 50 | 28 | 250 | 97 | 28
shape torch.Size([4, 16, 16]) | 2996 | 83 | 28 | 250 | 100 | 28
shape torch.Size([8, 16, 16]) | 5980 | 147 | 28 | 251 | 97 | 28
shape torch.Size([16, 16, 16]) | 11900 | 274 | 28 | 251 | 97 | 28
shape torch.Size([32, 16, 16]) | 24040 | 527 | 28 | 252 | 97 | 28
shape torch.Size([64, 16, 16]) | 47800 | 1037 | 28 | 251 | 100 | 28
shape torch.Size([128, 16, 16]) | 95600 | 2044 | 28 | 252 | 98 | 28
shape torch.Size([512, 16, 16]) | 388200 | 8100 | 28 | 280 | 100 | 28
shape torch.Size([1024, 16, 16]) | 769700 | 16000 | 28 | 322 | 117 | 28
shape torch.Size([1, 32, 32]) | 760 | 33 | 28 | 255 | 88 | 28
shape torch.Size([2, 32, 32]) | 1510 | 50 | 28 | 256 | 97 | 28
shape torch.Size([4, 32, 32]) | 3022 | 82 | 31 | 256 | 97 | 30
shape torch.Size([8, 32, 32]) | 6000 | 140 | 31 | 257 | 100 | 31
shape torch.Size([16, 32, 32]) | 12000 | 281 | 31 | 258 | 96 | 31
shape torch.Size([32, 32, 32]) | 24150 | 563 | 35 | 258 | 96 | 35
shape torch.Size([64, 32, 32]) | 48300 | 1119 | 36 | 258 | 97 | 36
shape torch.Size([128, 32, 32]) | 96500 | 2235 | 43 | 261 | 96 | 43
shape torch.Size([512, 32, 32]) | 383100 | 8930 | 82 | 317 | 191 | 82
shape torch.Size([1024, 32, 32]) | 766300 | 19200 | 122 | 400 | 312 | 122
shape torch.Size([1, 64, 64]) | 760 | 33 | 55 | 272 | 68 | 34
shape torch.Size([2, 64, 64]) | 1500 | 52 | 58 | 273 | 85 | 52
shape torch.Size([4, 64, 64]) | 3127 | 102 | 65 | 273 | 150 | 65
shape torch.Size([8, 64, 64]) | 6070 | 201 | 65 | 275 | 278 | 65
shape torch.Size([16, 64, 64]) | 12000 | 399 | 66 | 274 | 95 | 67
shape torch.Size([32, 64, 64]) | 23900 | 796 | 73 | 275 | 97 | 73
shape torch.Size([64, 64, 64]) | 48000 | 1594 | 75 | 283 | 123 | 76
shape torch.Size([128, 64, 64]) | 95000 | 3177 | 96 | 292 | 176 | 96
shape torch.Size([512, 64, 64]) | 379300 | 13520 | 208 | 426 | 551 | 208
shape torch.Size([1024, 64, 64]) | 758700 | 27100 | 306 | 570 | 919 | 306
shape torch.Size([1, 128, 128]) | 750 | 42 | 115 | 306 | 90 | 42
shape torch.Size([2, 128, 128]) | 1500 | 82 | 122 | 307 | 164 | 83
shape torch.Size([4, 128, 128]) | 2966 | 162 | 136 | 307 | 301 | 136
shape torch.Size([8, 128, 128]) | 5930 | 317 | 137 | 308 | 578 | 138
shape torch.Size([16, 128, 128]) | 12000 | 635 | 143 | 316 | 199 | 143
shape torch.Size([32, 128, 128]) | 23700 | 1266 | 152 | 322 | 241 | 152
shape torch.Size([64, 128, 128]) | 48000 | 2668 | 177 | 337 | 322 | 177
shape torch.Size([128, 128, 128]) | 96000 | 5366 | 228 | 365 | 514 | 228
shape torch.Size([512, 128, 128]) | 379400 | 21490 | 502 | 620 | 1697 | 502
shape torch.Size([1024, 128, 128]) | 755700 | 43040 | 764 | 903 | 3040 | 770
shape torch.Size([1, 256, 256]) | 750 | 70 | 235 | 383 | 178 | 72
shape torch.Size([2, 256, 256]) | 2000 | 138 | 250 | 384 | 329 | 139
shape torch.Size([4, 256, 256]) | 2988 | 277 | 279 | 404 | 655 | 278
shape torch.Size([8, 256, 256]) | 6100 | 546 | 283 | 420 | 1321 | 286
shape torch.Size([16, 256, 256]) | 12100 | 1149 | 330 | 441 | 472 | 330
shape torch.Size([32, 256, 256]) | 24040 | 2303 | 359 | 453 | 634 | 360
shape torch.Size([64, 256, 256]) | 48000 | 4626 | 408 | 472 | 925 | 408
shape torch.Size([128, 256, 256]) | 94700 | 9247 | 543 | 543 | 1582 | 543
shape torch.Size([512, 256, 256]) | 372000 | 37030 | 1310 | 1185 | 5711 | 1310
shape torch.Size([1024, 256, 256]) | 747200 | 74100 | 2116 | 1910 | 10660 | 2122
```
</details>
<details>
<summary>
Benchmark Results (adjoint=True)
</summary>
```
[----------------------------------------------------------------------------------------- linalg.lu_solve CUDA Adjoint -----------------------------------------------------------------------------------------]
| lu_solve looped_magma | lu_solve looped cusolver | lu_solve batched cublas | lu_solve batched magma | lu_solve unpack+solve_triangular | lu_solve heuristic
1 threads: -------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
shape torch.Size([1, 1, 1]) | 749 | 34 | 28 | 33 | 98 | 27
shape torch.Size([2, 1, 1]) | 1500 | 50 | 28 | 50 | 110 | 27
shape torch.Size([4, 1, 1]) | 3005 | 82 | 28 | 81 | 110 | 27
shape torch.Size([8, 1, 1]) | 5999 | 145 | 28 | 140 | 110 | 27
shape torch.Size([16, 1, 1]) | 12000 | 273 | 28 | 77 | 110 | 27
shape torch.Size([32, 1, 1]) | 24000 | 522 | 28 | 78 | 110 | 27
shape torch.Size([64, 1, 1]) | 48000 | 1000 | 28 | 77 | 100 | 27
shape torch.Size([128, 1, 1]) | 96000 | 2029 | 28 | 78 | 110 | 27
shape torch.Size([512, 1, 1]) | 383300 | 8100 | 28 | 78 | 110 | 28
shape torch.Size([1024, 1, 1]) | 767500 | 16100 | 28 | 77 | 100 | 27
shape torch.Size([1, 2, 2]) | 753 | 33 | 28 | 33 | 99 | 28
shape torch.Size([2, 2, 2]) | 1500 | 50 | 28 | 50 | 110 | 28
shape torch.Size([4, 2, 2]) | 3002 | 82 | 28 | 80 | 100 | 27
shape torch.Size([8, 2, 2]) | 6000 | 145 | 28 | 144 | 107 | 27
shape torch.Size([16, 2, 2]) | 12000 | 271 | 28 | 78 | 110 | 27
shape torch.Size([32, 2, 2]) | 24120 | 524 | 28 | 78 | 110 | 28
shape torch.Size([64, 2, 2]) | 48300 | 1030 | 28 | 78 | 111 | 27
shape torch.Size([128, 2, 2]) | 96100 | 2041 | 28 | 78 | 107 | 28
shape torch.Size([512, 2, 2]) | 383000 | 8100 | 28 | 79 | 108 | 28
shape torch.Size([1024, 2, 2]) | 766100 | 16000 | 28 | 78 | 110 | 28
shape torch.Size([1, 8, 8]) | 750 | 34 | 28 | 34 | 99 | 28
shape torch.Size([2, 8, 8]) | 1500 | 50 | 28 | 50 | 107 | 28
shape torch.Size([4, 8, 8]) | 2998 | 82 | 28 | 82 | 110 | 28
shape torch.Size([8, 8, 8]) | 5990 | 146 | 28 | 150 | 107 | 28
shape torch.Size([16, 8, 8]) | 11980 | 272 | 28 | 79 | 107 | 28
shape torch.Size([32, 8, 8]) | 23970 | 530 | 28 | 79 | 110 | 28
shape torch.Size([64, 8, 8]) | 47900 | 1040 | 28 | 79 | 108 | 28
shape torch.Size([128, 8, 8]) | 96000 | 2048 | 28 | 78 | 108 | 28
shape torch.Size([512, 8, 8]) | 383700 | 8100 | 28 | 80 | 108 | 28
shape torch.Size([1024, 8, 8]) | 766200 | 16300 | 28 | 80 | 108 | 28
shape torch.Size([1, 16, 16]) | 760 | 33 | 28 | 34 | 99 | 28
shape torch.Size([2, 16, 16]) | 1500 | 50 | 28 | 50 | 110 | 28 [85/469]
shape torch.Size([4, 16, 16]) | 3001 | 81 | 28 | 82 | 108 | 28
shape torch.Size([8, 16, 16]) | 6000 | 145 | 28 | 140 | 110 | 28
shape torch.Size([16, 16, 16]) | 12000 | 276 | 28 | 79 | 110 | 28
shape torch.Size([32, 16, 16]) | 23870 | 549 | 28 | 79 | 110 | 28
shape torch.Size([64, 16, 16]) | 47900 | 1098 | 29 | 80 | 100 | 28
shape torch.Size([128, 16, 16]) | 95800 | 2184 | 28 | 79 | 108 | 28
shape torch.Size([512, 16, 16]) | 386900 | 8769 | 28 | 80 | 108 | 28
shape torch.Size([1024, 16, 16]) | 769800 | 17460 | 37 | 80 | 107 | 37
shape torch.Size([1, 32, 32]) | 760 | 33 | 28 | 34 | 99 | 28
shape torch.Size([2, 32, 32]) | 1500 | 50 | 28 | 50 | 110 | 29
shape torch.Size([4, 32, 32]) | 3021 | 86 | 31 | 84 | 110 | 32
shape torch.Size([8, 32, 32]) | 6040 | 167 | 32 | 167 | 108 | 32
shape torch.Size([16, 32, 32]) | 12100 | 330 | 33 | 78 | 107 | 33
shape torch.Size([32, 32, 32]) | 24150 | 662 | 35 | 78 | 110 | 35
shape torch.Size([64, 32, 32]) | 48200 | 1323 | 36 | 79 | 110 | 36
shape torch.Size([128, 32, 32]) | 97000 | 2637 | 44 | 79 | 110 | 43
shape torch.Size([512, 32, 32]) | 382500 | 10580 | 83 | 180 | 198 | 83
shape torch.Size([1024, 32, 32]) | 766600 | 22670 | 123 | 260 | 318 | 120
shape torch.Size([1, 64, 64]) | 760 | 33 | 58 | 33 | 88 | 34
shape torch.Size([2, 64, 64]) | 1520 | 60 | 60 | 60 | 109 | 59
shape torch.Size([4, 64, 64]) | 3016 | 115 | 67 | 119 | 132 | 66
shape torch.Size([8, 64, 64]) | 6120 | 230 | 67 | 233 | 180 | 68
shape torch.Size([16, 64, 64]) | 12100 | 457 | 69 | 86 | 107 | 69
shape torch.Size([32, 64, 64]) | 24000 | 912 | 74 | 95 | 100 | 74
shape torch.Size([64, 64, 64]) | 48000 | 1833 | 76 | 106 | 123 | 76
shape torch.Size([128, 64, 64]) | 95000 | 3636 | 97 | 163 | 183 | 97
shape torch.Size([512, 64, 64]) | 380600 | 15600 | 210 | 464 | 549 | 210
shape torch.Size([1024, 64, 64]) | 761200 | 31140 | 308 | 741 | 918 | 308
shape torch.Size([1, 128, 128]) | 756 | 46 | 120 | 47 | 89 | 46
shape torch.Size([2, 128, 128]) | 1500 | 91 | 123 | 89 | 110 | 92
shape torch.Size([4, 128, 128]) | 2994 | 178 | 139 | 180 | 131 | 139
shape torch.Size([8, 128, 128]) | 5960 | 350 | 140 | 354 | 208 | 142
shape torch.Size([16, 128, 128]) | 12000 | 701 | 144 | 177 | 198 | 143
shape torch.Size([32, 128, 128]) | 23870 | 1401 | 155 | 225 | 246 | 155
shape torch.Size([64, 128, 128]) | 47600 | 2948 | 179 | 288 | 323 | 180
shape torch.Size([128, 128, 128]) | 96000 | 5910 | 231 | 442 | 512 | 231
shape torch.Size([512, 128, 128]) | 381200 | 23640 | 519 | 1400 | 1700 | 519
shape torch.Size([1024, 128, 128]) | 755800 | 47340 | 794 | 2436 | 3018 | 794
shape torch.Size([1, 256, 256]) | 760 | 74 | 246 | 77 | 88 | 78
shape torch.Size([2, 256, 256]) | 1510 | 150 | 256 | 150 | 117 | 150
shape torch.Size([4, 256, 256]) | 3030 | 296 | 284 | 296 | 209 | 284
shape torch.Size([8, 256, 256]) | 6100 | 588 | 286 | 592 | 394 | 288
shape torch.Size([16, 256, 256]) | 12200 | 1238 | 330 | 445 | 480 | 330
shape torch.Size([32, 256, 256]) | 24430 | 2476 | 368 | 568 | 629 | 367
shape torch.Size([64, 256, 256]) | 49000 | 4950 | 415 | 800 | 921 | 414
shape torch.Size([128, 256, 256]) | 96000 | 9900 | 552 | 1330 | 1579 | 553
shape torch.Size([512, 256, 256]) | 369400 | 39580 | 1410 | 4614 | 5616 | 1410
shape torch.Size([1024, 256, 256]) | 716200 | 79200 | 2270 | 8472 | 10500 | 2277
Times are in microseconds (us).
```
</details>
To generate the results below, I put the backend I wanted to test at the beginning of the function `lu_solve_kernel`, followed by a `break;`. Then I run the following script, changing the variable `name`. For the `lu_solve unpack+solve_triangular`, I also changed the `stmt` variable (uncomenting the commented one)
<details>
<summary>
Benchmarking script
</summary>
```python
import torch
import pickle
import itertools
from functools import partial
from torch.utils.benchmark import Timer, Compare
name = "heuristic"
label = "lu_solve {}".format(name)
shapes = [1, 2, 8, 16, 32, 64, 128, 256]
batches = [(1,), (2,), (4,), (8,), (16,), (32,), (64,), (128,), (512,), (1024,)]
results = []
make_arg = partial(torch.randn, dtype=torch.float32, device="cuda")
def f(LU, pivots, B, adjoint):
P, L, U = torch.lu_unpack(LU, pivots)
if adjoint:
X = torch.linalg.solve_triangular(U.mH, B, upper=False)
return P @ torch.linalg.solve_triangular(L.mH, X, upper=True, unitriangular=True, out=X)
else:
X = P.mT @ B
X = torch.linalg.solve_triangular(L, X, upper=False, unitriangular=True, out=X)
return torch.linalg.solve_triangular(U, X, upper=True, out=X)
for n, batch in itertools.product(shapes, batches):
LU, pivots = torch.linalg.lu_factor(make_arg(batch + (n, n)))
B = make_arg(batch + (n, 1))
print(LU.shape)
stmt = "torch.linalg.lu_solve(LU, pivots, B, adjoint=adjoint)"
#stmt = "f(LU, pivots, B, adjoint=adjoint)"
for adjoint in (True, False):
timer = Timer(stmt,
globals=globals(),
label="linalg.lu_solve CUDA{}".format(" Adjoint" if adjoint else ""),
description=label,
sub_label=f"shape {LU.shape}",
num_threads=1)
results.append(timer.blocked_autorange())
compare = Compare(results)
compare.trim_significant_figures()
compare.print()
with open("{}_lu_solve.pickle".format(name), 'wb') as f:
pickle.dump(results, f)
```
</details>
Finally, I joined all the results with the following script:
<details>
<summary>
Script to join the results
</summary>
```python
import pickle
from torch.utils.benchmark import Timer, Compare
files = [
"looped_magma",
"looped cusolver",
"batched cublas",
"batched magma",
"unpack+solve_triangular",
"heuristic",
]
timers = []
for name in files:
with open("{}_lu_solve.pickle".format(name), 'rb') as f:
timers += pickle.load(f)
compare = Compare(timers)
compare.trim_significant_figures()
compare.print()
```
</details>
### Fix for Magma's batched lu_solve when `adjoint=True`
I also developed the following fix around MAGMA's bug, but I ended up not using it, and preferring the triangular solves over it, as they were faster. I'm leaving it here in case it's useful in the future.
<details>
<summary>
Fix for MAGMA's issue with `adjoint=True`
</summary>
```cpp
auto lu_solve_batched_magma_fn = [m](const Tensor& LU, const Tensor& pivots, const Tensor& B, TransposeType trans) {
if (trans == TransposeType::NoTranspose) {
lu_solve_batched_magma(LU, pivots, B, trans);
return;
}
// There's a bug in magma for the other cases, so we need to properly perform mT or mH on LU
// The LU of the transpose is not the transpose of the LU
// We need to do LU = LDU' = L'U' where L' = LD, U' = D^{-1}U and D = diag(U)
auto diag = LU.diagonal(0, -2, -1);
auto LU_f = LU.tril(-1).mul_(diag.unsqueeze(-2)) +
LU.triu(1).div_(diag.unsqueeze(-1));
LU_f.diagonal(0, -2, -1).copy_(diag);
if (trans == TransposeType::ConjTranspose) {
LU_f = LU_f.conj_physical();
}
LU_f.transpose(-2, -1);
// At this point LU_f is F-contiguous, because triu / tril / conj_phisical return contiguous tensors
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(LU_f.mT().is_contiguous());
// Trivial permutation
auto pivots_aux = at::arange(1, m + 1, pivots.options()).expand_as(pivots).contiguous();
lu_solve_batched_magma(LU_f, pivots_aux, B, TransposeType::NoTranspose);
// We then need to multiply B by P on the right as (PLU)^T = B iff U^TL^T = BP
// Fill `perm` with the identity permutation (perhaps batched)
// This is faster than torch.lu_unpack + matmul, as this logic is borrowed from lu_unpack
const auto perm = at::arange(m, pivots.options().dtype(kLong)).expand(pivots.sizes()).contiguous();
auto iter = TensorIteratorConfig()
.set_check_mem_overlap(false)
.check_all_same_dtype(false)
.resize_outputs(false)
.declare_static_shape(pivots.sizes(), /*squash_dim=*/pivots.dim() - 1)
.add_output(perm)
.add_input(pivots)
.build();
unpack_pivots_stub(pivots.device().type(), iter, m);
B.scatter_(-2, perm.unsqueeze(-1).expand_as(B), B.clone());
};
```
</details>
Fixes #61657
[ghstack-poisoned]
This PR adds `linalg.lu_solve`. While doing so, I found a bug in MAGMA
when calling the batched MAGMA backend with trans=True. We work around
that by solving the system solving two triangular systems.
We also update the heuristics for this function, as they were fairly
outdated. We found that cuSolver is king, so luckily we do not need to
rely on the buggy backend from magma for this function.
We added tests testing this function left and right. We also added tests
for the different backends. We also activated the tests for AMD, as
those should work as well.
### Benchmarking
<details>
<summary>
Benchmark Results (adjoint=False)
</summary>
```
--------------------------------------------------------------------------------------------- linalg.lu_solve CUDA ---------------------------------------------------------------------------------------------]
| lu_solve looped_magma | lu_solve looped cusolver | lu_solve batched cublas | lu_solve batched magma | lu_solve unpack+solve_triangular | lu_solve heuristic
1 threads: -------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
shape torch.Size([1, 1, 1]) | 750 | 34 | 28 | 252 | 86 | 27
shape torch.Size([2, 1, 1]) | 1500 | 50 | 28 | 239 | 94 | 27
shape torch.Size([4, 1, 1]) | 2995 | 83 | 28 | 239 | 100 | 27
shape torch.Size([8, 1, 1]) | 6000 | 146 | 28 | 239 | 94 | 27
shape torch.Size([16, 1, 1]) | 11900 | 272 | 28 | 241 | 95 | 27
shape torch.Size([32, 1, 1]) | 23880 | 524 | 28 | 244 | 94 | 27
shape torch.Size([64, 1, 1]) | 48000 | 1000 | 28 | 245 | 99 | 27
shape torch.Size([128, 1, 1]) | 96000 | 2054 | 28 | 242 | 96 | 27
shape torch.Size([512, 1, 1]) | 381900 | 8100 | 28 | 250 | 94 | 27
shape torch.Size([1024, 1, 1]) | 763800 | 16200 | 28 | 257 | 95 | 27
shape torch.Size([1, 2, 2]) | 750 | 33 | 28 | 240 | 88 | 27
shape torch.Size([2, 2, 2]) | 1500 | 51 | 28 | 240 | 96 | 27
shape torch.Size([4, 2, 2]) | 2991 | 82 | 28 | 241 | 96 | 28
shape torch.Size([8, 2, 2]) | 6000 | 150 | 28 | 241 | 96 | 27
shape torch.Size([16, 2, 2]) | 12000 | 275 | 28 | 242 | 96 | 27
shape torch.Size([32, 2, 2]) | 23980 | 530 | 28 | 246 | 97 | 28
shape torch.Size([64, 2, 2]) | 48000 | 1000 | 28 | 244 | 96 | 27
shape torch.Size([128, 2, 2]) | 96000 | 2063 | 28 | 245 | 96 | 28
shape torch.Size([512, 2, 2]) | 382000 | 8300 | 28 | 257 | 97 | 28
shape torch.Size([1024, 2, 2]) | 764000 | 20000 | 28 | 271 | 97 | 28
shape torch.Size([1, 8, 8]) | 749 | 34 | 28 | 243 | 88 | 28
shape torch.Size([2, 8, 8]) | 1500 | 50 | 28 | 244 | 97 | 28
shape torch.Size([4, 8, 8]) | 2988 | 83 | 28 | 244 | 100 | 28
shape torch.Size([8, 8, 8]) | 5980 | 150 | 28 | 245 | 97 | 28
shape torch.Size([16, 8, 8]) | 12000 | 278 | 28 | 246 | 96 | 28
shape torch.Size([32, 8, 8]) | 23910 | 536 | 28 | 249 | 98 | 28
shape torch.Size([64, 8, 8]) | 47800 | 1100 | 28 | 247 | 96 | 28
shape torch.Size([128, 8, 8]) | 96000 | 2075 | 28 | 248 | 96 | 28
shape torch.Size([512, 8, 8]) | 382100 | 8300 | 28 | 270 | 97 | 28
shape torch.Size([1024, 8, 8]) | 764100 | 16400 | 28 | 291 | 100 | 28
shape torch.Size([1, 16, 16]) | 750 | 33 | 28 | 248 | 88 | 28
shape torch.Size([2, 16, 16]) | 1500 | 50 | 28 | 250 | 97 | 28
shape torch.Size([4, 16, 16]) | 2996 | 83 | 28 | 250 | 100 | 28
shape torch.Size([8, 16, 16]) | 5980 | 147 | 28 | 251 | 97 | 28
shape torch.Size([16, 16, 16]) | 11900 | 274 | 28 | 251 | 97 | 28
shape torch.Size([32, 16, 16]) | 24040 | 527 | 28 | 252 | 97 | 28
shape torch.Size([64, 16, 16]) | 47800 | 1037 | 28 | 251 | 100 | 28
shape torch.Size([128, 16, 16]) | 95600 | 2044 | 28 | 252 | 98 | 28
shape torch.Size([512, 16, 16]) | 388200 | 8100 | 28 | 280 | 100 | 28
shape torch.Size([1024, 16, 16]) | 769700 | 16000 | 28 | 322 | 117 | 28
shape torch.Size([1, 32, 32]) | 760 | 33 | 28 | 255 | 88 | 28
shape torch.Size([2, 32, 32]) | 1510 | 50 | 28 | 256 | 97 | 28
shape torch.Size([4, 32, 32]) | 3022 | 82 | 31 | 256 | 97 | 30
shape torch.Size([8, 32, 32]) | 6000 | 140 | 31 | 257 | 100 | 31
shape torch.Size([16, 32, 32]) | 12000 | 281 | 31 | 258 | 96 | 31
shape torch.Size([32, 32, 32]) | 24150 | 563 | 35 | 258 | 96 | 35
shape torch.Size([64, 32, 32]) | 48300 | 1119 | 36 | 258 | 97 | 36
shape torch.Size([128, 32, 32]) | 96500 | 2235 | 43 | 261 | 96 | 43
shape torch.Size([512, 32, 32]) | 383100 | 8930 | 82 | 317 | 191 | 82
shape torch.Size([1024, 32, 32]) | 766300 | 19200 | 122 | 400 | 312 | 122
shape torch.Size([1, 64, 64]) | 760 | 33 | 55 | 272 | 68 | 34
shape torch.Size([2, 64, 64]) | 1500 | 52 | 58 | 273 | 85 | 52
shape torch.Size([4, 64, 64]) | 3127 | 102 | 65 | 273 | 150 | 65
shape torch.Size([8, 64, 64]) | 6070 | 201 | 65 | 275 | 278 | 65
shape torch.Size([16, 64, 64]) | 12000 | 399 | 66 | 274 | 95 | 67
shape torch.Size([32, 64, 64]) | 23900 | 796 | 73 | 275 | 97 | 73
shape torch.Size([64, 64, 64]) | 48000 | 1594 | 75 | 283 | 123 | 76
shape torch.Size([128, 64, 64]) | 95000 | 3177 | 96 | 292 | 176 | 96
shape torch.Size([512, 64, 64]) | 379300 | 13520 | 208 | 426 | 551 | 208
shape torch.Size([1024, 64, 64]) | 758700 | 27100 | 306 | 570 | 919 | 306
shape torch.Size([1, 128, 128]) | 750 | 42 | 115 | 306 | 90 | 42
shape torch.Size([2, 128, 128]) | 1500 | 82 | 122 | 307 | 164 | 83
shape torch.Size([4, 128, 128]) | 2966 | 162 | 136 | 307 | 301 | 136
shape torch.Size([8, 128, 128]) | 5930 | 317 | 137 | 308 | 578 | 138
shape torch.Size([16, 128, 128]) | 12000 | 635 | 143 | 316 | 199 | 143
shape torch.Size([32, 128, 128]) | 23700 | 1266 | 152 | 322 | 241 | 152
shape torch.Size([64, 128, 128]) | 48000 | 2668 | 177 | 337 | 322 | 177
shape torch.Size([128, 128, 128]) | 96000 | 5366 | 228 | 365 | 514 | 228
shape torch.Size([512, 128, 128]) | 379400 | 21490 | 502 | 620 | 1697 | 502
shape torch.Size([1024, 128, 128]) | 755700 | 43040 | 764 | 903 | 3040 | 770
shape torch.Size([1, 256, 256]) | 750 | 70 | 235 | 383 | 178 | 72
shape torch.Size([2, 256, 256]) | 2000 | 138 | 250 | 384 | 329 | 139
shape torch.Size([4, 256, 256]) | 2988 | 277 | 279 | 404 | 655 | 278
shape torch.Size([8, 256, 256]) | 6100 | 546 | 283 | 420 | 1321 | 286
shape torch.Size([16, 256, 256]) | 12100 | 1149 | 330 | 441 | 472 | 330
shape torch.Size([32, 256, 256]) | 24040 | 2303 | 359 | 453 | 634 | 360
shape torch.Size([64, 256, 256]) | 48000 | 4626 | 408 | 472 | 925 | 408
shape torch.Size([128, 256, 256]) | 94700 | 9247 | 543 | 543 | 1582 | 543
shape torch.Size([512, 256, 256]) | 372000 | 37030 | 1310 | 1185 | 5711 | 1310
shape torch.Size([1024, 256, 256]) | 747200 | 74100 | 2116 | 1910 | 10660 | 2122
```
</details>
<details>
<summary>
Benchmark Results (adjoint=True)
</summary>
```
[----------------------------------------------------------------------------------------- linalg.lu_solve CUDA Adjoint -----------------------------------------------------------------------------------------]
| lu_solve looped_magma | lu_solve looped cusolver | lu_solve batched cublas | lu_solve batched magma | lu_solve unpack+solve_triangular | lu_solve heuristic
1 threads: -------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
shape torch.Size([1, 1, 1]) | 749 | 34 | 28 | 33 | 98 | 27
shape torch.Size([2, 1, 1]) | 1500 | 50 | 28 | 50 | 110 | 27
shape torch.Size([4, 1, 1]) | 3005 | 82 | 28 | 81 | 110 | 27
shape torch.Size([8, 1, 1]) | 5999 | 145 | 28 | 140 | 110 | 27
shape torch.Size([16, 1, 1]) | 12000 | 273 | 28 | 77 | 110 | 27
shape torch.Size([32, 1, 1]) | 24000 | 522 | 28 | 78 | 110 | 27
shape torch.Size([64, 1, 1]) | 48000 | 1000 | 28 | 77 | 100 | 27
shape torch.Size([128, 1, 1]) | 96000 | 2029 | 28 | 78 | 110 | 27
shape torch.Size([512, 1, 1]) | 383300 | 8100 | 28 | 78 | 110 | 28
shape torch.Size([1024, 1, 1]) | 767500 | 16100 | 28 | 77 | 100 | 27
shape torch.Size([1, 2, 2]) | 753 | 33 | 28 | 33 | 99 | 28
shape torch.Size([2, 2, 2]) | 1500 | 50 | 28 | 50 | 110 | 28
shape torch.Size([4, 2, 2]) | 3002 | 82 | 28 | 80 | 100 | 27
shape torch.Size([8, 2, 2]) | 6000 | 145 | 28 | 144 | 107 | 27
shape torch.Size([16, 2, 2]) | 12000 | 271 | 28 | 78 | 110 | 27
shape torch.Size([32, 2, 2]) | 24120 | 524 | 28 | 78 | 110 | 28
shape torch.Size([64, 2, 2]) | 48300 | 1030 | 28 | 78 | 111 | 27
shape torch.Size([128, 2, 2]) | 96100 | 2041 | 28 | 78 | 107 | 28
shape torch.Size([512, 2, 2]) | 383000 | 8100 | 28 | 79 | 108 | 28
shape torch.Size([1024, 2, 2]) | 766100 | 16000 | 28 | 78 | 110 | 28
shape torch.Size([1, 8, 8]) | 750 | 34 | 28 | 34 | 99 | 28
shape torch.Size([2, 8, 8]) | 1500 | 50 | 28 | 50 | 107 | 28
shape torch.Size([4, 8, 8]) | 2998 | 82 | 28 | 82 | 110 | 28
shape torch.Size([8, 8, 8]) | 5990 | 146 | 28 | 150 | 107 | 28
shape torch.Size([16, 8, 8]) | 11980 | 272 | 28 | 79 | 107 | 28
shape torch.Size([32, 8, 8]) | 23970 | 530 | 28 | 79 | 110 | 28
shape torch.Size([64, 8, 8]) | 47900 | 1040 | 28 | 79 | 108 | 28
shape torch.Size([128, 8, 8]) | 96000 | 2048 | 28 | 78 | 108 | 28
shape torch.Size([512, 8, 8]) | 383700 | 8100 | 28 | 80 | 108 | 28
shape torch.Size([1024, 8, 8]) | 766200 | 16300 | 28 | 80 | 108 | 28
shape torch.Size([1, 16, 16]) | 760 | 33 | 28 | 34 | 99 | 28
shape torch.Size([2, 16, 16]) | 1500 | 50 | 28 | 50 | 110 | 28 [85/469]
shape torch.Size([4, 16, 16]) | 3001 | 81 | 28 | 82 | 108 | 28
shape torch.Size([8, 16, 16]) | 6000 | 145 | 28 | 140 | 110 | 28
shape torch.Size([16, 16, 16]) | 12000 | 276 | 28 | 79 | 110 | 28
shape torch.Size([32, 16, 16]) | 23870 | 549 | 28 | 79 | 110 | 28
shape torch.Size([64, 16, 16]) | 47900 | 1098 | 29 | 80 | 100 | 28
shape torch.Size([128, 16, 16]) | 95800 | 2184 | 28 | 79 | 108 | 28
shape torch.Size([512, 16, 16]) | 386900 | 8769 | 28 | 80 | 108 | 28
shape torch.Size([1024, 16, 16]) | 769800 | 17460 | 37 | 80 | 107 | 37
shape torch.Size([1, 32, 32]) | 760 | 33 | 28 | 34 | 99 | 28
shape torch.Size([2, 32, 32]) | 1500 | 50 | 28 | 50 | 110 | 29
shape torch.Size([4, 32, 32]) | 3021 | 86 | 31 | 84 | 110 | 32
shape torch.Size([8, 32, 32]) | 6040 | 167 | 32 | 167 | 108 | 32
shape torch.Size([16, 32, 32]) | 12100 | 330 | 33 | 78 | 107 | 33
shape torch.Size([32, 32, 32]) | 24150 | 662 | 35 | 78 | 110 | 35
shape torch.Size([64, 32, 32]) | 48200 | 1323 | 36 | 79 | 110 | 36
shape torch.Size([128, 32, 32]) | 97000 | 2637 | 44 | 79 | 110 | 43
shape torch.Size([512, 32, 32]) | 382500 | 10580 | 83 | 180 | 198 | 83
shape torch.Size([1024, 32, 32]) | 766600 | 22670 | 123 | 260 | 318 | 120
shape torch.Size([1, 64, 64]) | 760 | 33 | 58 | 33 | 88 | 34
shape torch.Size([2, 64, 64]) | 1520 | 60 | 60 | 60 | 109 | 59
shape torch.Size([4, 64, 64]) | 3016 | 115 | 67 | 119 | 132 | 66
shape torch.Size([8, 64, 64]) | 6120 | 230 | 67 | 233 | 180 | 68
shape torch.Size([16, 64, 64]) | 12100 | 457 | 69 | 86 | 107 | 69
shape torch.Size([32, 64, 64]) | 24000 | 912 | 74 | 95 | 100 | 74
shape torch.Size([64, 64, 64]) | 48000 | 1833 | 76 | 106 | 123 | 76
shape torch.Size([128, 64, 64]) | 95000 | 3636 | 97 | 163 | 183 | 97
shape torch.Size([512, 64, 64]) | 380600 | 15600 | 210 | 464 | 549 | 210
shape torch.Size([1024, 64, 64]) | 761200 | 31140 | 308 | 741 | 918 | 308
shape torch.Size([1, 128, 128]) | 756 | 46 | 120 | 47 | 89 | 46
shape torch.Size([2, 128, 128]) | 1500 | 91 | 123 | 89 | 110 | 92
shape torch.Size([4, 128, 128]) | 2994 | 178 | 139 | 180 | 131 | 139
shape torch.Size([8, 128, 128]) | 5960 | 350 | 140 | 354 | 208 | 142
shape torch.Size([16, 128, 128]) | 12000 | 701 | 144 | 177 | 198 | 143
shape torch.Size([32, 128, 128]) | 23870 | 1401 | 155 | 225 | 246 | 155
shape torch.Size([64, 128, 128]) | 47600 | 2948 | 179 | 288 | 323 | 180
shape torch.Size([128, 128, 128]) | 96000 | 5910 | 231 | 442 | 512 | 231
shape torch.Size([512, 128, 128]) | 381200 | 23640 | 519 | 1400 | 1700 | 519
shape torch.Size([1024, 128, 128]) | 755800 | 47340 | 794 | 2436 | 3018 | 794
shape torch.Size([1, 256, 256]) | 760 | 74 | 246 | 77 | 88 | 78
shape torch.Size([2, 256, 256]) | 1510 | 150 | 256 | 150 | 117 | 150
shape torch.Size([4, 256, 256]) | 3030 | 296 | 284 | 296 | 209 | 284
shape torch.Size([8, 256, 256]) | 6100 | 588 | 286 | 592 | 394 | 288
shape torch.Size([16, 256, 256]) | 12200 | 1238 | 330 | 445 | 480 | 330
shape torch.Size([32, 256, 256]) | 24430 | 2476 | 368 | 568 | 629 | 367
shape torch.Size([64, 256, 256]) | 49000 | 4950 | 415 | 800 | 921 | 414
shape torch.Size([128, 256, 256]) | 96000 | 9900 | 552 | 1330 | 1579 | 553
shape torch.Size([512, 256, 256]) | 369400 | 39580 | 1410 | 4614 | 5616 | 1410
shape torch.Size([1024, 256, 256]) | 716200 | 79200 | 2270 | 8472 | 10500 | 2277
Times are in microseconds (us).
```
</details>
To generate the results below, I put the backend I wanted to test at the beginning of the function `lu_solve_kernel`, followed by a `break;`. Then I run the following script, changing the variable `name`. For the `lu_solve unpack+solve_triangular`, I also changed the `stmt` variable (uncomenting the commented one)
<details>
<summary>
Benchmarking script
</summary>
```python
import torch
import pickle
import itertools
from functools import partial
from torch.utils.benchmark import Timer, Compare
name = "heuristic"
label = "lu_solve {}".format(name)
shapes = [1, 2, 8, 16, 32, 64, 128, 256]
batches = [(1,), (2,), (4,), (8,), (16,), (32,), (64,), (128,), (512,), (1024,)]
results = []
make_arg = partial(torch.randn, dtype=torch.float32, device="cuda")
def f(LU, pivots, B, adjoint):
P, L, U = torch.lu_unpack(LU, pivots)
if adjoint:
X = torch.linalg.solve_triangular(U.mH, B, upper=False)
return P @ torch.linalg.solve_triangular(L.mH, X, upper=True, unitriangular=True, out=X)
else:
X = P.mT @ B
X = torch.linalg.solve_triangular(L, X, upper=False, unitriangular=True, out=X)
return torch.linalg.solve_triangular(U, X, upper=True, out=X)
for n, batch in itertools.product(shapes, batches):
LU, pivots = torch.linalg.lu_factor(make_arg(batch + (n, n)))
B = make_arg(batch + (n, 1))
print(LU.shape)
stmt = "torch.linalg.lu_solve(LU, pivots, B, adjoint=adjoint)"
#stmt = "f(LU, pivots, B, adjoint=adjoint)"
for adjoint in (True, False):
timer = Timer(stmt,
globals=globals(),
label="linalg.lu_solve CUDA{}".format(" Adjoint" if adjoint else ""),
description=label,
sub_label=f"shape {LU.shape}",
num_threads=1)
results.append(timer.blocked_autorange())
compare = Compare(results)
compare.trim_significant_figures()
compare.print()
with open("{}_lu_solve.pickle".format(name), 'wb') as f:
pickle.dump(results, f)
```
</details>
Finally, I joined all the results with the following script:
<details>
<summary>
Script to join the results
</summary>
```python
import pickle
from torch.utils.benchmark import Timer, Compare
files = [
"looped_magma",
"looped cusolver",
"batched cublas",
"batched magma",
"unpack+solve_triangular",
"heuristic",
]
timers = []
for name in files:
with open("{}_lu_solve.pickle".format(name), 'rb') as f:
timers += pickle.load(f)
compare = Compare(timers)
compare.trim_significant_figures()
compare.print()
```
</details>
### Fix for Magma's batched lu_solve when `adjoint=True`
I also developed the following fix around MAGMA's bug, but I ended up not using it, and preferring the triangular solves over it, as they were faster. I'm leaving it here in case it's useful in the future.
<details>
<summary>
Fix for MAGMA's issue with `adjoint=True`
</summary>
```cpp
auto lu_solve_batched_magma_fn = [m](const Tensor& LU, const Tensor& pivots, const Tensor& B, TransposeType trans) {
if (trans == TransposeType::NoTranspose) {
lu_solve_batched_magma(LU, pivots, B, trans);
return;
}
// There's a bug in magma for the other cases, so we need to properly perform mT or mH on LU
// The LU of the transpose is not the transpose of the LU
// We need to do LU = LDU' = L'U' where L' = LD, U' = D^{-1}U and D = diag(U)
auto diag = LU.diagonal(0, -2, -1);
auto LU_f = LU.tril(-1).mul_(diag.unsqueeze(-2)) +
LU.triu(1).div_(diag.unsqueeze(-1));
LU_f.diagonal(0, -2, -1).copy_(diag);
if (trans == TransposeType::ConjTranspose) {
LU_f = LU_f.conj_physical();
}
LU_f.transpose(-2, -1);
// At this point LU_f is F-contiguous, because triu / tril / conj_phisical return contiguous tensors
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(LU_f.mT().is_contiguous());
// Trivial permutation
auto pivots_aux = at::arange(1, m + 1, pivots.options()).expand_as(pivots).contiguous();
lu_solve_batched_magma(LU_f, pivots_aux, B, TransposeType::NoTranspose);
// We then need to multiply B by P on the right as (PLU)^T = B iff U^TL^T = BP
// Fill `perm` with the identity permutation (perhaps batched)
// This is faster than torch.lu_unpack + matmul, as this logic is borrowed from lu_unpack
const auto perm = at::arange(m, pivots.options().dtype(kLong)).expand(pivots.sizes()).contiguous();
auto iter = TensorIteratorConfig()
.set_check_mem_overlap(false)
.check_all_same_dtype(false)
.resize_outputs(false)
.declare_static_shape(pivots.sizes(), /*squash_dim=*/pivots.dim() - 1)
.add_output(perm)
.add_input(pivots)
.build();
unpack_pivots_stub(pivots.device().type(), iter, m);
B.scatter_(-2, perm.unsqueeze(-1).expand_as(B), B.clone());
};
```
</details>
Fixes #61657
[ghstack-poisoned]
This PR adds `linalg.lu_solve`. While doing so, I found a bug in MAGMA
when calling the batched MAGMA backend with trans=True. We work around
that by solving the system solving two triangular systems.
We also update the heuristics for this function, as they were fairly
outdated. We found that cuSolver is king, so luckily we do not need to
rely on the buggy backend from magma for this function.
We added tests testing this function left and right. We also added tests
for the different backends. We also activated the tests for AMD, as
those should work as well.
### Benchmarking
<details>
<summary>
Benchmark Results (adjoint=False)
</summary>
```
--------------------------------------------------------------------------------------------- linalg.lu_solve CUDA ---------------------------------------------------------------------------------------------]
| lu_solve looped_magma | lu_solve looped cusolver | lu_solve batched cublas | lu_solve batched magma | lu_solve unpack+solve_triangular | lu_solve heuristic
1 threads: -------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
shape torch.Size([1, 1, 1]) | 750 | 34 | 28 | 252 | 86 | 27
shape torch.Size([2, 1, 1]) | 1500 | 50 | 28 | 239 | 94 | 27
shape torch.Size([4, 1, 1]) | 2995 | 83 | 28 | 239 | 100 | 27
shape torch.Size([8, 1, 1]) | 6000 | 146 | 28 | 239 | 94 | 27
shape torch.Size([16, 1, 1]) | 11900 | 272 | 28 | 241 | 95 | 27
shape torch.Size([32, 1, 1]) | 23880 | 524 | 28 | 244 | 94 | 27
shape torch.Size([64, 1, 1]) | 48000 | 1000 | 28 | 245 | 99 | 27
shape torch.Size([128, 1, 1]) | 96000 | 2054 | 28 | 242 | 96 | 27
shape torch.Size([512, 1, 1]) | 381900 | 8100 | 28 | 250 | 94 | 27
shape torch.Size([1024, 1, 1]) | 763800 | 16200 | 28 | 257 | 95 | 27
shape torch.Size([1, 2, 2]) | 750 | 33 | 28 | 240 | 88 | 27
shape torch.Size([2, 2, 2]) | 1500 | 51 | 28 | 240 | 96 | 27
shape torch.Size([4, 2, 2]) | 2991 | 82 | 28 | 241 | 96 | 28
shape torch.Size([8, 2, 2]) | 6000 | 150 | 28 | 241 | 96 | 27
shape torch.Size([16, 2, 2]) | 12000 | 275 | 28 | 242 | 96 | 27
shape torch.Size([32, 2, 2]) | 23980 | 530 | 28 | 246 | 97 | 28
shape torch.Size([64, 2, 2]) | 48000 | 1000 | 28 | 244 | 96 | 27
shape torch.Size([128, 2, 2]) | 96000 | 2063 | 28 | 245 | 96 | 28
shape torch.Size([512, 2, 2]) | 382000 | 8300 | 28 | 257 | 97 | 28
shape torch.Size([1024, 2, 2]) | 764000 | 20000 | 28 | 271 | 97 | 28
shape torch.Size([1, 8, 8]) | 749 | 34 | 28 | 243 | 88 | 28
shape torch.Size([2, 8, 8]) | 1500 | 50 | 28 | 244 | 97 | 28
shape torch.Size([4, 8, 8]) | 2988 | 83 | 28 | 244 | 100 | 28
shape torch.Size([8, 8, 8]) | 5980 | 150 | 28 | 245 | 97 | 28
shape torch.Size([16, 8, 8]) | 12000 | 278 | 28 | 246 | 96 | 28
shape torch.Size([32, 8, 8]) | 23910 | 536 | 28 | 249 | 98 | 28
shape torch.Size([64, 8, 8]) | 47800 | 1100 | 28 | 247 | 96 | 28
shape torch.Size([128, 8, 8]) | 96000 | 2075 | 28 | 248 | 96 | 28
shape torch.Size([512, 8, 8]) | 382100 | 8300 | 28 | 270 | 97 | 28
shape torch.Size([1024, 8, 8]) | 764100 | 16400 | 28 | 291 | 100 | 28
shape torch.Size([1, 16, 16]) | 750 | 33 | 28 | 248 | 88 | 28
shape torch.Size([2, 16, 16]) | 1500 | 50 | 28 | 250 | 97 | 28
shape torch.Size([4, 16, 16]) | 2996 | 83 | 28 | 250 | 100 | 28
shape torch.Size([8, 16, 16]) | 5980 | 147 | 28 | 251 | 97 | 28
shape torch.Size([16, 16, 16]) | 11900 | 274 | 28 | 251 | 97 | 28
shape torch.Size([32, 16, 16]) | 24040 | 527 | 28 | 252 | 97 | 28
shape torch.Size([64, 16, 16]) | 47800 | 1037 | 28 | 251 | 100 | 28
shape torch.Size([128, 16, 16]) | 95600 | 2044 | 28 | 252 | 98 | 28
shape torch.Size([512, 16, 16]) | 388200 | 8100 | 28 | 280 | 100 | 28
shape torch.Size([1024, 16, 16]) | 769700 | 16000 | 28 | 322 | 117 | 28
shape torch.Size([1, 32, 32]) | 760 | 33 | 28 | 255 | 88 | 28
shape torch.Size([2, 32, 32]) | 1510 | 50 | 28 | 256 | 97 | 28
shape torch.Size([4, 32, 32]) | 3022 | 82 | 31 | 256 | 97 | 30
shape torch.Size([8, 32, 32]) | 6000 | 140 | 31 | 257 | 100 | 31
shape torch.Size([16, 32, 32]) | 12000 | 281 | 31 | 258 | 96 | 31
shape torch.Size([32, 32, 32]) | 24150 | 563 | 35 | 258 | 96 | 35
shape torch.Size([64, 32, 32]) | 48300 | 1119 | 36 | 258 | 97 | 36
shape torch.Size([128, 32, 32]) | 96500 | 2235 | 43 | 261 | 96 | 43
shape torch.Size([512, 32, 32]) | 383100 | 8930 | 82 | 317 | 191 | 82
shape torch.Size([1024, 32, 32]) | 766300 | 19200 | 122 | 400 | 312 | 122
shape torch.Size([1, 64, 64]) | 760 | 33 | 55 | 272 | 68 | 34
shape torch.Size([2, 64, 64]) | 1500 | 52 | 58 | 273 | 85 | 52
shape torch.Size([4, 64, 64]) | 3127 | 102 | 65 | 273 | 150 | 65
shape torch.Size([8, 64, 64]) | 6070 | 201 | 65 | 275 | 278 | 65
shape torch.Size([16, 64, 64]) | 12000 | 399 | 66 | 274 | 95 | 67
shape torch.Size([32, 64, 64]) | 23900 | 796 | 73 | 275 | 97 | 73
shape torch.Size([64, 64, 64]) | 48000 | 1594 | 75 | 283 | 123 | 76
shape torch.Size([128, 64, 64]) | 95000 | 3177 | 96 | 292 | 176 | 96
shape torch.Size([512, 64, 64]) | 379300 | 13520 | 208 | 426 | 551 | 208
shape torch.Size([1024, 64, 64]) | 758700 | 27100 | 306 | 570 | 919 | 306
shape torch.Size([1, 128, 128]) | 750 | 42 | 115 | 306 | 90 | 42
shape torch.Size([2, 128, 128]) | 1500 | 82 | 122 | 307 | 164 | 83
shape torch.Size([4, 128, 128]) | 2966 | 162 | 136 | 307 | 301 | 136
shape torch.Size([8, 128, 128]) | 5930 | 317 | 137 | 308 | 578 | 138
shape torch.Size([16, 128, 128]) | 12000 | 635 | 143 | 316 | 199 | 143
shape torch.Size([32, 128, 128]) | 23700 | 1266 | 152 | 322 | 241 | 152
shape torch.Size([64, 128, 128]) | 48000 | 2668 | 177 | 337 | 322 | 177
shape torch.Size([128, 128, 128]) | 96000 | 5366 | 228 | 365 | 514 | 228
shape torch.Size([512, 128, 128]) | 379400 | 21490 | 502 | 620 | 1697 | 502
shape torch.Size([1024, 128, 128]) | 755700 | 43040 | 764 | 903 | 3040 | 770
shape torch.Size([1, 256, 256]) | 750 | 70 | 235 | 383 | 178 | 72
shape torch.Size([2, 256, 256]) | 2000 | 138 | 250 | 384 | 329 | 139
shape torch.Size([4, 256, 256]) | 2988 | 277 | 279 | 404 | 655 | 278
shape torch.Size([8, 256, 256]) | 6100 | 546 | 283 | 420 | 1321 | 286
shape torch.Size([16, 256, 256]) | 12100 | 1149 | 330 | 441 | 472 | 330
shape torch.Size([32, 256, 256]) | 24040 | 2303 | 359 | 453 | 634 | 360
shape torch.Size([64, 256, 256]) | 48000 | 4626 | 408 | 472 | 925 | 408
shape torch.Size([128, 256, 256]) | 94700 | 9247 | 543 | 543 | 1582 | 543
shape torch.Size([512, 256, 256]) | 372000 | 37030 | 1310 | 1185 | 5711 | 1310
shape torch.Size([1024, 256, 256]) | 747200 | 74100 | 2116 | 1910 | 10660 | 2122
```
</details>
<details>
<summary>
Benchmark Results (adjoint=True)
</summary>
```
[----------------------------------------------------------------------------------------- linalg.lu_solve CUDA Adjoint -----------------------------------------------------------------------------------------]
| lu_solve looped_magma | lu_solve looped cusolver | lu_solve batched cublas | lu_solve batched magma | lu_solve unpack+solve_triangular | lu_solve heuristic
1 threads: -------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
shape torch.Size([1, 1, 1]) | 749 | 34 | 28 | 33 | 98 | 27
shape torch.Size([2, 1, 1]) | 1500 | 50 | 28 | 50 | 110 | 27
shape torch.Size([4, 1, 1]) | 3005 | 82 | 28 | 81 | 110 | 27
shape torch.Size([8, 1, 1]) | 5999 | 145 | 28 | 140 | 110 | 27
shape torch.Size([16, 1, 1]) | 12000 | 273 | 28 | 77 | 110 | 27
shape torch.Size([32, 1, 1]) | 24000 | 522 | 28 | 78 | 110 | 27
shape torch.Size([64, 1, 1]) | 48000 | 1000 | 28 | 77 | 100 | 27
shape torch.Size([128, 1, 1]) | 96000 | 2029 | 28 | 78 | 110 | 27
shape torch.Size([512, 1, 1]) | 383300 | 8100 | 28 | 78 | 110 | 28
shape torch.Size([1024, 1, 1]) | 767500 | 16100 | 28 | 77 | 100 | 27
shape torch.Size([1, 2, 2]) | 753 | 33 | 28 | 33 | 99 | 28
shape torch.Size([2, 2, 2]) | 1500 | 50 | 28 | 50 | 110 | 28
shape torch.Size([4, 2, 2]) | 3002 | 82 | 28 | 80 | 100 | 27
shape torch.Size([8, 2, 2]) | 6000 | 145 | 28 | 144 | 107 | 27
shape torch.Size([16, 2, 2]) | 12000 | 271 | 28 | 78 | 110 | 27
shape torch.Size([32, 2, 2]) | 24120 | 524 | 28 | 78 | 110 | 28
shape torch.Size([64, 2, 2]) | 48300 | 1030 | 28 | 78 | 111 | 27
shape torch.Size([128, 2, 2]) | 96100 | 2041 | 28 | 78 | 107 | 28
shape torch.Size([512, 2, 2]) | 383000 | 8100 | 28 | 79 | 108 | 28
shape torch.Size([1024, 2, 2]) | 766100 | 16000 | 28 | 78 | 110 | 28
shape torch.Size([1, 8, 8]) | 750 | 34 | 28 | 34 | 99 | 28
shape torch.Size([2, 8, 8]) | 1500 | 50 | 28 | 50 | 107 | 28
shape torch.Size([4, 8, 8]) | 2998 | 82 | 28 | 82 | 110 | 28
shape torch.Size([8, 8, 8]) | 5990 | 146 | 28 | 150 | 107 | 28
shape torch.Size([16, 8, 8]) | 11980 | 272 | 28 | 79 | 107 | 28
shape torch.Size([32, 8, 8]) | 23970 | 530 | 28 | 79 | 110 | 28
shape torch.Size([64, 8, 8]) | 47900 | 1040 | 28 | 79 | 108 | 28
shape torch.Size([128, 8, 8]) | 96000 | 2048 | 28 | 78 | 108 | 28
shape torch.Size([512, 8, 8]) | 383700 | 8100 | 28 | 80 | 108 | 28
shape torch.Size([1024, 8, 8]) | 766200 | 16300 | 28 | 80 | 108 | 28
shape torch.Size([1, 16, 16]) | 760 | 33 | 28 | 34 | 99 | 28
shape torch.Size([2, 16, 16]) | 1500 | 50 | 28 | 50 | 110 | 28 [85/469]
shape torch.Size([4, 16, 16]) | 3001 | 81 | 28 | 82 | 108 | 28
shape torch.Size([8, 16, 16]) | 6000 | 145 | 28 | 140 | 110 | 28
shape torch.Size([16, 16, 16]) | 12000 | 276 | 28 | 79 | 110 | 28
shape torch.Size([32, 16, 16]) | 23870 | 549 | 28 | 79 | 110 | 28
shape torch.Size([64, 16, 16]) | 47900 | 1098 | 29 | 80 | 100 | 28
shape torch.Size([128, 16, 16]) | 95800 | 2184 | 28 | 79 | 108 | 28
shape torch.Size([512, 16, 16]) | 386900 | 8769 | 28 | 80 | 108 | 28
shape torch.Size([1024, 16, 16]) | 769800 | 17460 | 37 | 80 | 107 | 37
shape torch.Size([1, 32, 32]) | 760 | 33 | 28 | 34 | 99 | 28
shape torch.Size([2, 32, 32]) | 1500 | 50 | 28 | 50 | 110 | 29
shape torch.Size([4, 32, 32]) | 3021 | 86 | 31 | 84 | 110 | 32
shape torch.Size([8, 32, 32]) | 6040 | 167 | 32 | 167 | 108 | 32
shape torch.Size([16, 32, 32]) | 12100 | 330 | 33 | 78 | 107 | 33
shape torch.Size([32, 32, 32]) | 24150 | 662 | 35 | 78 | 110 | 35
shape torch.Size([64, 32, 32]) | 48200 | 1323 | 36 | 79 | 110 | 36
shape torch.Size([128, 32, 32]) | 97000 | 2637 | 44 | 79 | 110 | 43
shape torch.Size([512, 32, 32]) | 382500 | 10580 | 83 | 180 | 198 | 83
shape torch.Size([1024, 32, 32]) | 766600 | 22670 | 123 | 260 | 318 | 120
shape torch.Size([1, 64, 64]) | 760 | 33 | 58 | 33 | 88 | 34
shape torch.Size([2, 64, 64]) | 1520 | 60 | 60 | 60 | 109 | 59
shape torch.Size([4, 64, 64]) | 3016 | 115 | 67 | 119 | 132 | 66
shape torch.Size([8, 64, 64]) | 6120 | 230 | 67 | 233 | 180 | 68
shape torch.Size([16, 64, 64]) | 12100 | 457 | 69 | 86 | 107 | 69
shape torch.Size([32, 64, 64]) | 24000 | 912 | 74 | 95 | 100 | 74
shape torch.Size([64, 64, 64]) | 48000 | 1833 | 76 | 106 | 123 | 76
shape torch.Size([128, 64, 64]) | 95000 | 3636 | 97 | 163 | 183 | 97
shape torch.Size([512, 64, 64]) | 380600 | 15600 | 210 | 464 | 549 | 210
shape torch.Size([1024, 64, 64]) | 761200 | 31140 | 308 | 741 | 918 | 308
shape torch.Size([1, 128, 128]) | 756 | 46 | 120 | 47 | 89 | 46
shape torch.Size([2, 128, 128]) | 1500 | 91 | 123 | 89 | 110 | 92
shape torch.Size([4, 128, 128]) | 2994 | 178 | 139 | 180 | 131 | 139
shape torch.Size([8, 128, 128]) | 5960 | 350 | 140 | 354 | 208 | 142
shape torch.Size([16, 128, 128]) | 12000 | 701 | 144 | 177 | 198 | 143
shape torch.Size([32, 128, 128]) | 23870 | 1401 | 155 | 225 | 246 | 155
shape torch.Size([64, 128, 128]) | 47600 | 2948 | 179 | 288 | 323 | 180
shape torch.Size([128, 128, 128]) | 96000 | 5910 | 231 | 442 | 512 | 231
shape torch.Size([512, 128, 128]) | 381200 | 23640 | 519 | 1400 | 1700 | 519
shape torch.Size([1024, 128, 128]) | 755800 | 47340 | 794 | 2436 | 3018 | 794
shape torch.Size([1, 256, 256]) | 760 | 74 | 246 | 77 | 88 | 78
shape torch.Size([2, 256, 256]) | 1510 | 150 | 256 | 150 | 117 | 150
shape torch.Size([4, 256, 256]) | 3030 | 296 | 284 | 296 | 209 | 284
shape torch.Size([8, 256, 256]) | 6100 | 588 | 286 | 592 | 394 | 288
shape torch.Size([16, 256, 256]) | 12200 | 1238 | 330 | 445 | 480 | 330
shape torch.Size([32, 256, 256]) | 24430 | 2476 | 368 | 568 | 629 | 367
shape torch.Size([64, 256, 256]) | 49000 | 4950 | 415 | 800 | 921 | 414
shape torch.Size([128, 256, 256]) | 96000 | 9900 | 552 | 1330 | 1579 | 553
shape torch.Size([512, 256, 256]) | 369400 | 39580 | 1410 | 4614 | 5616 | 1410
shape torch.Size([1024, 256, 256]) | 716200 | 79200 | 2270 | 8472 | 10500 | 2277
Times are in microseconds (us).
```
</details>
To generate the results below, I put the backend I wanted to test at the beginning of the function `lu_solve_kernel`, followed by a `break;`. Then I run the following script, changing the variable `name`. For the `lu_solve unpack+solve_triangular`, I also changed the `stmt` variable (uncomenting the commented one)
<details>
<summary>
Benchmarking script
</summary>
```python
import torch
import pickle
import itertools
from functools import partial
from torch.utils.benchmark import Timer, Compare
name = "heuristic"
label = "lu_solve {}".format(name)
shapes = [1, 2, 8, 16, 32, 64, 128, 256]
batches = [(1,), (2,), (4,), (8,), (16,), (32,), (64,), (128,), (512,), (1024,)]
results = []
make_arg = partial(torch.randn, dtype=torch.float32, device="cuda")
def f(LU, pivots, B, adjoint):
P, L, U = torch.lu_unpack(LU, pivots)
if adjoint:
X = torch.linalg.solve_triangular(U.mH, B, upper=False)
return P @ torch.linalg.solve_triangular(L.mH, X, upper=True, unitriangular=True, out=X)
else:
X = P.mT @ B
X = torch.linalg.solve_triangular(L, X, upper=False, unitriangular=True, out=X)
return torch.linalg.solve_triangular(U, X, upper=True, out=X)
for n, batch in itertools.product(shapes, batches):
LU, pivots = torch.linalg.lu_factor(make_arg(batch + (n, n)))
B = make_arg(batch + (n, 1))
print(LU.shape)
stmt = "torch.linalg.lu_solve(LU, pivots, B, adjoint=adjoint)"
#stmt = "f(LU, pivots, B, adjoint=adjoint)"
for adjoint in (True, False):
timer = Timer(stmt,
globals=globals(),
label="linalg.lu_solve CUDA{}".format(" Adjoint" if adjoint else ""),
description=label,
sub_label=f"shape {LU.shape}",
num_threads=1)
results.append(timer.blocked_autorange())
compare = Compare(results)
compare.trim_significant_figures()
compare.print()
with open("{}_lu_solve.pickle".format(name), 'wb') as f:
pickle.dump(results, f)
```
</details>
Finally, I joined all the results with the following script:
<details>
<summary>
Script to join the results
</summary>
```python
import pickle
from torch.utils.benchmark import Timer, Compare
files = [
"looped_magma",
"looped cusolver",
"batched cublas",
"batched magma",
"unpack+solve_triangular",
"heuristic",
]
timers = []
for name in files:
with open("{}_lu_solve.pickle".format(name), 'rb') as f:
timers += pickle.load(f)
compare = Compare(timers)
compare.trim_significant_figures()
compare.print()
```
</details>
### Fix for Magma's batched lu_solve when `adjoint=True`
I also developed the following fix around MAGMA's bug, but I ended up not using it, and preferring the triangular solves over it, as they were faster. I'm leaving it here in case it's useful in the future.
<details>
<summary>
Fix for MAGMA's issue with `adjoint=True`
</summary>
```cpp
auto lu_solve_batched_magma_fn = [m](const Tensor& LU, const Tensor& pivots, const Tensor& B, TransposeType trans) {
if (trans == TransposeType::NoTranspose) {
lu_solve_batched_magma(LU, pivots, B, trans);
return;
}
// There's a bug in magma for the other cases, so we need to properly perform mT or mH on LU
// The LU of the transpose is not the transpose of the LU
// We need to do LU = LDU' = L'U' where L' = LD, U' = D^{-1}U and D = diag(U)
auto diag = LU.diagonal(0, -2, -1);
auto LU_f = LU.tril(-1).mul_(diag.unsqueeze(-2)) +
LU.triu(1).div_(diag.unsqueeze(-1));
LU_f.diagonal(0, -2, -1).copy_(diag);
if (trans == TransposeType::ConjTranspose) {
LU_f = LU_f.conj_physical();
}
LU_f.transpose(-2, -1);
// At this point LU_f is F-contiguous, because triu / tril / conj_phisical return contiguous tensors
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(LU_f.mT().is_contiguous());
// Trivial permutation
auto pivots_aux = at::arange(1, m + 1, pivots.options()).expand_as(pivots).contiguous();
lu_solve_batched_magma(LU_f, pivots_aux, B, TransposeType::NoTranspose);
// We then need to multiply B by P on the right as (PLU)^T = B iff U^TL^T = BP
// Fill `perm` with the identity permutation (perhaps batched)
// This is faster than torch.lu_unpack + matmul, as this logic is borrowed from lu_unpack
const auto perm = at::arange(m, pivots.options().dtype(kLong)).expand(pivots.sizes()).contiguous();
auto iter = TensorIteratorConfig()
.set_check_mem_overlap(false)
.check_all_same_dtype(false)
.resize_outputs(false)
.declare_static_shape(pivots.sizes(), /*squash_dim=*/pivots.dim() - 1)
.add_output(perm)
.add_input(pivots)
.build();
unpack_pivots_stub(pivots.device().type(), iter, m);
B.scatter_(-2, perm.unsqueeze(-1).expand_as(B), B.clone());
};
```
</details>
Fixes #61657
[ghstack-poisoned]
This PR adds `linalg.lu_solve`. While doing so, I found a bug in MAGMA
when calling the batched MAGMA backend with trans=True. We work around
that by solving the system solving two triangular systems.
We also update the heuristics for this function, as they were fairly
outdated. We found that cuSolver is king, so luckily we do not need to
rely on the buggy backend from magma for this function.
We added tests testing this function left and right. We also added tests
for the different backends. We also activated the tests for AMD, as
those should work as well.
### Benchmarking
<details>
<summary>
Benchmark Results (adjoint=False)
</summary>
```
--------------------------------------------------------------------------------------------- linalg.lu_solve CUDA ---------------------------------------------------------------------------------------------]
| lu_solve looped_magma | lu_solve looped cusolver | lu_solve batched cublas | lu_solve batched magma | lu_solve unpack+solve_triangular | lu_solve heuristic
1 threads: -------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
shape torch.Size([1, 1, 1]) | 750 | 34 | 28 | 252 | 86 | 27
shape torch.Size([2, 1, 1]) | 1500 | 50 | 28 | 239 | 94 | 27
shape torch.Size([4, 1, 1]) | 2995 | 83 | 28 | 239 | 100 | 27
shape torch.Size([8, 1, 1]) | 6000 | 146 | 28 | 239 | 94 | 27
shape torch.Size([16, 1, 1]) | 11900 | 272 | 28 | 241 | 95 | 27
shape torch.Size([32, 1, 1]) | 23880 | 524 | 28 | 244 | 94 | 27
shape torch.Size([64, 1, 1]) | 48000 | 1000 | 28 | 245 | 99 | 27
shape torch.Size([128, 1, 1]) | 96000 | 2054 | 28 | 242 | 96 | 27
shape torch.Size([512, 1, 1]) | 381900 | 8100 | 28 | 250 | 94 | 27
shape torch.Size([1024, 1, 1]) | 763800 | 16200 | 28 | 257 | 95 | 27
shape torch.Size([1, 2, 2]) | 750 | 33 | 28 | 240 | 88 | 27
shape torch.Size([2, 2, 2]) | 1500 | 51 | 28 | 240 | 96 | 27
shape torch.Size([4, 2, 2]) | 2991 | 82 | 28 | 241 | 96 | 28
shape torch.Size([8, 2, 2]) | 6000 | 150 | 28 | 241 | 96 | 27
shape torch.Size([16, 2, 2]) | 12000 | 275 | 28 | 242 | 96 | 27
shape torch.Size([32, 2, 2]) | 23980 | 530 | 28 | 246 | 97 | 28
shape torch.Size([64, 2, 2]) | 48000 | 1000 | 28 | 244 | 96 | 27
shape torch.Size([128, 2, 2]) | 96000 | 2063 | 28 | 245 | 96 | 28
shape torch.Size([512, 2, 2]) | 382000 | 8300 | 28 | 257 | 97 | 28
shape torch.Size([1024, 2, 2]) | 764000 | 20000 | 28 | 271 | 97 | 28
shape torch.Size([1, 8, 8]) | 749 | 34 | 28 | 243 | 88 | 28
shape torch.Size([2, 8, 8]) | 1500 | 50 | 28 | 244 | 97 | 28
shape torch.Size([4, 8, 8]) | 2988 | 83 | 28 | 244 | 100 | 28
shape torch.Size([8, 8, 8]) | 5980 | 150 | 28 | 245 | 97 | 28
shape torch.Size([16, 8, 8]) | 12000 | 278 | 28 | 246 | 96 | 28
shape torch.Size([32, 8, 8]) | 23910 | 536 | 28 | 249 | 98 | 28
shape torch.Size([64, 8, 8]) | 47800 | 1100 | 28 | 247 | 96 | 28
shape torch.Size([128, 8, 8]) | 96000 | 2075 | 28 | 248 | 96 | 28
shape torch.Size([512, 8, 8]) | 382100 | 8300 | 28 | 270 | 97 | 28
shape torch.Size([1024, 8, 8]) | 764100 | 16400 | 28 | 291 | 100 | 28
shape torch.Size([1, 16, 16]) | 750 | 33 | 28 | 248 | 88 | 28
shape torch.Size([2, 16, 16]) | 1500 | 50 | 28 | 250 | 97 | 28
shape torch.Size([4, 16, 16]) | 2996 | 83 | 28 | 250 | 100 | 28
shape torch.Size([8, 16, 16]) | 5980 | 147 | 28 | 251 | 97 | 28
shape torch.Size([16, 16, 16]) | 11900 | 274 | 28 | 251 | 97 | 28
shape torch.Size([32, 16, 16]) | 24040 | 527 | 28 | 252 | 97 | 28
shape torch.Size([64, 16, 16]) | 47800 | 1037 | 28 | 251 | 100 | 28
shape torch.Size([128, 16, 16]) | 95600 | 2044 | 28 | 252 | 98 | 28
shape torch.Size([512, 16, 16]) | 388200 | 8100 | 28 | 280 | 100 | 28
shape torch.Size([1024, 16, 16]) | 769700 | 16000 | 28 | 322 | 117 | 28
shape torch.Size([1, 32, 32]) | 760 | 33 | 28 | 255 | 88 | 28
shape torch.Size([2, 32, 32]) | 1510 | 50 | 28 | 256 | 97 | 28
shape torch.Size([4, 32, 32]) | 3022 | 82 | 31 | 256 | 97 | 30
shape torch.Size([8, 32, 32]) | 6000 | 140 | 31 | 257 | 100 | 31
shape torch.Size([16, 32, 32]) | 12000 | 281 | 31 | 258 | 96 | 31
shape torch.Size([32, 32, 32]) | 24150 | 563 | 35 | 258 | 96 | 35
shape torch.Size([64, 32, 32]) | 48300 | 1119 | 36 | 258 | 97 | 36
shape torch.Size([128, 32, 32]) | 96500 | 2235 | 43 | 261 | 96 | 43
shape torch.Size([512, 32, 32]) | 383100 | 8930 | 82 | 317 | 191 | 82
shape torch.Size([1024, 32, 32]) | 766300 | 19200 | 122 | 400 | 312 | 122
shape torch.Size([1, 64, 64]) | 760 | 33 | 55 | 272 | 68 | 34
shape torch.Size([2, 64, 64]) | 1500 | 52 | 58 | 273 | 85 | 52
shape torch.Size([4, 64, 64]) | 3127 | 102 | 65 | 273 | 150 | 65
shape torch.Size([8, 64, 64]) | 6070 | 201 | 65 | 275 | 278 | 65
shape torch.Size([16, 64, 64]) | 12000 | 399 | 66 | 274 | 95 | 67
shape torch.Size([32, 64, 64]) | 23900 | 796 | 73 | 275 | 97 | 73
shape torch.Size([64, 64, 64]) | 48000 | 1594 | 75 | 283 | 123 | 76
shape torch.Size([128, 64, 64]) | 95000 | 3177 | 96 | 292 | 176 | 96
shape torch.Size([512, 64, 64]) | 379300 | 13520 | 208 | 426 | 551 | 208
shape torch.Size([1024, 64, 64]) | 758700 | 27100 | 306 | 570 | 919 | 306
shape torch.Size([1, 128, 128]) | 750 | 42 | 115 | 306 | 90 | 42
shape torch.Size([2, 128, 128]) | 1500 | 82 | 122 | 307 | 164 | 83
shape torch.Size([4, 128, 128]) | 2966 | 162 | 136 | 307 | 301 | 136
shape torch.Size([8, 128, 128]) | 5930 | 317 | 137 | 308 | 578 | 138
shape torch.Size([16, 128, 128]) | 12000 | 635 | 143 | 316 | 199 | 143
shape torch.Size([32, 128, 128]) | 23700 | 1266 | 152 | 322 | 241 | 152
shape torch.Size([64, 128, 128]) | 48000 | 2668 | 177 | 337 | 322 | 177
shape torch.Size([128, 128, 128]) | 96000 | 5366 | 228 | 365 | 514 | 228
shape torch.Size([512, 128, 128]) | 379400 | 21490 | 502 | 620 | 1697 | 502
shape torch.Size([1024, 128, 128]) | 755700 | 43040 | 764 | 903 | 3040 | 770
shape torch.Size([1, 256, 256]) | 750 | 70 | 235 | 383 | 178 | 72
shape torch.Size([2, 256, 256]) | 2000 | 138 | 250 | 384 | 329 | 139
shape torch.Size([4, 256, 256]) | 2988 | 277 | 279 | 404 | 655 | 278
shape torch.Size([8, 256, 256]) | 6100 | 546 | 283 | 420 | 1321 | 286
shape torch.Size([16, 256, 256]) | 12100 | 1149 | 330 | 441 | 472 | 330
shape torch.Size([32, 256, 256]) | 24040 | 2303 | 359 | 453 | 634 | 360
shape torch.Size([64, 256, 256]) | 48000 | 4626 | 408 | 472 | 925 | 408
shape torch.Size([128, 256, 256]) | 94700 | 9247 | 543 | 543 | 1582 | 543
shape torch.Size([512, 256, 256]) | 372000 | 37030 | 1310 | 1185 | 5711 | 1310
shape torch.Size([1024, 256, 256]) | 747200 | 74100 | 2116 | 1910 | 10660 | 2122
```
</details>
<details>
<summary>
Benchmark Results (adjoint=True)
</summary>
```
[----------------------------------------------------------------------------------------- linalg.lu_solve CUDA Adjoint -----------------------------------------------------------------------------------------]
| lu_solve looped_magma | lu_solve looped cusolver | lu_solve batched cublas | lu_solve batched magma | lu_solve unpack+solve_triangular | lu_solve heuristic
1 threads: -------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
shape torch.Size([1, 1, 1]) | 749 | 34 | 28 | 33 | 98 | 27
shape torch.Size([2, 1, 1]) | 1500 | 50 | 28 | 50 | 110 | 27
shape torch.Size([4, 1, 1]) | 3005 | 82 | 28 | 81 | 110 | 27
shape torch.Size([8, 1, 1]) | 5999 | 145 | 28 | 140 | 110 | 27
shape torch.Size([16, 1, 1]) | 12000 | 273 | 28 | 77 | 110 | 27
shape torch.Size([32, 1, 1]) | 24000 | 522 | 28 | 78 | 110 | 27
shape torch.Size([64, 1, 1]) | 48000 | 1000 | 28 | 77 | 100 | 27
shape torch.Size([128, 1, 1]) | 96000 | 2029 | 28 | 78 | 110 | 27
shape torch.Size([512, 1, 1]) | 383300 | 8100 | 28 | 78 | 110 | 28
shape torch.Size([1024, 1, 1]) | 767500 | 16100 | 28 | 77 | 100 | 27
shape torch.Size([1, 2, 2]) | 753 | 33 | 28 | 33 | 99 | 28
shape torch.Size([2, 2, 2]) | 1500 | 50 | 28 | 50 | 110 | 28
shape torch.Size([4, 2, 2]) | 3002 | 82 | 28 | 80 | 100 | 27
shape torch.Size([8, 2, 2]) | 6000 | 145 | 28 | 144 | 107 | 27
shape torch.Size([16, 2, 2]) | 12000 | 271 | 28 | 78 | 110 | 27
shape torch.Size([32, 2, 2]) | 24120 | 524 | 28 | 78 | 110 | 28
shape torch.Size([64, 2, 2]) | 48300 | 1030 | 28 | 78 | 111 | 27
shape torch.Size([128, 2, 2]) | 96100 | 2041 | 28 | 78 | 107 | 28
shape torch.Size([512, 2, 2]) | 383000 | 8100 | 28 | 79 | 108 | 28
shape torch.Size([1024, 2, 2]) | 766100 | 16000 | 28 | 78 | 110 | 28
shape torch.Size([1, 8, 8]) | 750 | 34 | 28 | 34 | 99 | 28
shape torch.Size([2, 8, 8]) | 1500 | 50 | 28 | 50 | 107 | 28
shape torch.Size([4, 8, 8]) | 2998 | 82 | 28 | 82 | 110 | 28
shape torch.Size([8, 8, 8]) | 5990 | 146 | 28 | 150 | 107 | 28
shape torch.Size([16, 8, 8]) | 11980 | 272 | 28 | 79 | 107 | 28
shape torch.Size([32, 8, 8]) | 23970 | 530 | 28 | 79 | 110 | 28
shape torch.Size([64, 8, 8]) | 47900 | 1040 | 28 | 79 | 108 | 28
shape torch.Size([128, 8, 8]) | 96000 | 2048 | 28 | 78 | 108 | 28
shape torch.Size([512, 8, 8]) | 383700 | 8100 | 28 | 80 | 108 | 28
shape torch.Size([1024, 8, 8]) | 766200 | 16300 | 28 | 80 | 108 | 28
shape torch.Size([1, 16, 16]) | 760 | 33 | 28 | 34 | 99 | 28
shape torch.Size([2, 16, 16]) | 1500 | 50 | 28 | 50 | 110 | 28 [85/469]
shape torch.Size([4, 16, 16]) | 3001 | 81 | 28 | 82 | 108 | 28
shape torch.Size([8, 16, 16]) | 6000 | 145 | 28 | 140 | 110 | 28
shape torch.Size([16, 16, 16]) | 12000 | 276 | 28 | 79 | 110 | 28
shape torch.Size([32, 16, 16]) | 23870 | 549 | 28 | 79 | 110 | 28
shape torch.Size([64, 16, 16]) | 47900 | 1098 | 29 | 80 | 100 | 28
shape torch.Size([128, 16, 16]) | 95800 | 2184 | 28 | 79 | 108 | 28
shape torch.Size([512, 16, 16]) | 386900 | 8769 | 28 | 80 | 108 | 28
shape torch.Size([1024, 16, 16]) | 769800 | 17460 | 37 | 80 | 107 | 37
shape torch.Size([1, 32, 32]) | 760 | 33 | 28 | 34 | 99 | 28
shape torch.Size([2, 32, 32]) | 1500 | 50 | 28 | 50 | 110 | 29
shape torch.Size([4, 32, 32]) | 3021 | 86 | 31 | 84 | 110 | 32
shape torch.Size([8, 32, 32]) | 6040 | 167 | 32 | 167 | 108 | 32
shape torch.Size([16, 32, 32]) | 12100 | 330 | 33 | 78 | 107 | 33
shape torch.Size([32, 32, 32]) | 24150 | 662 | 35 | 78 | 110 | 35
shape torch.Size([64, 32, 32]) | 48200 | 1323 | 36 | 79 | 110 | 36
shape torch.Size([128, 32, 32]) | 97000 | 2637 | 44 | 79 | 110 | 43
shape torch.Size([512, 32, 32]) | 382500 | 10580 | 83 | 180 | 198 | 83
shape torch.Size([1024, 32, 32]) | 766600 | 22670 | 123 | 260 | 318 | 120
shape torch.Size([1, 64, 64]) | 760 | 33 | 58 | 33 | 88 | 34
shape torch.Size([2, 64, 64]) | 1520 | 60 | 60 | 60 | 109 | 59
shape torch.Size([4, 64, 64]) | 3016 | 115 | 67 | 119 | 132 | 66
shape torch.Size([8, 64, 64]) | 6120 | 230 | 67 | 233 | 180 | 68
shape torch.Size([16, 64, 64]) | 12100 | 457 | 69 | 86 | 107 | 69
shape torch.Size([32, 64, 64]) | 24000 | 912 | 74 | 95 | 100 | 74
shape torch.Size([64, 64, 64]) | 48000 | 1833 | 76 | 106 | 123 | 76
shape torch.Size([128, 64, 64]) | 95000 | 3636 | 97 | 163 | 183 | 97
shape torch.Size([512, 64, 64]) | 380600 | 15600 | 210 | 464 | 549 | 210
shape torch.Size([1024, 64, 64]) | 761200 | 31140 | 308 | 741 | 918 | 308
shape torch.Size([1, 128, 128]) | 756 | 46 | 120 | 47 | 89 | 46
shape torch.Size([2, 128, 128]) | 1500 | 91 | 123 | 89 | 110 | 92
shape torch.Size([4, 128, 128]) | 2994 | 178 | 139 | 180 | 131 | 139
shape torch.Size([8, 128, 128]) | 5960 | 350 | 140 | 354 | 208 | 142
shape torch.Size([16, 128, 128]) | 12000 | 701 | 144 | 177 | 198 | 143
shape torch.Size([32, 128, 128]) | 23870 | 1401 | 155 | 225 | 246 | 155
shape torch.Size([64, 128, 128]) | 47600 | 2948 | 179 | 288 | 323 | 180
shape torch.Size([128, 128, 128]) | 96000 | 5910 | 231 | 442 | 512 | 231
shape torch.Size([512, 128, 128]) | 381200 | 23640 | 519 | 1400 | 1700 | 519
shape torch.Size([1024, 128, 128]) | 755800 | 47340 | 794 | 2436 | 3018 | 794
shape torch.Size([1, 256, 256]) | 760 | 74 | 246 | 77 | 88 | 78
shape torch.Size([2, 256, 256]) | 1510 | 150 | 256 | 150 | 117 | 150
shape torch.Size([4, 256, 256]) | 3030 | 296 | 284 | 296 | 209 | 284
shape torch.Size([8, 256, 256]) | 6100 | 588 | 286 | 592 | 394 | 288
shape torch.Size([16, 256, 256]) | 12200 | 1238 | 330 | 445 | 480 | 330
shape torch.Size([32, 256, 256]) | 24430 | 2476 | 368 | 568 | 629 | 367
shape torch.Size([64, 256, 256]) | 49000 | 4950 | 415 | 800 | 921 | 414
shape torch.Size([128, 256, 256]) | 96000 | 9900 | 552 | 1330 | 1579 | 553
shape torch.Size([512, 256, 256]) | 369400 | 39580 | 1410 | 4614 | 5616 | 1410
shape torch.Size([1024, 256, 256]) | 716200 | 79200 | 2270 | 8472 | 10500 | 2277
Times are in microseconds (us).
```
</details>
To generate the results below, I put the backend I wanted to test at the beginning of the function `lu_solve_kernel`, followed by a `break;`. Then I run the following script, changing the variable `name`. For the `lu_solve unpack+solve_triangular`, I also changed the `stmt` variable (uncomenting the commented one)
<details>
<summary>
Benchmarking script
</summary>
```python
import torch
import pickle
import itertools
from functools import partial
from torch.utils.benchmark import Timer, Compare
name = "heuristic"
label = "lu_solve {}".format(name)
shapes = [1, 2, 8, 16, 32, 64, 128, 256]
batches = [(1,), (2,), (4,), (8,), (16,), (32,), (64,), (128,), (512,), (1024,)]
results = []
make_arg = partial(torch.randn, dtype=torch.float32, device="cuda")
def f(LU, pivots, B, adjoint):
P, L, U = torch.lu_unpack(LU, pivots)
if adjoint:
X = torch.linalg.solve_triangular(U.mH, B, upper=False)
return P @ torch.linalg.solve_triangular(L.mH, X, upper=True, unitriangular=True, out=X)
else:
X = P.mT @ B
X = torch.linalg.solve_triangular(L, X, upper=False, unitriangular=True, out=X)
return torch.linalg.solve_triangular(U, X, upper=True, out=X)
for n, batch in itertools.product(shapes, batches):
LU, pivots = torch.linalg.lu_factor(make_arg(batch + (n, n)))
B = make_arg(batch + (n, 1))
print(LU.shape)
stmt = "torch.linalg.lu_solve(LU, pivots, B, adjoint=adjoint)"
#stmt = "f(LU, pivots, B, adjoint=adjoint)"
for adjoint in (True, False):
timer = Timer(stmt,
globals=globals(),
label="linalg.lu_solve CUDA{}".format(" Adjoint" if adjoint else ""),
description=label,
sub_label=f"shape {LU.shape}",
num_threads=1)
results.append(timer.blocked_autorange())
compare = Compare(results)
compare.trim_significant_figures()
compare.print()
with open("{}_lu_solve.pickle".format(name), 'wb') as f:
pickle.dump(results, f)
```
</details>
Finally, I joined all the results with the following script:
<details>
<summary>
Script to join the results
</summary>
```python
import pickle
from torch.utils.benchmark import Timer, Compare
files = [
"looped_magma",
"looped cusolver",
"batched cublas",
"batched magma",
"unpack+solve_triangular",
"heuristic",
]
timers = []
for name in files:
with open("{}_lu_solve.pickle".format(name), 'rb') as f:
timers += pickle.load(f)
compare = Compare(timers)
compare.trim_significant_figures()
compare.print()
```
</details>
### Fix for Magma's batched lu_solve when `adjoint=True`
I also developed the following fix around MAGMA's bug, but I ended up not using it, and preferring the triangular solves over it, as they were faster. I'm leaving it here in case it's useful in the future.
<details>
<summary>
Fix for MAGMA's issue with `adjoint=True`
</summary>
```cpp
auto lu_solve_batched_magma_fn = [m](const Tensor& LU, const Tensor& pivots, const Tensor& B, TransposeType trans) {
if (trans == TransposeType::NoTranspose) {
lu_solve_batched_magma(LU, pivots, B, trans);
return;
}
// There's a bug in magma for the other cases, so we need to properly perform mT or mH on LU
// The LU of the transpose is not the transpose of the LU
// We need to do LU = LDU' = L'U' where L' = LD, U' = D^{-1}U and D = diag(U)
auto diag = LU.diagonal(0, -2, -1);
auto LU_f = LU.tril(-1).mul_(diag.unsqueeze(-2)) +
LU.triu(1).div_(diag.unsqueeze(-1));
LU_f.diagonal(0, -2, -1).copy_(diag);
if (trans == TransposeType::ConjTranspose) {
LU_f = LU_f.conj_physical();
}
LU_f.transpose(-2, -1);
// At this point LU_f is F-contiguous, because triu / tril / conj_phisical return contiguous tensors
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(LU_f.mT().is_contiguous());
// Trivial permutation
auto pivots_aux = at::arange(1, m + 1, pivots.options()).expand_as(pivots).contiguous();
lu_solve_batched_magma(LU_f, pivots_aux, B, TransposeType::NoTranspose);
// We then need to multiply B by P on the right as (PLU)^T = B iff U^TL^T = BP
// Fill `perm` with the identity permutation (perhaps batched)
// This is faster than torch.lu_unpack + matmul, as this logic is borrowed from lu_unpack
const auto perm = at::arange(m, pivots.options().dtype(kLong)).expand(pivots.sizes()).contiguous();
auto iter = TensorIteratorConfig()
.set_check_mem_overlap(false)
.check_all_same_dtype(false)
.resize_outputs(false)
.declare_static_shape(pivots.sizes(), /*squash_dim=*/pivots.dim() - 1)
.add_output(perm)
.add_input(pivots)
.build();
unpack_pivots_stub(pivots.device().type(), iter, m);
B.scatter_(-2, perm.unsqueeze(-1).expand_as(B), B.clone());
};
```
</details>
Fixes #61657
[ghstack-poisoned]
This PR adds `linalg.lu_solve`. While doing so, I found a bug in MAGMA
when calling the batched MAGMA backend with trans=True. We work around
that by solving the system solving two triangular systems.
We also update the heuristics for this function, as they were fairly
outdated. We found that cuSolver is king, so luckily we do not need to
rely on the buggy backend from magma for this function.
We added tests testing this function left and right. We also added tests
for the different backends. We also activated the tests for AMD, as
those should work as well.
### Benchmarking
<details>
<summary>
Benchmark Results (adjoint=False)
</summary>
```
--------------------------------------------------------------------------------------------- linalg.lu_solve CUDA ---------------------------------------------------------------------------------------------]
| lu_solve looped_magma | lu_solve looped cusolver | lu_solve batched cublas | lu_solve batched magma | lu_solve unpack+solve_triangular | lu_solve heuristic
1 threads: -------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
shape torch.Size([1, 1, 1]) | 750 | 34 | 28 | 252 | 86 | 27
shape torch.Size([2, 1, 1]) | 1500 | 50 | 28 | 239 | 94 | 27
shape torch.Size([4, 1, 1]) | 2995 | 83 | 28 | 239 | 100 | 27
shape torch.Size([8, 1, 1]) | 6000 | 146 | 28 | 239 | 94 | 27
shape torch.Size([16, 1, 1]) | 11900 | 272 | 28 | 241 | 95 | 27
shape torch.Size([32, 1, 1]) | 23880 | 524 | 28 | 244 | 94 | 27
shape torch.Size([64, 1, 1]) | 48000 | 1000 | 28 | 245 | 99 | 27
shape torch.Size([128, 1, 1]) | 96000 | 2054 | 28 | 242 | 96 | 27
shape torch.Size([512, 1, 1]) | 381900 | 8100 | 28 | 250 | 94 | 27
shape torch.Size([1024, 1, 1]) | 763800 | 16200 | 28 | 257 | 95 | 27
shape torch.Size([1, 2, 2]) | 750 | 33 | 28 | 240 | 88 | 27
shape torch.Size([2, 2, 2]) | 1500 | 51 | 28 | 240 | 96 | 27
shape torch.Size([4, 2, 2]) | 2991 | 82 | 28 | 241 | 96 | 28
shape torch.Size([8, 2, 2]) | 6000 | 150 | 28 | 241 | 96 | 27
shape torch.Size([16, 2, 2]) | 12000 | 275 | 28 | 242 | 96 | 27
shape torch.Size([32, 2, 2]) | 23980 | 530 | 28 | 246 | 97 | 28
shape torch.Size([64, 2, 2]) | 48000 | 1000 | 28 | 244 | 96 | 27
shape torch.Size([128, 2, 2]) | 96000 | 2063 | 28 | 245 | 96 | 28
shape torch.Size([512, 2, 2]) | 382000 | 8300 | 28 | 257 | 97 | 28
shape torch.Size([1024, 2, 2]) | 764000 | 20000 | 28 | 271 | 97 | 28
shape torch.Size([1, 8, 8]) | 749 | 34 | 28 | 243 | 88 | 28
shape torch.Size([2, 8, 8]) | 1500 | 50 | 28 | 244 | 97 | 28
shape torch.Size([4, 8, 8]) | 2988 | 83 | 28 | 244 | 100 | 28
shape torch.Size([8, 8, 8]) | 5980 | 150 | 28 | 245 | 97 | 28
shape torch.Size([16, 8, 8]) | 12000 | 278 | 28 | 246 | 96 | 28
shape torch.Size([32, 8, 8]) | 23910 | 536 | 28 | 249 | 98 | 28
shape torch.Size([64, 8, 8]) | 47800 | 1100 | 28 | 247 | 96 | 28
shape torch.Size([128, 8, 8]) | 96000 | 2075 | 28 | 248 | 96 | 28
shape torch.Size([512, 8, 8]) | 382100 | 8300 | 28 | 270 | 97 | 28
shape torch.Size([1024, 8, 8]) | 764100 | 16400 | 28 | 291 | 100 | 28
shape torch.Size([1, 16, 16]) | 750 | 33 | 28 | 248 | 88 | 28
shape torch.Size([2, 16, 16]) | 1500 | 50 | 28 | 250 | 97 | 28
shape torch.Size([4, 16, 16]) | 2996 | 83 | 28 | 250 | 100 | 28
shape torch.Size([8, 16, 16]) | 5980 | 147 | 28 | 251 | 97 | 28
shape torch.Size([16, 16, 16]) | 11900 | 274 | 28 | 251 | 97 | 28
shape torch.Size([32, 16, 16]) | 24040 | 527 | 28 | 252 | 97 | 28
shape torch.Size([64, 16, 16]) | 47800 | 1037 | 28 | 251 | 100 | 28
shape torch.Size([128, 16, 16]) | 95600 | 2044 | 28 | 252 | 98 | 28
shape torch.Size([512, 16, 16]) | 388200 | 8100 | 28 | 280 | 100 | 28
shape torch.Size([1024, 16, 16]) | 769700 | 16000 | 28 | 322 | 117 | 28
shape torch.Size([1, 32, 32]) | 760 | 33 | 28 | 255 | 88 | 28
shape torch.Size([2, 32, 32]) | 1510 | 50 | 28 | 256 | 97 | 28
shape torch.Size([4, 32, 32]) | 3022 | 82 | 31 | 256 | 97 | 30
shape torch.Size([8, 32, 32]) | 6000 | 140 | 31 | 257 | 100 | 31
shape torch.Size([16, 32, 32]) | 12000 | 281 | 31 | 258 | 96 | 31
shape torch.Size([32, 32, 32]) | 24150 | 563 | 35 | 258 | 96 | 35
shape torch.Size([64, 32, 32]) | 48300 | 1119 | 36 | 258 | 97 | 36
shape torch.Size([128, 32, 32]) | 96500 | 2235 | 43 | 261 | 96 | 43
shape torch.Size([512, 32, 32]) | 383100 | 8930 | 82 | 317 | 191 | 82
shape torch.Size([1024, 32, 32]) | 766300 | 19200 | 122 | 400 | 312 | 122
shape torch.Size([1, 64, 64]) | 760 | 33 | 55 | 272 | 68 | 34
shape torch.Size([2, 64, 64]) | 1500 | 52 | 58 | 273 | 85 | 52
shape torch.Size([4, 64, 64]) | 3127 | 102 | 65 | 273 | 150 | 65
shape torch.Size([8, 64, 64]) | 6070 | 201 | 65 | 275 | 278 | 65
shape torch.Size([16, 64, 64]) | 12000 | 399 | 66 | 274 | 95 | 67
shape torch.Size([32, 64, 64]) | 23900 | 796 | 73 | 275 | 97 | 73
shape torch.Size([64, 64, 64]) | 48000 | 1594 | 75 | 283 | 123 | 76
shape torch.Size([128, 64, 64]) | 95000 | 3177 | 96 | 292 | 176 | 96
shape torch.Size([512, 64, 64]) | 379300 | 13520 | 208 | 426 | 551 | 208
shape torch.Size([1024, 64, 64]) | 758700 | 27100 | 306 | 570 | 919 | 306
shape torch.Size([1, 128, 128]) | 750 | 42 | 115 | 306 | 90 | 42
shape torch.Size([2, 128, 128]) | 1500 | 82 | 122 | 307 | 164 | 83
shape torch.Size([4, 128, 128]) | 2966 | 162 | 136 | 307 | 301 | 136
shape torch.Size([8, 128, 128]) | 5930 | 317 | 137 | 308 | 578 | 138
shape torch.Size([16, 128, 128]) | 12000 | 635 | 143 | 316 | 199 | 143
shape torch.Size([32, 128, 128]) | 23700 | 1266 | 152 | 322 | 241 | 152
shape torch.Size([64, 128, 128]) | 48000 | 2668 | 177 | 337 | 322 | 177
shape torch.Size([128, 128, 128]) | 96000 | 5366 | 228 | 365 | 514 | 228
shape torch.Size([512, 128, 128]) | 379400 | 21490 | 502 | 620 | 1697 | 502
shape torch.Size([1024, 128, 128]) | 755700 | 43040 | 764 | 903 | 3040 | 770
shape torch.Size([1, 256, 256]) | 750 | 70 | 235 | 383 | 178 | 72
shape torch.Size([2, 256, 256]) | 2000 | 138 | 250 | 384 | 329 | 139
shape torch.Size([4, 256, 256]) | 2988 | 277 | 279 | 404 | 655 | 278
shape torch.Size([8, 256, 256]) | 6100 | 546 | 283 | 420 | 1321 | 286
shape torch.Size([16, 256, 256]) | 12100 | 1149 | 330 | 441 | 472 | 330
shape torch.Size([32, 256, 256]) | 24040 | 2303 | 359 | 453 | 634 | 360
shape torch.Size([64, 256, 256]) | 48000 | 4626 | 408 | 472 | 925 | 408
shape torch.Size([128, 256, 256]) | 94700 | 9247 | 543 | 543 | 1582 | 543
shape torch.Size([512, 256, 256]) | 372000 | 37030 | 1310 | 1185 | 5711 | 1310
shape torch.Size([1024, 256, 256]) | 747200 | 74100 | 2116 | 1910 | 10660 | 2122
```
</details>
<details>
<summary>
Benchmark Results (adjoint=True)
</summary>
```
[----------------------------------------------------------------------------------------- linalg.lu_solve CUDA Adjoint -----------------------------------------------------------------------------------------]
| lu_solve looped_magma | lu_solve looped cusolver | lu_solve batched cublas | lu_solve batched magma | lu_solve unpack+solve_triangular | lu_solve heuristic
1 threads: -------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
shape torch.Size([1, 1, 1]) | 749 | 34 | 28 | 33 | 98 | 27
shape torch.Size([2, 1, 1]) | 1500 | 50 | 28 | 50 | 110 | 27
shape torch.Size([4, 1, 1]) | 3005 | 82 | 28 | 81 | 110 | 27
shape torch.Size([8, 1, 1]) | 5999 | 145 | 28 | 140 | 110 | 27
shape torch.Size([16, 1, 1]) | 12000 | 273 | 28 | 77 | 110 | 27
shape torch.Size([32, 1, 1]) | 24000 | 522 | 28 | 78 | 110 | 27
shape torch.Size([64, 1, 1]) | 48000 | 1000 | 28 | 77 | 100 | 27
shape torch.Size([128, 1, 1]) | 96000 | 2029 | 28 | 78 | 110 | 27
shape torch.Size([512, 1, 1]) | 383300 | 8100 | 28 | 78 | 110 | 28
shape torch.Size([1024, 1, 1]) | 767500 | 16100 | 28 | 77 | 100 | 27
shape torch.Size([1, 2, 2]) | 753 | 33 | 28 | 33 | 99 | 28
shape torch.Size([2, 2, 2]) | 1500 | 50 | 28 | 50 | 110 | 28
shape torch.Size([4, 2, 2]) | 3002 | 82 | 28 | 80 | 100 | 27
shape torch.Size([8, 2, 2]) | 6000 | 145 | 28 | 144 | 107 | 27
shape torch.Size([16, 2, 2]) | 12000 | 271 | 28 | 78 | 110 | 27
shape torch.Size([32, 2, 2]) | 24120 | 524 | 28 | 78 | 110 | 28
shape torch.Size([64, 2, 2]) | 48300 | 1030 | 28 | 78 | 111 | 27
shape torch.Size([128, 2, 2]) | 96100 | 2041 | 28 | 78 | 107 | 28
shape torch.Size([512, 2, 2]) | 383000 | 8100 | 28 | 79 | 108 | 28
shape torch.Size([1024, 2, 2]) | 766100 | 16000 | 28 | 78 | 110 | 28
shape torch.Size([1, 8, 8]) | 750 | 34 | 28 | 34 | 99 | 28
shape torch.Size([2, 8, 8]) | 1500 | 50 | 28 | 50 | 107 | 28
shape torch.Size([4, 8, 8]) | 2998 | 82 | 28 | 82 | 110 | 28
shape torch.Size([8, 8, 8]) | 5990 | 146 | 28 | 150 | 107 | 28
shape torch.Size([16, 8, 8]) | 11980 | 272 | 28 | 79 | 107 | 28
shape torch.Size([32, 8, 8]) | 23970 | 530 | 28 | 79 | 110 | 28
shape torch.Size([64, 8, 8]) | 47900 | 1040 | 28 | 79 | 108 | 28
shape torch.Size([128, 8, 8]) | 96000 | 2048 | 28 | 78 | 108 | 28
shape torch.Size([512, 8, 8]) | 383700 | 8100 | 28 | 80 | 108 | 28
shape torch.Size([1024, 8, 8]) | 766200 | 16300 | 28 | 80 | 108 | 28
shape torch.Size([1, 16, 16]) | 760 | 33 | 28 | 34 | 99 | 28
shape torch.Size([2, 16, 16]) | 1500 | 50 | 28 | 50 | 110 | 28 [85/469]
shape torch.Size([4, 16, 16]) | 3001 | 81 | 28 | 82 | 108 | 28
shape torch.Size([8, 16, 16]) | 6000 | 145 | 28 | 140 | 110 | 28
shape torch.Size([16, 16, 16]) | 12000 | 276 | 28 | 79 | 110 | 28
shape torch.Size([32, 16, 16]) | 23870 | 549 | 28 | 79 | 110 | 28
shape torch.Size([64, 16, 16]) | 47900 | 1098 | 29 | 80 | 100 | 28
shape torch.Size([128, 16, 16]) | 95800 | 2184 | 28 | 79 | 108 | 28
shape torch.Size([512, 16, 16]) | 386900 | 8769 | 28 | 80 | 108 | 28
shape torch.Size([1024, 16, 16]) | 769800 | 17460 | 37 | 80 | 107 | 37
shape torch.Size([1, 32, 32]) | 760 | 33 | 28 | 34 | 99 | 28
shape torch.Size([2, 32, 32]) | 1500 | 50 | 28 | 50 | 110 | 29
shape torch.Size([4, 32, 32]) | 3021 | 86 | 31 | 84 | 110 | 32
shape torch.Size([8, 32, 32]) | 6040 | 167 | 32 | 167 | 108 | 32
shape torch.Size([16, 32, 32]) | 12100 | 330 | 33 | 78 | 107 | 33
shape torch.Size([32, 32, 32]) | 24150 | 662 | 35 | 78 | 110 | 35
shape torch.Size([64, 32, 32]) | 48200 | 1323 | 36 | 79 | 110 | 36
shape torch.Size([128, 32, 32]) | 97000 | 2637 | 44 | 79 | 110 | 43
shape torch.Size([512, 32, 32]) | 382500 | 10580 | 83 | 180 | 198 | 83
shape torch.Size([1024, 32, 32]) | 766600 | 22670 | 123 | 260 | 318 | 120
shape torch.Size([1, 64, 64]) | 760 | 33 | 58 | 33 | 88 | 34
shape torch.Size([2, 64, 64]) | 1520 | 60 | 60 | 60 | 109 | 59
shape torch.Size([4, 64, 64]) | 3016 | 115 | 67 | 119 | 132 | 66
shape torch.Size([8, 64, 64]) | 6120 | 230 | 67 | 233 | 180 | 68
shape torch.Size([16, 64, 64]) | 12100 | 457 | 69 | 86 | 107 | 69
shape torch.Size([32, 64, 64]) | 24000 | 912 | 74 | 95 | 100 | 74
shape torch.Size([64, 64, 64]) | 48000 | 1833 | 76 | 106 | 123 | 76
shape torch.Size([128, 64, 64]) | 95000 | 3636 | 97 | 163 | 183 | 97
shape torch.Size([512, 64, 64]) | 380600 | 15600 | 210 | 464 | 549 | 210
shape torch.Size([1024, 64, 64]) | 761200 | 31140 | 308 | 741 | 918 | 308
shape torch.Size([1, 128, 128]) | 756 | 46 | 120 | 47 | 89 | 46
shape torch.Size([2, 128, 128]) | 1500 | 91 | 123 | 89 | 110 | 92
shape torch.Size([4, 128, 128]) | 2994 | 178 | 139 | 180 | 131 | 139
shape torch.Size([8, 128, 128]) | 5960 | 350 | 140 | 354 | 208 | 142
shape torch.Size([16, 128, 128]) | 12000 | 701 | 144 | 177 | 198 | 143
shape torch.Size([32, 128, 128]) | 23870 | 1401 | 155 | 225 | 246 | 155
shape torch.Size([64, 128, 128]) | 47600 | 2948 | 179 | 288 | 323 | 180
shape torch.Size([128, 128, 128]) | 96000 | 5910 | 231 | 442 | 512 | 231
shape torch.Size([512, 128, 128]) | 381200 | 23640 | 519 | 1400 | 1700 | 519
shape torch.Size([1024, 128, 128]) | 755800 | 47340 | 794 | 2436 | 3018 | 794
shape torch.Size([1, 256, 256]) | 760 | 74 | 246 | 77 | 88 | 78
shape torch.Size([2, 256, 256]) | 1510 | 150 | 256 | 150 | 117 | 150
shape torch.Size([4, 256, 256]) | 3030 | 296 | 284 | 296 | 209 | 284
shape torch.Size([8, 256, 256]) | 6100 | 588 | 286 | 592 | 394 | 288
shape torch.Size([16, 256, 256]) | 12200 | 1238 | 330 | 445 | 480 | 330
shape torch.Size([32, 256, 256]) | 24430 | 2476 | 368 | 568 | 629 | 367
shape torch.Size([64, 256, 256]) | 49000 | 4950 | 415 | 800 | 921 | 414
shape torch.Size([128, 256, 256]) | 96000 | 9900 | 552 | 1330 | 1579 | 553
shape torch.Size([512, 256, 256]) | 369400 | 39580 | 1410 | 4614 | 5616 | 1410
shape torch.Size([1024, 256, 256]) | 716200 | 79200 | 2270 | 8472 | 10500 | 2277
Times are in microseconds (us).
```
</details>
To generate the results below, I put the backend I wanted to test at the beginning of the function `lu_solve_kernel`, followed by a `break;`. Then I run the following script, changing the variable `name`. For the `lu_solve unpack+solve_triangular`, I also changed the `stmt` variable (uncomenting the commented one)
<details>
<summary>
Benchmarking script
</summary>
```python
import torch
import pickle
import itertools
from functools import partial
from torch.utils.benchmark import Timer, Compare
name = "heuristic"
label = "lu_solve {}".format(name)
shapes = [1, 2, 8, 16, 32, 64, 128, 256]
batches = [(1,), (2,), (4,), (8,), (16,), (32,), (64,), (128,), (512,), (1024,)]
results = []
make_arg = partial(torch.randn, dtype=torch.float32, device="cuda")
def f(LU, pivots, B, adjoint):
P, L, U = torch.lu_unpack(LU, pivots)
if adjoint:
X = torch.linalg.solve_triangular(U.mH, B, upper=False)
return P @ torch.linalg.solve_triangular(L.mH, X, upper=True, unitriangular=True, out=X)
else:
X = P.mT @ B
X = torch.linalg.solve_triangular(L, X, upper=False, unitriangular=True, out=X)
return torch.linalg.solve_triangular(U, X, upper=True, out=X)
for n, batch in itertools.product(shapes, batches):
LU, pivots = torch.linalg.lu_factor(make_arg(batch + (n, n)))
B = make_arg(batch + (n, 1))
print(LU.shape)
stmt = "torch.linalg.lu_solve(LU, pivots, B, adjoint=adjoint)"
#stmt = "f(LU, pivots, B, adjoint=adjoint)"
for adjoint in (True, False):
timer = Timer(stmt,
globals=globals(),
label="linalg.lu_solve CUDA{}".format(" Adjoint" if adjoint else ""),
description=label,
sub_label=f"shape {LU.shape}",
num_threads=1)
results.append(timer.blocked_autorange())
compare = Compare(results)
compare.trim_significant_figures()
compare.print()
with open("{}_lu_solve.pickle".format(name), 'wb') as f:
pickle.dump(results, f)
```
</details>
Finally, I joined all the results with the following script:
<details>
<summary>
Script to join the results
</summary>
```python
import pickle
from torch.utils.benchmark import Timer, Compare
files = [
"looped_magma",
"looped cusolver",
"batched cublas",
"batched magma",
"unpack+solve_triangular",
"heuristic",
]
timers = []
for name in files:
with open("{}_lu_solve.pickle".format(name), 'rb') as f:
timers += pickle.load(f)
compare = Compare(timers)
compare.trim_significant_figures()
compare.print()
```
</details>
### Fix for Magma's batched lu_solve when `adjoint=True`
I also developed the following fix around MAGMA's bug, but I ended up not using it, and preferring the triangular solves over it, as they were faster. I'm leaving it here in case it's useful in the future.
<details>
<summary>
Fix for MAGMA's issue with `adjoint=True`
</summary>
```cpp
auto lu_solve_batched_magma_fn = [m](const Tensor& LU, const Tensor& pivots, const Tensor& B, TransposeType trans) {
if (trans == TransposeType::NoTranspose) {
lu_solve_batched_magma(LU, pivots, B, trans);
return;
}
// There's a bug in magma for the other cases, so we need to properly perform mT or mH on LU
// The LU of the transpose is not the transpose of the LU
// We need to do LU = LDU' = L'U' where L' = LD, U' = D^{-1}U and D = diag(U)
auto diag = LU.diagonal(0, -2, -1);
auto LU_f = LU.tril(-1).mul_(diag.unsqueeze(-2)) +
LU.triu(1).div_(diag.unsqueeze(-1));
LU_f.diagonal(0, -2, -1).copy_(diag);
if (trans == TransposeType::ConjTranspose) {
LU_f = LU_f.conj_physical();
}
LU_f.transpose(-2, -1);
// At this point LU_f is F-contiguous, because triu / tril / conj_phisical return contiguous tensors
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(LU_f.mT().is_contiguous());
// Trivial permutation
auto pivots_aux = at::arange(1, m + 1, pivots.options()).expand_as(pivots).contiguous();
lu_solve_batched_magma(LU_f, pivots_aux, B, TransposeType::NoTranspose);
// We then need to multiply B by P on the right as (PLU)^T = B iff U^TL^T = BP
// Fill `perm` with the identity permutation (perhaps batched)
// This is faster than torch.lu_unpack + matmul, as this logic is borrowed from lu_unpack
const auto perm = at::arange(m, pivots.options().dtype(kLong)).expand(pivots.sizes()).contiguous();
auto iter = TensorIteratorConfig()
.set_check_mem_overlap(false)
.check_all_same_dtype(false)
.resize_outputs(false)
.declare_static_shape(pivots.sizes(), /*squash_dim=*/pivots.dim() - 1)
.add_output(perm)
.add_input(pivots)
.build();
unpack_pivots_stub(pivots.device().type(), iter, m);
B.scatter_(-2, perm.unsqueeze(-1).expand_as(B), B.clone());
};
```
</details>
Fixes #61657
[ghstack-poisoned]
This PR adds `linalg.lu_solve`. While doing so, I found a bug in MAGMA
when calling the batched MAGMA backend with trans=True. We work around
that by solving the system solving two triangular systems.
We also update the heuristics for this function, as they were fairly
outdated. We found that cuSolver is king, so luckily we do not need to
rely on the buggy backend from magma for this function.
We added tests testing this function left and right. We also added tests
for the different backends. We also activated the tests for AMD, as
those should work as well.
### Benchmarking
<details>
<summary>
Benchmark Results (adjoint=False)
</summary>
```
--------------------------------------------------------------------------------------------- linalg.lu_solve CUDA ---------------------------------------------------------------------------------------------]
| lu_solve looped_magma | lu_solve looped cusolver | lu_solve batched cublas | lu_solve batched magma | lu_solve unpack+solve_triangular | lu_solve heuristic
1 threads: -------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
shape torch.Size([1, 1, 1]) | 750 | 34 | 28 | 252 | 78 | 27
shape torch.Size([2, 1, 1]) | 1500 | 50 | 28 | 239 | 85 | 27
shape torch.Size([4, 1, 1]) | 2995 | 83 | 28 | 239 | 85 | 27
shape torch.Size([8, 1, 1]) | 6000 | 146 | 28 | 239 | 85 | 27
shape torch.Size([16, 1, 1]) | 11900 | 272 | 28 | 241 | 85 | 27
shape torch.Size([32, 1, 1]) | 23880 | 524 | 28 | 244 | 85 | 27
shape torch.Size([64, 1, 1]) | 48000 | 1000 | 28 | 245 | 85 | 27
shape torch.Size([128, 1, 1]) | 96000 | 2054 | 28 | 242 | 86 | 27
shape torch.Size([512, 1, 1]) | 381900 | 8100 | 28 | 250 | 85 | 27
shape torch.Size([1024, 1, 1]) | 763800 | 16200 | 28 | 257 | 86 | 27
shape torch.Size([1, 2, 2]) | 750 | 33 | 28 | 240 | 82 | 27
shape torch.Size([2, 2, 2]) | 1500 | 51 | 28 | 240 | 88 | 27
shape torch.Size([4, 2, 2]) | 2991 | 82 | 28 | 241 | 88 | 28
shape torch.Size([8, 2, 2]) | 6000 | 150 | 28 | 241 | 88 | 27
shape torch.Size([16, 2, 2]) | 12000 | 275 | 28 | 242 | 88 | 27
shape torch.Size([32, 2, 2]) | 23980 | 530 | 28 | 246 | 90 | 28
shape torch.Size([64, 2, 2]) | 48000 | 1000 | 28 | 244 | 89 | 27
shape torch.Size([128, 2, 2]) | 96000 | 2063 | 28 | 245 | 89 | 28
shape torch.Size([512, 2, 2]) | 382000 | 8300 | 28 | 257 | 89 | 28
shape torch.Size([1024, 2, 2]) | 764000 | 20000 | 28 | 271 | 88 | 28
shape torch.Size([1, 8, 8]) | 749 | 34 | 28 | 243 | 82 | 28
shape torch.Size([2, 8, 8]) | 1500 | 50 | 28 | 244 | 89 | 28
shape torch.Size([4, 8, 8]) | 2988 | 83 | 28 | 244 | 89 | 28
shape torch.Size([8, 8, 8]) | 5980 | 150 | 28 | 245 | 88 | 28
shape torch.Size([16, 8, 8]) | 12000 | 278 | 28 | 246 | 89 | 28
shape torch.Size([32, 8, 8]) | 23910 | 536 | 28 | 249 | 89 | 28
shape torch.Size([64, 8, 8]) | 47800 | 1100 | 28 | 247 | 89 | 28
shape torch.Size([128, 8, 8]) | 96000 | 2075 | 28 | 248 | 89 | 28
shape torch.Size([512, 8, 8]) | 382100 | 8300 | 28 | 270 | 89 | 28
shape torch.Size([1024, 8, 8]) | 764100 | 16400 | 28 | 291 | 89 | 28
shape torch.Size([1, 16, 16]) | 750 | 33 | 28 | 248 | 82 | 28
shape torch.Size([2, 16, 16]) | 1500 | 50 | 28 | 250 | 89 | 28
shape torch.Size([4, 16, 16]) | 2996 | 83 | 28 | 250 | 89 | 28
shape torch.Size([8, 16, 16]) | 5980 | 147 | 28 | 251 | 90 | 28
shape torch.Size([16, 16, 16]) | 11900 | 274 | 28 | 251 | 89 | 28
shape torch.Size([32, 16, 16]) | 24040 | 527 | 28 | 252 | 89 | 28
shape torch.Size([64, 16, 16]) | 47800 | 1037 | 28 | 251 | 89 | 28
shape torch.Size([128, 16, 16]) | 95600 | 2044 | 28 | 252 | 89 | 28
shape torch.Size([512, 16, 16]) | 388200 | 8100 | 28 | 280 | 88 | 28
shape torch.Size([1024, 16, 16]) | 769700 | 16000 | 28 | 322 | 89 | 28
shape torch.Size([1, 32, 32]) | 760 | 33 | 28 | 255 | 82 | 28
shape torch.Size([2, 32, 32]) | 1510 | 50 | 28 | 256 | 89 | 28
shape torch.Size([4, 32, 32]) | 3022 | 82 | 31 | 256 | 89 | 30
shape torch.Size([8, 32, 32]) | 6000 | 140 | 31 | 257 | 89 | 31
shape torch.Size([16, 32, 32]) | 12000 | 281 | 31 | 258 | 89 | 31
shape torch.Size([32, 32, 32]) | 24150 | 563 | 35 | 258 | 89 | 35
shape torch.Size([64, 32, 32]) | 48300 | 1119 | 36 | 258 | 90 | 36
shape torch.Size([128, 32, 32]) | 96500 | 2235 | 43 | 261 | 88 | 43
shape torch.Size([512, 32, 32]) | 383100 | 8930 | 82 | 317 | 137 | 82
shape torch.Size([1024, 32, 32]) | 766300 | 19200 | 122 | 400 | 187 | 122
shape torch.Size([1, 64, 64]) | 760 | 33 | 55 | 272 | 71 | 34
shape torch.Size([2, 64, 64]) | 1500 | 52 | 58 | 273 | 90 | 52
shape torch.Size([4, 64, 64]) | 3127 | 102 | 65 | 273 | 110 | 65
shape torch.Size([8, 64, 64]) | 6070 | 201 | 65 | 275 | 160 | 65
shape torch.Size([16, 64, 64]) | 12000 | 399 | 66 | 274 | 88 | 67
shape torch.Size([32, 64, 64]) | 23900 | 796 | 73 | 275 | 91 | 73
shape torch.Size([64, 64, 64]) | 48000 | 1594 | 75 | 283 | 94 | 76
shape torch.Size([128, 64, 64]) | 95000 | 3177 | 96 | 292 | 122 | 96
shape torch.Size([512, 64, 64]) | 379300 | 13520 | 208 | 426 | 313 | 208
shape torch.Size([1024, 64, 64]) | 758700 | 27100 | 306 | 570 | 437 | 306
shape torch.Size([1, 128, 128]) | 750 | 42 | 115 | 306 | 71 | 42
shape torch.Size([2, 128, 128]) | 1500 | 82 | 122 | 307 | 90 | 83
shape torch.Size([4, 128, 128]) | 2966 | 162 | 136 | 307 | 117 | 136
shape torch.Size([8, 128, 128]) | 5930 | 317 | 137 | 308 | 202 | 138
shape torch.Size([16, 128, 128]) | 12000 | 635 | 143 | 316 | 162 | 143
shape torch.Size([32, 128, 128]) | 23700 | 1266 | 152 | 322 | 168 | 152
shape torch.Size([64, 128, 128]) | 48000 | 2668 | 177 | 337 | 196 | 177
shape torch.Size([128, 128, 128]) | 96000 | 5366 | 228 | 365 | 260 | 228
shape torch.Size([512, 128, 128]) | 379400 | 21490 | 502 | 620 | 669 | 502
shape torch.Size([1024, 128, 128]) | 755700 | 43040 | 764 | 903 | 978 | 770
shape torch.Size([1, 256, 256]) | 750 | 70 | 235 | 383 | 78 | 72
shape torch.Size([2, 256, 256]) | 2000 | 138 | 250 | 384 | 120 | 139
shape torch.Size([4, 256, 256]) | 2988 | 277 | 279 | 404 | 191 | 278
shape torch.Size([8, 256, 256]) | 6100 | 546 | 283 | 420 | 332 | 286
shape torch.Size([16, 256, 256]) | 12100 | 1149 | 330 | 441 | 364 | 330
shape torch.Size([32, 256, 256]) | 24040 | 2303 | 359 | 453 | 401 | 360
shape torch.Size([64, 256, 256]) | 48000 | 4626 | 408 | 472 | 466 | 408
shape torch.Size([128, 256, 256]) | 94700 | 9247 | 543 | 543 | 658 | 543
shape torch.Size([512, 256, 256]) | 372000 | 37030 | 1310 | 1185 | 1896 | 1310
shape torch.Size([1024, 256, 256]) | 747200 | 74100 | 2116 | 1910 | 3051 | 2122
```
</details>
<details>
<summary>
Benchmark Results (adjoint=True)
</summary>
```
[----------------------------------------------------------------------------------------- linalg.lu_solve CUDA Adjoint -----------------------------------------------------------------------------------------]
| lu_solve looped_magma | lu_solve looped cusolver | lu_solve batched cublas | lu_solve batched magma | lu_solve unpack+solve_triangular | lu_solve heuristic
1 threads: -------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
shape torch.Size([1, 1, 1]) | 749 | 34 | 28 | 33 | 60 | 27
shape torch.Size([2, 1, 1]) | 1500 | 50 | 28 | 50 | 67 | 27
shape torch.Size([4, 1, 1]) | 3005 | 82 | 28 | 81 | 67 | 27
shape torch.Size([8, 1, 1]) | 5999 | 145 | 28 | 140 | 71 | 27
shape torch.Size([16, 1, 1]) | 12000 | 273 | 28 | 77 | 67 | 27
shape torch.Size([32, 1, 1]) | 24000 | 522 | 28 | 78 | 67 | 27
shape torch.Size([64, 1, 1]) | 48000 | 1000 | 28 | 77 | 67 | 27
shape torch.Size([128, 1, 1]) | 96000 | 2029 | 28 | 78 | 67 | 27
shape torch.Size([512, 1, 1]) | 383300 | 8100 | 28 | 78 | 67 | 28
shape torch.Size([1024, 1, 1]) | 767500 | 16100 | 28 | 77 | 68 | 27
shape torch.Size([1, 2, 2]) | 753 | 33 | 28 | 33 | 60 | 28
shape torch.Size([2, 2, 2]) | 1500 | 50 | 28 | 50 | 67 | 28
shape torch.Size([4, 2, 2]) | 3002 | 82 | 28 | 80 | 67 | 27
shape torch.Size([8, 2, 2]) | 6000 | 145 | 28 | 144 | 67 | 27
shape torch.Size([16, 2, 2]) | 12000 | 271 | 28 | 78 | 66 | 27
shape torch.Size([32, 2, 2]) | 24120 | 524 | 28 | 78 | 69 | 28
shape torch.Size([64, 2, 2]) | 48300 | 1030 | 28 | 78 | 66 | 27
shape torch.Size([128, 2, 2]) | 96100 | 2041 | 28 | 78 | 67 | 28
shape torch.Size([512, 2, 2]) | 383000 | 8100 | 28 | 79 | 67 | 28
shape torch.Size([1024, 2, 2]) | 766100 | 16000 | 28 | 78 | 67 | 28
shape torch.Size([1, 8, 8]) | 750 | 34 | 28 | 34 | 60 | 28
shape torch.Size([2, 8, 8]) | 1500 | 50 | 28 | 50 | 67 | 28
shape torch.Size([4, 8, 8]) | 2998 | 82 | 28 | 82 | 67 | 28
shape torch.Size([8, 8, 8]) | 5990 | 146 | 28 | 150 | 66 | 28
shape torch.Size([16, 8, 8]) | 11980 | 272 | 28 | 79 | 67 | 28
shape torch.Size([32, 8, 8]) | 23970 | 530 | 28 | 79 | 67 | 28
shape torch.Size([64, 8, 8]) | 47900 | 1040 | 28 | 79 | 67 | 28
shape torch.Size([128, 8, 8]) | 96000 | 2048 | 28 | 78 | 67 | 28
shape torch.Size([512, 8, 8]) | 383700 | 8100 | 28 | 80 | 67 | 28
shape torch.Size([1024, 8, 8]) | 766200 | 16300 | 28 | 80 | 68 | 28
shape torch.Size([1, 16, 16]) | 760 | 33 | 28 | 34 | 60 | 28
shape torch.Size([2, 16, 16]) | 1500 | 50 | 28 | 50 | 67 | 28
shape torch.Size([4, 16, 16]) | 3001 | 81 | 28 | 82 | 67 | 28
shape torch.Size([8, 16, 16]) | 6000 | 145 | 28 | 140 | 67 | 28
shape torch.Size([16, 16, 16]) | 12000 | 276 | 28 | 79 | 67 | 28
shape torch.Size([32, 16, 16]) | 23870 | 549 | 28 | 79 | 67 | 28
shape torch.Size([64, 16, 16]) | 47900 | 1098 | 29 | 80 | 68 | 28
shape torch.Size([128, 16, 16]) | 95800 | 2184 | 28 | 79 | 68 | 28
shape torch.Size([512, 16, 16]) | 386900 | 8769 | 28 | 80 | 67 | 28
shape torch.Size([1024, 16, 16]) | 769800 | 17460 | 37 | 80 | 67 | 37
shape torch.Size([1, 32, 32]) | 760 | 33 | 28 | 34 | 60 | 28
shape torch.Size([2, 32, 32]) | 1500 | 50 | 28 | 50 | 67 | 29
shape torch.Size([4, 32, 32]) | 3021 | 86 | 31 | 84 | 67 | 32
shape torch.Size([8, 32, 32]) | 6040 | 167 | 32 | 167 | 67 | 32
shape torch.Size([16, 32, 32]) | 12100 | 330 | 33 | 78 | 67 | 33
shape torch.Size([32, 32, 32]) | 24150 | 662 | 35 | 78 | 66 | 35
shape torch.Size([64, 32, 32]) | 48200 | 1323 | 36 | 79 | 67 | 36
shape torch.Size([128, 32, 32]) | 97000 | 2637 | 44 | 79 | 67 | 43
shape torch.Size([512, 32, 32]) | 382500 | 10580 | 83 | 180 | 121 | 83
shape torch.Size([1024, 32, 32]) | 766600 | 22670 | 123 | 260 | 165 | 120
shape torch.Size([1, 64, 64]) | 760 | 33 | 58 | 33 | 49 | 34
shape torch.Size([2, 64, 64]) | 1520 | 60 | 60 | 60 | 67 | 59
shape torch.Size([4, 64, 64]) | 3016 | 115 | 67 | 119 | 90 | 66
shape torch.Size([8, 64, 64]) | 6120 | 230 | 67 | 233 | 153 | 68
shape torch.Size([16, 64, 64]) | 12100 | 457 | 69 | 86 | 77 | 69
shape torch.Size([32, 64, 64]) | 24000 | 912 | 74 | 95 | 80 | 74
shape torch.Size([64, 64, 64]) | 48000 | 1833 | 76 | 106 | 82 | 76
shape torch.Size([128, 64, 64]) | 95000 | 3636 | 97 | 163 | 108 | 97
shape torch.Size([512, 64, 64]) | 380600 | 15600 | 210 | 464 | 278 | 210
shape torch.Size([1024, 64, 64]) | 761200 | 31140 | 308 | 741 | 377 | 308
shape torch.Size([1, 128, 128]) | 756 | 46 | 120 | 47 | 49 | 46
shape torch.Size([2, 128, 128]) | 1500 | 91 | 123 | 89 | 67 | 92
shape torch.Size([4, 128, 128]) | 2994 | 178 | 139 | 180 | 117 | 139
shape torch.Size([8, 128, 128]) | 5960 | 350 | 140 | 354 | 214 | 142
shape torch.Size([16, 128, 128]) | 12000 | 701 | 144 | 177 | 148 | 143
shape torch.Size([32, 128, 128]) | 23870 | 1401 | 155 | 225 | 159 | 155
shape torch.Size([64, 128, 128]) | 47600 | 2948 | 179 | 288 | 184 | 180
shape torch.Size([128, 128, 128]) | 96000 | 5910 | 231 | 442 | 242 | 231
shape torch.Size([512, 128, 128]) | 381200 | 23640 | 519 | 1400 | 643 | 519
shape torch.Size([1024, 128, 128]) | 755800 | 47340 | 794 | 2436 | 930 | 794
shape torch.Size([1, 256, 256]) | 760 | 74 | 246 | 77 | 62 | 78
shape torch.Size([2, 256, 256]) | 1510 | 150 | 256 | 150 | 106 | 150
shape torch.Size([4, 256, 256]) | 3030 | 296 | 284 | 296 | 183 | 284
shape torch.Size([8, 256, 256]) | 6100 | 588 | 286 | 592 | 337 | 288
shape torch.Size([16, 256, 256]) | 12200 | 1238 | 330 | 445 | 330 | 330
shape torch.Size([32, 256, 256]) | 24430 | 2476 | 368 | 568 | 365 | 367
shape torch.Size([64, 256, 256]) | 49000 | 4950 | 415 | 800 | 408 | 414
shape torch.Size([128, 256, 256]) | 96000 | 9900 | 552 | 1330 | 561 | 553
shape torch.Size([512, 256, 256]) | 369400 | 39580 | 1410 | 4614 | 1649 | 1410
shape torch.Size([1024, 256, 256]) | 716200 | 79200 | 2270 | 8472 | 2528 | 2277
Times are in microseconds (us).
```
</details>
To generate the results below, I put the backend I wanted to test at the beginning of the function `lu_solve_kernel`, followed by a `return;`. Then I run the following script, changing the variable `name`.
<details>
<summary>
Benchmarking script
</summary>
```python
import torch
import pickle
import itertools
from functools import partial
from torch.utils.benchmark import Timer, Compare
name = "heuristic"
label = "lu_solve {}".format(name)
shapes = [1, 2, 8, 16, 32, 64, 128, 256]
batches = [(1,), (2,), (4,), (8,), (16,), (32,), (64,), (128,), (512,), (1024,)]
results = []
make_arg = partial(torch.randn, dtype=torch.float32, device="cuda")
for n, batch in itertools.product(shapes, batches):
LU, pivots = torch.linalg.lu_factor(make_arg(batch + (n, n)))
B = make_arg(batch + (n, 1))
print(LU.shape)
stmt = "torch.linalg.lu_solve(LU, pivots, B, adjoint=adjoint)"
for adjoint in (True, False):
timer = Timer(stmt,
globals=globals(),
label="linalg.lu_solve CUDA{}".format(" Adjoint" if adjoint else ""),
description=label,
sub_label=f"shape {LU.shape}",
num_threads=1)
results.append(timer.blocked_autorange())
compare = Compare(results)
compare.trim_significant_figures()
compare.print()
with open("{}_lu_solve.pickle".format(name), 'wb') as f:
pickle.dump(results, f)
```
</details>
Finally, I joined all the results with the following script:
<details>
<summary>
Script to join the results
</summary>
```python
import pickle
from torch.utils.benchmark import Timer, Compare
files = [
"looped_magma",
"looped cusolver",
"batched cublas",
"batched magma",
"unpack+solve_triangular",
"heuristic",
]
timers = []
for name in files:
with open("{}_lu_solve.pickle".format(name), 'rb') as f:
timers += pickle.load(f)
compare = Compare(timers)
compare.trim_significant_figures()
compare.print()
```
</details>
### Fix for Magma's batched lu_solve when `adjoint=True`
I also developed the following fix around MAGMA's bug, but I ended up not using it, and preferring the triangular solves over it, as they were faster. I'm leaving it here in case it's useful in the future.
<details>
<summary>
Fix for MAGMA's issue with `adjoint=True`
</summary>
```cpp
auto lu_solve_batched_magma_fn = [m](const Tensor& LU, const Tensor& pivots, const Tensor& B, TransposeType trans) {
if (trans == TransposeType::NoTranspose) {
lu_solve_batched_magma(LU, pivots, B, trans);
return;
}
// There's a bug in magma for the other cases, so we need to properly perform mT or mH on LU
// The LU of the transpose is not the transpose of the LU
// We need to do LU = LDU' = L'U' where L' = LD, U' = D^{-1}U and D = diag(U)
auto diag = LU.diagonal(0, -2, -1);
auto LU_f = LU.tril(-1).mul_(diag.unsqueeze(-2)) +
LU.triu(1).div_(diag.unsqueeze(-1));
LU_f.diagonal(0, -2, -1).copy_(diag);
if (trans == TransposeType::ConjTranspose) {
LU_f = LU_f.conj_physical();
}
LU_f.transpose(-2, -1);
// At this point LU_f is F-contiguous, because triu / tril / conj_phisical return contiguous tensors
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(LU_f.mT().is_contiguous());
// Trivial permutation
auto pivots_aux = at::arange(1, m + 1, pivots.options()).expand_as(pivots).contiguous();
lu_solve_batched_magma(LU_f, pivots_aux, B, TransposeType::NoTranspose);
// We then need to multiply B by P on the right as (PLU)^T = B iff U^TL^T = BP
// Fill `perm` with the identity permutation (perhaps batched)
// This is faster than torch.lu_unpack + matmul, as this logic is borrowed from lu_unpack
const auto perm = at::arange(m, pivots.options().dtype(kLong)).expand(pivots.sizes()).contiguous();
auto iter = TensorIteratorConfig()
.set_check_mem_overlap(false)
.check_all_same_dtype(false)
.resize_outputs(false)
.declare_static_shape(pivots.sizes(), /*squash_dim=*/pivots.dim() - 1)
.add_output(perm)
.add_input(pivots)
.build();
unpack_pivots_stub(pivots.device().type(), iter, m);
B.scatter_(-2, perm.unsqueeze(-1).expand_as(B), B.clone());
};
```
</details>
Fixes #61657
[ghstack-poisoned]
This PR adds `linalg.lu_solve`. While doing so, I found a bug in MAGMA
when calling the batched MAGMA backend with trans=True. We work around
that by solving the system solving two triangular systems.
We also update the heuristics for this function, as they were fairly
outdated. We found that cuSolver is king, so luckily we do not need to
rely on the buggy backend from magma for this function.
We added tests testing this function left and right. We also added tests
for the different backends. We also activated the tests for AMD, as
those should work as well.
### Benchmarking
<details>
<summary>
Benchmark Results (adjoint=False)
</summary>
```
--------------------------------------------------------------------------------------------- linalg.lu_solve CUDA ---------------------------------------------------------------------------------------------]
| lu_solve looped_magma | lu_solve looped cusolver | lu_solve batched cublas | lu_solve batched magma | lu_solve unpack+solve_triangular | lu_solve heuristic
1 threads: -------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
shape torch.Size([1, 1, 1]) | 750 | 34 | 28 | 252 | 78 | 27
shape torch.Size([2, 1, 1]) | 1500 | 50 | 28 | 239 | 85 | 27
shape torch.Size([4, 1, 1]) | 2995 | 83 | 28 | 239 | 85 | 27
shape torch.Size([8, 1, 1]) | 6000 | 146 | 28 | 239 | 85 | 27
shape torch.Size([16, 1, 1]) | 11900 | 272 | 28 | 241 | 85 | 27
shape torch.Size([32, 1, 1]) | 23880 | 524 | 28 | 244 | 85 | 27
shape torch.Size([64, 1, 1]) | 48000 | 1000 | 28 | 245 | 85 | 27
shape torch.Size([128, 1, 1]) | 96000 | 2054 | 28 | 242 | 86 | 27
shape torch.Size([512, 1, 1]) | 381900 | 8100 | 28 | 250 | 85 | 27
shape torch.Size([1024, 1, 1]) | 763800 | 16200 | 28 | 257 | 86 | 27
shape torch.Size([1, 2, 2]) | 750 | 33 | 28 | 240 | 82 | 27
shape torch.Size([2, 2, 2]) | 1500 | 51 | 28 | 240 | 88 | 27
shape torch.Size([4, 2, 2]) | 2991 | 82 | 28 | 241 | 88 | 28
shape torch.Size([8, 2, 2]) | 6000 | 150 | 28 | 241 | 88 | 27
shape torch.Size([16, 2, 2]) | 12000 | 275 | 28 | 242 | 88 | 27
shape torch.Size([32, 2, 2]) | 23980 | 530 | 28 | 246 | 90 | 28
shape torch.Size([64, 2, 2]) | 48000 | 1000 | 28 | 244 | 89 | 27
shape torch.Size([128, 2, 2]) | 96000 | 2063 | 28 | 245 | 89 | 28
shape torch.Size([512, 2, 2]) | 382000 | 8300 | 28 | 257 | 89 | 28
shape torch.Size([1024, 2, 2]) | 764000 | 20000 | 28 | 271 | 88 | 28
shape torch.Size([1, 8, 8]) | 749 | 34 | 28 | 243 | 82 | 28
shape torch.Size([2, 8, 8]) | 1500 | 50 | 28 | 244 | 89 | 28
shape torch.Size([4, 8, 8]) | 2988 | 83 | 28 | 244 | 89 | 28
shape torch.Size([8, 8, 8]) | 5980 | 150 | 28 | 245 | 88 | 28
shape torch.Size([16, 8, 8]) | 12000 | 278 | 28 | 246 | 89 | 28
shape torch.Size([32, 8, 8]) | 23910 | 536 | 28 | 249 | 89 | 28
shape torch.Size([64, 8, 8]) | 47800 | 1100 | 28 | 247 | 89 | 28
shape torch.Size([128, 8, 8]) | 96000 | 2075 | 28 | 248 | 89 | 28
shape torch.Size([512, 8, 8]) | 382100 | 8300 | 28 | 270 | 89 | 28
shape torch.Size([1024, 8, 8]) | 764100 | 16400 | 28 | 291 | 89 | 28
shape torch.Size([1, 16, 16]) | 750 | 33 | 28 | 248 | 82 | 28
shape torch.Size([2, 16, 16]) | 1500 | 50 | 28 | 250 | 89 | 28
shape torch.Size([4, 16, 16]) | 2996 | 83 | 28 | 250 | 89 | 28
shape torch.Size([8, 16, 16]) | 5980 | 147 | 28 | 251 | 90 | 28
shape torch.Size([16, 16, 16]) | 11900 | 274 | 28 | 251 | 89 | 28
shape torch.Size([32, 16, 16]) | 24040 | 527 | 28 | 252 | 89 | 28
shape torch.Size([64, 16, 16]) | 47800 | 1037 | 28 | 251 | 89 | 28
shape torch.Size([128, 16, 16]) | 95600 | 2044 | 28 | 252 | 89 | 28
shape torch.Size([512, 16, 16]) | 388200 | 8100 | 28 | 280 | 88 | 28
shape torch.Size([1024, 16, 16]) | 769700 | 16000 | 28 | 322 | 89 | 28
shape torch.Size([1, 32, 32]) | 760 | 33 | 28 | 255 | 82 | 28
shape torch.Size([2, 32, 32]) | 1510 | 50 | 28 | 256 | 89 | 28
shape torch.Size([4, 32, 32]) | 3022 | 82 | 31 | 256 | 89 | 30
shape torch.Size([8, 32, 32]) | 6000 | 140 | 31 | 257 | 89 | 31
shape torch.Size([16, 32, 32]) | 12000 | 281 | 31 | 258 | 89 | 31
shape torch.Size([32, 32, 32]) | 24150 | 563 | 35 | 258 | 89 | 35
shape torch.Size([64, 32, 32]) | 48300 | 1119 | 36 | 258 | 90 | 36
shape torch.Size([128, 32, 32]) | 96500 | 2235 | 43 | 261 | 88 | 43
shape torch.Size([512, 32, 32]) | 383100 | 8930 | 82 | 317 | 137 | 82
shape torch.Size([1024, 32, 32]) | 766300 | 19200 | 122 | 400 | 187 | 122
shape torch.Size([1, 64, 64]) | 760 | 33 | 55 | 272 | 71 | 34
shape torch.Size([2, 64, 64]) | 1500 | 52 | 58 | 273 | 90 | 52
shape torch.Size([4, 64, 64]) | 3127 | 102 | 65 | 273 | 110 | 65
shape torch.Size([8, 64, 64]) | 6070 | 201 | 65 | 275 | 160 | 65
shape torch.Size([16, 64, 64]) | 12000 | 399 | 66 | 274 | 88 | 67
shape torch.Size([32, 64, 64]) | 23900 | 796 | 73 | 275 | 91 | 73
shape torch.Size([64, 64, 64]) | 48000 | 1594 | 75 | 283 | 94 | 76
shape torch.Size([128, 64, 64]) | 95000 | 3177 | 96 | 292 | 122 | 96
shape torch.Size([512, 64, 64]) | 379300 | 13520 | 208 | 426 | 313 | 208
shape torch.Size([1024, 64, 64]) | 758700 | 27100 | 306 | 570 | 437 | 306
shape torch.Size([1, 128, 128]) | 750 | 42 | 115 | 306 | 71 | 42
shape torch.Size([2, 128, 128]) | 1500 | 82 | 122 | 307 | 90 | 83
shape torch.Size([4, 128, 128]) | 2966 | 162 | 136 | 307 | 117 | 136
shape torch.Size([8, 128, 128]) | 5930 | 317 | 137 | 308 | 202 | 138
shape torch.Size([16, 128, 128]) | 12000 | 635 | 143 | 316 | 162 | 143
shape torch.Size([32, 128, 128]) | 23700 | 1266 | 152 | 322 | 168 | 152
shape torch.Size([64, 128, 128]) | 48000 | 2668 | 177 | 337 | 196 | 177
shape torch.Size([128, 128, 128]) | 96000 | 5366 | 228 | 365 | 260 | 228
shape torch.Size([512, 128, 128]) | 379400 | 21490 | 502 | 620 | 669 | 502
shape torch.Size([1024, 128, 128]) | 755700 | 43040 | 764 | 903 | 978 | 770
shape torch.Size([1, 256, 256]) | 750 | 70 | 235 | 383 | 78 | 72
shape torch.Size([2, 256, 256]) | 2000 | 138 | 250 | 384 | 120 | 139
shape torch.Size([4, 256, 256]) | 2988 | 277 | 279 | 404 | 191 | 278
shape torch.Size([8, 256, 256]) | 6100 | 546 | 283 | 420 | 332 | 286
shape torch.Size([16, 256, 256]) | 12100 | 1149 | 330 | 441 | 364 | 330
shape torch.Size([32, 256, 256]) | 24040 | 2303 | 359 | 453 | 401 | 360
shape torch.Size([64, 256, 256]) | 48000 | 4626 | 408 | 472 | 466 | 408
shape torch.Size([128, 256, 256]) | 94700 | 9247 | 543 | 543 | 658 | 543
shape torch.Size([512, 256, 256]) | 372000 | 37030 | 1310 | 1185 | 1896 | 1310
shape torch.Size([1024, 256, 256]) | 747200 | 74100 | 2116 | 1910 | 3051 | 2122
```
</details>
<details>
<summary>
Benchmark Results (adjoint=True)
</summary>
```
[----------------------------------------------------------------------------------------- linalg.lu_solve CUDA Adjoint -----------------------------------------------------------------------------------------]
| lu_solve looped_magma | lu_solve looped cusolver | lu_solve batched cublas | lu_solve batched magma | lu_solve unpack+solve_triangular | lu_solve heuristic
1 threads: -------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
shape torch.Size([1, 1, 1]) | 749 | 34 | 28 | 33 | 60 | 27
shape torch.Size([2, 1, 1]) | 1500 | 50 | 28 | 50 | 67 | 27
shape torch.Size([4, 1, 1]) | 3005 | 82 | 28 | 81 | 67 | 27
shape torch.Size([8, 1, 1]) | 5999 | 145 | 28 | 140 | 71 | 27
shape torch.Size([16, 1, 1]) | 12000 | 273 | 28 | 77 | 67 | 27
shape torch.Size([32, 1, 1]) | 24000 | 522 | 28 | 78 | 67 | 27
shape torch.Size([64, 1, 1]) | 48000 | 1000 | 28 | 77 | 67 | 27
shape torch.Size([128, 1, 1]) | 96000 | 2029 | 28 | 78 | 67 | 27
shape torch.Size([512, 1, 1]) | 383300 | 8100 | 28 | 78 | 67 | 28
shape torch.Size([1024, 1, 1]) | 767500 | 16100 | 28 | 77 | 68 | 27
shape torch.Size([1, 2, 2]) | 753 | 33 | 28 | 33 | 60 | 28
shape torch.Size([2, 2, 2]) | 1500 | 50 | 28 | 50 | 67 | 28
shape torch.Size([4, 2, 2]) | 3002 | 82 | 28 | 80 | 67 | 27
shape torch.Size([8, 2, 2]) | 6000 | 145 | 28 | 144 | 67 | 27
shape torch.Size([16, 2, 2]) | 12000 | 271 | 28 | 78 | 66 | 27
shape torch.Size([32, 2, 2]) | 24120 | 524 | 28 | 78 | 69 | 28
shape torch.Size([64, 2, 2]) | 48300 | 1030 | 28 | 78 | 66 | 27
shape torch.Size([128, 2, 2]) | 96100 | 2041 | 28 | 78 | 67 | 28
shape torch.Size([512, 2, 2]) | 383000 | 8100 | 28 | 79 | 67 | 28
shape torch.Size([1024, 2, 2]) | 766100 | 16000 | 28 | 78 | 67 | 28
shape torch.Size([1, 8, 8]) | 750 | 34 | 28 | 34 | 60 | 28
shape torch.Size([2, 8, 8]) | 1500 | 50 | 28 | 50 | 67 | 28
shape torch.Size([4, 8, 8]) | 2998 | 82 | 28 | 82 | 67 | 28
shape torch.Size([8, 8, 8]) | 5990 | 146 | 28 | 150 | 66 | 28
shape torch.Size([16, 8, 8]) | 11980 | 272 | 28 | 79 | 67 | 28
shape torch.Size([32, 8, 8]) | 23970 | 530 | 28 | 79 | 67 | 28
shape torch.Size([64, 8, 8]) | 47900 | 1040 | 28 | 79 | 67 | 28
shape torch.Size([128, 8, 8]) | 96000 | 2048 | 28 | 78 | 67 | 28
shape torch.Size([512, 8, 8]) | 383700 | 8100 | 28 | 80 | 67 | 28
shape torch.Size([1024, 8, 8]) | 766200 | 16300 | 28 | 80 | 68 | 28
shape torch.Size([1, 16, 16]) | 760 | 33 | 28 | 34 | 60 | 28
shape torch.Size([2, 16, 16]) | 1500 | 50 | 28 | 50 | 67 | 28
shape torch.Size([4, 16, 16]) | 3001 | 81 | 28 | 82 | 67 | 28
shape torch.Size([8, 16, 16]) | 6000 | 145 | 28 | 140 | 67 | 28
shape torch.Size([16, 16, 16]) | 12000 | 276 | 28 | 79 | 67 | 28
shape torch.Size([32, 16, 16]) | 23870 | 549 | 28 | 79 | 67 | 28
shape torch.Size([64, 16, 16]) | 47900 | 1098 | 29 | 80 | 68 | 28
shape torch.Size([128, 16, 16]) | 95800 | 2184 | 28 | 79 | 68 | 28
shape torch.Size([512, 16, 16]) | 386900 | 8769 | 28 | 80 | 67 | 28
shape torch.Size([1024, 16, 16]) | 769800 | 17460 | 37 | 80 | 67 | 37
shape torch.Size([1, 32, 32]) | 760 | 33 | 28 | 34 | 60 | 28
shape torch.Size([2, 32, 32]) | 1500 | 50 | 28 | 50 | 67 | 29
shape torch.Size([4, 32, 32]) | 3021 | 86 | 31 | 84 | 67 | 32
shape torch.Size([8, 32, 32]) | 6040 | 167 | 32 | 167 | 67 | 32
shape torch.Size([16, 32, 32]) | 12100 | 330 | 33 | 78 | 67 | 33
shape torch.Size([32, 32, 32]) | 24150 | 662 | 35 | 78 | 66 | 35
shape torch.Size([64, 32, 32]) | 48200 | 1323 | 36 | 79 | 67 | 36
shape torch.Size([128, 32, 32]) | 97000 | 2637 | 44 | 79 | 67 | 43
shape torch.Size([512, 32, 32]) | 382500 | 10580 | 83 | 180 | 121 | 83
shape torch.Size([1024, 32, 32]) | 766600 | 22670 | 123 | 260 | 165 | 120
shape torch.Size([1, 64, 64]) | 760 | 33 | 58 | 33 | 49 | 34
shape torch.Size([2, 64, 64]) | 1520 | 60 | 60 | 60 | 67 | 59
shape torch.Size([4, 64, 64]) | 3016 | 115 | 67 | 119 | 90 | 66
shape torch.Size([8, 64, 64]) | 6120 | 230 | 67 | 233 | 153 | 68
shape torch.Size([16, 64, 64]) | 12100 | 457 | 69 | 86 | 77 | 69
shape torch.Size([32, 64, 64]) | 24000 | 912 | 74 | 95 | 80 | 74
shape torch.Size([64, 64, 64]) | 48000 | 1833 | 76 | 106 | 82 | 76
shape torch.Size([128, 64, 64]) | 95000 | 3636 | 97 | 163 | 108 | 97
shape torch.Size([512, 64, 64]) | 380600 | 15600 | 210 | 464 | 278 | 210
shape torch.Size([1024, 64, 64]) | 761200 | 31140 | 308 | 741 | 377 | 308
shape torch.Size([1, 128, 128]) | 756 | 46 | 120 | 47 | 49 | 46
shape torch.Size([2, 128, 128]) | 1500 | 91 | 123 | 89 | 67 | 92
shape torch.Size([4, 128, 128]) | 2994 | 178 | 139 | 180 | 117 | 139
shape torch.Size([8, 128, 128]) | 5960 | 350 | 140 | 354 | 214 | 142
shape torch.Size([16, 128, 128]) | 12000 | 701 | 144 | 177 | 148 | 143
shape torch.Size([32, 128, 128]) | 23870 | 1401 | 155 | 225 | 159 | 155
shape torch.Size([64, 128, 128]) | 47600 | 2948 | 179 | 288 | 184 | 180
shape torch.Size([128, 128, 128]) | 96000 | 5910 | 231 | 442 | 242 | 231
shape torch.Size([512, 128, 128]) | 381200 | 23640 | 519 | 1400 | 643 | 519
shape torch.Size([1024, 128, 128]) | 755800 | 47340 | 794 | 2436 | 930 | 794
shape torch.Size([1, 256, 256]) | 760 | 74 | 246 | 77 | 62 | 78
shape torch.Size([2, 256, 256]) | 1510 | 150 | 256 | 150 | 106 | 150
shape torch.Size([4, 256, 256]) | 3030 | 296 | 284 | 296 | 183 | 284
shape torch.Size([8, 256, 256]) | 6100 | 588 | 286 | 592 | 337 | 288
shape torch.Size([16, 256, 256]) | 12200 | 1238 | 330 | 445 | 330 | 330
shape torch.Size([32, 256, 256]) | 24430 | 2476 | 368 | 568 | 365 | 367
shape torch.Size([64, 256, 256]) | 49000 | 4950 | 415 | 800 | 408 | 414
shape torch.Size([128, 256, 256]) | 96000 | 9900 | 552 | 1330 | 561 | 553
shape torch.Size([512, 256, 256]) | 369400 | 39580 | 1410 | 4614 | 1649 | 1410
shape torch.Size([1024, 256, 256]) | 716200 | 79200 | 2270 | 8472 | 2528 | 2277
Times are in microseconds (us).
```
</details>
To generate the results below, I put the backend I wanted to test at the beginning of the function `lu_solve_kernel`, followed by a `return;`. Then I run the following script, changing the variable `name`.
<details>
<summary>
Benchmarking script
</summary>
```python
import torch
import pickle
import itertools
from functools import partial
from torch.utils.benchmark import Timer, Compare
name = "heuristic"
label = "lu_solve {}".format(name)
shapes = [1, 2, 8, 16, 32, 64, 128, 256]
batches = [(1,), (2,), (4,), (8,), (16,), (32,), (64,), (128,), (512,), (1024,)]
results = []
make_arg = partial(torch.randn, dtype=torch.float32, device="cuda")
for n, batch in itertools.product(shapes, batches):
LU, pivots = torch.linalg.lu_factor(make_arg(batch + (n, n)))
B = make_arg(batch + (n, 1))
print(LU.shape)
stmt = "torch.linalg.lu_solve(LU, pivots, B, adjoint=adjoint)"
for adjoint in (True, False):
timer = Timer(stmt,
globals=globals(),
label="linalg.lu_solve CUDA{}".format(" Adjoint" if adjoint else ""),
description=label,
sub_label=f"shape {LU.shape}",
num_threads=1)
results.append(timer.blocked_autorange())
compare = Compare(results)
compare.trim_significant_figures()
compare.print()
with open("{}_lu_solve.pickle".format(name), 'wb') as f:
pickle.dump(results, f)
```
</details>
Finally, I joined all the results with the following script:
<details>
<summary>
Script to join the results
</summary>
```python
import pickle
from torch.utils.benchmark import Timer, Compare
files = [
"looped_magma",
"looped cusolver",
"batched cublas",
"batched magma",
"unpack+solve_triangular",
"heuristic",
]
timers = []
for name in files:
with open("{}_lu_solve.pickle".format(name), 'rb') as f:
timers += pickle.load(f)
compare = Compare(timers)
compare.trim_significant_figures()
compare.print()
```
</details>
### Fix for Magma's batched lu_solve when `adjoint=True`
I also developed the following fix around MAGMA's bug, but I ended up not using it, and preferring the triangular solves over it, as they were faster. I'm leaving it here in case it's useful in the future.
<details>
<summary>
Fix for MAGMA's issue with `adjoint=True`
</summary>
```cpp
auto lu_solve_batched_magma_fn = [m](const Tensor& LU, const Tensor& pivots, const Tensor& B, TransposeType trans) {
if (trans == TransposeType::NoTranspose) {
lu_solve_batched_magma(LU, pivots, B, trans);
return;
}
// There's a bug in magma for the other cases, so we need to properly perform mT or mH on LU
// The LU of the transpose is not the transpose of the LU
// We need to do LU = LDU' = L'U' where L' = LD, U' = D^{-1}U and D = diag(U)
auto diag = LU.diagonal(0, -2, -1);
auto LU_f = LU.tril(-1).mul_(diag.unsqueeze(-2)) +
LU.triu(1).div_(diag.unsqueeze(-1));
LU_f.diagonal(0, -2, -1).copy_(diag);
if (trans == TransposeType::ConjTranspose) {
LU_f = LU_f.conj_physical();
}
LU_f.transpose(-2, -1);
// At this point LU_f is F-contiguous, because triu / tril / conj_phisical return contiguous tensors
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(LU_f.mT().is_contiguous());
// Trivial permutation
auto pivots_aux = at::arange(1, m + 1, pivots.options()).expand_as(pivots).contiguous();
lu_solve_batched_magma(LU_f, pivots_aux, B, TransposeType::NoTranspose);
// We then need to multiply B by P on the right as (PLU)^T = B iff U^TL^T = BP
// Fill `perm` with the identity permutation (perhaps batched)
// This is faster than torch.lu_unpack + matmul, as this logic is borrowed from lu_unpack
const auto perm = at::arange(m, pivots.options().dtype(kLong)).expand(pivots.sizes()).contiguous();
auto iter = TensorIteratorConfig()
.set_check_mem_overlap(false)
.check_all_same_dtype(false)
.resize_outputs(false)
.declare_static_shape(pivots.sizes(), /*squash_dim=*/pivots.dim() - 1)
.add_output(perm)
.add_input(pivots)
.build();
unpack_pivots_stub(pivots.device().type(), iter, m);
B.scatter_(-2, perm.unsqueeze(-1).expand_as(B), B.clone());
};
```
</details>
Fixes #61657
[ghstack-poisoned]
|
The CI is not green, but the issues seem unrelated. I'll try to merge this one after a small correction of the docs. |
This PR adds `linalg.lu_solve`. While doing so, I found a bug in MAGMA
when calling the batched MAGMA backend with trans=True. We work around
that by solving the system solving two triangular systems.
We also update the heuristics for this function, as they were fairly
outdated. We found that cuSolver is king, so luckily we do not need to
rely on the buggy backend from magma for this function.
We added tests testing this function left and right. We also added tests
for the different backends. We also activated the tests for AMD, as
those should work as well.
### Benchmarking
<details>
<summary>
Benchmark Results (adjoint=False)
</summary>
```
--------------------------------------------------------------------------------------------- linalg.lu_solve CUDA ---------------------------------------------------------------------------------------------]
| lu_solve looped_magma | lu_solve looped cusolver | lu_solve batched cublas | lu_solve batched magma | lu_solve unpack+solve_triangular | lu_solve heuristic
1 threads: -------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
shape torch.Size([1, 1, 1]) | 750 | 34 | 28 | 252 | 78 | 27
shape torch.Size([2, 1, 1]) | 1500 | 50 | 28 | 239 | 85 | 27
shape torch.Size([4, 1, 1]) | 2995 | 83 | 28 | 239 | 85 | 27
shape torch.Size([8, 1, 1]) | 6000 | 146 | 28 | 239 | 85 | 27
shape torch.Size([16, 1, 1]) | 11900 | 272 | 28 | 241 | 85 | 27
shape torch.Size([32, 1, 1]) | 23880 | 524 | 28 | 244 | 85 | 27
shape torch.Size([64, 1, 1]) | 48000 | 1000 | 28 | 245 | 85 | 27
shape torch.Size([128, 1, 1]) | 96000 | 2054 | 28 | 242 | 86 | 27
shape torch.Size([512, 1, 1]) | 381900 | 8100 | 28 | 250 | 85 | 27
shape torch.Size([1024, 1, 1]) | 763800 | 16200 | 28 | 257 | 86 | 27
shape torch.Size([1, 2, 2]) | 750 | 33 | 28 | 240 | 82 | 27
shape torch.Size([2, 2, 2]) | 1500 | 51 | 28 | 240 | 88 | 27
shape torch.Size([4, 2, 2]) | 2991 | 82 | 28 | 241 | 88 | 28
shape torch.Size([8, 2, 2]) | 6000 | 150 | 28 | 241 | 88 | 27
shape torch.Size([16, 2, 2]) | 12000 | 275 | 28 | 242 | 88 | 27
shape torch.Size([32, 2, 2]) | 23980 | 530 | 28 | 246 | 90 | 28
shape torch.Size([64, 2, 2]) | 48000 | 1000 | 28 | 244 | 89 | 27
shape torch.Size([128, 2, 2]) | 96000 | 2063 | 28 | 245 | 89 | 28
shape torch.Size([512, 2, 2]) | 382000 | 8300 | 28 | 257 | 89 | 28
shape torch.Size([1024, 2, 2]) | 764000 | 20000 | 28 | 271 | 88 | 28
shape torch.Size([1, 8, 8]) | 749 | 34 | 28 | 243 | 82 | 28
shape torch.Size([2, 8, 8]) | 1500 | 50 | 28 | 244 | 89 | 28
shape torch.Size([4, 8, 8]) | 2988 | 83 | 28 | 244 | 89 | 28
shape torch.Size([8, 8, 8]) | 5980 | 150 | 28 | 245 | 88 | 28
shape torch.Size([16, 8, 8]) | 12000 | 278 | 28 | 246 | 89 | 28
shape torch.Size([32, 8, 8]) | 23910 | 536 | 28 | 249 | 89 | 28
shape torch.Size([64, 8, 8]) | 47800 | 1100 | 28 | 247 | 89 | 28
shape torch.Size([128, 8, 8]) | 96000 | 2075 | 28 | 248 | 89 | 28
shape torch.Size([512, 8, 8]) | 382100 | 8300 | 28 | 270 | 89 | 28
shape torch.Size([1024, 8, 8]) | 764100 | 16400 | 28 | 291 | 89 | 28
shape torch.Size([1, 16, 16]) | 750 | 33 | 28 | 248 | 82 | 28
shape torch.Size([2, 16, 16]) | 1500 | 50 | 28 | 250 | 89 | 28
shape torch.Size([4, 16, 16]) | 2996 | 83 | 28 | 250 | 89 | 28
shape torch.Size([8, 16, 16]) | 5980 | 147 | 28 | 251 | 90 | 28
shape torch.Size([16, 16, 16]) | 11900 | 274 | 28 | 251 | 89 | 28
shape torch.Size([32, 16, 16]) | 24040 | 527 | 28 | 252 | 89 | 28
shape torch.Size([64, 16, 16]) | 47800 | 1037 | 28 | 251 | 89 | 28
shape torch.Size([128, 16, 16]) | 95600 | 2044 | 28 | 252 | 89 | 28
shape torch.Size([512, 16, 16]) | 388200 | 8100 | 28 | 280 | 88 | 28
shape torch.Size([1024, 16, 16]) | 769700 | 16000 | 28 | 322 | 89 | 28
shape torch.Size([1, 32, 32]) | 760 | 33 | 28 | 255 | 82 | 28
shape torch.Size([2, 32, 32]) | 1510 | 50 | 28 | 256 | 89 | 28
shape torch.Size([4, 32, 32]) | 3022 | 82 | 31 | 256 | 89 | 30
shape torch.Size([8, 32, 32]) | 6000 | 140 | 31 | 257 | 89 | 31
shape torch.Size([16, 32, 32]) | 12000 | 281 | 31 | 258 | 89 | 31
shape torch.Size([32, 32, 32]) | 24150 | 563 | 35 | 258 | 89 | 35
shape torch.Size([64, 32, 32]) | 48300 | 1119 | 36 | 258 | 90 | 36
shape torch.Size([128, 32, 32]) | 96500 | 2235 | 43 | 261 | 88 | 43
shape torch.Size([512, 32, 32]) | 383100 | 8930 | 82 | 317 | 137 | 82
shape torch.Size([1024, 32, 32]) | 766300 | 19200 | 122 | 400 | 187 | 122
shape torch.Size([1, 64, 64]) | 760 | 33 | 55 | 272 | 71 | 34
shape torch.Size([2, 64, 64]) | 1500 | 52 | 58 | 273 | 90 | 52
shape torch.Size([4, 64, 64]) | 3127 | 102 | 65 | 273 | 110 | 65
shape torch.Size([8, 64, 64]) | 6070 | 201 | 65 | 275 | 160 | 65
shape torch.Size([16, 64, 64]) | 12000 | 399 | 66 | 274 | 88 | 67
shape torch.Size([32, 64, 64]) | 23900 | 796 | 73 | 275 | 91 | 73
shape torch.Size([64, 64, 64]) | 48000 | 1594 | 75 | 283 | 94 | 76
shape torch.Size([128, 64, 64]) | 95000 | 3177 | 96 | 292 | 122 | 96
shape torch.Size([512, 64, 64]) | 379300 | 13520 | 208 | 426 | 313 | 208
shape torch.Size([1024, 64, 64]) | 758700 | 27100 | 306 | 570 | 437 | 306
shape torch.Size([1, 128, 128]) | 750 | 42 | 115 | 306 | 71 | 42
shape torch.Size([2, 128, 128]) | 1500 | 82 | 122 | 307 | 90 | 83
shape torch.Size([4, 128, 128]) | 2966 | 162 | 136 | 307 | 117 | 136
shape torch.Size([8, 128, 128]) | 5930 | 317 | 137 | 308 | 202 | 138
shape torch.Size([16, 128, 128]) | 12000 | 635 | 143 | 316 | 162 | 143
shape torch.Size([32, 128, 128]) | 23700 | 1266 | 152 | 322 | 168 | 152
shape torch.Size([64, 128, 128]) | 48000 | 2668 | 177 | 337 | 196 | 177
shape torch.Size([128, 128, 128]) | 96000 | 5366 | 228 | 365 | 260 | 228
shape torch.Size([512, 128, 128]) | 379400 | 21490 | 502 | 620 | 669 | 502
shape torch.Size([1024, 128, 128]) | 755700 | 43040 | 764 | 903 | 978 | 770
shape torch.Size([1, 256, 256]) | 750 | 70 | 235 | 383 | 78 | 72
shape torch.Size([2, 256, 256]) | 2000 | 138 | 250 | 384 | 120 | 139
shape torch.Size([4, 256, 256]) | 2988 | 277 | 279 | 404 | 191 | 278
shape torch.Size([8, 256, 256]) | 6100 | 546 | 283 | 420 | 332 | 286
shape torch.Size([16, 256, 256]) | 12100 | 1149 | 330 | 441 | 364 | 330
shape torch.Size([32, 256, 256]) | 24040 | 2303 | 359 | 453 | 401 | 360
shape torch.Size([64, 256, 256]) | 48000 | 4626 | 408 | 472 | 466 | 408
shape torch.Size([128, 256, 256]) | 94700 | 9247 | 543 | 543 | 658 | 543
shape torch.Size([512, 256, 256]) | 372000 | 37030 | 1310 | 1185 | 1896 | 1310
shape torch.Size([1024, 256, 256]) | 747200 | 74100 | 2116 | 1910 | 3051 | 2122
```
</details>
<details>
<summary>
Benchmark Results (adjoint=True)
</summary>
```
[----------------------------------------------------------------------------------------- linalg.lu_solve CUDA Adjoint -----------------------------------------------------------------------------------------]
| lu_solve looped_magma | lu_solve looped cusolver | lu_solve batched cublas | lu_solve batched magma | lu_solve unpack+solve_triangular | lu_solve heuristic
1 threads: -------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
shape torch.Size([1, 1, 1]) | 749 | 34 | 28 | 33 | 60 | 27
shape torch.Size([2, 1, 1]) | 1500 | 50 | 28 | 50 | 67 | 27
shape torch.Size([4, 1, 1]) | 3005 | 82 | 28 | 81 | 67 | 27
shape torch.Size([8, 1, 1]) | 5999 | 145 | 28 | 140 | 71 | 27
shape torch.Size([16, 1, 1]) | 12000 | 273 | 28 | 77 | 67 | 27
shape torch.Size([32, 1, 1]) | 24000 | 522 | 28 | 78 | 67 | 27
shape torch.Size([64, 1, 1]) | 48000 | 1000 | 28 | 77 | 67 | 27
shape torch.Size([128, 1, 1]) | 96000 | 2029 | 28 | 78 | 67 | 27
shape torch.Size([512, 1, 1]) | 383300 | 8100 | 28 | 78 | 67 | 28
shape torch.Size([1024, 1, 1]) | 767500 | 16100 | 28 | 77 | 68 | 27
shape torch.Size([1, 2, 2]) | 753 | 33 | 28 | 33 | 60 | 28
shape torch.Size([2, 2, 2]) | 1500 | 50 | 28 | 50 | 67 | 28
shape torch.Size([4, 2, 2]) | 3002 | 82 | 28 | 80 | 67 | 27
shape torch.Size([8, 2, 2]) | 6000 | 145 | 28 | 144 | 67 | 27
shape torch.Size([16, 2, 2]) | 12000 | 271 | 28 | 78 | 66 | 27
shape torch.Size([32, 2, 2]) | 24120 | 524 | 28 | 78 | 69 | 28
shape torch.Size([64, 2, 2]) | 48300 | 1030 | 28 | 78 | 66 | 27
shape torch.Size([128, 2, 2]) | 96100 | 2041 | 28 | 78 | 67 | 28
shape torch.Size([512, 2, 2]) | 383000 | 8100 | 28 | 79 | 67 | 28
shape torch.Size([1024, 2, 2]) | 766100 | 16000 | 28 | 78 | 67 | 28
shape torch.Size([1, 8, 8]) | 750 | 34 | 28 | 34 | 60 | 28
shape torch.Size([2, 8, 8]) | 1500 | 50 | 28 | 50 | 67 | 28
shape torch.Size([4, 8, 8]) | 2998 | 82 | 28 | 82 | 67 | 28
shape torch.Size([8, 8, 8]) | 5990 | 146 | 28 | 150 | 66 | 28
shape torch.Size([16, 8, 8]) | 11980 | 272 | 28 | 79 | 67 | 28
shape torch.Size([32, 8, 8]) | 23970 | 530 | 28 | 79 | 67 | 28
shape torch.Size([64, 8, 8]) | 47900 | 1040 | 28 | 79 | 67 | 28
shape torch.Size([128, 8, 8]) | 96000 | 2048 | 28 | 78 | 67 | 28
shape torch.Size([512, 8, 8]) | 383700 | 8100 | 28 | 80 | 67 | 28
shape torch.Size([1024, 8, 8]) | 766200 | 16300 | 28 | 80 | 68 | 28
shape torch.Size([1, 16, 16]) | 760 | 33 | 28 | 34 | 60 | 28
shape torch.Size([2, 16, 16]) | 1500 | 50 | 28 | 50 | 67 | 28
shape torch.Size([4, 16, 16]) | 3001 | 81 | 28 | 82 | 67 | 28
shape torch.Size([8, 16, 16]) | 6000 | 145 | 28 | 140 | 67 | 28
shape torch.Size([16, 16, 16]) | 12000 | 276 | 28 | 79 | 67 | 28
shape torch.Size([32, 16, 16]) | 23870 | 549 | 28 | 79 | 67 | 28
shape torch.Size([64, 16, 16]) | 47900 | 1098 | 29 | 80 | 68 | 28
shape torch.Size([128, 16, 16]) | 95800 | 2184 | 28 | 79 | 68 | 28
shape torch.Size([512, 16, 16]) | 386900 | 8769 | 28 | 80 | 67 | 28
shape torch.Size([1024, 16, 16]) | 769800 | 17460 | 37 | 80 | 67 | 37
shape torch.Size([1, 32, 32]) | 760 | 33 | 28 | 34 | 60 | 28
shape torch.Size([2, 32, 32]) | 1500 | 50 | 28 | 50 | 67 | 29
shape torch.Size([4, 32, 32]) | 3021 | 86 | 31 | 84 | 67 | 32
shape torch.Size([8, 32, 32]) | 6040 | 167 | 32 | 167 | 67 | 32
shape torch.Size([16, 32, 32]) | 12100 | 330 | 33 | 78 | 67 | 33
shape torch.Size([32, 32, 32]) | 24150 | 662 | 35 | 78 | 66 | 35
shape torch.Size([64, 32, 32]) | 48200 | 1323 | 36 | 79 | 67 | 36
shape torch.Size([128, 32, 32]) | 97000 | 2637 | 44 | 79 | 67 | 43
shape torch.Size([512, 32, 32]) | 382500 | 10580 | 83 | 180 | 121 | 83
shape torch.Size([1024, 32, 32]) | 766600 | 22670 | 123 | 260 | 165 | 120
shape torch.Size([1, 64, 64]) | 760 | 33 | 58 | 33 | 49 | 34
shape torch.Size([2, 64, 64]) | 1520 | 60 | 60 | 60 | 67 | 59
shape torch.Size([4, 64, 64]) | 3016 | 115 | 67 | 119 | 90 | 66
shape torch.Size([8, 64, 64]) | 6120 | 230 | 67 | 233 | 153 | 68
shape torch.Size([16, 64, 64]) | 12100 | 457 | 69 | 86 | 77 | 69
shape torch.Size([32, 64, 64]) | 24000 | 912 | 74 | 95 | 80 | 74
shape torch.Size([64, 64, 64]) | 48000 | 1833 | 76 | 106 | 82 | 76
shape torch.Size([128, 64, 64]) | 95000 | 3636 | 97 | 163 | 108 | 97
shape torch.Size([512, 64, 64]) | 380600 | 15600 | 210 | 464 | 278 | 210
shape torch.Size([1024, 64, 64]) | 761200 | 31140 | 308 | 741 | 377 | 308
shape torch.Size([1, 128, 128]) | 756 | 46 | 120 | 47 | 49 | 46
shape torch.Size([2, 128, 128]) | 1500 | 91 | 123 | 89 | 67 | 92
shape torch.Size([4, 128, 128]) | 2994 | 178 | 139 | 180 | 117 | 139
shape torch.Size([8, 128, 128]) | 5960 | 350 | 140 | 354 | 214 | 142
shape torch.Size([16, 128, 128]) | 12000 | 701 | 144 | 177 | 148 | 143
shape torch.Size([32, 128, 128]) | 23870 | 1401 | 155 | 225 | 159 | 155
shape torch.Size([64, 128, 128]) | 47600 | 2948 | 179 | 288 | 184 | 180
shape torch.Size([128, 128, 128]) | 96000 | 5910 | 231 | 442 | 242 | 231
shape torch.Size([512, 128, 128]) | 381200 | 23640 | 519 | 1400 | 643 | 519
shape torch.Size([1024, 128, 128]) | 755800 | 47340 | 794 | 2436 | 930 | 794
shape torch.Size([1, 256, 256]) | 760 | 74 | 246 | 77 | 62 | 78
shape torch.Size([2, 256, 256]) | 1510 | 150 | 256 | 150 | 106 | 150
shape torch.Size([4, 256, 256]) | 3030 | 296 | 284 | 296 | 183 | 284
shape torch.Size([8, 256, 256]) | 6100 | 588 | 286 | 592 | 337 | 288
shape torch.Size([16, 256, 256]) | 12200 | 1238 | 330 | 445 | 330 | 330
shape torch.Size([32, 256, 256]) | 24430 | 2476 | 368 | 568 | 365 | 367
shape torch.Size([64, 256, 256]) | 49000 | 4950 | 415 | 800 | 408 | 414
shape torch.Size([128, 256, 256]) | 96000 | 9900 | 552 | 1330 | 561 | 553
shape torch.Size([512, 256, 256]) | 369400 | 39580 | 1410 | 4614 | 1649 | 1410
shape torch.Size([1024, 256, 256]) | 716200 | 79200 | 2270 | 8472 | 2528 | 2277
Times are in microseconds (us).
```
</details>
To generate the results below, I put the backend I wanted to test at the beginning of the function `lu_solve_kernel`, followed by a `return;`. Then I run the following script, changing the variable `name`.
<details>
<summary>
Benchmarking script
</summary>
```python
import torch
import pickle
import itertools
from functools import partial
from torch.utils.benchmark import Timer, Compare
name = "heuristic"
label = "lu_solve {}".format(name)
shapes = [1, 2, 8, 16, 32, 64, 128, 256]
batches = [(1,), (2,), (4,), (8,), (16,), (32,), (64,), (128,), (512,), (1024,)]
results = []
make_arg = partial(torch.randn, dtype=torch.float32, device="cuda")
for n, batch in itertools.product(shapes, batches):
LU, pivots = torch.linalg.lu_factor(make_arg(batch + (n, n)))
B = make_arg(batch + (n, 1))
print(LU.shape)
stmt = "torch.linalg.lu_solve(LU, pivots, B, adjoint=adjoint)"
for adjoint in (True, False):
timer = Timer(stmt,
globals=globals(),
label="linalg.lu_solve CUDA{}".format(" Adjoint" if adjoint else ""),
description=label,
sub_label=f"shape {LU.shape}",
num_threads=1)
results.append(timer.blocked_autorange())
compare = Compare(results)
compare.trim_significant_figures()
compare.print()
with open("{}_lu_solve.pickle".format(name), 'wb') as f:
pickle.dump(results, f)
```
</details>
Finally, I joined all the results with the following script:
<details>
<summary>
Script to join the results
</summary>
```python
import pickle
from torch.utils.benchmark import Timer, Compare
files = [
"looped_magma",
"looped cusolver",
"batched cublas",
"batched magma",
"unpack+solve_triangular",
"heuristic",
]
timers = []
for name in files:
with open("{}_lu_solve.pickle".format(name), 'rb') as f:
timers += pickle.load(f)
compare = Compare(timers)
compare.trim_significant_figures()
compare.print()
```
</details>
### Fix for Magma's batched lu_solve when `adjoint=True`
I also developed the following fix around MAGMA's bug, but I ended up not using it, and preferring the triangular solves over it, as they were faster. I'm leaving it here in case it's useful in the future.
<details>
<summary>
Fix for MAGMA's issue with `adjoint=True`
</summary>
```cpp
auto lu_solve_batched_magma_fn = [m](const Tensor& LU, const Tensor& pivots, const Tensor& B, TransposeType trans) {
if (trans == TransposeType::NoTranspose) {
lu_solve_batched_magma(LU, pivots, B, trans);
return;
}
// There's a bug in magma for the other cases, so we need to properly perform mT or mH on LU
// The LU of the transpose is not the transpose of the LU
// We need to do LU = LDU' = L'U' where L' = LD, U' = D^{-1}U and D = diag(U)
auto diag = LU.diagonal(0, -2, -1);
auto LU_f = LU.tril(-1).mul_(diag.unsqueeze(-2)) +
LU.triu(1).div_(diag.unsqueeze(-1));
LU_f.diagonal(0, -2, -1).copy_(diag);
if (trans == TransposeType::ConjTranspose) {
LU_f = LU_f.conj_physical();
}
LU_f.transpose(-2, -1);
// At this point LU_f is F-contiguous, because triu / tril / conj_phisical return contiguous tensors
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(LU_f.mT().is_contiguous());
// Trivial permutation
auto pivots_aux = at::arange(1, m + 1, pivots.options()).expand_as(pivots).contiguous();
lu_solve_batched_magma(LU_f, pivots_aux, B, TransposeType::NoTranspose);
// We then need to multiply B by P on the right as (PLU)^T = B iff U^TL^T = BP
// Fill `perm` with the identity permutation (perhaps batched)
// This is faster than torch.lu_unpack + matmul, as this logic is borrowed from lu_unpack
const auto perm = at::arange(m, pivots.options().dtype(kLong)).expand(pivots.sizes()).contiguous();
auto iter = TensorIteratorConfig()
.set_check_mem_overlap(false)
.check_all_same_dtype(false)
.resize_outputs(false)
.declare_static_shape(pivots.sizes(), /*squash_dim=*/pivots.dim() - 1)
.add_output(perm)
.add_input(pivots)
.build();
unpack_pivots_stub(pivots.device().type(), iter, m);
B.scatter_(-2, perm.unsqueeze(-1).expand_as(B), B.clone());
};
```
</details>
Fixes #61657
[ghstack-poisoned]
|
@pytorchmergebot merge this please |
|
Merge failed due to Matched rule superuser, but it was not reviewed yet by any of:dreiss,laurencer,Adolfo-Karim,cheetah2216,mvsampath, ... |
This PR adds `linalg.lu_solve`. While doing so, I found a bug in MAGMA
when calling the batched MAGMA backend with trans=True. We work around
that by solving the system solving two triangular systems.
We also update the heuristics for this function, as they were fairly
outdated. We found that cuSolver is king, so luckily we do not need to
rely on the buggy backend from magma for this function.
We added tests testing this function left and right. We also added tests
for the different backends. We also activated the tests for AMD, as
those should work as well.
### Benchmarking
<details>
<summary>
Benchmark Results (adjoint=False)
</summary>
```
--------------------------------------------------------------------------------------------- linalg.lu_solve CUDA ---------------------------------------------------------------------------------------------]
| lu_solve looped_magma | lu_solve looped cusolver | lu_solve batched cublas | lu_solve batched magma | lu_solve unpack+solve_triangular | lu_solve heuristic
1 threads: -------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
shape torch.Size([1, 1, 1]) | 750 | 34 | 28 | 252 | 78 | 27
shape torch.Size([2, 1, 1]) | 1500 | 50 | 28 | 239 | 85 | 27
shape torch.Size([4, 1, 1]) | 2995 | 83 | 28 | 239 | 85 | 27
shape torch.Size([8, 1, 1]) | 6000 | 146 | 28 | 239 | 85 | 27
shape torch.Size([16, 1, 1]) | 11900 | 272 | 28 | 241 | 85 | 27
shape torch.Size([32, 1, 1]) | 23880 | 524 | 28 | 244 | 85 | 27
shape torch.Size([64, 1, 1]) | 48000 | 1000 | 28 | 245 | 85 | 27
shape torch.Size([128, 1, 1]) | 96000 | 2054 | 28 | 242 | 86 | 27
shape torch.Size([512, 1, 1]) | 381900 | 8100 | 28 | 250 | 85 | 27
shape torch.Size([1024, 1, 1]) | 763800 | 16200 | 28 | 257 | 86 | 27
shape torch.Size([1, 2, 2]) | 750 | 33 | 28 | 240 | 82 | 27
shape torch.Size([2, 2, 2]) | 1500 | 51 | 28 | 240 | 88 | 27
shape torch.Size([4, 2, 2]) | 2991 | 82 | 28 | 241 | 88 | 28
shape torch.Size([8, 2, 2]) | 6000 | 150 | 28 | 241 | 88 | 27
shape torch.Size([16, 2, 2]) | 12000 | 275 | 28 | 242 | 88 | 27
shape torch.Size([32, 2, 2]) | 23980 | 530 | 28 | 246 | 90 | 28
shape torch.Size([64, 2, 2]) | 48000 | 1000 | 28 | 244 | 89 | 27
shape torch.Size([128, 2, 2]) | 96000 | 2063 | 28 | 245 | 89 | 28
shape torch.Size([512, 2, 2]) | 382000 | 8300 | 28 | 257 | 89 | 28
shape torch.Size([1024, 2, 2]) | 764000 | 20000 | 28 | 271 | 88 | 28
shape torch.Size([1, 8, 8]) | 749 | 34 | 28 | 243 | 82 | 28
shape torch.Size([2, 8, 8]) | 1500 | 50 | 28 | 244 | 89 | 28
shape torch.Size([4, 8, 8]) | 2988 | 83 | 28 | 244 | 89 | 28
shape torch.Size([8, 8, 8]) | 5980 | 150 | 28 | 245 | 88 | 28
shape torch.Size([16, 8, 8]) | 12000 | 278 | 28 | 246 | 89 | 28
shape torch.Size([32, 8, 8]) | 23910 | 536 | 28 | 249 | 89 | 28
shape torch.Size([64, 8, 8]) | 47800 | 1100 | 28 | 247 | 89 | 28
shape torch.Size([128, 8, 8]) | 96000 | 2075 | 28 | 248 | 89 | 28
shape torch.Size([512, 8, 8]) | 382100 | 8300 | 28 | 270 | 89 | 28
shape torch.Size([1024, 8, 8]) | 764100 | 16400 | 28 | 291 | 89 | 28
shape torch.Size([1, 16, 16]) | 750 | 33 | 28 | 248 | 82 | 28
shape torch.Size([2, 16, 16]) | 1500 | 50 | 28 | 250 | 89 | 28
shape torch.Size([4, 16, 16]) | 2996 | 83 | 28 | 250 | 89 | 28
shape torch.Size([8, 16, 16]) | 5980 | 147 | 28 | 251 | 90 | 28
shape torch.Size([16, 16, 16]) | 11900 | 274 | 28 | 251 | 89 | 28
shape torch.Size([32, 16, 16]) | 24040 | 527 | 28 | 252 | 89 | 28
shape torch.Size([64, 16, 16]) | 47800 | 1037 | 28 | 251 | 89 | 28
shape torch.Size([128, 16, 16]) | 95600 | 2044 | 28 | 252 | 89 | 28
shape torch.Size([512, 16, 16]) | 388200 | 8100 | 28 | 280 | 88 | 28
shape torch.Size([1024, 16, 16]) | 769700 | 16000 | 28 | 322 | 89 | 28
shape torch.Size([1, 32, 32]) | 760 | 33 | 28 | 255 | 82 | 28
shape torch.Size([2, 32, 32]) | 1510 | 50 | 28 | 256 | 89 | 28
shape torch.Size([4, 32, 32]) | 3022 | 82 | 31 | 256 | 89 | 30
shape torch.Size([8, 32, 32]) | 6000 | 140 | 31 | 257 | 89 | 31
shape torch.Size([16, 32, 32]) | 12000 | 281 | 31 | 258 | 89 | 31
shape torch.Size([32, 32, 32]) | 24150 | 563 | 35 | 258 | 89 | 35
shape torch.Size([64, 32, 32]) | 48300 | 1119 | 36 | 258 | 90 | 36
shape torch.Size([128, 32, 32]) | 96500 | 2235 | 43 | 261 | 88 | 43
shape torch.Size([512, 32, 32]) | 383100 | 8930 | 82 | 317 | 137 | 82
shape torch.Size([1024, 32, 32]) | 766300 | 19200 | 122 | 400 | 187 | 122
shape torch.Size([1, 64, 64]) | 760 | 33 | 55 | 272 | 71 | 34
shape torch.Size([2, 64, 64]) | 1500 | 52 | 58 | 273 | 90 | 52
shape torch.Size([4, 64, 64]) | 3127 | 102 | 65 | 273 | 110 | 65
shape torch.Size([8, 64, 64]) | 6070 | 201 | 65 | 275 | 160 | 65
shape torch.Size([16, 64, 64]) | 12000 | 399 | 66 | 274 | 88 | 67
shape torch.Size([32, 64, 64]) | 23900 | 796 | 73 | 275 | 91 | 73
shape torch.Size([64, 64, 64]) | 48000 | 1594 | 75 | 283 | 94 | 76
shape torch.Size([128, 64, 64]) | 95000 | 3177 | 96 | 292 | 122 | 96
shape torch.Size([512, 64, 64]) | 379300 | 13520 | 208 | 426 | 313 | 208
shape torch.Size([1024, 64, 64]) | 758700 | 27100 | 306 | 570 | 437 | 306
shape torch.Size([1, 128, 128]) | 750 | 42 | 115 | 306 | 71 | 42
shape torch.Size([2, 128, 128]) | 1500 | 82 | 122 | 307 | 90 | 83
shape torch.Size([4, 128, 128]) | 2966 | 162 | 136 | 307 | 117 | 136
shape torch.Size([8, 128, 128]) | 5930 | 317 | 137 | 308 | 202 | 138
shape torch.Size([16, 128, 128]) | 12000 | 635 | 143 | 316 | 162 | 143
shape torch.Size([32, 128, 128]) | 23700 | 1266 | 152 | 322 | 168 | 152
shape torch.Size([64, 128, 128]) | 48000 | 2668 | 177 | 337 | 196 | 177
shape torch.Size([128, 128, 128]) | 96000 | 5366 | 228 | 365 | 260 | 228
shape torch.Size([512, 128, 128]) | 379400 | 21490 | 502 | 620 | 669 | 502
shape torch.Size([1024, 128, 128]) | 755700 | 43040 | 764 | 903 | 978 | 770
shape torch.Size([1, 256, 256]) | 750 | 70 | 235 | 383 | 78 | 72
shape torch.Size([2, 256, 256]) | 2000 | 138 | 250 | 384 | 120 | 139
shape torch.Size([4, 256, 256]) | 2988 | 277 | 279 | 404 | 191 | 278
shape torch.Size([8, 256, 256]) | 6100 | 546 | 283 | 420 | 332 | 286
shape torch.Size([16, 256, 256]) | 12100 | 1149 | 330 | 441 | 364 | 330
shape torch.Size([32, 256, 256]) | 24040 | 2303 | 359 | 453 | 401 | 360
shape torch.Size([64, 256, 256]) | 48000 | 4626 | 408 | 472 | 466 | 408
shape torch.Size([128, 256, 256]) | 94700 | 9247 | 543 | 543 | 658 | 543
shape torch.Size([512, 256, 256]) | 372000 | 37030 | 1310 | 1185 | 1896 | 1310
shape torch.Size([1024, 256, 256]) | 747200 | 74100 | 2116 | 1910 | 3051 | 2122
```
</details>
<details>
<summary>
Benchmark Results (adjoint=True)
</summary>
```
[----------------------------------------------------------------------------------------- linalg.lu_solve CUDA Adjoint -----------------------------------------------------------------------------------------]
| lu_solve looped_magma | lu_solve looped cusolver | lu_solve batched cublas | lu_solve batched magma | lu_solve unpack+solve_triangular | lu_solve heuristic
1 threads: -------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
shape torch.Size([1, 1, 1]) | 749 | 34 | 28 | 33 | 60 | 27
shape torch.Size([2, 1, 1]) | 1500 | 50 | 28 | 50 | 67 | 27
shape torch.Size([4, 1, 1]) | 3005 | 82 | 28 | 81 | 67 | 27
shape torch.Size([8, 1, 1]) | 5999 | 145 | 28 | 140 | 71 | 27
shape torch.Size([16, 1, 1]) | 12000 | 273 | 28 | 77 | 67 | 27
shape torch.Size([32, 1, 1]) | 24000 | 522 | 28 | 78 | 67 | 27
shape torch.Size([64, 1, 1]) | 48000 | 1000 | 28 | 77 | 67 | 27
shape torch.Size([128, 1, 1]) | 96000 | 2029 | 28 | 78 | 67 | 27
shape torch.Size([512, 1, 1]) | 383300 | 8100 | 28 | 78 | 67 | 28
shape torch.Size([1024, 1, 1]) | 767500 | 16100 | 28 | 77 | 68 | 27
shape torch.Size([1, 2, 2]) | 753 | 33 | 28 | 33 | 60 | 28
shape torch.Size([2, 2, 2]) | 1500 | 50 | 28 | 50 | 67 | 28
shape torch.Size([4, 2, 2]) | 3002 | 82 | 28 | 80 | 67 | 27
shape torch.Size([8, 2, 2]) | 6000 | 145 | 28 | 144 | 67 | 27
shape torch.Size([16, 2, 2]) | 12000 | 271 | 28 | 78 | 66 | 27
shape torch.Size([32, 2, 2]) | 24120 | 524 | 28 | 78 | 69 | 28
shape torch.Size([64, 2, 2]) | 48300 | 1030 | 28 | 78 | 66 | 27
shape torch.Size([128, 2, 2]) | 96100 | 2041 | 28 | 78 | 67 | 28
shape torch.Size([512, 2, 2]) | 383000 | 8100 | 28 | 79 | 67 | 28
shape torch.Size([1024, 2, 2]) | 766100 | 16000 | 28 | 78 | 67 | 28
shape torch.Size([1, 8, 8]) | 750 | 34 | 28 | 34 | 60 | 28
shape torch.Size([2, 8, 8]) | 1500 | 50 | 28 | 50 | 67 | 28
shape torch.Size([4, 8, 8]) | 2998 | 82 | 28 | 82 | 67 | 28
shape torch.Size([8, 8, 8]) | 5990 | 146 | 28 | 150 | 66 | 28
shape torch.Size([16, 8, 8]) | 11980 | 272 | 28 | 79 | 67 | 28
shape torch.Size([32, 8, 8]) | 23970 | 530 | 28 | 79 | 67 | 28
shape torch.Size([64, 8, 8]) | 47900 | 1040 | 28 | 79 | 67 | 28
shape torch.Size([128, 8, 8]) | 96000 | 2048 | 28 | 78 | 67 | 28
shape torch.Size([512, 8, 8]) | 383700 | 8100 | 28 | 80 | 67 | 28
shape torch.Size([1024, 8, 8]) | 766200 | 16300 | 28 | 80 | 68 | 28
shape torch.Size([1, 16, 16]) | 760 | 33 | 28 | 34 | 60 | 28
shape torch.Size([2, 16, 16]) | 1500 | 50 | 28 | 50 | 67 | 28
shape torch.Size([4, 16, 16]) | 3001 | 81 | 28 | 82 | 67 | 28
shape torch.Size([8, 16, 16]) | 6000 | 145 | 28 | 140 | 67 | 28
shape torch.Size([16, 16, 16]) | 12000 | 276 | 28 | 79 | 67 | 28
shape torch.Size([32, 16, 16]) | 23870 | 549 | 28 | 79 | 67 | 28
shape torch.Size([64, 16, 16]) | 47900 | 1098 | 29 | 80 | 68 | 28
shape torch.Size([128, 16, 16]) | 95800 | 2184 | 28 | 79 | 68 | 28
shape torch.Size([512, 16, 16]) | 386900 | 8769 | 28 | 80 | 67 | 28
shape torch.Size([1024, 16, 16]) | 769800 | 17460 | 37 | 80 | 67 | 37
shape torch.Size([1, 32, 32]) | 760 | 33 | 28 | 34 | 60 | 28
shape torch.Size([2, 32, 32]) | 1500 | 50 | 28 | 50 | 67 | 29
shape torch.Size([4, 32, 32]) | 3021 | 86 | 31 | 84 | 67 | 32
shape torch.Size([8, 32, 32]) | 6040 | 167 | 32 | 167 | 67 | 32
shape torch.Size([16, 32, 32]) | 12100 | 330 | 33 | 78 | 67 | 33
shape torch.Size([32, 32, 32]) | 24150 | 662 | 35 | 78 | 66 | 35
shape torch.Size([64, 32, 32]) | 48200 | 1323 | 36 | 79 | 67 | 36
shape torch.Size([128, 32, 32]) | 97000 | 2637 | 44 | 79 | 67 | 43
shape torch.Size([512, 32, 32]) | 382500 | 10580 | 83 | 180 | 121 | 83
shape torch.Size([1024, 32, 32]) | 766600 | 22670 | 123 | 260 | 165 | 120
shape torch.Size([1, 64, 64]) | 760 | 33 | 58 | 33 | 49 | 34
shape torch.Size([2, 64, 64]) | 1520 | 60 | 60 | 60 | 67 | 59
shape torch.Size([4, 64, 64]) | 3016 | 115 | 67 | 119 | 90 | 66
shape torch.Size([8, 64, 64]) | 6120 | 230 | 67 | 233 | 153 | 68
shape torch.Size([16, 64, 64]) | 12100 | 457 | 69 | 86 | 77 | 69
shape torch.Size([32, 64, 64]) | 24000 | 912 | 74 | 95 | 80 | 74
shape torch.Size([64, 64, 64]) | 48000 | 1833 | 76 | 106 | 82 | 76
shape torch.Size([128, 64, 64]) | 95000 | 3636 | 97 | 163 | 108 | 97
shape torch.Size([512, 64, 64]) | 380600 | 15600 | 210 | 464 | 278 | 210
shape torch.Size([1024, 64, 64]) | 761200 | 31140 | 308 | 741 | 377 | 308
shape torch.Size([1, 128, 128]) | 756 | 46 | 120 | 47 | 49 | 46
shape torch.Size([2, 128, 128]) | 1500 | 91 | 123 | 89 | 67 | 92
shape torch.Size([4, 128, 128]) | 2994 | 178 | 139 | 180 | 117 | 139
shape torch.Size([8, 128, 128]) | 5960 | 350 | 140 | 354 | 214 | 142
shape torch.Size([16, 128, 128]) | 12000 | 701 | 144 | 177 | 148 | 143
shape torch.Size([32, 128, 128]) | 23870 | 1401 | 155 | 225 | 159 | 155
shape torch.Size([64, 128, 128]) | 47600 | 2948 | 179 | 288 | 184 | 180
shape torch.Size([128, 128, 128]) | 96000 | 5910 | 231 | 442 | 242 | 231
shape torch.Size([512, 128, 128]) | 381200 | 23640 | 519 | 1400 | 643 | 519
shape torch.Size([1024, 128, 128]) | 755800 | 47340 | 794 | 2436 | 930 | 794
shape torch.Size([1, 256, 256]) | 760 | 74 | 246 | 77 | 62 | 78
shape torch.Size([2, 256, 256]) | 1510 | 150 | 256 | 150 | 106 | 150
shape torch.Size([4, 256, 256]) | 3030 | 296 | 284 | 296 | 183 | 284
shape torch.Size([8, 256, 256]) | 6100 | 588 | 286 | 592 | 337 | 288
shape torch.Size([16, 256, 256]) | 12200 | 1238 | 330 | 445 | 330 | 330
shape torch.Size([32, 256, 256]) | 24430 | 2476 | 368 | 568 | 365 | 367
shape torch.Size([64, 256, 256]) | 49000 | 4950 | 415 | 800 | 408 | 414
shape torch.Size([128, 256, 256]) | 96000 | 9900 | 552 | 1330 | 561 | 553
shape torch.Size([512, 256, 256]) | 369400 | 39580 | 1410 | 4614 | 1649 | 1410
shape torch.Size([1024, 256, 256]) | 716200 | 79200 | 2270 | 8472 | 2528 | 2277
Times are in microseconds (us).
```
</details>
To generate the results below, I put the backend I wanted to test at the beginning of the function `lu_solve_kernel`, followed by a `return;`. Then I run the following script, changing the variable `name`.
<details>
<summary>
Benchmarking script
</summary>
```python
import torch
import pickle
import itertools
from functools import partial
from torch.utils.benchmark import Timer, Compare
name = "heuristic"
label = "lu_solve {}".format(name)
shapes = [1, 2, 8, 16, 32, 64, 128, 256]
batches = [(1,), (2,), (4,), (8,), (16,), (32,), (64,), (128,), (512,), (1024,)]
results = []
make_arg = partial(torch.randn, dtype=torch.float32, device="cuda")
for n, batch in itertools.product(shapes, batches):
LU, pivots = torch.linalg.lu_factor(make_arg(batch + (n, n)))
B = make_arg(batch + (n, 1))
print(LU.shape)
stmt = "torch.linalg.lu_solve(LU, pivots, B, adjoint=adjoint)"
for adjoint in (True, False):
timer = Timer(stmt,
globals=globals(),
label="linalg.lu_solve CUDA{}".format(" Adjoint" if adjoint else ""),
description=label,
sub_label=f"shape {LU.shape}",
num_threads=1)
results.append(timer.blocked_autorange())
compare = Compare(results)
compare.trim_significant_figures()
compare.print()
with open("{}_lu_solve.pickle".format(name), 'wb') as f:
pickle.dump(results, f)
```
</details>
Finally, I joined all the results with the following script:
<details>
<summary>
Script to join the results
</summary>
```python
import pickle
from torch.utils.benchmark import Timer, Compare
files = [
"looped_magma",
"looped cusolver",
"batched cublas",
"batched magma",
"unpack+solve_triangular",
"heuristic",
]
timers = []
for name in files:
with open("{}_lu_solve.pickle".format(name), 'rb') as f:
timers += pickle.load(f)
compare = Compare(timers)
compare.trim_significant_figures()
compare.print()
```
</details>
### Fix for Magma's batched lu_solve when `adjoint=True`
I also developed the following fix around MAGMA's bug, but I ended up not using it, and preferring the triangular solves over it, as they were faster. I'm leaving it here in case it's useful in the future.
<details>
<summary>
Fix for MAGMA's issue with `adjoint=True`
</summary>
```cpp
auto lu_solve_batched_magma_fn = [m](const Tensor& LU, const Tensor& pivots, const Tensor& B, TransposeType trans) {
if (trans == TransposeType::NoTranspose) {
lu_solve_batched_magma(LU, pivots, B, trans);
return;
}
// There's a bug in magma for the other cases, so we need to properly perform mT or mH on LU
// The LU of the transpose is not the transpose of the LU
// We need to do LU = LDU' = L'U' where L' = LD, U' = D^{-1}U and D = diag(U)
auto diag = LU.diagonal(0, -2, -1);
auto LU_f = LU.tril(-1).mul_(diag.unsqueeze(-2)) +
LU.triu(1).div_(diag.unsqueeze(-1));
LU_f.diagonal(0, -2, -1).copy_(diag);
if (trans == TransposeType::ConjTranspose) {
LU_f = LU_f.conj_physical();
}
LU_f.transpose(-2, -1);
// At this point LU_f is F-contiguous, because triu / tril / conj_phisical return contiguous tensors
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(LU_f.mT().is_contiguous());
// Trivial permutation
auto pivots_aux = at::arange(1, m + 1, pivots.options()).expand_as(pivots).contiguous();
lu_solve_batched_magma(LU_f, pivots_aux, B, TransposeType::NoTranspose);
// We then need to multiply B by P on the right as (PLU)^T = B iff U^TL^T = BP
// Fill `perm` with the identity permutation (perhaps batched)
// This is faster than torch.lu_unpack + matmul, as this logic is borrowed from lu_unpack
const auto perm = at::arange(m, pivots.options().dtype(kLong)).expand(pivots.sizes()).contiguous();
auto iter = TensorIteratorConfig()
.set_check_mem_overlap(false)
.check_all_same_dtype(false)
.resize_outputs(false)
.declare_static_shape(pivots.sizes(), /*squash_dim=*/pivots.dim() - 1)
.add_output(perm)
.add_input(pivots)
.build();
unpack_pivots_stub(pivots.device().type(), iter, m);
B.scatter_(-2, perm.unsqueeze(-1).expand_as(B), B.clone());
};
```
</details>
Fixes #61657
[ghstack-poisoned]
This PR adds `linalg.lu_solve`. While doing so, I found a bug in MAGMA
when calling the batched MAGMA backend with trans=True. We work around
that by solving the system solving two triangular systems.
We also update the heuristics for this function, as they were fairly
outdated. We found that cuSolver is king, so luckily we do not need to
rely on the buggy backend from magma for this function.
We added tests testing this function left and right. We also added tests
for the different backends. We also activated the tests for AMD, as
those should work as well.
### Benchmarking
<details>
<summary>
Benchmark Results (adjoint=False)
</summary>
```
--------------------------------------------------------------------------------------------- linalg.lu_solve CUDA ---------------------------------------------------------------------------------------------]
| lu_solve looped_magma | lu_solve looped cusolver | lu_solve batched cublas | lu_solve batched magma | lu_solve unpack+solve_triangular | lu_solve heuristic
1 threads: -------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
shape torch.Size([1, 1, 1]) | 750 | 34 | 28 | 252 | 78 | 27
shape torch.Size([2, 1, 1]) | 1500 | 50 | 28 | 239 | 85 | 27
shape torch.Size([4, 1, 1]) | 2995 | 83 | 28 | 239 | 85 | 27
shape torch.Size([8, 1, 1]) | 6000 | 146 | 28 | 239 | 85 | 27
shape torch.Size([16, 1, 1]) | 11900 | 272 | 28 | 241 | 85 | 27
shape torch.Size([32, 1, 1]) | 23880 | 524 | 28 | 244 | 85 | 27
shape torch.Size([64, 1, 1]) | 48000 | 1000 | 28 | 245 | 85 | 27
shape torch.Size([128, 1, 1]) | 96000 | 2054 | 28 | 242 | 86 | 27
shape torch.Size([512, 1, 1]) | 381900 | 8100 | 28 | 250 | 85 | 27
shape torch.Size([1024, 1, 1]) | 763800 | 16200 | 28 | 257 | 86 | 27
shape torch.Size([1, 2, 2]) | 750 | 33 | 28 | 240 | 82 | 27
shape torch.Size([2, 2, 2]) | 1500 | 51 | 28 | 240 | 88 | 27
shape torch.Size([4, 2, 2]) | 2991 | 82 | 28 | 241 | 88 | 28
shape torch.Size([8, 2, 2]) | 6000 | 150 | 28 | 241 | 88 | 27
shape torch.Size([16, 2, 2]) | 12000 | 275 | 28 | 242 | 88 | 27
shape torch.Size([32, 2, 2]) | 23980 | 530 | 28 | 246 | 90 | 28
shape torch.Size([64, 2, 2]) | 48000 | 1000 | 28 | 244 | 89 | 27
shape torch.Size([128, 2, 2]) | 96000 | 2063 | 28 | 245 | 89 | 28
shape torch.Size([512, 2, 2]) | 382000 | 8300 | 28 | 257 | 89 | 28
shape torch.Size([1024, 2, 2]) | 764000 | 20000 | 28 | 271 | 88 | 28
shape torch.Size([1, 8, 8]) | 749 | 34 | 28 | 243 | 82 | 28
shape torch.Size([2, 8, 8]) | 1500 | 50 | 28 | 244 | 89 | 28
shape torch.Size([4, 8, 8]) | 2988 | 83 | 28 | 244 | 89 | 28
shape torch.Size([8, 8, 8]) | 5980 | 150 | 28 | 245 | 88 | 28
shape torch.Size([16, 8, 8]) | 12000 | 278 | 28 | 246 | 89 | 28
shape torch.Size([32, 8, 8]) | 23910 | 536 | 28 | 249 | 89 | 28
shape torch.Size([64, 8, 8]) | 47800 | 1100 | 28 | 247 | 89 | 28
shape torch.Size([128, 8, 8]) | 96000 | 2075 | 28 | 248 | 89 | 28
shape torch.Size([512, 8, 8]) | 382100 | 8300 | 28 | 270 | 89 | 28
shape torch.Size([1024, 8, 8]) | 764100 | 16400 | 28 | 291 | 89 | 28
shape torch.Size([1, 16, 16]) | 750 | 33 | 28 | 248 | 82 | 28
shape torch.Size([2, 16, 16]) | 1500 | 50 | 28 | 250 | 89 | 28
shape torch.Size([4, 16, 16]) | 2996 | 83 | 28 | 250 | 89 | 28
shape torch.Size([8, 16, 16]) | 5980 | 147 | 28 | 251 | 90 | 28
shape torch.Size([16, 16, 16]) | 11900 | 274 | 28 | 251 | 89 | 28
shape torch.Size([32, 16, 16]) | 24040 | 527 | 28 | 252 | 89 | 28
shape torch.Size([64, 16, 16]) | 47800 | 1037 | 28 | 251 | 89 | 28
shape torch.Size([128, 16, 16]) | 95600 | 2044 | 28 | 252 | 89 | 28
shape torch.Size([512, 16, 16]) | 388200 | 8100 | 28 | 280 | 88 | 28
shape torch.Size([1024, 16, 16]) | 769700 | 16000 | 28 | 322 | 89 | 28
shape torch.Size([1, 32, 32]) | 760 | 33 | 28 | 255 | 82 | 28
shape torch.Size([2, 32, 32]) | 1510 | 50 | 28 | 256 | 89 | 28
shape torch.Size([4, 32, 32]) | 3022 | 82 | 31 | 256 | 89 | 30
shape torch.Size([8, 32, 32]) | 6000 | 140 | 31 | 257 | 89 | 31
shape torch.Size([16, 32, 32]) | 12000 | 281 | 31 | 258 | 89 | 31
shape torch.Size([32, 32, 32]) | 24150 | 563 | 35 | 258 | 89 | 35
shape torch.Size([64, 32, 32]) | 48300 | 1119 | 36 | 258 | 90 | 36
shape torch.Size([128, 32, 32]) | 96500 | 2235 | 43 | 261 | 88 | 43
shape torch.Size([512, 32, 32]) | 383100 | 8930 | 82 | 317 | 137 | 82
shape torch.Size([1024, 32, 32]) | 766300 | 19200 | 122 | 400 | 187 | 122
shape torch.Size([1, 64, 64]) | 760 | 33 | 55 | 272 | 71 | 34
shape torch.Size([2, 64, 64]) | 1500 | 52 | 58 | 273 | 90 | 52
shape torch.Size([4, 64, 64]) | 3127 | 102 | 65 | 273 | 110 | 65
shape torch.Size([8, 64, 64]) | 6070 | 201 | 65 | 275 | 160 | 65
shape torch.Size([16, 64, 64]) | 12000 | 399 | 66 | 274 | 88 | 67
shape torch.Size([32, 64, 64]) | 23900 | 796 | 73 | 275 | 91 | 73
shape torch.Size([64, 64, 64]) | 48000 | 1594 | 75 | 283 | 94 | 76
shape torch.Size([128, 64, 64]) | 95000 | 3177 | 96 | 292 | 122 | 96
shape torch.Size([512, 64, 64]) | 379300 | 13520 | 208 | 426 | 313 | 208
shape torch.Size([1024, 64, 64]) | 758700 | 27100 | 306 | 570 | 437 | 306
shape torch.Size([1, 128, 128]) | 750 | 42 | 115 | 306 | 71 | 42
shape torch.Size([2, 128, 128]) | 1500 | 82 | 122 | 307 | 90 | 83
shape torch.Size([4, 128, 128]) | 2966 | 162 | 136 | 307 | 117 | 136
shape torch.Size([8, 128, 128]) | 5930 | 317 | 137 | 308 | 202 | 138
shape torch.Size([16, 128, 128]) | 12000 | 635 | 143 | 316 | 162 | 143
shape torch.Size([32, 128, 128]) | 23700 | 1266 | 152 | 322 | 168 | 152
shape torch.Size([64, 128, 128]) | 48000 | 2668 | 177 | 337 | 196 | 177
shape torch.Size([128, 128, 128]) | 96000 | 5366 | 228 | 365 | 260 | 228
shape torch.Size([512, 128, 128]) | 379400 | 21490 | 502 | 620 | 669 | 502
shape torch.Size([1024, 128, 128]) | 755700 | 43040 | 764 | 903 | 978 | 770
shape torch.Size([1, 256, 256]) | 750 | 70 | 235 | 383 | 78 | 72
shape torch.Size([2, 256, 256]) | 2000 | 138 | 250 | 384 | 120 | 139
shape torch.Size([4, 256, 256]) | 2988 | 277 | 279 | 404 | 191 | 278
shape torch.Size([8, 256, 256]) | 6100 | 546 | 283 | 420 | 332 | 286
shape torch.Size([16, 256, 256]) | 12100 | 1149 | 330 | 441 | 364 | 330
shape torch.Size([32, 256, 256]) | 24040 | 2303 | 359 | 453 | 401 | 360
shape torch.Size([64, 256, 256]) | 48000 | 4626 | 408 | 472 | 466 | 408
shape torch.Size([128, 256, 256]) | 94700 | 9247 | 543 | 543 | 658 | 543
shape torch.Size([512, 256, 256]) | 372000 | 37030 | 1310 | 1185 | 1896 | 1310
shape torch.Size([1024, 256, 256]) | 747200 | 74100 | 2116 | 1910 | 3051 | 2122
```
</details>
<details>
<summary>
Benchmark Results (adjoint=True)
</summary>
```
[----------------------------------------------------------------------------------------- linalg.lu_solve CUDA Adjoint -----------------------------------------------------------------------------------------]
| lu_solve looped_magma | lu_solve looped cusolver | lu_solve batched cublas | lu_solve batched magma | lu_solve unpack+solve_triangular | lu_solve heuristic
1 threads: -------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
shape torch.Size([1, 1, 1]) | 749 | 34 | 28 | 33 | 60 | 27
shape torch.Size([2, 1, 1]) | 1500 | 50 | 28 | 50 | 67 | 27
shape torch.Size([4, 1, 1]) | 3005 | 82 | 28 | 81 | 67 | 27
shape torch.Size([8, 1, 1]) | 5999 | 145 | 28 | 140 | 71 | 27
shape torch.Size([16, 1, 1]) | 12000 | 273 | 28 | 77 | 67 | 27
shape torch.Size([32, 1, 1]) | 24000 | 522 | 28 | 78 | 67 | 27
shape torch.Size([64, 1, 1]) | 48000 | 1000 | 28 | 77 | 67 | 27
shape torch.Size([128, 1, 1]) | 96000 | 2029 | 28 | 78 | 67 | 27
shape torch.Size([512, 1, 1]) | 383300 | 8100 | 28 | 78 | 67 | 28
shape torch.Size([1024, 1, 1]) | 767500 | 16100 | 28 | 77 | 68 | 27
shape torch.Size([1, 2, 2]) | 753 | 33 | 28 | 33 | 60 | 28
shape torch.Size([2, 2, 2]) | 1500 | 50 | 28 | 50 | 67 | 28
shape torch.Size([4, 2, 2]) | 3002 | 82 | 28 | 80 | 67 | 27
shape torch.Size([8, 2, 2]) | 6000 | 145 | 28 | 144 | 67 | 27
shape torch.Size([16, 2, 2]) | 12000 | 271 | 28 | 78 | 66 | 27
shape torch.Size([32, 2, 2]) | 24120 | 524 | 28 | 78 | 69 | 28
shape torch.Size([64, 2, 2]) | 48300 | 1030 | 28 | 78 | 66 | 27
shape torch.Size([128, 2, 2]) | 96100 | 2041 | 28 | 78 | 67 | 28
shape torch.Size([512, 2, 2]) | 383000 | 8100 | 28 | 79 | 67 | 28
shape torch.Size([1024, 2, 2]) | 766100 | 16000 | 28 | 78 | 67 | 28
shape torch.Size([1, 8, 8]) | 750 | 34 | 28 | 34 | 60 | 28
shape torch.Size([2, 8, 8]) | 1500 | 50 | 28 | 50 | 67 | 28
shape torch.Size([4, 8, 8]) | 2998 | 82 | 28 | 82 | 67 | 28
shape torch.Size([8, 8, 8]) | 5990 | 146 | 28 | 150 | 66 | 28
shape torch.Size([16, 8, 8]) | 11980 | 272 | 28 | 79 | 67 | 28
shape torch.Size([32, 8, 8]) | 23970 | 530 | 28 | 79 | 67 | 28
shape torch.Size([64, 8, 8]) | 47900 | 1040 | 28 | 79 | 67 | 28
shape torch.Size([128, 8, 8]) | 96000 | 2048 | 28 | 78 | 67 | 28
shape torch.Size([512, 8, 8]) | 383700 | 8100 | 28 | 80 | 67 | 28
shape torch.Size([1024, 8, 8]) | 766200 | 16300 | 28 | 80 | 68 | 28
shape torch.Size([1, 16, 16]) | 760 | 33 | 28 | 34 | 60 | 28
shape torch.Size([2, 16, 16]) | 1500 | 50 | 28 | 50 | 67 | 28
shape torch.Size([4, 16, 16]) | 3001 | 81 | 28 | 82 | 67 | 28
shape torch.Size([8, 16, 16]) | 6000 | 145 | 28 | 140 | 67 | 28
shape torch.Size([16, 16, 16]) | 12000 | 276 | 28 | 79 | 67 | 28
shape torch.Size([32, 16, 16]) | 23870 | 549 | 28 | 79 | 67 | 28
shape torch.Size([64, 16, 16]) | 47900 | 1098 | 29 | 80 | 68 | 28
shape torch.Size([128, 16, 16]) | 95800 | 2184 | 28 | 79 | 68 | 28
shape torch.Size([512, 16, 16]) | 386900 | 8769 | 28 | 80 | 67 | 28
shape torch.Size([1024, 16, 16]) | 769800 | 17460 | 37 | 80 | 67 | 37
shape torch.Size([1, 32, 32]) | 760 | 33 | 28 | 34 | 60 | 28
shape torch.Size([2, 32, 32]) | 1500 | 50 | 28 | 50 | 67 | 29
shape torch.Size([4, 32, 32]) | 3021 | 86 | 31 | 84 | 67 | 32
shape torch.Size([8, 32, 32]) | 6040 | 167 | 32 | 167 | 67 | 32
shape torch.Size([16, 32, 32]) | 12100 | 330 | 33 | 78 | 67 | 33
shape torch.Size([32, 32, 32]) | 24150 | 662 | 35 | 78 | 66 | 35
shape torch.Size([64, 32, 32]) | 48200 | 1323 | 36 | 79 | 67 | 36
shape torch.Size([128, 32, 32]) | 97000 | 2637 | 44 | 79 | 67 | 43
shape torch.Size([512, 32, 32]) | 382500 | 10580 | 83 | 180 | 121 | 83
shape torch.Size([1024, 32, 32]) | 766600 | 22670 | 123 | 260 | 165 | 120
shape torch.Size([1, 64, 64]) | 760 | 33 | 58 | 33 | 49 | 34
shape torch.Size([2, 64, 64]) | 1520 | 60 | 60 | 60 | 67 | 59
shape torch.Size([4, 64, 64]) | 3016 | 115 | 67 | 119 | 90 | 66
shape torch.Size([8, 64, 64]) | 6120 | 230 | 67 | 233 | 153 | 68
shape torch.Size([16, 64, 64]) | 12100 | 457 | 69 | 86 | 77 | 69
shape torch.Size([32, 64, 64]) | 24000 | 912 | 74 | 95 | 80 | 74
shape torch.Size([64, 64, 64]) | 48000 | 1833 | 76 | 106 | 82 | 76
shape torch.Size([128, 64, 64]) | 95000 | 3636 | 97 | 163 | 108 | 97
shape torch.Size([512, 64, 64]) | 380600 | 15600 | 210 | 464 | 278 | 210
shape torch.Size([1024, 64, 64]) | 761200 | 31140 | 308 | 741 | 377 | 308
shape torch.Size([1, 128, 128]) | 756 | 46 | 120 | 47 | 49 | 46
shape torch.Size([2, 128, 128]) | 1500 | 91 | 123 | 89 | 67 | 92
shape torch.Size([4, 128, 128]) | 2994 | 178 | 139 | 180 | 117 | 139
shape torch.Size([8, 128, 128]) | 5960 | 350 | 140 | 354 | 214 | 142
shape torch.Size([16, 128, 128]) | 12000 | 701 | 144 | 177 | 148 | 143
shape torch.Size([32, 128, 128]) | 23870 | 1401 | 155 | 225 | 159 | 155
shape torch.Size([64, 128, 128]) | 47600 | 2948 | 179 | 288 | 184 | 180
shape torch.Size([128, 128, 128]) | 96000 | 5910 | 231 | 442 | 242 | 231
shape torch.Size([512, 128, 128]) | 381200 | 23640 | 519 | 1400 | 643 | 519
shape torch.Size([1024, 128, 128]) | 755800 | 47340 | 794 | 2436 | 930 | 794
shape torch.Size([1, 256, 256]) | 760 | 74 | 246 | 77 | 62 | 78
shape torch.Size([2, 256, 256]) | 1510 | 150 | 256 | 150 | 106 | 150
shape torch.Size([4, 256, 256]) | 3030 | 296 | 284 | 296 | 183 | 284
shape torch.Size([8, 256, 256]) | 6100 | 588 | 286 | 592 | 337 | 288
shape torch.Size([16, 256, 256]) | 12200 | 1238 | 330 | 445 | 330 | 330
shape torch.Size([32, 256, 256]) | 24430 | 2476 | 368 | 568 | 365 | 367
shape torch.Size([64, 256, 256]) | 49000 | 4950 | 415 | 800 | 408 | 414
shape torch.Size([128, 256, 256]) | 96000 | 9900 | 552 | 1330 | 561 | 553
shape torch.Size([512, 256, 256]) | 369400 | 39580 | 1410 | 4614 | 1649 | 1410
shape torch.Size([1024, 256, 256]) | 716200 | 79200 | 2270 | 8472 | 2528 | 2277
Times are in microseconds (us).
```
</details>
To generate the results below, I put the backend I wanted to test at the beginning of the function `lu_solve_kernel`, followed by a `return;`. Then I run the following script, changing the variable `name`.
<details>
<summary>
Benchmarking script
</summary>
```python
import torch
import pickle
import itertools
from functools import partial
from torch.utils.benchmark import Timer, Compare
name = "heuristic"
label = "lu_solve {}".format(name)
shapes = [1, 2, 8, 16, 32, 64, 128, 256]
batches = [(1,), (2,), (4,), (8,), (16,), (32,), (64,), (128,), (512,), (1024,)]
results = []
make_arg = partial(torch.randn, dtype=torch.float32, device="cuda")
for n, batch in itertools.product(shapes, batches):
LU, pivots = torch.linalg.lu_factor(make_arg(batch + (n, n)))
B = make_arg(batch + (n, 1))
print(LU.shape)
stmt = "torch.linalg.lu_solve(LU, pivots, B, adjoint=adjoint)"
for adjoint in (True, False):
timer = Timer(stmt,
globals=globals(),
label="linalg.lu_solve CUDA{}".format(" Adjoint" if adjoint else ""),
description=label,
sub_label=f"shape {LU.shape}",
num_threads=1)
results.append(timer.blocked_autorange())
compare = Compare(results)
compare.trim_significant_figures()
compare.print()
with open("{}_lu_solve.pickle".format(name), 'wb') as f:
pickle.dump(results, f)
```
</details>
Finally, I joined all the results with the following script:
<details>
<summary>
Script to join the results
</summary>
```python
import pickle
from torch.utils.benchmark import Timer, Compare
files = [
"looped_magma",
"looped cusolver",
"batched cublas",
"batched magma",
"unpack+solve_triangular",
"heuristic",
]
timers = []
for name in files:
with open("{}_lu_solve.pickle".format(name), 'rb') as f:
timers += pickle.load(f)
compare = Compare(timers)
compare.trim_significant_figures()
compare.print()
```
</details>
### Fix for Magma's batched lu_solve when `adjoint=True`
I also developed the following fix around MAGMA's bug, but I ended up not using it, and preferring the triangular solves over it, as they were faster. I'm leaving it here in case it's useful in the future.
<details>
<summary>
Fix for MAGMA's issue with `adjoint=True`
</summary>
```cpp
auto lu_solve_batched_magma_fn = [m](const Tensor& LU, const Tensor& pivots, const Tensor& B, TransposeType trans) {
if (trans == TransposeType::NoTranspose) {
lu_solve_batched_magma(LU, pivots, B, trans);
return;
}
// There's a bug in magma for the other cases, so we need to properly perform mT or mH on LU
// The LU of the transpose is not the transpose of the LU
// We need to do LU = LDU' = L'U' where L' = LD, U' = D^{-1}U and D = diag(U)
auto diag = LU.diagonal(0, -2, -1);
auto LU_f = LU.tril(-1).mul_(diag.unsqueeze(-2)) +
LU.triu(1).div_(diag.unsqueeze(-1));
LU_f.diagonal(0, -2, -1).copy_(diag);
if (trans == TransposeType::ConjTranspose) {
LU_f = LU_f.conj_physical();
}
LU_f.transpose(-2, -1);
// At this point LU_f is F-contiguous, because triu / tril / conj_phisical return contiguous tensors
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(LU_f.mT().is_contiguous());
// Trivial permutation
auto pivots_aux = at::arange(1, m + 1, pivots.options()).expand_as(pivots).contiguous();
lu_solve_batched_magma(LU_f, pivots_aux, B, TransposeType::NoTranspose);
// We then need to multiply B by P on the right as (PLU)^T = B iff U^TL^T = BP
// Fill `perm` with the identity permutation (perhaps batched)
// This is faster than torch.lu_unpack + matmul, as this logic is borrowed from lu_unpack
const auto perm = at::arange(m, pivots.options().dtype(kLong)).expand(pivots.sizes()).contiguous();
auto iter = TensorIteratorConfig()
.set_check_mem_overlap(false)
.check_all_same_dtype(false)
.resize_outputs(false)
.declare_static_shape(pivots.sizes(), /*squash_dim=*/pivots.dim() - 1)
.add_output(perm)
.add_input(pivots)
.build();
unpack_pivots_stub(pivots.device().type(), iter, m);
B.scatter_(-2, perm.unsqueeze(-1).expand_as(B), B.clone());
};
```
</details>
Fixes #61657
[ghstack-poisoned]
|
@nikitaved added the case you requested in #74046 (comment) to this PR as it belongs here. |
This PR adds `linalg.lu_solve`. While doing so, I found a bug in MAGMA
when calling the batched MAGMA backend with trans=True. We work around
that by solving the system solving two triangular systems.
We also update the heuristics for this function, as they were fairly
outdated. We found that cuSolver is king, so luckily we do not need to
rely on the buggy backend from magma for this function.
We added tests testing this function left and right. We also added tests
for the different backends. We also activated the tests for AMD, as
those should work as well.
### Benchmarking
<details>
<summary>
Benchmark Results (adjoint=False)
</summary>
```
--------------------------------------------------------------------------------------------- linalg.lu_solve CUDA ---------------------------------------------------------------------------------------------]
| lu_solve looped_magma | lu_solve looped cusolver | lu_solve batched cublas | lu_solve batched magma | lu_solve unpack+solve_triangular | lu_solve heuristic
1 threads: -------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
shape torch.Size([1, 1, 1]) | 750 | 34 | 28 | 252 | 78 | 27
shape torch.Size([2, 1, 1]) | 1500 | 50 | 28 | 239 | 85 | 27
shape torch.Size([4, 1, 1]) | 2995 | 83 | 28 | 239 | 85 | 27
shape torch.Size([8, 1, 1]) | 6000 | 146 | 28 | 239 | 85 | 27
shape torch.Size([16, 1, 1]) | 11900 | 272 | 28 | 241 | 85 | 27
shape torch.Size([32, 1, 1]) | 23880 | 524 | 28 | 244 | 85 | 27
shape torch.Size([64, 1, 1]) | 48000 | 1000 | 28 | 245 | 85 | 27
shape torch.Size([128, 1, 1]) | 96000 | 2054 | 28 | 242 | 86 | 27
shape torch.Size([512, 1, 1]) | 381900 | 8100 | 28 | 250 | 85 | 27
shape torch.Size([1024, 1, 1]) | 763800 | 16200 | 28 | 257 | 86 | 27
shape torch.Size([1, 2, 2]) | 750 | 33 | 28 | 240 | 82 | 27
shape torch.Size([2, 2, 2]) | 1500 | 51 | 28 | 240 | 88 | 27
shape torch.Size([4, 2, 2]) | 2991 | 82 | 28 | 241 | 88 | 28
shape torch.Size([8, 2, 2]) | 6000 | 150 | 28 | 241 | 88 | 27
shape torch.Size([16, 2, 2]) | 12000 | 275 | 28 | 242 | 88 | 27
shape torch.Size([32, 2, 2]) | 23980 | 530 | 28 | 246 | 90 | 28
shape torch.Size([64, 2, 2]) | 48000 | 1000 | 28 | 244 | 89 | 27
shape torch.Size([128, 2, 2]) | 96000 | 2063 | 28 | 245 | 89 | 28
shape torch.Size([512, 2, 2]) | 382000 | 8300 | 28 | 257 | 89 | 28
shape torch.Size([1024, 2, 2]) | 764000 | 20000 | 28 | 271 | 88 | 28
shape torch.Size([1, 8, 8]) | 749 | 34 | 28 | 243 | 82 | 28
shape torch.Size([2, 8, 8]) | 1500 | 50 | 28 | 244 | 89 | 28
shape torch.Size([4, 8, 8]) | 2988 | 83 | 28 | 244 | 89 | 28
shape torch.Size([8, 8, 8]) | 5980 | 150 | 28 | 245 | 88 | 28
shape torch.Size([16, 8, 8]) | 12000 | 278 | 28 | 246 | 89 | 28
shape torch.Size([32, 8, 8]) | 23910 | 536 | 28 | 249 | 89 | 28
shape torch.Size([64, 8, 8]) | 47800 | 1100 | 28 | 247 | 89 | 28
shape torch.Size([128, 8, 8]) | 96000 | 2075 | 28 | 248 | 89 | 28
shape torch.Size([512, 8, 8]) | 382100 | 8300 | 28 | 270 | 89 | 28
shape torch.Size([1024, 8, 8]) | 764100 | 16400 | 28 | 291 | 89 | 28
shape torch.Size([1, 16, 16]) | 750 | 33 | 28 | 248 | 82 | 28
shape torch.Size([2, 16, 16]) | 1500 | 50 | 28 | 250 | 89 | 28
shape torch.Size([4, 16, 16]) | 2996 | 83 | 28 | 250 | 89 | 28
shape torch.Size([8, 16, 16]) | 5980 | 147 | 28 | 251 | 90 | 28
shape torch.Size([16, 16, 16]) | 11900 | 274 | 28 | 251 | 89 | 28
shape torch.Size([32, 16, 16]) | 24040 | 527 | 28 | 252 | 89 | 28
shape torch.Size([64, 16, 16]) | 47800 | 1037 | 28 | 251 | 89 | 28
shape torch.Size([128, 16, 16]) | 95600 | 2044 | 28 | 252 | 89 | 28
shape torch.Size([512, 16, 16]) | 388200 | 8100 | 28 | 280 | 88 | 28
shape torch.Size([1024, 16, 16]) | 769700 | 16000 | 28 | 322 | 89 | 28
shape torch.Size([1, 32, 32]) | 760 | 33 | 28 | 255 | 82 | 28
shape torch.Size([2, 32, 32]) | 1510 | 50 | 28 | 256 | 89 | 28
shape torch.Size([4, 32, 32]) | 3022 | 82 | 31 | 256 | 89 | 30
shape torch.Size([8, 32, 32]) | 6000 | 140 | 31 | 257 | 89 | 31
shape torch.Size([16, 32, 32]) | 12000 | 281 | 31 | 258 | 89 | 31
shape torch.Size([32, 32, 32]) | 24150 | 563 | 35 | 258 | 89 | 35
shape torch.Size([64, 32, 32]) | 48300 | 1119 | 36 | 258 | 90 | 36
shape torch.Size([128, 32, 32]) | 96500 | 2235 | 43 | 261 | 88 | 43
shape torch.Size([512, 32, 32]) | 383100 | 8930 | 82 | 317 | 137 | 82
shape torch.Size([1024, 32, 32]) | 766300 | 19200 | 122 | 400 | 187 | 122
shape torch.Size([1, 64, 64]) | 760 | 33 | 55 | 272 | 71 | 34
shape torch.Size([2, 64, 64]) | 1500 | 52 | 58 | 273 | 90 | 52
shape torch.Size([4, 64, 64]) | 3127 | 102 | 65 | 273 | 110 | 65
shape torch.Size([8, 64, 64]) | 6070 | 201 | 65 | 275 | 160 | 65
shape torch.Size([16, 64, 64]) | 12000 | 399 | 66 | 274 | 88 | 67
shape torch.Size([32, 64, 64]) | 23900 | 796 | 73 | 275 | 91 | 73
shape torch.Size([64, 64, 64]) | 48000 | 1594 | 75 | 283 | 94 | 76
shape torch.Size([128, 64, 64]) | 95000 | 3177 | 96 | 292 | 122 | 96
shape torch.Size([512, 64, 64]) | 379300 | 13520 | 208 | 426 | 313 | 208
shape torch.Size([1024, 64, 64]) | 758700 | 27100 | 306 | 570 | 437 | 306
shape torch.Size([1, 128, 128]) | 750 | 42 | 115 | 306 | 71 | 42
shape torch.Size([2, 128, 128]) | 1500 | 82 | 122 | 307 | 90 | 83
shape torch.Size([4, 128, 128]) | 2966 | 162 | 136 | 307 | 117 | 136
shape torch.Size([8, 128, 128]) | 5930 | 317 | 137 | 308 | 202 | 138
shape torch.Size([16, 128, 128]) | 12000 | 635 | 143 | 316 | 162 | 143
shape torch.Size([32, 128, 128]) | 23700 | 1266 | 152 | 322 | 168 | 152
shape torch.Size([64, 128, 128]) | 48000 | 2668 | 177 | 337 | 196 | 177
shape torch.Size([128, 128, 128]) | 96000 | 5366 | 228 | 365 | 260 | 228
shape torch.Size([512, 128, 128]) | 379400 | 21490 | 502 | 620 | 669 | 502
shape torch.Size([1024, 128, 128]) | 755700 | 43040 | 764 | 903 | 978 | 770
shape torch.Size([1, 256, 256]) | 750 | 70 | 235 | 383 | 78 | 72
shape torch.Size([2, 256, 256]) | 2000 | 138 | 250 | 384 | 120 | 139
shape torch.Size([4, 256, 256]) | 2988 | 277 | 279 | 404 | 191 | 278
shape torch.Size([8, 256, 256]) | 6100 | 546 | 283 | 420 | 332 | 286
shape torch.Size([16, 256, 256]) | 12100 | 1149 | 330 | 441 | 364 | 330
shape torch.Size([32, 256, 256]) | 24040 | 2303 | 359 | 453 | 401 | 360
shape torch.Size([64, 256, 256]) | 48000 | 4626 | 408 | 472 | 466 | 408
shape torch.Size([128, 256, 256]) | 94700 | 9247 | 543 | 543 | 658 | 543
shape torch.Size([512, 256, 256]) | 372000 | 37030 | 1310 | 1185 | 1896 | 1310
shape torch.Size([1024, 256, 256]) | 747200 | 74100 | 2116 | 1910 | 3051 | 2122
```
</details>
<details>
<summary>
Benchmark Results (adjoint=True)
</summary>
```
[----------------------------------------------------------------------------------------- linalg.lu_solve CUDA Adjoint -----------------------------------------------------------------------------------------]
| lu_solve looped_magma | lu_solve looped cusolver | lu_solve batched cublas | lu_solve batched magma | lu_solve unpack+solve_triangular | lu_solve heuristic
1 threads: -------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
shape torch.Size([1, 1, 1]) | 749 | 34 | 28 | 33 | 60 | 27
shape torch.Size([2, 1, 1]) | 1500 | 50 | 28 | 50 | 67 | 27
shape torch.Size([4, 1, 1]) | 3005 | 82 | 28 | 81 | 67 | 27
shape torch.Size([8, 1, 1]) | 5999 | 145 | 28 | 140 | 71 | 27
shape torch.Size([16, 1, 1]) | 12000 | 273 | 28 | 77 | 67 | 27
shape torch.Size([32, 1, 1]) | 24000 | 522 | 28 | 78 | 67 | 27
shape torch.Size([64, 1, 1]) | 48000 | 1000 | 28 | 77 | 67 | 27
shape torch.Size([128, 1, 1]) | 96000 | 2029 | 28 | 78 | 67 | 27
shape torch.Size([512, 1, 1]) | 383300 | 8100 | 28 | 78 | 67 | 28
shape torch.Size([1024, 1, 1]) | 767500 | 16100 | 28 | 77 | 68 | 27
shape torch.Size([1, 2, 2]) | 753 | 33 | 28 | 33 | 60 | 28
shape torch.Size([2, 2, 2]) | 1500 | 50 | 28 | 50 | 67 | 28
shape torch.Size([4, 2, 2]) | 3002 | 82 | 28 | 80 | 67 | 27
shape torch.Size([8, 2, 2]) | 6000 | 145 | 28 | 144 | 67 | 27
shape torch.Size([16, 2, 2]) | 12000 | 271 | 28 | 78 | 66 | 27
shape torch.Size([32, 2, 2]) | 24120 | 524 | 28 | 78 | 69 | 28
shape torch.Size([64, 2, 2]) | 48300 | 1030 | 28 | 78 | 66 | 27
shape torch.Size([128, 2, 2]) | 96100 | 2041 | 28 | 78 | 67 | 28
shape torch.Size([512, 2, 2]) | 383000 | 8100 | 28 | 79 | 67 | 28
shape torch.Size([1024, 2, 2]) | 766100 | 16000 | 28 | 78 | 67 | 28
shape torch.Size([1, 8, 8]) | 750 | 34 | 28 | 34 | 60 | 28
shape torch.Size([2, 8, 8]) | 1500 | 50 | 28 | 50 | 67 | 28
shape torch.Size([4, 8, 8]) | 2998 | 82 | 28 | 82 | 67 | 28
shape torch.Size([8, 8, 8]) | 5990 | 146 | 28 | 150 | 66 | 28
shape torch.Size([16, 8, 8]) | 11980 | 272 | 28 | 79 | 67 | 28
shape torch.Size([32, 8, 8]) | 23970 | 530 | 28 | 79 | 67 | 28
shape torch.Size([64, 8, 8]) | 47900 | 1040 | 28 | 79 | 67 | 28
shape torch.Size([128, 8, 8]) | 96000 | 2048 | 28 | 78 | 67 | 28
shape torch.Size([512, 8, 8]) | 383700 | 8100 | 28 | 80 | 67 | 28
shape torch.Size([1024, 8, 8]) | 766200 | 16300 | 28 | 80 | 68 | 28
shape torch.Size([1, 16, 16]) | 760 | 33 | 28 | 34 | 60 | 28
shape torch.Size([2, 16, 16]) | 1500 | 50 | 28 | 50 | 67 | 28
shape torch.Size([4, 16, 16]) | 3001 | 81 | 28 | 82 | 67 | 28
shape torch.Size([8, 16, 16]) | 6000 | 145 | 28 | 140 | 67 | 28
shape torch.Size([16, 16, 16]) | 12000 | 276 | 28 | 79 | 67 | 28
shape torch.Size([32, 16, 16]) | 23870 | 549 | 28 | 79 | 67 | 28
shape torch.Size([64, 16, 16]) | 47900 | 1098 | 29 | 80 | 68 | 28
shape torch.Size([128, 16, 16]) | 95800 | 2184 | 28 | 79 | 68 | 28
shape torch.Size([512, 16, 16]) | 386900 | 8769 | 28 | 80 | 67 | 28
shape torch.Size([1024, 16, 16]) | 769800 | 17460 | 37 | 80 | 67 | 37
shape torch.Size([1, 32, 32]) | 760 | 33 | 28 | 34 | 60 | 28
shape torch.Size([2, 32, 32]) | 1500 | 50 | 28 | 50 | 67 | 29
shape torch.Size([4, 32, 32]) | 3021 | 86 | 31 | 84 | 67 | 32
shape torch.Size([8, 32, 32]) | 6040 | 167 | 32 | 167 | 67 | 32
shape torch.Size([16, 32, 32]) | 12100 | 330 | 33 | 78 | 67 | 33
shape torch.Size([32, 32, 32]) | 24150 | 662 | 35 | 78 | 66 | 35
shape torch.Size([64, 32, 32]) | 48200 | 1323 | 36 | 79 | 67 | 36
shape torch.Size([128, 32, 32]) | 97000 | 2637 | 44 | 79 | 67 | 43
shape torch.Size([512, 32, 32]) | 382500 | 10580 | 83 | 180 | 121 | 83
shape torch.Size([1024, 32, 32]) | 766600 | 22670 | 123 | 260 | 165 | 120
shape torch.Size([1, 64, 64]) | 760 | 33 | 58 | 33 | 49 | 34
shape torch.Size([2, 64, 64]) | 1520 | 60 | 60 | 60 | 67 | 59
shape torch.Size([4, 64, 64]) | 3016 | 115 | 67 | 119 | 90 | 66
shape torch.Size([8, 64, 64]) | 6120 | 230 | 67 | 233 | 153 | 68
shape torch.Size([16, 64, 64]) | 12100 | 457 | 69 | 86 | 77 | 69
shape torch.Size([32, 64, 64]) | 24000 | 912 | 74 | 95 | 80 | 74
shape torch.Size([64, 64, 64]) | 48000 | 1833 | 76 | 106 | 82 | 76
shape torch.Size([128, 64, 64]) | 95000 | 3636 | 97 | 163 | 108 | 97
shape torch.Size([512, 64, 64]) | 380600 | 15600 | 210 | 464 | 278 | 210
shape torch.Size([1024, 64, 64]) | 761200 | 31140 | 308 | 741 | 377 | 308
shape torch.Size([1, 128, 128]) | 756 | 46 | 120 | 47 | 49 | 46
shape torch.Size([2, 128, 128]) | 1500 | 91 | 123 | 89 | 67 | 92
shape torch.Size([4, 128, 128]) | 2994 | 178 | 139 | 180 | 117 | 139
shape torch.Size([8, 128, 128]) | 5960 | 350 | 140 | 354 | 214 | 142
shape torch.Size([16, 128, 128]) | 12000 | 701 | 144 | 177 | 148 | 143
shape torch.Size([32, 128, 128]) | 23870 | 1401 | 155 | 225 | 159 | 155
shape torch.Size([64, 128, 128]) | 47600 | 2948 | 179 | 288 | 184 | 180
shape torch.Size([128, 128, 128]) | 96000 | 5910 | 231 | 442 | 242 | 231
shape torch.Size([512, 128, 128]) | 381200 | 23640 | 519 | 1400 | 643 | 519
shape torch.Size([1024, 128, 128]) | 755800 | 47340 | 794 | 2436 | 930 | 794
shape torch.Size([1, 256, 256]) | 760 | 74 | 246 | 77 | 62 | 78
shape torch.Size([2, 256, 256]) | 1510 | 150 | 256 | 150 | 106 | 150
shape torch.Size([4, 256, 256]) | 3030 | 296 | 284 | 296 | 183 | 284
shape torch.Size([8, 256, 256]) | 6100 | 588 | 286 | 592 | 337 | 288
shape torch.Size([16, 256, 256]) | 12200 | 1238 | 330 | 445 | 330 | 330
shape torch.Size([32, 256, 256]) | 24430 | 2476 | 368 | 568 | 365 | 367
shape torch.Size([64, 256, 256]) | 49000 | 4950 | 415 | 800 | 408 | 414
shape torch.Size([128, 256, 256]) | 96000 | 9900 | 552 | 1330 | 561 | 553
shape torch.Size([512, 256, 256]) | 369400 | 39580 | 1410 | 4614 | 1649 | 1410
shape torch.Size([1024, 256, 256]) | 716200 | 79200 | 2270 | 8472 | 2528 | 2277
Times are in microseconds (us).
```
</details>
To generate the results below, I put the backend I wanted to test at the beginning of the function `lu_solve_kernel`, followed by a `return;`. Then I run the following script, changing the variable `name`.
<details>
<summary>
Benchmarking script
</summary>
```python
import torch
import pickle
import itertools
from functools import partial
from torch.utils.benchmark import Timer, Compare
name = "heuristic"
label = "lu_solve {}".format(name)
shapes = [1, 2, 8, 16, 32, 64, 128, 256]
batches = [(1,), (2,), (4,), (8,), (16,), (32,), (64,), (128,), (512,), (1024,)]
results = []
make_arg = partial(torch.randn, dtype=torch.float32, device="cuda")
for n, batch in itertools.product(shapes, batches):
LU, pivots = torch.linalg.lu_factor(make_arg(batch + (n, n)))
B = make_arg(batch + (n, 1))
print(LU.shape)
stmt = "torch.linalg.lu_solve(LU, pivots, B, adjoint=adjoint)"
for adjoint in (True, False):
timer = Timer(stmt,
globals=globals(),
label="linalg.lu_solve CUDA{}".format(" Adjoint" if adjoint else ""),
description=label,
sub_label=f"shape {LU.shape}",
num_threads=1)
results.append(timer.blocked_autorange())
compare = Compare(results)
compare.trim_significant_figures()
compare.print()
with open("{}_lu_solve.pickle".format(name), 'wb') as f:
pickle.dump(results, f)
```
</details>
Finally, I joined all the results with the following script:
<details>
<summary>
Script to join the results
</summary>
```python
import pickle
from torch.utils.benchmark import Timer, Compare
files = [
"looped_magma",
"looped cusolver",
"batched cublas",
"batched magma",
"unpack+solve_triangular",
"heuristic",
]
timers = []
for name in files:
with open("{}_lu_solve.pickle".format(name), 'rb') as f:
timers += pickle.load(f)
compare = Compare(timers)
compare.trim_significant_figures()
compare.print()
```
</details>
### Fix for Magma's batched lu_solve when `adjoint=True`
I also developed the following fix around MAGMA's bug, but I ended up not using it, and preferring the triangular solves over it, as they were faster. I'm leaving it here in case it's useful in the future.
<details>
<summary>
Fix for MAGMA's issue with `adjoint=True`
</summary>
```cpp
auto lu_solve_batched_magma_fn = [m](const Tensor& LU, const Tensor& pivots, const Tensor& B, TransposeType trans) {
if (trans == TransposeType::NoTranspose) {
lu_solve_batched_magma(LU, pivots, B, trans);
return;
}
// There's a bug in magma for the other cases, so we need to properly perform mT or mH on LU
// The LU of the transpose is not the transpose of the LU
// We need to do LU = LDU' = L'U' where L' = LD, U' = D^{-1}U and D = diag(U)
auto diag = LU.diagonal(0, -2, -1);
auto LU_f = LU.tril(-1).mul_(diag.unsqueeze(-2)) +
LU.triu(1).div_(diag.unsqueeze(-1));
LU_f.diagonal(0, -2, -1).copy_(diag);
if (trans == TransposeType::ConjTranspose) {
LU_f = LU_f.conj_physical();
}
LU_f.transpose(-2, -1);
// At this point LU_f is F-contiguous, because triu / tril / conj_phisical return contiguous tensors
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(LU_f.mT().is_contiguous());
// Trivial permutation
auto pivots_aux = at::arange(1, m + 1, pivots.options()).expand_as(pivots).contiguous();
lu_solve_batched_magma(LU_f, pivots_aux, B, TransposeType::NoTranspose);
// We then need to multiply B by P on the right as (PLU)^T = B iff U^TL^T = BP
// Fill `perm` with the identity permutation (perhaps batched)
// This is faster than torch.lu_unpack + matmul, as this logic is borrowed from lu_unpack
const auto perm = at::arange(m, pivots.options().dtype(kLong)).expand(pivots.sizes()).contiguous();
auto iter = TensorIteratorConfig()
.set_check_mem_overlap(false)
.check_all_same_dtype(false)
.resize_outputs(false)
.declare_static_shape(pivots.sizes(), /*squash_dim=*/pivots.dim() - 1)
.add_output(perm)
.add_input(pivots)
.build();
unpack_pivots_stub(pivots.device().type(), iter, m);
B.scatter_(-2, perm.unsqueeze(-1).expand_as(B), B.clone());
};
```
</details>
Fixes #61657
[ghstack-poisoned]
|
@pytorchmergebot merge this please |
|
Going to revert this change and bunch of changes depending on it, as it breaks internal build system with the following linker error: |
|
What can be done about this @malfet ? |
|
@lezcano I'm going to revert as I could not figure out the forward fix correctly, but I assume the problem stems from the fact that |
|
@pytorchbot revert this, as it breaks internal builds with It should have affected OSS build as well, investigating why it is not the case |
|
@malfet any news on this? It'd be nice to merge this stack (which is mostly approved already) before the branch cut, as it gives a x2.5-x10 speed-up to |
|
@lezcano let me try to import this change and see if it still have the same issue |
|
@lezcano this PR can not be imported as |
|
Relanding in #77634 |
Stack from ghstack:
This PR adds
linalg.lu_solve. While doing so, I found a bug in MAGMAwhen calling the batched MAGMA backend with trans=True. We work around
that by solving the system solving two triangular systems.
We also update the heuristics for this function, as they were fairly
outdated. We found that cuSolver is king, so luckily we do not need to
rely on the buggy backend from magma for this function.
We added tests testing this function left and right. We also added tests
for the different backends. We also activated the tests for AMD, as
those should work as well.
Benchmarking
Benchmark Results (adjoint=False)
Benchmark Results (adjoint=True)
To generate the results below, I put the backend I wanted to test at the beginning of the function
lu_solve_kernel, followed by areturn;. Then I run the following script, changing the variablename.Benchmarking script
Finally, I joined all the results with the following script:
Script to join the results
Fix for Magma's batched lu_solve when
adjoint=TrueI also developed the following fix around MAGMA's bug, but I ended up not using it, and preferring the triangular solves over it, as they were faster. I'm leaving it here in case it's useful in the future.
Fix for MAGMA's issue with `adjoint=True`
Fixes #61657