Skip to content

Add linalg.lu_solve#72935

Closed
lezcano wants to merge 33 commits intogh/Lezcano/47/basefrom
gh/Lezcano/47/head
Closed

Add linalg.lu_solve#72935
lezcano wants to merge 33 commits intogh/Lezcano/47/basefrom
gh/Lezcano/47/head

Conversation

@lezcano
Copy link
Copy Markdown
Collaborator

@lezcano lezcano commented Feb 16, 2022

Stack from ghstack:

This PR adds linalg.lu_solve. While doing so, I found a bug in MAGMA
when calling the batched MAGMA backend with trans=True. We work around
that by solving the system solving two triangular systems.

We also update the heuristics for this function, as they were fairly
outdated. We found that cuSolver is king, so luckily we do not need to
rely on the buggy backend from magma for this function.

We added tests testing this function left and right. We also added tests
for the different backends. We also activated the tests for AMD, as
those should work as well.

Benchmarking

Benchmark Results (adjoint=False)
--------------------------------------------------------------------------------------------- linalg.lu_solve CUDA ---------------------------------------------------------------------------------------------]   
                                          |  lu_solve looped_magma  |  lu_solve looped cusolver  |  lu_solve batched cublas  |  lu_solve batched magma  |  lu_solve unpack+solve_triangular  |  lu_solve heuristic   
1 threads: -------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------   
      shape torch.Size([1, 1, 1])         |             750         |              34            |              28           |            252           |                  78                |           27          
      shape torch.Size([2, 1, 1])         |            1500         |              50            |              28           |            239           |                  85                |           27          
      shape torch.Size([4, 1, 1])         |            2995         |              83            |              28           |            239           |                  85                |           27          
      shape torch.Size([8, 1, 1])         |            6000         |             146            |              28           |            239           |                  85                |           27          
      shape torch.Size([16, 1, 1])        |           11900         |             272            |              28           |            241           |                  85                |           27          
      shape torch.Size([32, 1, 1])        |           23880         |             524            |              28           |            244           |                  85                |           27          
      shape torch.Size([64, 1, 1])        |           48000         |            1000            |              28           |            245           |                  85                |           27          
      shape torch.Size([128, 1, 1])       |           96000         |            2054            |              28           |            242           |                  86                |           27          
      shape torch.Size([512, 1, 1])       |          381900         |            8100            |              28           |            250           |                  85                |           27          
      shape torch.Size([1024, 1, 1])      |          763800         |           16200            |              28           |            257           |                  86                |           27          
      shape torch.Size([1, 2, 2])         |             750         |              33            |              28           |            240           |                  82                |           27          
      shape torch.Size([2, 2, 2])         |            1500         |              51            |              28           |            240           |                  88                |           27          
      shape torch.Size([4, 2, 2])         |            2991         |              82            |              28           |            241           |                  88                |           28          
      shape torch.Size([8, 2, 2])         |            6000         |             150            |              28           |            241           |                  88                |           27          
      shape torch.Size([16, 2, 2])        |           12000         |             275            |              28           |            242           |                  88                |           27          
      shape torch.Size([32, 2, 2])        |           23980         |             530            |              28           |            246           |                  90                |           28          
      shape torch.Size([64, 2, 2])        |           48000         |            1000            |              28           |            244           |                  89                |           27          
      shape torch.Size([128, 2, 2])       |           96000         |            2063            |              28           |            245           |                  89                |           28          
      shape torch.Size([512, 2, 2])       |          382000         |            8300            |              28           |            257           |                  89                |           28          
      shape torch.Size([1024, 2, 2])      |          764000         |           20000            |              28           |            271           |                  88                |           28          
      shape torch.Size([1, 8, 8])         |             749         |              34            |              28           |            243           |                  82                |           28          
      shape torch.Size([2, 8, 8])         |            1500         |              50            |              28           |            244           |                  89                |           28          
      shape torch.Size([4, 8, 8])         |            2988         |              83            |              28           |            244           |                  89                |           28          
      shape torch.Size([8, 8, 8])         |            5980         |             150            |              28           |            245           |                  88                |           28          
      shape torch.Size([16, 8, 8])        |           12000         |             278            |              28           |            246           |                  89                |           28          
      shape torch.Size([32, 8, 8])        |           23910         |             536            |              28           |            249           |                  89                |           28          
      shape torch.Size([64, 8, 8])        |           47800         |            1100            |              28           |            247           |                  89                |           28          
      shape torch.Size([128, 8, 8])       |           96000         |            2075            |              28           |            248           |                  89                |           28          
      shape torch.Size([512, 8, 8])       |          382100         |            8300            |              28           |            270           |                  89                |           28          
      shape torch.Size([1024, 8, 8])      |          764100         |           16400            |              28           |            291           |                  89                |           28          
      shape torch.Size([1, 16, 16])       |             750         |              33            |              28           |            248           |                  82                |           28       
      shape torch.Size([2, 16, 16])       |            1500         |              50            |              28           |            250           |                  89                |           28       
      shape torch.Size([4, 16, 16])       |            2996         |              83            |              28           |            250           |                  89                |           28       
      shape torch.Size([8, 16, 16])       |            5980         |             147            |              28           |            251           |                  90                |           28       
      shape torch.Size([16, 16, 16])      |           11900         |             274            |              28           |            251           |                  89                |           28       
      shape torch.Size([32, 16, 16])      |           24040         |             527            |              28           |            252           |                  89                |           28       
      shape torch.Size([64, 16, 16])      |           47800         |            1037            |              28           |            251           |                  89                |           28       
      shape torch.Size([128, 16, 16])     |           95600         |            2044            |              28           |            252           |                  89                |           28       
      shape torch.Size([512, 16, 16])     |          388200         |            8100            |              28           |            280           |                  88                |           28       
      shape torch.Size([1024, 16, 16])    |          769700         |           16000            |              28           |            322           |                  89                |           28       
      shape torch.Size([1, 32, 32])       |             760         |              33            |              28           |            255           |                  82                |           28       
      shape torch.Size([2, 32, 32])       |            1510         |              50            |              28           |            256           |                  89                |           28       
      shape torch.Size([4, 32, 32])       |            3022         |              82            |              31           |            256           |                  89                |           30       
      shape torch.Size([8, 32, 32])       |            6000         |             140            |              31           |            257           |                  89                |           31       
      shape torch.Size([16, 32, 32])      |           12000         |             281            |              31           |            258           |                  89                |           31       
      shape torch.Size([32, 32, 32])      |           24150         |             563            |              35           |            258           |                  89                |           35       
      shape torch.Size([64, 32, 32])      |           48300         |            1119            |              36           |            258           |                  90                |           36       
      shape torch.Size([128, 32, 32])     |           96500         |            2235            |              43           |            261           |                  88                |           43       
      shape torch.Size([512, 32, 32])     |          383100         |            8930            |              82           |            317           |                 137                |           82       
      shape torch.Size([1024, 32, 32])    |          766300         |           19200            |             122           |            400           |                 187                |          122       
      shape torch.Size([1, 64, 64])       |             760         |              33            |              55           |            272           |                  71                |           34       
      shape torch.Size([2, 64, 64])       |            1500         |              52            |              58           |            273           |                  90                |           52       
      shape torch.Size([4, 64, 64])       |            3127         |             102            |              65           |            273           |                 110                |           65       
      shape torch.Size([8, 64, 64])       |            6070         |             201            |              65           |            275           |                 160                |           65       
      shape torch.Size([16, 64, 64])      |           12000         |             399            |              66           |            274           |                  88                |           67       
      shape torch.Size([32, 64, 64])      |           23900         |             796            |              73           |            275           |                  91                |           73       
      shape torch.Size([64, 64, 64])      |           48000         |            1594            |              75           |            283           |                  94                |           76       
      shape torch.Size([128, 64, 64])     |           95000         |            3177            |              96           |            292           |                 122                |           96       
      shape torch.Size([512, 64, 64])     |          379300         |           13520            |             208           |            426           |                 313                |          208       
      shape torch.Size([1024, 64, 64])    |          758700         |           27100            |             306           |            570           |                 437                |          306       
      shape torch.Size([1, 128, 128])     |             750         |              42            |             115           |            306           |                  71                |           42       
      shape torch.Size([2, 128, 128])     |            1500         |              82            |             122           |            307           |                  90                |           83       
      shape torch.Size([4, 128, 128])     |            2966         |             162            |             136           |            307           |                 117                |          136       
      shape torch.Size([8, 128, 128])     |            5930         |             317            |             137           |            308           |                 202                |          138       
      shape torch.Size([16, 128, 128])    |           12000         |             635            |             143           |            316           |                 162                |          143       
      shape torch.Size([32, 128, 128])    |           23700         |            1266            |             152           |            322           |                 168                |          152       
      shape torch.Size([64, 128, 128])    |           48000         |            2668            |             177           |            337           |                 196                |          177       
      shape torch.Size([128, 128, 128])   |           96000         |            5366            |             228           |            365           |                 260                |          228       
      shape torch.Size([512, 128, 128])   |          379400         |           21490            |             502           |            620           |                 669                |          502       
      shape torch.Size([1024, 128, 128])  |          755700         |           43040            |             764           |            903           |                 978                |          770       
      shape torch.Size([1, 256, 256])     |             750         |              70            |             235           |            383           |                  78                |           72       
      shape torch.Size([2, 256, 256])     |            2000         |             138            |             250           |            384           |                 120                |          139       
      shape torch.Size([4, 256, 256])     |            2988         |             277            |             279           |            404           |                 191                |          278       
      shape torch.Size([8, 256, 256])     |            6100         |             546            |             283           |            420           |                 332                |          286       
      shape torch.Size([16, 256, 256])    |           12100         |            1149            |             330           |            441           |                 364                |          330       
      shape torch.Size([32, 256, 256])    |           24040         |            2303            |             359           |            453           |                 401                |          360       
      shape torch.Size([64, 256, 256])    |           48000         |            4626            |             408           |            472           |                 466                |          408       
      shape torch.Size([128, 256, 256])   |           94700         |            9247            |             543           |            543           |                 658                |          543       
      shape torch.Size([512, 256, 256])   |          372000         |           37030            |            1310           |           1185           |                1896                |         1310       
      shape torch.Size([1024, 256, 256])  |          747200         |           74100            |            2116           |           1910           |                3051                |         2122 
Benchmark Results (adjoint=True)
[----------------------------------------------------------------------------------------- linalg.lu_solve CUDA Adjoint -----------------------------------------------------------------------------------------]   
                                          |  lu_solve looped_magma  |  lu_solve looped cusolver  |  lu_solve batched cublas  |  lu_solve batched magma  |  lu_solve unpack+solve_triangular  |  lu_solve heuristic   
1 threads: -------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------   
      shape torch.Size([1, 1, 1])         |             749         |              34            |              28           |             33           |                  60                |           27          
      shape torch.Size([2, 1, 1])         |            1500         |              50            |              28           |             50           |                  67                |           27          
      shape torch.Size([4, 1, 1])         |            3005         |              82            |              28           |             81           |                  67                |           27          
      shape torch.Size([8, 1, 1])         |            5999         |             145            |              28           |            140           |                  71                |           27          
      shape torch.Size([16, 1, 1])        |           12000         |             273            |              28           |             77           |                  67                |           27          
      shape torch.Size([32, 1, 1])        |           24000         |             522            |              28           |             78           |                  67                |           27          
      shape torch.Size([64, 1, 1])        |           48000         |            1000            |              28           |             77           |                  67                |           27          
      shape torch.Size([128, 1, 1])       |           96000         |            2029            |              28           |             78           |                  67                |           27          
      shape torch.Size([512, 1, 1])       |          383300         |            8100            |              28           |             78           |                  67                |           28          
      shape torch.Size([1024, 1, 1])      |          767500         |           16100            |              28           |             77           |                  68                |           27          
      shape torch.Size([1, 2, 2])         |             753         |              33            |              28           |             33           |                  60                |           28          
      shape torch.Size([2, 2, 2])         |            1500         |              50            |              28           |             50           |                  67                |           28          
      shape torch.Size([4, 2, 2])         |            3002         |              82            |              28           |             80           |                  67                |           27          
      shape torch.Size([8, 2, 2])         |            6000         |             145            |              28           |            144           |                  67                |           27          
      shape torch.Size([16, 2, 2])        |           12000         |             271            |              28           |             78           |                  66                |           27          
      shape torch.Size([32, 2, 2])        |           24120         |             524            |              28           |             78           |                  69                |           28          
      shape torch.Size([64, 2, 2])        |           48300         |            1030            |              28           |             78           |                  66                |           27          
      shape torch.Size([128, 2, 2])       |           96100         |            2041            |              28           |             78           |                  67                |           28          
      shape torch.Size([512, 2, 2])       |          383000         |            8100            |              28           |             79           |                  67                |           28          
      shape torch.Size([1024, 2, 2])      |          766100         |           16000            |              28           |             78           |                  67                |           28          
      shape torch.Size([1, 8, 8])         |             750         |              34            |              28           |             34           |                  60                |           28          
      shape torch.Size([2, 8, 8])         |            1500         |              50            |              28           |             50           |                  67                |           28          
      shape torch.Size([4, 8, 8])         |            2998         |              82            |              28           |             82           |                  67                |           28          
      shape torch.Size([8, 8, 8])         |            5990         |             146            |              28           |            150           |                  66                |           28          
      shape torch.Size([16, 8, 8])        |           11980         |             272            |              28           |             79           |                  67                |           28          
      shape torch.Size([32, 8, 8])        |           23970         |             530            |              28           |             79           |                  67                |           28          
      shape torch.Size([64, 8, 8])        |           47900         |            1040            |              28           |             79           |                  67                |           28          
      shape torch.Size([128, 8, 8])       |           96000         |            2048            |              28           |             78           |                  67                |           28          
      shape torch.Size([512, 8, 8])       |          383700         |            8100            |              28           |             80           |                  67                |           28          
      shape torch.Size([1024, 8, 8])      |          766200         |           16300            |              28           |             80           |                  68                |           28          
      shape torch.Size([1, 16, 16])       |             760         |              33            |              28           |             34           |                  60                |           28          
      shape torch.Size([2, 16, 16])       |            1500         |              50            |              28           |             50           |                  67                |           28          
      shape torch.Size([4, 16, 16])       |            3001         |              81            |              28           |             82           |                  67                |           28          
      shape torch.Size([8, 16, 16])       |            6000         |             145            |              28           |            140           |                  67                |           28          
      shape torch.Size([16, 16, 16])      |           12000         |             276            |              28           |             79           |                  67                |           28          
      shape torch.Size([32, 16, 16])      |           23870         |             549            |              28           |             79           |                  67                |           28          
      shape torch.Size([64, 16, 16])      |           47900         |            1098            |              29           |             80           |                  68                |           28          
      shape torch.Size([128, 16, 16])     |           95800         |            2184            |              28           |             79           |                  68                |           28          
      shape torch.Size([512, 16, 16])     |          386900         |            8769            |              28           |             80           |                  67                |           28          
      shape torch.Size([1024, 16, 16])    |          769800         |           17460            |              37           |             80           |                  67                |           37          
      shape torch.Size([1, 32, 32])       |             760         |              33            |              28           |             34           |                  60                |           28          
      shape torch.Size([2, 32, 32])       |            1500         |              50            |              28           |             50           |                  67                |           29          
      shape torch.Size([4, 32, 32])       |            3021         |              86            |              31           |             84           |                  67                |           32          
      shape torch.Size([8, 32, 32])       |            6040         |             167            |              32           |            167           |                  67                |           32          
      shape torch.Size([16, 32, 32])      |           12100         |             330            |              33           |             78           |                  67                |           33          
      shape torch.Size([32, 32, 32])      |           24150         |             662            |              35           |             78           |                  66                |           35          
      shape torch.Size([64, 32, 32])      |           48200         |            1323            |              36           |             79           |                  67                |           36          
      shape torch.Size([128, 32, 32])     |           97000         |            2637            |              44           |             79           |                  67                |           43          
      shape torch.Size([512, 32, 32])     |          382500         |           10580            |              83           |            180           |                 121                |           83          
      shape torch.Size([1024, 32, 32])    |          766600         |           22670            |             123           |            260           |                 165                |          120        
      shape torch.Size([1, 64, 64])       |             760         |              33            |              58           |             33           |                  49                |           34       
      shape torch.Size([2, 64, 64])       |            1520         |              60            |              60           |             60           |                  67                |           59       
      shape torch.Size([4, 64, 64])       |            3016         |             115            |              67           |            119           |                  90                |           66       
      shape torch.Size([8, 64, 64])       |            6120         |             230            |              67           |            233           |                 153                |           68       
      shape torch.Size([16, 64, 64])      |           12100         |             457            |              69           |             86           |                  77                |           69       
      shape torch.Size([32, 64, 64])      |           24000         |             912            |              74           |             95           |                  80                |           74       
      shape torch.Size([64, 64, 64])      |           48000         |            1833            |              76           |            106           |                  82                |           76       
      shape torch.Size([128, 64, 64])     |           95000         |            3636            |              97           |            163           |                 108                |           97       
      shape torch.Size([512, 64, 64])     |          380600         |           15600            |             210           |            464           |                 278                |          210       
      shape torch.Size([1024, 64, 64])    |          761200         |           31140            |             308           |            741           |                 377                |          308       
      shape torch.Size([1, 128, 128])     |             756         |              46            |             120           |             47           |                  49                |           46       
      shape torch.Size([2, 128, 128])     |            1500         |              91            |             123           |             89           |                  67                |           92       
      shape torch.Size([4, 128, 128])     |            2994         |             178            |             139           |            180           |                 117                |          139       
      shape torch.Size([8, 128, 128])     |            5960         |             350            |             140           |            354           |                 214                |          142       
      shape torch.Size([16, 128, 128])    |           12000         |             701            |             144           |            177           |                 148                |          143       
      shape torch.Size([32, 128, 128])    |           23870         |            1401            |             155           |            225           |                 159                |          155       
      shape torch.Size([64, 128, 128])    |           47600         |            2948            |             179           |            288           |                 184                |          180       
      shape torch.Size([128, 128, 128])   |           96000         |            5910            |             231           |            442           |                 242                |          231       
      shape torch.Size([512, 128, 128])   |          381200         |           23640            |             519           |           1400           |                 643                |          519       
      shape torch.Size([1024, 128, 128])  |          755800         |           47340            |             794           |           2436           |                 930                |          794       
      shape torch.Size([1, 256, 256])     |             760         |              74            |             246           |             77           |                  62                |           78       
      shape torch.Size([2, 256, 256])     |            1510         |             150            |             256           |            150           |                 106                |          150       
      shape torch.Size([4, 256, 256])     |            3030         |             296            |             284           |            296           |                 183                |          284       
      shape torch.Size([8, 256, 256])     |            6100         |             588            |             286           |            592           |                 337                |          288       
      shape torch.Size([16, 256, 256])    |           12200         |            1238            |             330           |            445           |                 330                |          330       
      shape torch.Size([32, 256, 256])    |           24430         |            2476            |             368           |            568           |                 365                |          367       
      shape torch.Size([64, 256, 256])    |           49000         |            4950            |             415           |            800           |                 408                |          414       
      shape torch.Size([128, 256, 256])   |           96000         |            9900            |             552           |           1330           |                 561                |          553       
      shape torch.Size([512, 256, 256])   |          369400         |           39580            |            1410           |           4614           |                1649                |         1410       
      shape torch.Size([1024, 256, 256])  |          716200         |           79200            |            2270           |           8472           |                2528                |         2277      

Times are in microseconds (us).

To generate the results below, I put the backend I wanted to test at the beginning of the function lu_solve_kernel, followed by a return;. Then I run the following script, changing the variable name.

Benchmarking script
import torch
import pickle
import itertools
from functools import partial
from torch.utils.benchmark import Timer, Compare

name = "heuristic"
label = "lu_solve {}".format(name)
shapes = [1, 2, 8, 16, 32, 64, 128, 256]
batches = [(1,), (2,), (4,), (8,), (16,), (32,), (64,), (128,), (512,), (1024,)]
results = []
make_arg = partial(torch.randn, dtype=torch.float32, device="cuda")

for n, batch in itertools.product(shapes, batches):
    LU, pivots = torch.linalg.lu_factor(make_arg(batch + (n, n)))
    B = make_arg(batch + (n, 1))
    print(LU.shape)
    stmt = "torch.linalg.lu_solve(LU, pivots, B, adjoint=adjoint)"
    for adjoint in (True, False):
        timer = Timer(stmt,
                      globals=globals(),
                      label="linalg.lu_solve CUDA{}".format(" Adjoint" if adjoint else ""),
                      description=label,
                      sub_label=f"shape {LU.shape}",
                      num_threads=1)
        results.append(timer.blocked_autorange())


compare = Compare(results)
compare.trim_significant_figures()
compare.print()

with open("{}_lu_solve.pickle".format(name), 'wb') as f:
    pickle.dump(results, f)

Finally, I joined all the results with the following script:

Script to join the results
import pickle
from torch.utils.benchmark import Timer, Compare

files = [
         "looped_magma",
         "looped cusolver",
         "batched cublas",
         "batched magma",
         "unpack+solve_triangular",
         "heuristic",
        ]

timers = []
for name in files:
    with open("{}_lu_solve.pickle".format(name), 'rb') as f:
        timers += pickle.load(f)

compare = Compare(timers)
compare.trim_significant_figures()
compare.print()

Fix for Magma's batched lu_solve when adjoint=True

I also developed the following fix around MAGMA's bug, but I ended up not using it, and preferring the triangular solves over it, as they were faster. I'm leaving it here in case it's useful in the future.

Fix for MAGMA's issue with `adjoint=True`
auto lu_solve_batched_magma_fn = [m](const Tensor& LU, const Tensor& pivots, const Tensor& B, TransposeType trans) {
    if (trans == TransposeType::NoTranspose) {
      lu_solve_batched_magma(LU, pivots, B, trans);
      return;
    }
    // There's a bug in magma for the other cases, so we need to properly perform mT or mH on LU
    // The LU of the transpose is not the transpose of the LU
    // We need to do LU = LDU' = L'U' where L' = LD, U' = D^{-1}U and D = diag(U)
    auto diag = LU.diagonal(0, -2, -1);
    auto LU_f = LU.tril(-1).mul_(diag.unsqueeze(-2)) +
                LU.triu(1).div_(diag.unsqueeze(-1));
    LU_f.diagonal(0, -2, -1).copy_(diag);

    if (trans == TransposeType::ConjTranspose) {
      LU_f = LU_f.conj_physical();
    }
    LU_f.transpose(-2, -1);
    // At this point LU_f is F-contiguous, because triu / tril / conj_phisical return contiguous tensors
    TORCH_INTERNAL_ASSERT_DEBUG_ONLY(LU_f.mT().is_contiguous());

    // Trivial permutation
    auto pivots_aux = at::arange(1, m + 1, pivots.options()).expand_as(pivots).contiguous();

    lu_solve_batched_magma(LU_f, pivots_aux, B, TransposeType::NoTranspose);

    // We then need to multiply B by P on the right as (PLU)^T = B iff U^TL^T = BP
    // Fill `perm` with the identity permutation (perhaps batched)
    // This is faster than torch.lu_unpack + matmul, as this logic is borrowed from lu_unpack
    const auto perm = at::arange(m, pivots.options().dtype(kLong)).expand(pivots.sizes()).contiguous();
    auto iter = TensorIteratorConfig()
      .set_check_mem_overlap(false)
      .check_all_same_dtype(false)
      .resize_outputs(false)
      .declare_static_shape(pivots.sizes(), /*squash_dim=*/pivots.dim() - 1)
      .add_output(perm)
      .add_input(pivots)
      .build();
    unpack_pivots_stub(pivots.device().type(), iter, m);
    B.scatter_(-2, perm.unsqueeze(-1).expand_as(B), B.clone());
};

Fixes #61657

This PR adds `linalg.lu_solve`. While doing so, I found a bug in MAGMA
when calling the batched MAGMA backend with trans=True. We work around
that by solving the system solving two triangular systems.

We also update the heuristics for this function, as they were fairly
updated. We found that cuSolver is king, so luckily we do not need to
rely on the buggy backend from magma for this function.

We added tests testing this function left and right. We also added tests
for the different backends. We also activated the tests for AMD, as
those should work as well.

Fixes #61657

[ghstack-poisoned]
@pytorch-bot
Copy link
Copy Markdown

pytorch-bot Bot commented Feb 16, 2022

CI Flow Status

⚛️ CI Flow

Ruleset - Version: v1
Ruleset - File: https://github.com/pytorch/pytorch/blob/2ec56a5cc9cc6d523cec0822ca8b20b348df4a74/.github/generated-ciflow-ruleset.json
PR ciflow labels: ciflow/default
Add ciflow labels to this PR to trigger more builds:

Workflows Labels (bold enabled) Status
Triggered Workflows
linux-binary-conda ciflow/binaries, ciflow/binaries_conda, ciflow/default ✅ triggered
linux-binary-libtorch-cxx11-abi ciflow/binaries, ciflow/binaries_libtorch, ciflow/default ✅ triggered
linux-binary-libtorch-pre-cxx11 ciflow/binaries, ciflow/binaries_libtorch, ciflow/default ✅ triggered
linux-binary-manywheel ciflow/binaries, ciflow/binaries_wheel, ciflow/default ✅ triggered
linux-bionic-py3.7-clang9 ciflow/all, ciflow/cpu, ciflow/default, ciflow/linux, ciflow/noarch, ciflow/trunk, ciflow/xla ✅ triggered
linux-bionic-rocm4.5-py3.7 ciflow/all, ciflow/default, ciflow/linux, ciflow/rocm, ciflow/trunk ✅ triggered
linux-docs ciflow/all, ciflow/cpu, ciflow/default, ciflow/docs, ciflow/linux, ciflow/trunk ✅ triggered
linux-vulkan-bionic-py3.7-clang9 ciflow/all, ciflow/cpu, ciflow/default, ciflow/linux, ciflow/trunk, ciflow/vulkan ✅ triggered
linux-xenial-cuda11.3-py3.7-gcc7 ciflow/all, ciflow/cuda, ciflow/default, ciflow/linux, ciflow/trunk ✅ triggered
linux-xenial-cuda11.3-py3.7-gcc7-bazel-test ciflow/all, ciflow/bazel, ciflow/cpu, ciflow/default, ciflow/linux, ciflow/trunk ✅ triggered
linux-xenial-py3-clang5-mobile-build ciflow/all, ciflow/default, ciflow/linux, ciflow/mobile, ciflow/trunk ✅ triggered
linux-xenial-py3-clang5-mobile-custom-build-static ciflow/all, ciflow/default, ciflow/linux, ciflow/mobile, ciflow/trunk ✅ triggered
linux-xenial-py3.7-clang7-asan ciflow/all, ciflow/cpu, ciflow/default, ciflow/linux, ciflow/sanitizers, ciflow/trunk ✅ triggered
linux-xenial-py3.7-clang7-onnx ciflow/all, ciflow/cpu, ciflow/default, ciflow/linux, ciflow/onnx, ciflow/trunk ✅ triggered
linux-xenial-py3.7-gcc5.4 ciflow/all, ciflow/cpu, ciflow/default, ciflow/linux, ciflow/trunk ✅ triggered
linux-xenial-py3.7-gcc7 ciflow/all, ciflow/cpu, ciflow/default, ciflow/linux, ciflow/trunk ✅ triggered
linux-xenial-py3.7-gcc7-no-ops ciflow/all, ciflow/cpu, ciflow/default, ciflow/linux, ciflow/trunk ✅ triggered
pytorch-linux-xenial-py3-clang5-android-ndk-r19c-gradle-custom-build-single ciflow/all, ciflow/android, ciflow/cpu, ciflow/default, ciflow/linux, ciflow/trunk ✅ triggered
pytorch-linux-xenial-py3-clang5-android-ndk-r19c-gradle-custom-build-single-full-jit ciflow/all, ciflow/android, ciflow/cpu, ciflow/default, ciflow/linux, ciflow/trunk ✅ triggered
win-vs2019-cpu-py3 ciflow/all, ciflow/cpu, ciflow/default, ciflow/trunk, ciflow/win ✅ triggered
win-vs2019-cuda11.3-py3 ciflow/all, ciflow/cuda, ciflow/default, ciflow/trunk, ciflow/win ✅ triggered
windows-binary-libtorch-cxx11-abi ciflow/binaries, ciflow/binaries_libtorch, ciflow/default ✅ triggered
windows-binary-libtorch-pre-cxx11 ciflow/binaries, ciflow/binaries_libtorch, ciflow/default ✅ triggered
windows-binary-wheel ciflow/binaries, ciflow/binaries_wheel, ciflow/default ✅ triggered
Skipped Workflows
caffe2-linux-xenial-py3.7-gcc5.4 ciflow/all, ciflow/cpu, ciflow/linux, ciflow/trunk 🚫 skipped
docker-builds ciflow/all, ciflow/trunk 🚫 skipped
ios-12-5-1-arm64 ciflow/all, ciflow/ios, ciflow/macos, ciflow/trunk 🚫 skipped
ios-12-5-1-arm64-coreml ciflow/all, ciflow/ios, ciflow/macos, ciflow/trunk 🚫 skipped
ios-12-5-1-arm64-custom-ops ciflow/all, ciflow/ios, ciflow/macos, ciflow/trunk 🚫 skipped
ios-12-5-1-arm64-full-jit ciflow/all, ciflow/ios, ciflow/macos, ciflow/trunk 🚫 skipped
ios-12-5-1-arm64-metal ciflow/all, ciflow/ios, ciflow/macos, ciflow/trunk 🚫 skipped
ios-12-5-1-x86-64 ciflow/all, ciflow/ios, ciflow/macos, ciflow/trunk 🚫 skipped
ios-12-5-1-x86-64-coreml ciflow/all, ciflow/ios, ciflow/macos, ciflow/trunk 🚫 skipped
ios-12-5-1-x86-64-full-jit ciflow/all, ciflow/ios, ciflow/macos, ciflow/trunk 🚫 skipped
libtorch-linux-xenial-cuda10.2-py3.7-gcc7 ciflow/all, ciflow/cuda, ciflow/libtorch, ciflow/linux, ciflow/trunk 🚫 skipped
libtorch-linux-xenial-cuda11.3-py3.7-gcc7 ciflow/all, ciflow/cuda, ciflow/libtorch, ciflow/linux, ciflow/trunk 🚫 skipped
linux-bionic-cuda10.2-py3.9-gcc7 ciflow/all, ciflow/cuda, ciflow/linux, ciflow/slow, ciflow/trunk 🚫 skipped
linux-docs-push ciflow/all, ciflow/cpu, ciflow/linux, ciflow/scheduled 🚫 skipped
linux-xenial-cuda11.3-py3.7-gcc7-no-ops ciflow/all, ciflow/cuda, ciflow/linux, ciflow/trunk 🚫 skipped
macos-10-15-py3-arm64 ciflow/all, ciflow/macos, ciflow/trunk 🚫 skipped
macos-10-15-py3-lite-interpreter-x86-64 ciflow/all, ciflow/macos, ciflow/trunk 🚫 skipped
macos-11-py3-x86-64 ciflow/all, ciflow/macos, ciflow/trunk 🚫 skipped
parallelnative-linux-xenial-py3.7-gcc5.4 ciflow/all, ciflow/cpu, ciflow/linux, ciflow/trunk 🚫 skipped
periodic-libtorch-linux-bionic-cuda11.5-py3.7-gcc7 ciflow/all, ciflow/cuda, ciflow/libtorch, ciflow/linux, ciflow/scheduled 🚫 skipped
periodic-libtorch-linux-xenial-cuda11.1-py3.7-gcc7 ciflow/all, ciflow/cuda, ciflow/libtorch, ciflow/linux, ciflow/scheduled 🚫 skipped
periodic-linux-bionic-cuda11.5-py3.7-gcc7 ciflow/all, ciflow/cuda, ciflow/linux, ciflow/scheduled 🚫 skipped
periodic-linux-xenial-cuda10.2-py3-gcc7-slow-gradcheck ciflow/all, ciflow/cuda, ciflow/linux, ciflow/scheduled, ciflow/slow, ciflow/slow-gradcheck 🚫 skipped
periodic-linux-xenial-cuda11.1-py3.7-gcc7-debug ciflow/all, ciflow/cuda, ciflow/linux, ciflow/scheduled 🚫 skipped
periodic-win-vs2019-cuda11.1-py3 ciflow/all, ciflow/cuda, ciflow/scheduled, ciflow/win 🚫 skipped
periodic-win-vs2019-cuda11.5-py3 ciflow/all, ciflow/cuda, ciflow/scheduled, ciflow/win 🚫 skipped
pytorch-linux-xenial-py3-clang5-android-ndk-r19c-build ciflow/all, ciflow/android, ciflow/cpu, ciflow/linux, ciflow/trunk 🚫 skipped

@facebook-github-bot
Copy link
Copy Markdown
Contributor

facebook-github-bot commented Feb 16, 2022

🔗 Helpful links

❌ 6 New Failures

As of commit f927ccc (more details on the Dr. CI page):

Expand to see more
  • 6/6 failures introduced in this PR

🕵️ 6 new failures recognized by patterns

The following CI failures do not appear to be due to upstream breakages

See GitHub Actions build pull / pytorch-xla-linux-bionic-py3.7-clang8 / test (xla, 1, 1, linux.2xlarge) (1/6)

Step: "Test" (full log | diagnosis details | 🔁 rerun)

2022-05-11T12:39:06.9251577Z /var/lib/jenkins/w... virtual member functions can be marked 'override'
2022-05-11T12:39:06.9245478Z In file included from /var/lib/jenkins/workspace/xla/torch_xla/csrc/aten_xla_bridge.cpp:11:
2022-05-11T12:39:06.9246392Z /var/lib/jenkins/workspace/xla/torch_xla/csrc/tensor_impl.h:33:40: error: only virtual member functions can be marked 'override'
2022-05-11T12:39:06.9246967Z   at::IntArrayRef sizes_custom() const override;
2022-05-11T12:39:06.9247362Z                                        ^~~~~~~~
2022-05-11T12:39:06.9248145Z /var/lib/jenkins/workspace/xla/torch_xla/csrc/tensor_impl.h:34:42: error: only virtual member functions can be marked 'override'
2022-05-11T12:39:06.9248693Z   at::IntArrayRef strides_custom() const override;
2022-05-11T12:39:06.9249074Z                                          ^~~~~~~~
2022-05-11T12:39:06.9249822Z /var/lib/jenkins/workspace/xla/torch_xla/csrc/tensor_impl.h:36:30: error: only virtual member functions can be marked 'override'
2022-05-11T12:39:06.9250353Z   int64_t dim_custom() const override;
2022-05-11T12:39:06.9250678Z                              ^~~~~~~~
2022-05-11T12:39:06.9251577Z /var/lib/jenkins/workspace/xla/torch_xla/csrc/tensor_impl.h:38:32: error: only virtual member functions can be marked 'override'
2022-05-11T12:39:06.9252092Z   int64_t numel_custom() const override;
2022-05-11T12:39:06.9252422Z                                ^~~~~~~~
2022-05-11T12:39:06.9252709Z 4 errors generated.
2022-05-11T12:39:08.0434572Z [10/179] clang++-8 -MMD -MF /var/lib/jenkins/workspace/xla/build/temp.linux-x86_64-3.7/torch_xla/csrc/convert_ops.o.d -Wsign-compare -DNDEBUG -g -fwrapv -O3 -Wall -Wstrict-prototypes -D_GLIBCXX_USE_CXX11_ABI=1 -fPIC -I/var/lib/jenkins/workspace/xla -I/var/lib/jenkins/workspace/xla/third_party/tensorflow/bazel-tensorflow -I/var/lib/jenkins/workspace/xla/third_party/tensorflow/bazel-bin -I/var/lib/jenkins/workspace/xla/third_party/tensorflow/bazel-tensorflow/external/protobuf_archive/src -I/var/lib/jenkins/workspace/xla/third_party/tensorflow/bazel-tensorflow/external/com_google_protobuf/src -I/var/lib/jenkins/workspace/xla/third_party/tensorflow/bazel-tensorflow/external/eigen_archive -I/var/lib/jenkins/workspace/xla/third_party/tensorflow/bazel-tensorflow/external/com_google_absl -I/var/lib/jenkins/workspace -I/var/lib/jenkins/workspace/torch/csrc -I/var/lib/jenkins/workspace/torch/lib/tmp_install/include -I/opt/conda/lib/python3.7/site-packages/torch/include -I/opt/conda/lib/python3.7/site-packages/torch/include/torch/csrc/api/include -I/opt/conda/lib/python3.7/site-packages/torch/include/TH -I/opt/conda/lib/python3.7/site-packages/torch/include/THC -I/opt/conda/include/python3.7m -c -c /var/lib/jenkins/workspace/xla/torch_xla/csrc/convert_ops.cpp -o /var/lib/jenkins/workspace/xla/build/temp.linux-x86_64-3.7/torch_xla/csrc/convert_ops.o -std=c++14 -Wno-sign-compare -Wno-deprecated-declarations -Wno-return-type -Wno-macro-redefined -Wno-return-std-move -DNDEBUG -DTORCH_API_INCLUDE_EXTENSION_H '-DPYBIND11_COMPILER_TYPE="_clang"' '-DPYBIND11_STDLIB="_libstdcpp"' '-DPYBIND11_BUILD_ABI="_cxxabi1002"' -DTORCH_EXTENSION_NAME=_XLAC -D_GLIBCXX_USE_CXX11_ABI=1
2022-05-11T12:39:32.8502959Z [11/179] clang++-8 -MMD -MF /var/lib/jenkins/workspace/xla/build/temp.linux-x86_64-3.7/torch_xla/csrc/RegisterXLA.o.d -Wsign-compare -DNDEBUG -g -fwrapv -O3 -Wall -Wstrict-prototypes -D_GLIBCXX_USE_CXX11_ABI=1 -fPIC -I/var/lib/jenkins/workspace/xla -I/var/lib/jenkins/workspace/xla/third_party/tensorflow/bazel-tensorflow -I/var/lib/jenkins/workspace/xla/third_party/tensorflow/bazel-bin -I/var/lib/jenkins/workspace/xla/third_party/tensorflow/bazel-tensorflow/external/protobuf_archive/src -I/var/lib/jenkins/workspace/xla/third_party/tensorflow/bazel-tensorflow/external/com_google_protobuf/src -I/var/lib/jenkins/workspace/xla/third_party/tensorflow/bazel-tensorflow/external/eigen_archive -I/var/lib/jenkins/workspace/xla/third_party/tensorflow/bazel-tensorflow/external/com_google_absl -I/var/lib/jenkins/workspace -I/var/lib/jenkins/workspace/torch/csrc -I/var/lib/jenkins/workspace/torch/lib/tmp_install/include -I/opt/conda/lib/python3.7/site-packages/torch/include -I/opt/conda/lib/python3.7/site-packages/torch/include/torch/csrc/api/include -I/opt/conda/lib/python3.7/site-packages/torch/include/TH -I/opt/conda/lib/python3.7/site-packages/torch/include/THC -I/opt/conda/include/python3.7m -c -c /var/lib/jenkins/workspace/xla/torch_xla/csrc/RegisterXLA.cpp -o /var/lib/jenkins/workspace/xla/build/temp.linux-x86_64-3.7/torch_xla/csrc/RegisterXLA.o -std=c++14 -Wno-sign-compare -Wno-deprecated-declarations -Wno-return-type -Wno-macro-redefined -Wno-return-std-move -DNDEBUG -DTORCH_API_INCLUDE_EXTENSION_H '-DPYBIND11_COMPILER_TYPE="_clang"' '-DPYBIND11_STDLIB="_libstdcpp"' '-DPYBIND11_BUILD_ABI="_cxxabi1002"' -DTORCH_EXTENSION_NAME=_XLAC -D_GLIBCXX_USE_CXX11_ABI=1
2022-05-11T12:39:32.8505461Z /var/lib/jenkins/workspace/xla/torch_xla/csrc/RegisterXLA.cpp:63:6: warning: unused function 'resize_out' [-Wunused-function]
2022-05-11T12:39:32.8505886Z void resize_out(const Tensor &out, IntArrayRef sizes, IntArrayRef strides, const TensorOptions &options) {
2022-05-11T12:39:32.8506207Z      ^
2022-05-11T12:39:32.8506626Z /var/lib/jenkins/workspace/xla/torch_xla/csrc/RegisterXLA.cpp:82:6: warning: unused function 'check_inplace' [-Wunused-function]
2022-05-11T12:39:32.8507075Z void check_inplace(const Tensor &self, IntArrayRef sizes, const TensorOptions &options) {

See GitHub Actions build pull / linux-xenial-py3.7-gcc5.4 / test (backwards_compat, 1, 1, linux.2xlarge) (2/6)

Step: "Test" (full log | diagnosis details | 🔁 rerun)

2022-05-11T11:17:07.9088304Z The PR is introduc...m to confirm whether this change is wanted or not.
2022-05-11T11:17:07.9074035Z processing existing schema:  text(__torch__.torch.classes.profiling.SourceRef _0) -> (str _0)
2022-05-11T11:17:07.9075346Z processing existing schema:  count(__torch__.torch.classes.profiling.InstructionStats _0) -> (int _0)
2022-05-11T11:17:07.9076754Z processing existing schema:  duration_ns(__torch__.torch.classes.profiling.InstructionStats _0) -> (int _0)
2022-05-11T11:17:07.9078464Z processing existing schema:  source(__torch__.torch.classes.profiling.SourceStats _0) -> (__torch__.torch.classes.profiling.SourceRef _0)
2022-05-11T11:17:07.9080223Z processing existing schema:  line_map(__torch__.torch.classes.profiling.SourceStats _0) -> (Dict(int, __torch__.torch.classes.profiling.InstructionStats) _0)
2022-05-11T11:17:07.9081305Z processing existing schema:  __init__(__torch__.torch.classes.profiling._ScriptProfile _0) -> (NoneType _0)
2022-05-11T11:17:07.9082719Z processing existing schema:  enable(__torch__.torch.classes.profiling._ScriptProfile _0) -> (NoneType _0)
2022-05-11T11:17:07.9084070Z processing existing schema:  disable(__torch__.torch.classes.profiling._ScriptProfile _0) -> (NoneType _0)
2022-05-11T11:17:07.9086073Z processing existing schema:  _dump_stats(__torch__.torch.classes.profiling._ScriptProfile _0) -> (__torch__.torch.classes.profiling.SourceStats[] _0)
2022-05-11T11:17:07.9087765Z processing existing schema:  __init__(__torch__.torch.classes.dist_rpc.WorkerInfo _0, str _1, int _2) -> (NoneType _0)
2022-05-11T11:17:07.9088304Z The PR is introducing backward incompatible changes to the operator library. Please contact PyTorch team to confirm whether this change is wanted or not. 
2022-05-11T11:17:07.9088338Z 
2022-05-11T11:17:07.9088488Z Broken ops: [
2022-05-11T11:17:07.9089067Z 	aten::set_.source_Tensor_storage_offset(Tensor(a!) self, Tensor source, int storage_offset, int[] size, int[] stride=[]) -> (Tensor(a!))
2022-05-11T11:17:07.9089279Z 	prim::oneDNNFusionGroup(...) -> (...)
2022-05-11T11:17:07.9089406Z 	prim::oneDNNFusionGuard(...) -> (...)
2022-05-11T11:17:07.9089579Z 	aten::linalg_vander(Tensor x, *, int? N=None) -> (Tensor)
2022-05-11T11:17:07.9089642Z ]
2022-05-11T11:17:08.0124913Z + cleanup
2022-05-11T11:17:08.0125266Z + retcode=1
2022-05-11T11:17:08.0125478Z + set +x

See GitHub Actions build pull / linux-xenial-cuda11.3-py3.7-gcc7-bazel-test / build-and-test (3/6)

Step: "Build" (full log | diagnosis details | 🔁 rerun)

2022-05-11T11:55:13.1138481Z torch/csrc/utils/p...:64: error: 'func_' was not declared in this scope
2022-05-11T11:55:13.1132202Z Use --sandbox_debug to see verbose messages from the sandbox
2022-05-11T11:55:13.1132921Z torch/csrc/utils/python_dispatch.cpp:58:3: error: 'SafePyObject' does not name a type; did you mean 'createPyObject'?
2022-05-11T11:55:13.1133375Z    SafePyObject func_;
2022-05-11T11:55:13.1133659Z    ^~~~~~~~~~~~
2022-05-11T11:55:13.1133923Z    createPyObject
2022-05-11T11:55:13.1134731Z torch/csrc/utils/python_dispatch.cpp: In constructor 'torch::impl::dispatch::PythonKernelHolder::PythonKernelHolder(pybind11::object)':
2022-05-11T11:55:13.1135710Z torch/csrc/utils/python_dispatch.cpp:60:41: error: class 'torch::impl::dispatch::PythonKernelHolder' does not have any field named 'func_'
2022-05-11T11:55:13.1136343Z    PythonKernelHolder(py::object func) : func_(func.release().ptr(), getPyInterpreter()) {}
2022-05-11T11:55:13.1136785Z                                          ^~~~~
2022-05-11T11:55:13.1137662Z torch/csrc/utils/python_dispatch.cpp: In member function 'void torch::impl::dispatch::PythonKernelHolder::operator()(const c10::OperatorHandle&, c10::DispatchKeySet, torch::jit::Stack*)':
2022-05-11T11:55:13.1138481Z torch/csrc/utils/python_dispatch.cpp:66:64: error: 'func_' was not declared in this scope
2022-05-11T11:55:13.1139140Z      auto obj = py::reinterpret_steal<py::object>(PyObject_Call(func_.ptr(getPyInterpreter()), args_kwargs.first.ptr(), args_kwargs.second.ptr()));
2022-05-11T11:55:13.1139675Z                                                                 ^~~~~
2022-05-11T11:55:14.2266374Z �[32mINFO: �[0mElapsed time: 2255.166s, Critical Path: 412.12s
2022-05-11T11:55:14.2267006Z �[32mINFO: �[0m3419 processes: 185 internal, 3234 processwrapper-sandbox.
2022-05-11T11:55:14.2267577Z �[31m�[1mFAILED:�[0m Build did NOT complete successfully
2022-05-11T11:55:14.2289266Z �[31m�[1mFAILED:�[0m Build did NOT complete successfully
2022-05-11T11:55:14.2455252Z �[0m+ cleanup
2022-05-11T11:55:14.2484343Z + retcode=1
2022-05-11T11:55:14.2484676Z + set +x
2022-05-11T11:55:14.3030067Z ##[error]Process completed with exit code 1.

See GitHub Actions build pull / win-vs2019-cpu-py3 / test (default, 2, 2, windows.4xlarge) (4/6)

Step: "Test" (full log | diagnosis details | 🔁 rerun)

2022-05-11T11:45:33.2444546Z ls: cannot access ...d/win_tmp/ci_scripts/*': No such file or directory
2022-05-11T11:45:33.1343412Z + export TEST_DIR_WIN
2022-05-11T11:45:33.1343691Z + export PYTORCH_FINAL_PACKAGE_DIR=/c/2306597987/build-results/
2022-05-11T11:45:33.1343980Z + PYTORCH_FINAL_PACKAGE_DIR=/c/2306597987/build-results/
2022-05-11T11:45:33.1412554Z ++ cygpath -w /c/2306597987/build-results/
2022-05-11T11:45:33.1523662Z + PYTORCH_FINAL_PACKAGE_DIR_WIN='C:\2306597987\build-results\'
2022-05-11T11:45:33.1523968Z + export PYTORCH_FINAL_PACKAGE_DIR_WIN
2022-05-11T11:45:33.1524278Z + mkdir -p /c/actions-runner/_work/pytorch/pytorch/build/win_tmp/build/torch
2022-05-11T11:45:33.1909019Z + CI_SCRIPTS_DIR=/c/actions-runner/_work/pytorch/pytorch/build/win_tmp/ci_scripts
2022-05-11T11:45:33.1909431Z + mkdir -p /c/actions-runner/_work/pytorch/pytorch/build/win_tmp/ci_scripts
2022-05-11T11:45:33.2136673Z ++ ls '/c/actions-runner/_work/pytorch/pytorch/build/win_tmp/ci_scripts/*'
2022-05-11T11:45:33.2444546Z ls: cannot access '/c/actions-runner/_work/pytorch/pytorch/build/win_tmp/ci_scripts/*': No such file or directory
2022-05-11T11:45:33.2447455Z + '[' -n '' ']'
2022-05-11T11:45:33.2448123Z + export SCRIPT_HELPERS_DIR=/c/actions-runner/_work/pytorch/pytorch/.jenkins/pytorch/win-test-helpers
2022-05-11T11:45:33.2448566Z + SCRIPT_HELPERS_DIR=/c/actions-runner/_work/pytorch/pytorch/.jenkins/pytorch/win-test-helpers
2022-05-11T11:45:33.2448877Z + [[ win-vs2019-cpu-py3 == *cuda11* ]]
2022-05-11T11:45:33.2449115Z + [[ default = \f\o\r\c\e\_\o\n\_\c\p\u ]]
2022-05-11T11:45:33.2449663Z + run_tests
2022-05-11T11:45:33.2450009Z + for path in '/c/Program Files/NVIDIA Corporation/NVSMI/nvidia-smi.exe' /c/Windows/System32/nvidia-smi.exe
2022-05-11T11:45:33.2450554Z + [[ -x /c/Program Files/NVIDIA Corporation/NVSMI/nvidia-smi.exe ]]
2022-05-11T11:45:33.2452145Z + '/c/Program Files/NVIDIA Corporation/NVSMI/nvidia-smi.exe'
2022-05-11T11:45:34.5747768Z NVIDIA-SMI has failed because it couldn't communicate with the NVIDIA driver. Make sure that the latest NVIDIA driver is installed and running. This can also be happening if non-NVIDIA GPU is running as primary display, and NVIDIA GPU is in WDDM mode.

See GitHub Actions build pull / win-vs2019-cuda11.3-py3 / test (force_on_cpu, 1, 1, windows.4xlarge) (5/6)

Step: "Test" (full log | diagnosis details | 🔁 rerun)

2022-05-11T11:54:48.7795335Z ls: cannot access ...d/win_tmp/ci_scripts/*': No such file or directory
2022-05-11T11:54:48.6621151Z + export TEST_DIR_WIN
2022-05-11T11:54:48.6621400Z + export PYTORCH_FINAL_PACKAGE_DIR=/c/2306597987/build-results/
2022-05-11T11:54:48.6621704Z + PYTORCH_FINAL_PACKAGE_DIR=/c/2306597987/build-results/
2022-05-11T11:54:48.6694590Z ++ cygpath -w /c/2306597987/build-results/
2022-05-11T11:54:48.6812026Z + PYTORCH_FINAL_PACKAGE_DIR_WIN='C:\2306597987\build-results\'
2022-05-11T11:54:48.6812332Z + export PYTORCH_FINAL_PACKAGE_DIR_WIN
2022-05-11T11:54:48.6812652Z + mkdir -p /c/actions-runner/_work/pytorch/pytorch/build/win_tmp/build/torch
2022-05-11T11:54:48.7233985Z + CI_SCRIPTS_DIR=/c/actions-runner/_work/pytorch/pytorch/build/win_tmp/ci_scripts
2022-05-11T11:54:48.7234373Z + mkdir -p /c/actions-runner/_work/pytorch/pytorch/build/win_tmp/ci_scripts
2022-05-11T11:54:48.7452617Z ++ ls '/c/actions-runner/_work/pytorch/pytorch/build/win_tmp/ci_scripts/*'
2022-05-11T11:54:48.7795335Z ls: cannot access '/c/actions-runner/_work/pytorch/pytorch/build/win_tmp/ci_scripts/*': No such file or directory
2022-05-11T11:54:48.7799016Z + '[' -n '' ']'
2022-05-11T11:54:48.7799392Z + export SCRIPT_HELPERS_DIR=/c/actions-runner/_work/pytorch/pytorch/.jenkins/pytorch/win-test-helpers
2022-05-11T11:54:48.7799932Z + SCRIPT_HELPERS_DIR=/c/actions-runner/_work/pytorch/pytorch/.jenkins/pytorch/win-test-helpers
2022-05-11T11:54:48.7800342Z + [[ win-vs2019-cuda11.3-py3 == *cuda11* ]]
2022-05-11T11:54:48.7800586Z + export BUILD_SPLIT_CUDA=ON
2022-05-11T11:54:48.7800781Z + BUILD_SPLIT_CUDA=ON
2022-05-11T11:54:48.7801018Z + [[ force_on_cpu = \f\o\r\c\e\_\o\n\_\c\p\u ]]
2022-05-11T11:54:48.7801859Z + export USE_CUDA=0
2022-05-11T11:54:48.7802135Z + USE_CUDA=0
2022-05-11T11:54:48.7802317Z + run_tests

See GitHub Actions build pull / win-vs2019-cuda11.3-py3 / test (default, 1, 2, windows.8xlarge.nvidia.gpu) (6/6)

Step: "Test" (full log | diagnosis details | 🔁 rerun)

2022-05-11T11:56:08.1673408Z ls: cannot access ...d/win_tmp/ci_scripts/*': No such file or directory
2022-05-11T11:56:08.0280988Z + export TEST_DIR_WIN
2022-05-11T11:56:08.0281568Z + export PYTORCH_FINAL_PACKAGE_DIR=/c/2306597987/build-results/
2022-05-11T11:56:08.0282248Z + PYTORCH_FINAL_PACKAGE_DIR=/c/2306597987/build-results/
2022-05-11T11:56:08.0374179Z ++ cygpath -w /c/2306597987/build-results/
2022-05-11T11:56:08.0537745Z + PYTORCH_FINAL_PACKAGE_DIR_WIN='C:\2306597987\build-results\'
2022-05-11T11:56:08.0538456Z + export PYTORCH_FINAL_PACKAGE_DIR_WIN
2022-05-11T11:56:08.0539257Z + mkdir -p /c/actions-runner/_work/pytorch/pytorch/build/win_tmp/build/torch
2022-05-11T11:56:08.1022955Z + CI_SCRIPTS_DIR=/c/actions-runner/_work/pytorch/pytorch/build/win_tmp/ci_scripts
2022-05-11T11:56:08.1297745Z + mkdir -p /c/actions-runner/_work/pytorch/pytorch/build/win_tmp/ci_scripts
2022-05-11T11:56:08.1298217Z ++ ls '/c/actions-runner/_work/pytorch/pytorch/build/win_tmp/ci_scripts/*'
2022-05-11T11:56:08.1673408Z ls: cannot access '/c/actions-runner/_work/pytorch/pytorch/build/win_tmp/ci_scripts/*': No such file or directory
2022-05-11T11:56:08.1677645Z + '[' -n '' ']'
2022-05-11T11:56:08.1678349Z + export SCRIPT_HELPERS_DIR=/c/actions-runner/_work/pytorch/pytorch/.jenkins/pytorch/win-test-helpers
2022-05-11T11:56:08.1679248Z + SCRIPT_HELPERS_DIR=/c/actions-runner/_work/pytorch/pytorch/.jenkins/pytorch/win-test-helpers
2022-05-11T11:56:08.1679852Z + [[ win-vs2019-cuda11.3-py3 == *cuda11* ]]
2022-05-11T11:56:08.1680314Z + export BUILD_SPLIT_CUDA=ON
2022-05-11T11:56:08.1680741Z + BUILD_SPLIT_CUDA=ON
2022-05-11T11:56:08.1681159Z + [[ default = \f\o\r\c\e\_\o\n\_\c\p\u ]]
2022-05-11T11:56:08.1681740Z + run_tests
2022-05-11T11:56:08.1683244Z + for path in '/c/Program Files/NVIDIA Corporation/NVSMI/nvidia-smi.exe' /c/Windows/System32/nvidia-smi.exe
2022-05-11T11:56:08.1683773Z + [[ -x /c/Program Files/NVIDIA Corporation/NVSMI/nvidia-smi.exe ]]

This comment was automatically generated by Dr. CI (expand for details).

Please report bugs/suggestions to the (internal) Dr. CI Users group.

Click here to manually regenerate this comment.

lezcano added a commit that referenced this pull request Feb 16, 2022
This PR adds `linalg.lu_solve`. While doing so, I found a bug in MAGMA
when calling the batched MAGMA backend with trans=True. We work around
that by solving the system solving two triangular systems.

We also update the heuristics for this function, as they were fairly
updated. We found that cuSolver is king, so luckily we do not need to
rely on the buggy backend from magma for this function.

We added tests testing this function left and right. We also added tests
for the different backends. We also activated the tests for AMD, as
those should work as well.

Fixes #61657

ghstack-source-id: 86a91f0
Pull Request resolved: #72935
@lezcano lezcano removed the request for review from ezyang February 16, 2022 20:19
@lezcano lezcano added the module: linear algebra Issues related to specialized linear algebra operations in PyTorch; includes matrix multiply matmul label Feb 16, 2022
@lezcano lezcano marked this pull request as draft February 16, 2022 20:20
This PR adds `linalg.lu_solve`. While doing so, I found a bug in MAGMA
when calling the batched MAGMA backend with trans=True. We work around
that by solving the system solving two triangular systems.

We also update the heuristics for this function, as they were fairly
outdated. We found that cuSolver is king, so luckily we do not need to
rely on the buggy backend from magma for this function.

We added tests testing this function left and right. We also added tests
for the different backends. We also activated the tests for AMD, as
those should work as well.

### Benchmarking

<details>
<summary>
Benchmark Results (adjoint=False)
</summary>

```
--------------------------------------------------------------------------------------------- linalg.lu_solve CUDA ---------------------------------------------------------------------------------------------]   
                                          |  lu_solve looped_magma  |  lu_solve looped cusolver  |  lu_solve batched cublas  |  lu_solve batched magma  |  lu_solve unpack+solve_triangular  |  lu_solve heuristic   
1 threads: -------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------   
      shape torch.Size([1, 1, 1])         |             750         |              34            |              28           |            252           |                  86                |           27          
      shape torch.Size([2, 1, 1])         |            1500         |              50            |              28           |            239           |                  94                |           27          
      shape torch.Size([4, 1, 1])         |            2995         |              83            |              28           |            239           |                 100                |           27          
      shape torch.Size([8, 1, 1])         |            6000         |             146            |              28           |            239           |                  94                |           27          
      shape torch.Size([16, 1, 1])        |           11900         |             272            |              28           |            241           |                  95                |           27          
      shape torch.Size([32, 1, 1])        |           23880         |             524            |              28           |            244           |                  94                |           27          
      shape torch.Size([64, 1, 1])        |           48000         |            1000            |              28           |            245           |                  99                |           27          
      shape torch.Size([128, 1, 1])       |           96000         |            2054            |              28           |            242           |                  96                |           27          
      shape torch.Size([512, 1, 1])       |          381900         |            8100            |              28           |            250           |                  94                |           27          
      shape torch.Size([1024, 1, 1])      |          763800         |           16200            |              28           |            257           |                  95                |           27          
      shape torch.Size([1, 2, 2])         |             750         |              33            |              28           |            240           |                  88                |           27          
      shape torch.Size([2, 2, 2])         |            1500         |              51            |              28           |            240           |                  96                |           27          
      shape torch.Size([4, 2, 2])         |            2991         |              82            |              28           |            241           |                  96                |           28          
      shape torch.Size([8, 2, 2])         |            6000         |             150            |              28           |            241           |                  96                |           27          
      shape torch.Size([16, 2, 2])        |           12000         |             275            |              28           |            242           |                  96                |           27          
      shape torch.Size([32, 2, 2])        |           23980         |             530            |              28           |            246           |                  97                |           28          
      shape torch.Size([64, 2, 2])        |           48000         |            1000            |              28           |            244           |                  96                |           27          
      shape torch.Size([128, 2, 2])       |           96000         |            2063            |              28           |            245           |                  96                |           28          
      shape torch.Size([512, 2, 2])       |          382000         |            8300            |              28           |            257           |                  97                |           28          
      shape torch.Size([1024, 2, 2])      |          764000         |           20000            |              28           |            271           |                  97                |           28          
      shape torch.Size([1, 8, 8])         |             749         |              34            |              28           |            243           |                  88                |           28          
      shape torch.Size([2, 8, 8])         |            1500         |              50            |              28           |            244           |                  97                |           28          
      shape torch.Size([4, 8, 8])         |            2988         |              83            |              28           |            244           |                 100                |           28          
      shape torch.Size([8, 8, 8])         |            5980         |             150            |              28           |            245           |                  97                |           28          
      shape torch.Size([16, 8, 8])        |           12000         |             278            |              28           |            246           |                  96                |           28          
      shape torch.Size([32, 8, 8])        |           23910         |             536            |              28           |            249           |                  98                |           28          
      shape torch.Size([64, 8, 8])        |           47800         |            1100            |              28           |            247           |                  96                |           28          
      shape torch.Size([128, 8, 8])       |           96000         |            2075            |              28           |            248           |                  96                |           28          
      shape torch.Size([512, 8, 8])       |          382100         |            8300            |              28           |            270           |                  97                |           28          
      shape torch.Size([1024, 8, 8])      |          764100         |           16400            |              28           |            291           |                 100                |           28          
      shape torch.Size([1, 16, 16])       |             750         |              33            |              28           |            248           |                  88                |           28
      shape torch.Size([2, 16, 16])       |            1500         |              50            |              28           |            250           |                  97                |           28       
      shape torch.Size([4, 16, 16])       |            2996         |              83            |              28           |            250           |                 100                |           28       
      shape torch.Size([8, 16, 16])       |            5980         |             147            |              28           |            251           |                  97                |           28       
      shape torch.Size([16, 16, 16])      |           11900         |             274            |              28           |            251           |                  97                |           28       
      shape torch.Size([32, 16, 16])      |           24040         |             527            |              28           |            252           |                  97                |           28       
      shape torch.Size([64, 16, 16])      |           47800         |            1037            |              28           |            251           |                 100                |           28       
      shape torch.Size([128, 16, 16])     |           95600         |            2044            |              28           |            252           |                  98                |           28       
      shape torch.Size([512, 16, 16])     |          388200         |            8100            |              28           |            280           |                 100                |           28       
      shape torch.Size([1024, 16, 16])    |          769700         |           16000            |              28           |            322           |                 117                |           28       
      shape torch.Size([1, 32, 32])       |             760         |              33            |              28           |            255           |                  88                |           28       
      shape torch.Size([2, 32, 32])       |            1510         |              50            |              28           |            256           |                  97                |           28       
      shape torch.Size([4, 32, 32])       |            3022         |              82            |              31           |            256           |                  97                |           30       
      shape torch.Size([8, 32, 32])       |            6000         |             140            |              31           |            257           |                 100                |           31       
      shape torch.Size([16, 32, 32])      |           12000         |             281            |              31           |            258           |                  96                |           31       
      shape torch.Size([32, 32, 32])      |           24150         |             563            |              35           |            258           |                  96                |           35       
      shape torch.Size([64, 32, 32])      |           48300         |            1119            |              36           |            258           |                  97                |           36       
      shape torch.Size([128, 32, 32])     |           96500         |            2235            |              43           |            261           |                  96                |           43       
      shape torch.Size([512, 32, 32])     |          383100         |            8930            |              82           |            317           |                 191                |           82       
      shape torch.Size([1024, 32, 32])    |          766300         |           19200            |             122           |            400           |                 312                |          122       
      shape torch.Size([1, 64, 64])       |             760         |              33            |              55           |            272           |                  68                |           34       
      shape torch.Size([2, 64, 64])       |            1500         |              52            |              58           |            273           |                  85                |           52       
      shape torch.Size([4, 64, 64])       |            3127         |             102            |              65           |            273           |                 150                |           65       
      shape torch.Size([8, 64, 64])       |            6070         |             201            |              65           |            275           |                 278                |           65       
      shape torch.Size([16, 64, 64])      |           12000         |             399            |              66           |            274           |                  95                |           67       
      shape torch.Size([32, 64, 64])      |           23900         |             796            |              73           |            275           |                  97                |           73       
      shape torch.Size([64, 64, 64])      |           48000         |            1594            |              75           |            283           |                 123                |           76       
      shape torch.Size([128, 64, 64])     |           95000         |            3177            |              96           |            292           |                 176                |           96       
      shape torch.Size([512, 64, 64])     |          379300         |           13520            |             208           |            426           |                 551                |          208       
      shape torch.Size([1024, 64, 64])    |          758700         |           27100            |             306           |            570           |                 919                |          306       
      shape torch.Size([1, 128, 128])     |             750         |              42            |             115           |            306           |                  90                |           42       
      shape torch.Size([2, 128, 128])     |            1500         |              82            |             122           |            307           |                 164                |           83       
      shape torch.Size([4, 128, 128])     |            2966         |             162            |             136           |            307           |                 301                |          136       
      shape torch.Size([8, 128, 128])     |            5930         |             317            |             137           |            308           |                 578                |          138       
      shape torch.Size([16, 128, 128])    |           12000         |             635            |             143           |            316           |                 199                |          143       
      shape torch.Size([32, 128, 128])    |           23700         |            1266            |             152           |            322           |                 241                |          152       
      shape torch.Size([64, 128, 128])    |           48000         |            2668            |             177           |            337           |                 322                |          177       
      shape torch.Size([128, 128, 128])   |           96000         |            5366            |             228           |            365           |                 514                |          228       
      shape torch.Size([512, 128, 128])   |          379400         |           21490            |             502           |            620           |                1697                |          502       
      shape torch.Size([1024, 128, 128])  |          755700         |           43040            |             764           |            903           |                3040                |          770       
      shape torch.Size([1, 256, 256])     |             750         |              70            |             235           |            383           |                 178                |           72       
      shape torch.Size([2, 256, 256])     |            2000         |             138            |             250           |            384           |                 329                |          139       
      shape torch.Size([4, 256, 256])     |            2988         |             277            |             279           |            404           |                 655                |          278       
      shape torch.Size([8, 256, 256])     |            6100         |             546            |             283           |            420           |                1321                |          286       
      shape torch.Size([16, 256, 256])    |           12100         |            1149            |             330           |            441           |                 472                |          330       
      shape torch.Size([32, 256, 256])    |           24040         |            2303            |             359           |            453           |                 634                |          360       
      shape torch.Size([64, 256, 256])    |           48000         |            4626            |             408           |            472           |                 925                |          408       
      shape torch.Size([128, 256, 256])   |           94700         |            9247            |             543           |            543           |                1582                |          543       
      shape torch.Size([512, 256, 256])   |          372000         |           37030            |            1310           |           1185           |                5711                |         1310       
      shape torch.Size([1024, 256, 256])  |          747200         |           74100            |            2116           |           1910           |               10660                |         2122       
```
</details>

<details>
<summary>
Benchmark Results (adjoint=True)
</summary>

```
[----------------------------------------------------------------------------------------- linalg.lu_solve CUDA Adjoint -----------------------------------------------------------------------------------------]   
                                          |  lu_solve looped_magma  |  lu_solve looped cusolver  |  lu_solve batched cublas  |  lu_solve batched magma  |  lu_solve unpack+solve_triangular  |  lu_solve heuristic   
1 threads: -------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------   
      shape torch.Size([1, 1, 1])         |             749         |              34            |              28           |             33           |                  98                |           27          
      shape torch.Size([2, 1, 1])         |            1500         |              50            |              28           |             50           |                 110                |           27          
      shape torch.Size([4, 1, 1])         |            3005         |              82            |              28           |             81           |                 110                |           27          
      shape torch.Size([8, 1, 1])         |            5999         |             145            |              28           |            140           |                 110                |           27          
      shape torch.Size([16, 1, 1])        |           12000         |             273            |              28           |             77           |                 110                |           27          
      shape torch.Size([32, 1, 1])        |           24000         |             522            |              28           |             78           |                 110                |           27          
      shape torch.Size([64, 1, 1])        |           48000         |            1000            |              28           |             77           |                 100                |           27          
      shape torch.Size([128, 1, 1])       |           96000         |            2029            |              28           |             78           |                 110                |           27          
      shape torch.Size([512, 1, 1])       |          383300         |            8100            |              28           |             78           |                 110                |           28          
      shape torch.Size([1024, 1, 1])      |          767500         |           16100            |              28           |             77           |                 100                |           27          
      shape torch.Size([1, 2, 2])         |             753         |              33            |              28           |             33           |                  99                |           28          
      shape torch.Size([2, 2, 2])         |            1500         |              50            |              28           |             50           |                 110                |           28          
      shape torch.Size([4, 2, 2])         |            3002         |              82            |              28           |             80           |                 100                |           27          
      shape torch.Size([8, 2, 2])         |            6000         |             145            |              28           |            144           |                 107                |           27          
      shape torch.Size([16, 2, 2])        |           12000         |             271            |              28           |             78           |                 110                |           27          
      shape torch.Size([32, 2, 2])        |           24120         |             524            |              28           |             78           |                 110                |           28          
      shape torch.Size([64, 2, 2])        |           48300         |            1030            |              28           |             78           |                 111                |           27          
      shape torch.Size([128, 2, 2])       |           96100         |            2041            |              28           |             78           |                 107                |           28          
      shape torch.Size([512, 2, 2])       |          383000         |            8100            |              28           |             79           |                 108                |           28          
      shape torch.Size([1024, 2, 2])      |          766100         |           16000            |              28           |             78           |                 110                |           28          
      shape torch.Size([1, 8, 8])         |             750         |              34            |              28           |             34           |                  99                |           28          
      shape torch.Size([2, 8, 8])         |            1500         |              50            |              28           |             50           |                 107                |           28          
      shape torch.Size([4, 8, 8])         |            2998         |              82            |              28           |             82           |                 110                |           28          
      shape torch.Size([8, 8, 8])         |            5990         |             146            |              28           |            150           |                 107                |           28          
      shape torch.Size([16, 8, 8])        |           11980         |             272            |              28           |             79           |                 107                |           28          
      shape torch.Size([32, 8, 8])        |           23970         |             530            |              28           |             79           |                 110                |           28          
      shape torch.Size([64, 8, 8])        |           47900         |            1040            |              28           |             79           |                 108                |           28          
      shape torch.Size([128, 8, 8])       |           96000         |            2048            |              28           |             78           |                 108                |           28          
      shape torch.Size([512, 8, 8])       |          383700         |            8100            |              28           |             80           |                 108                |           28          
      shape torch.Size([1024, 8, 8])      |          766200         |           16300            |              28           |             80           |                 108                |           28          
      shape torch.Size([1, 16, 16])       |             760         |              33            |              28           |             34           |                  99                |           28          
      shape torch.Size([2, 16, 16])       |            1500         |              50            |              28           |             50           |                 110                |           28  [85/469]
      shape torch.Size([4, 16, 16])       |            3001         |              81            |              28           |             82           |                 108                |           28          
      shape torch.Size([8, 16, 16])       |            6000         |             145            |              28           |            140           |                 110                |           28          
      shape torch.Size([16, 16, 16])      |           12000         |             276            |              28           |             79           |                 110                |           28          
      shape torch.Size([32, 16, 16])      |           23870         |             549            |              28           |             79           |                 110                |           28          
      shape torch.Size([64, 16, 16])      |           47900         |            1098            |              29           |             80           |                 100                |           28          
      shape torch.Size([128, 16, 16])     |           95800         |            2184            |              28           |             79           |                 108                |           28          
      shape torch.Size([512, 16, 16])     |          386900         |            8769            |              28           |             80           |                 108                |           28          
      shape torch.Size([1024, 16, 16])    |          769800         |           17460            |              37           |             80           |                 107                |           37          
      shape torch.Size([1, 32, 32])       |             760         |              33            |              28           |             34           |                  99                |           28          
      shape torch.Size([2, 32, 32])       |            1500         |              50            |              28           |             50           |                 110                |           29          
      shape torch.Size([4, 32, 32])       |            3021         |              86            |              31           |             84           |                 110                |           32          
      shape torch.Size([8, 32, 32])       |            6040         |             167            |              32           |            167           |                 108                |           32          
      shape torch.Size([16, 32, 32])      |           12100         |             330            |              33           |             78           |                 107                |           33          
      shape torch.Size([32, 32, 32])      |           24150         |             662            |              35           |             78           |                 110                |           35          
      shape torch.Size([64, 32, 32])      |           48200         |            1323            |              36           |             79           |                 110                |           36          
      shape torch.Size([128, 32, 32])     |           97000         |            2637            |              44           |             79           |                 110                |           43          
      shape torch.Size([512, 32, 32])     |          382500         |           10580            |              83           |            180           |                 198                |           83          
      shape torch.Size([1024, 32, 32])    |          766600         |           22670            |             123           |            260           |                 318                |          120          
      shape torch.Size([1, 64, 64])       |             760         |              33            |              58           |             33           |                  88                |           34          
      shape torch.Size([2, 64, 64])       |            1520         |              60            |              60           |             60           |                 109                |           59          
      shape torch.Size([4, 64, 64])       |            3016         |             115            |              67           |            119           |                 132                |           66          
      shape torch.Size([8, 64, 64])       |            6120         |             230            |              67           |            233           |                 180                |           68          
      shape torch.Size([16, 64, 64])      |           12100         |             457            |              69           |             86           |                 107                |           69          
      shape torch.Size([32, 64, 64])      |           24000         |             912            |              74           |             95           |                 100                |           74          
      shape torch.Size([64, 64, 64])      |           48000         |            1833            |              76           |            106           |                 123                |           76          
      shape torch.Size([128, 64, 64])     |           95000         |            3636            |              97           |            163           |                 183                |           97          
      shape torch.Size([512, 64, 64])     |          380600         |           15600            |             210           |            464           |                 549                |          210          
      shape torch.Size([1024, 64, 64])    |          761200         |           31140            |             308           |            741           |                 918                |          308          
      shape torch.Size([1, 128, 128])     |             756         |              46            |             120           |             47           |                  89                |           46          
      shape torch.Size([2, 128, 128])     |            1500         |              91            |             123           |             89           |                 110                |           92          
      shape torch.Size([4, 128, 128])     |            2994         |             178            |             139           |            180           |                 131                |          139          
      shape torch.Size([8, 128, 128])     |            5960         |             350            |             140           |            354           |                 208                |          142          
      shape torch.Size([16, 128, 128])    |           12000         |             701            |             144           |            177           |                 198                |          143          
      shape torch.Size([32, 128, 128])    |           23870         |            1401            |             155           |            225           |                 246                |          155          
      shape torch.Size([64, 128, 128])    |           47600         |            2948            |             179           |            288           |                 323                |          180          
      shape torch.Size([128, 128, 128])   |           96000         |            5910            |             231           |            442           |                 512                |          231          
      shape torch.Size([512, 128, 128])   |          381200         |           23640            |             519           |           1400           |                1700                |          519          
      shape torch.Size([1024, 128, 128])  |          755800         |           47340            |             794           |           2436           |                3018                |          794          
      shape torch.Size([1, 256, 256])     |             760         |              74            |             246           |             77           |                  88                |           78          
      shape torch.Size([2, 256, 256])     |            1510         |             150            |             256           |            150           |                 117                |          150          
      shape torch.Size([4, 256, 256])     |            3030         |             296            |             284           |            296           |                 209                |          284          
      shape torch.Size([8, 256, 256])     |            6100         |             588            |             286           |            592           |                 394                |          288          
      shape torch.Size([16, 256, 256])    |           12200         |            1238            |             330           |            445           |                 480                |          330          
      shape torch.Size([32, 256, 256])    |           24430         |            2476            |             368           |            568           |                 629                |          367          
      shape torch.Size([64, 256, 256])    |           49000         |            4950            |             415           |            800           |                 921                |          414          
      shape torch.Size([128, 256, 256])   |           96000         |            9900            |             552           |           1330           |                1579                |          553          
      shape torch.Size([512, 256, 256])   |          369400         |           39580            |            1410           |           4614           |                5616                |         1410          
      shape torch.Size([1024, 256, 256])  |          716200         |           79200            |            2270           |           8472           |               10500                |         2277          

Times are in microseconds (us).
```

</details>

To generate the results below, I put the backend I wanted to test at the beginning of the function `lu_solve_kernel`, followed by a `break;`. Then I run the following script, changing the variable `name`. For the `lu_solve unpack+solve_triangular`, I also changed the `stmt` variable (uncomenting the commented one)
<details>
<summary>
Benchmarking script
</summary>

```python
import torch
import pickle
import itertools
from functools import partial
from torch.utils.benchmark import Timer, Compare

name = "heuristic"
label = "lu_solve {}".format(name)
shapes = [1, 2, 8, 16, 32, 64, 128, 256]
batches = [(1,), (2,), (4,), (8,), (16,), (32,), (64,), (128,), (512,), (1024,)]
results = []
make_arg = partial(torch.randn, dtype=torch.float32, device="cuda")


def f(LU, pivots, B, adjoint):
    P, L, U = torch.lu_unpack(LU, pivots)
    if adjoint:
        X = torch.linalg.solve_triangular(U.mH, B, upper=False)
        return P @ torch.linalg.solve_triangular(L.mH, X, upper=True, unitriangular=True, out=X)
    else:
        X = P.mT @ B
        X = torch.linalg.solve_triangular(L, X, upper=False, unitriangular=True, out=X)
        return torch.linalg.solve_triangular(U, X, upper=True, out=X)


for n, batch in itertools.product(shapes, batches):
    LU, pivots = torch.linalg.lu_factor(make_arg(batch + (n, n)))
    B = make_arg(batch + (n, 1))
    print(LU.shape)
    stmt = "torch.linalg.lu_solve(LU, pivots, B, adjoint=adjoint)"
    #stmt = "f(LU, pivots, B, adjoint=adjoint)"
    for adjoint in (True, False):
        timer = Timer(stmt,
                      globals=globals(),
                      label="linalg.lu_solve CUDA{}".format(" Adjoint" if adjoint else ""),
                      description=label,
                      sub_label=f"shape {LU.shape}",
                      num_threads=1)
        results.append(timer.blocked_autorange())


compare = Compare(results)
compare.trim_significant_figures()
compare.print()

with open("{}_lu_solve.pickle".format(name), 'wb') as f:
    pickle.dump(results, f)
```
</details>

Finally, I joined all the results with the following script:
<details>
<summary>
Script to join the results
</summary>

```python
import pickle
from torch.utils.benchmark import Timer, Compare

files = [
         "looped_magma",
         "looped cusolver",
         "batched cublas",
         "batched magma",
         "unpack+solve_triangular",
         "heuristic",
        ]

timers = []
for name in files:
    with open("{}_lu_solve.pickle".format(name), 'rb') as f:
        timers += pickle.load(f)

compare = Compare(timers)
compare.trim_significant_figures()
compare.print()
```
</details>

### Fix for Magma's batched lu_solve when `adjoint=True`
I also developed the following fix around MAGMA's bug, but I ended up not using it, and preferring the triangular solves over it, as they were faster. I'm leaving it here in case it's useful in the future.

<details>
<summary>
Fix for MAGMA's issue with `adjoint=True`
</summary>

```cpp
auto lu_solve_batched_magma_fn = [m](const Tensor& LU, const Tensor& pivots, const Tensor& B, TransposeType trans) {
    if (trans == TransposeType::NoTranspose) {
      lu_solve_batched_magma(LU, pivots, B, trans);
      return;
    }
    // There's a bug in magma for the other cases, so we need to properly perform mT or mH on LU
    // The LU of the transpose is not the transpose of the LU
    // We need to do LU = LDU' = L'U' where L' = LD, U' = D^{-1}U and D = diag(U)
    auto diag = LU.diagonal(0, -2, -1);
    auto LU_f = LU.tril(-1).mul_(diag.unsqueeze(-2)) +
                LU.triu(1).div_(diag.unsqueeze(-1));
    LU_f.diagonal(0, -2, -1).copy_(diag);

    if (trans == TransposeType::ConjTranspose) {
      LU_f = LU_f.conj_physical();
    }
    LU_f.transpose(-2, -1);
    // At this point LU_f is F-contiguous, because triu / tril / conj_phisical return contiguous tensors
    TORCH_INTERNAL_ASSERT_DEBUG_ONLY(LU_f.mT().is_contiguous());

    // Trivial permutation
    auto pivots_aux = at::arange(1, m + 1, pivots.options()).expand_as(pivots).contiguous();

    lu_solve_batched_magma(LU_f, pivots_aux, B, TransposeType::NoTranspose);

    // We then need to multiply B by P on the right as (PLU)^T = B iff U^TL^T = BP
    // Fill `perm` with the identity permutation (perhaps batched)
    // This is faster than torch.lu_unpack + matmul, as this logic is borrowed from lu_unpack
    const auto perm = at::arange(m, pivots.options().dtype(kLong)).expand(pivots.sizes()).contiguous();
    auto iter = TensorIteratorConfig()
      .set_check_mem_overlap(false)
      .check_all_same_dtype(false)
      .resize_outputs(false)
      .declare_static_shape(pivots.sizes(), /*squash_dim=*/pivots.dim() - 1)
      .add_output(perm)
      .add_input(pivots)
      .build();
    unpack_pivots_stub(pivots.device().type(), iter, m);
    B.scatter_(-2, perm.unsqueeze(-1).expand_as(B), B.clone());
};
```

</details>



Fixes #61657

[ghstack-poisoned]
This PR adds `linalg.lu_solve`. While doing so, I found a bug in MAGMA
when calling the batched MAGMA backend with trans=True. We work around
that by solving the system solving two triangular systems.

We also update the heuristics for this function, as they were fairly
outdated. We found that cuSolver is king, so luckily we do not need to
rely on the buggy backend from magma for this function.

We added tests testing this function left and right. We also added tests
for the different backends. We also activated the tests for AMD, as
those should work as well.

### Benchmarking

<details>
<summary>
Benchmark Results (adjoint=False)
</summary>

```
--------------------------------------------------------------------------------------------- linalg.lu_solve CUDA ---------------------------------------------------------------------------------------------]   
                                          |  lu_solve looped_magma  |  lu_solve looped cusolver  |  lu_solve batched cublas  |  lu_solve batched magma  |  lu_solve unpack+solve_triangular  |  lu_solve heuristic   
1 threads: -------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------   
      shape torch.Size([1, 1, 1])         |             750         |              34            |              28           |            252           |                  86                |           27          
      shape torch.Size([2, 1, 1])         |            1500         |              50            |              28           |            239           |                  94                |           27          
      shape torch.Size([4, 1, 1])         |            2995         |              83            |              28           |            239           |                 100                |           27          
      shape torch.Size([8, 1, 1])         |            6000         |             146            |              28           |            239           |                  94                |           27          
      shape torch.Size([16, 1, 1])        |           11900         |             272            |              28           |            241           |                  95                |           27          
      shape torch.Size([32, 1, 1])        |           23880         |             524            |              28           |            244           |                  94                |           27          
      shape torch.Size([64, 1, 1])        |           48000         |            1000            |              28           |            245           |                  99                |           27          
      shape torch.Size([128, 1, 1])       |           96000         |            2054            |              28           |            242           |                  96                |           27          
      shape torch.Size([512, 1, 1])       |          381900         |            8100            |              28           |            250           |                  94                |           27          
      shape torch.Size([1024, 1, 1])      |          763800         |           16200            |              28           |            257           |                  95                |           27          
      shape torch.Size([1, 2, 2])         |             750         |              33            |              28           |            240           |                  88                |           27          
      shape torch.Size([2, 2, 2])         |            1500         |              51            |              28           |            240           |                  96                |           27          
      shape torch.Size([4, 2, 2])         |            2991         |              82            |              28           |            241           |                  96                |           28          
      shape torch.Size([8, 2, 2])         |            6000         |             150            |              28           |            241           |                  96                |           27          
      shape torch.Size([16, 2, 2])        |           12000         |             275            |              28           |            242           |                  96                |           27          
      shape torch.Size([32, 2, 2])        |           23980         |             530            |              28           |            246           |                  97                |           28          
      shape torch.Size([64, 2, 2])        |           48000         |            1000            |              28           |            244           |                  96                |           27          
      shape torch.Size([128, 2, 2])       |           96000         |            2063            |              28           |            245           |                  96                |           28          
      shape torch.Size([512, 2, 2])       |          382000         |            8300            |              28           |            257           |                  97                |           28          
      shape torch.Size([1024, 2, 2])      |          764000         |           20000            |              28           |            271           |                  97                |           28          
      shape torch.Size([1, 8, 8])         |             749         |              34            |              28           |            243           |                  88                |           28          
      shape torch.Size([2, 8, 8])         |            1500         |              50            |              28           |            244           |                  97                |           28          
      shape torch.Size([4, 8, 8])         |            2988         |              83            |              28           |            244           |                 100                |           28          
      shape torch.Size([8, 8, 8])         |            5980         |             150            |              28           |            245           |                  97                |           28          
      shape torch.Size([16, 8, 8])        |           12000         |             278            |              28           |            246           |                  96                |           28          
      shape torch.Size([32, 8, 8])        |           23910         |             536            |              28           |            249           |                  98                |           28          
      shape torch.Size([64, 8, 8])        |           47800         |            1100            |              28           |            247           |                  96                |           28          
      shape torch.Size([128, 8, 8])       |           96000         |            2075            |              28           |            248           |                  96                |           28          
      shape torch.Size([512, 8, 8])       |          382100         |            8300            |              28           |            270           |                  97                |           28          
      shape torch.Size([1024, 8, 8])      |          764100         |           16400            |              28           |            291           |                 100                |           28          
      shape torch.Size([1, 16, 16])       |             750         |              33            |              28           |            248           |                  88                |           28
      shape torch.Size([2, 16, 16])       |            1500         |              50            |              28           |            250           |                  97                |           28       
      shape torch.Size([4, 16, 16])       |            2996         |              83            |              28           |            250           |                 100                |           28       
      shape torch.Size([8, 16, 16])       |            5980         |             147            |              28           |            251           |                  97                |           28       
      shape torch.Size([16, 16, 16])      |           11900         |             274            |              28           |            251           |                  97                |           28       
      shape torch.Size([32, 16, 16])      |           24040         |             527            |              28           |            252           |                  97                |           28       
      shape torch.Size([64, 16, 16])      |           47800         |            1037            |              28           |            251           |                 100                |           28       
      shape torch.Size([128, 16, 16])     |           95600         |            2044            |              28           |            252           |                  98                |           28       
      shape torch.Size([512, 16, 16])     |          388200         |            8100            |              28           |            280           |                 100                |           28       
      shape torch.Size([1024, 16, 16])    |          769700         |           16000            |              28           |            322           |                 117                |           28       
      shape torch.Size([1, 32, 32])       |             760         |              33            |              28           |            255           |                  88                |           28       
      shape torch.Size([2, 32, 32])       |            1510         |              50            |              28           |            256           |                  97                |           28       
      shape torch.Size([4, 32, 32])       |            3022         |              82            |              31           |            256           |                  97                |           30       
      shape torch.Size([8, 32, 32])       |            6000         |             140            |              31           |            257           |                 100                |           31       
      shape torch.Size([16, 32, 32])      |           12000         |             281            |              31           |            258           |                  96                |           31       
      shape torch.Size([32, 32, 32])      |           24150         |             563            |              35           |            258           |                  96                |           35       
      shape torch.Size([64, 32, 32])      |           48300         |            1119            |              36           |            258           |                  97                |           36       
      shape torch.Size([128, 32, 32])     |           96500         |            2235            |              43           |            261           |                  96                |           43       
      shape torch.Size([512, 32, 32])     |          383100         |            8930            |              82           |            317           |                 191                |           82       
      shape torch.Size([1024, 32, 32])    |          766300         |           19200            |             122           |            400           |                 312                |          122       
      shape torch.Size([1, 64, 64])       |             760         |              33            |              55           |            272           |                  68                |           34       
      shape torch.Size([2, 64, 64])       |            1500         |              52            |              58           |            273           |                  85                |           52       
      shape torch.Size([4, 64, 64])       |            3127         |             102            |              65           |            273           |                 150                |           65       
      shape torch.Size([8, 64, 64])       |            6070         |             201            |              65           |            275           |                 278                |           65       
      shape torch.Size([16, 64, 64])      |           12000         |             399            |              66           |            274           |                  95                |           67       
      shape torch.Size([32, 64, 64])      |           23900         |             796            |              73           |            275           |                  97                |           73       
      shape torch.Size([64, 64, 64])      |           48000         |            1594            |              75           |            283           |                 123                |           76       
      shape torch.Size([128, 64, 64])     |           95000         |            3177            |              96           |            292           |                 176                |           96       
      shape torch.Size([512, 64, 64])     |          379300         |           13520            |             208           |            426           |                 551                |          208       
      shape torch.Size([1024, 64, 64])    |          758700         |           27100            |             306           |            570           |                 919                |          306       
      shape torch.Size([1, 128, 128])     |             750         |              42            |             115           |            306           |                  90                |           42       
      shape torch.Size([2, 128, 128])     |            1500         |              82            |             122           |            307           |                 164                |           83       
      shape torch.Size([4, 128, 128])     |            2966         |             162            |             136           |            307           |                 301                |          136       
      shape torch.Size([8, 128, 128])     |            5930         |             317            |             137           |            308           |                 578                |          138       
      shape torch.Size([16, 128, 128])    |           12000         |             635            |             143           |            316           |                 199                |          143       
      shape torch.Size([32, 128, 128])    |           23700         |            1266            |             152           |            322           |                 241                |          152       
      shape torch.Size([64, 128, 128])    |           48000         |            2668            |             177           |            337           |                 322                |          177       
      shape torch.Size([128, 128, 128])   |           96000         |            5366            |             228           |            365           |                 514                |          228       
      shape torch.Size([512, 128, 128])   |          379400         |           21490            |             502           |            620           |                1697                |          502       
      shape torch.Size([1024, 128, 128])  |          755700         |           43040            |             764           |            903           |                3040                |          770       
      shape torch.Size([1, 256, 256])     |             750         |              70            |             235           |            383           |                 178                |           72       
      shape torch.Size([2, 256, 256])     |            2000         |             138            |             250           |            384           |                 329                |          139       
      shape torch.Size([4, 256, 256])     |            2988         |             277            |             279           |            404           |                 655                |          278       
      shape torch.Size([8, 256, 256])     |            6100         |             546            |             283           |            420           |                1321                |          286       
      shape torch.Size([16, 256, 256])    |           12100         |            1149            |             330           |            441           |                 472                |          330       
      shape torch.Size([32, 256, 256])    |           24040         |            2303            |             359           |            453           |                 634                |          360       
      shape torch.Size([64, 256, 256])    |           48000         |            4626            |             408           |            472           |                 925                |          408       
      shape torch.Size([128, 256, 256])   |           94700         |            9247            |             543           |            543           |                1582                |          543       
      shape torch.Size([512, 256, 256])   |          372000         |           37030            |            1310           |           1185           |                5711                |         1310       
      shape torch.Size([1024, 256, 256])  |          747200         |           74100            |            2116           |           1910           |               10660                |         2122       
```
</details>

<details>
<summary>
Benchmark Results (adjoint=True)
</summary>

```
[----------------------------------------------------------------------------------------- linalg.lu_solve CUDA Adjoint -----------------------------------------------------------------------------------------]   
                                          |  lu_solve looped_magma  |  lu_solve looped cusolver  |  lu_solve batched cublas  |  lu_solve batched magma  |  lu_solve unpack+solve_triangular  |  lu_solve heuristic   
1 threads: -------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------   
      shape torch.Size([1, 1, 1])         |             749         |              34            |              28           |             33           |                  98                |           27          
      shape torch.Size([2, 1, 1])         |            1500         |              50            |              28           |             50           |                 110                |           27          
      shape torch.Size([4, 1, 1])         |            3005         |              82            |              28           |             81           |                 110                |           27          
      shape torch.Size([8, 1, 1])         |            5999         |             145            |              28           |            140           |                 110                |           27          
      shape torch.Size([16, 1, 1])        |           12000         |             273            |              28           |             77           |                 110                |           27          
      shape torch.Size([32, 1, 1])        |           24000         |             522            |              28           |             78           |                 110                |           27          
      shape torch.Size([64, 1, 1])        |           48000         |            1000            |              28           |             77           |                 100                |           27          
      shape torch.Size([128, 1, 1])       |           96000         |            2029            |              28           |             78           |                 110                |           27          
      shape torch.Size([512, 1, 1])       |          383300         |            8100            |              28           |             78           |                 110                |           28          
      shape torch.Size([1024, 1, 1])      |          767500         |           16100            |              28           |             77           |                 100                |           27          
      shape torch.Size([1, 2, 2])         |             753         |              33            |              28           |             33           |                  99                |           28          
      shape torch.Size([2, 2, 2])         |            1500         |              50            |              28           |             50           |                 110                |           28          
      shape torch.Size([4, 2, 2])         |            3002         |              82            |              28           |             80           |                 100                |           27          
      shape torch.Size([8, 2, 2])         |            6000         |             145            |              28           |            144           |                 107                |           27          
      shape torch.Size([16, 2, 2])        |           12000         |             271            |              28           |             78           |                 110                |           27          
      shape torch.Size([32, 2, 2])        |           24120         |             524            |              28           |             78           |                 110                |           28          
      shape torch.Size([64, 2, 2])        |           48300         |            1030            |              28           |             78           |                 111                |           27          
      shape torch.Size([128, 2, 2])       |           96100         |            2041            |              28           |             78           |                 107                |           28          
      shape torch.Size([512, 2, 2])       |          383000         |            8100            |              28           |             79           |                 108                |           28          
      shape torch.Size([1024, 2, 2])      |          766100         |           16000            |              28           |             78           |                 110                |           28          
      shape torch.Size([1, 8, 8])         |             750         |              34            |              28           |             34           |                  99                |           28          
      shape torch.Size([2, 8, 8])         |            1500         |              50            |              28           |             50           |                 107                |           28          
      shape torch.Size([4, 8, 8])         |            2998         |              82            |              28           |             82           |                 110                |           28          
      shape torch.Size([8, 8, 8])         |            5990         |             146            |              28           |            150           |                 107                |           28          
      shape torch.Size([16, 8, 8])        |           11980         |             272            |              28           |             79           |                 107                |           28          
      shape torch.Size([32, 8, 8])        |           23970         |             530            |              28           |             79           |                 110                |           28          
      shape torch.Size([64, 8, 8])        |           47900         |            1040            |              28           |             79           |                 108                |           28          
      shape torch.Size([128, 8, 8])       |           96000         |            2048            |              28           |             78           |                 108                |           28          
      shape torch.Size([512, 8, 8])       |          383700         |            8100            |              28           |             80           |                 108                |           28          
      shape torch.Size([1024, 8, 8])      |          766200         |           16300            |              28           |             80           |                 108                |           28          
      shape torch.Size([1, 16, 16])       |             760         |              33            |              28           |             34           |                  99                |           28          
      shape torch.Size([2, 16, 16])       |            1500         |              50            |              28           |             50           |                 110                |           28  [85/469]
      shape torch.Size([4, 16, 16])       |            3001         |              81            |              28           |             82           |                 108                |           28          
      shape torch.Size([8, 16, 16])       |            6000         |             145            |              28           |            140           |                 110                |           28          
      shape torch.Size([16, 16, 16])      |           12000         |             276            |              28           |             79           |                 110                |           28          
      shape torch.Size([32, 16, 16])      |           23870         |             549            |              28           |             79           |                 110                |           28          
      shape torch.Size([64, 16, 16])      |           47900         |            1098            |              29           |             80           |                 100                |           28          
      shape torch.Size([128, 16, 16])     |           95800         |            2184            |              28           |             79           |                 108                |           28          
      shape torch.Size([512, 16, 16])     |          386900         |            8769            |              28           |             80           |                 108                |           28          
      shape torch.Size([1024, 16, 16])    |          769800         |           17460            |              37           |             80           |                 107                |           37          
      shape torch.Size([1, 32, 32])       |             760         |              33            |              28           |             34           |                  99                |           28          
      shape torch.Size([2, 32, 32])       |            1500         |              50            |              28           |             50           |                 110                |           29          
      shape torch.Size([4, 32, 32])       |            3021         |              86            |              31           |             84           |                 110                |           32          
      shape torch.Size([8, 32, 32])       |            6040         |             167            |              32           |            167           |                 108                |           32          
      shape torch.Size([16, 32, 32])      |           12100         |             330            |              33           |             78           |                 107                |           33          
      shape torch.Size([32, 32, 32])      |           24150         |             662            |              35           |             78           |                 110                |           35          
      shape torch.Size([64, 32, 32])      |           48200         |            1323            |              36           |             79           |                 110                |           36          
      shape torch.Size([128, 32, 32])     |           97000         |            2637            |              44           |             79           |                 110                |           43          
      shape torch.Size([512, 32, 32])     |          382500         |           10580            |              83           |            180           |                 198                |           83          
      shape torch.Size([1024, 32, 32])    |          766600         |           22670            |             123           |            260           |                 318                |          120          
      shape torch.Size([1, 64, 64])       |             760         |              33            |              58           |             33           |                  88                |           34          
      shape torch.Size([2, 64, 64])       |            1520         |              60            |              60           |             60           |                 109                |           59          
      shape torch.Size([4, 64, 64])       |            3016         |             115            |              67           |            119           |                 132                |           66          
      shape torch.Size([8, 64, 64])       |            6120         |             230            |              67           |            233           |                 180                |           68          
      shape torch.Size([16, 64, 64])      |           12100         |             457            |              69           |             86           |                 107                |           69          
      shape torch.Size([32, 64, 64])      |           24000         |             912            |              74           |             95           |                 100                |           74          
      shape torch.Size([64, 64, 64])      |           48000         |            1833            |              76           |            106           |                 123                |           76          
      shape torch.Size([128, 64, 64])     |           95000         |            3636            |              97           |            163           |                 183                |           97          
      shape torch.Size([512, 64, 64])     |          380600         |           15600            |             210           |            464           |                 549                |          210          
      shape torch.Size([1024, 64, 64])    |          761200         |           31140            |             308           |            741           |                 918                |          308          
      shape torch.Size([1, 128, 128])     |             756         |              46            |             120           |             47           |                  89                |           46          
      shape torch.Size([2, 128, 128])     |            1500         |              91            |             123           |             89           |                 110                |           92          
      shape torch.Size([4, 128, 128])     |            2994         |             178            |             139           |            180           |                 131                |          139          
      shape torch.Size([8, 128, 128])     |            5960         |             350            |             140           |            354           |                 208                |          142          
      shape torch.Size([16, 128, 128])    |           12000         |             701            |             144           |            177           |                 198                |          143          
      shape torch.Size([32, 128, 128])    |           23870         |            1401            |             155           |            225           |                 246                |          155          
      shape torch.Size([64, 128, 128])    |           47600         |            2948            |             179           |            288           |                 323                |          180          
      shape torch.Size([128, 128, 128])   |           96000         |            5910            |             231           |            442           |                 512                |          231          
      shape torch.Size([512, 128, 128])   |          381200         |           23640            |             519           |           1400           |                1700                |          519          
      shape torch.Size([1024, 128, 128])  |          755800         |           47340            |             794           |           2436           |                3018                |          794          
      shape torch.Size([1, 256, 256])     |             760         |              74            |             246           |             77           |                  88                |           78          
      shape torch.Size([2, 256, 256])     |            1510         |             150            |             256           |            150           |                 117                |          150          
      shape torch.Size([4, 256, 256])     |            3030         |             296            |             284           |            296           |                 209                |          284          
      shape torch.Size([8, 256, 256])     |            6100         |             588            |             286           |            592           |                 394                |          288          
      shape torch.Size([16, 256, 256])    |           12200         |            1238            |             330           |            445           |                 480                |          330          
      shape torch.Size([32, 256, 256])    |           24430         |            2476            |             368           |            568           |                 629                |          367          
      shape torch.Size([64, 256, 256])    |           49000         |            4950            |             415           |            800           |                 921                |          414          
      shape torch.Size([128, 256, 256])   |           96000         |            9900            |             552           |           1330           |                1579                |          553          
      shape torch.Size([512, 256, 256])   |          369400         |           39580            |            1410           |           4614           |                5616                |         1410          
      shape torch.Size([1024, 256, 256])  |          716200         |           79200            |            2270           |           8472           |               10500                |         2277          

Times are in microseconds (us).
```

</details>

To generate the results below, I put the backend I wanted to test at the beginning of the function `lu_solve_kernel`, followed by a `break;`. Then I run the following script, changing the variable `name`. For the `lu_solve unpack+solve_triangular`, I also changed the `stmt` variable (uncomenting the commented one)
<details>
<summary>
Benchmarking script
</summary>

```python
import torch
import pickle
import itertools
from functools import partial
from torch.utils.benchmark import Timer, Compare

name = "heuristic"
label = "lu_solve {}".format(name)
shapes = [1, 2, 8, 16, 32, 64, 128, 256]
batches = [(1,), (2,), (4,), (8,), (16,), (32,), (64,), (128,), (512,), (1024,)]
results = []
make_arg = partial(torch.randn, dtype=torch.float32, device="cuda")


def f(LU, pivots, B, adjoint):
    P, L, U = torch.lu_unpack(LU, pivots)
    if adjoint:
        X = torch.linalg.solve_triangular(U.mH, B, upper=False)
        return P @ torch.linalg.solve_triangular(L.mH, X, upper=True, unitriangular=True, out=X)
    else:
        X = P.mT @ B
        X = torch.linalg.solve_triangular(L, X, upper=False, unitriangular=True, out=X)
        return torch.linalg.solve_triangular(U, X, upper=True, out=X)


for n, batch in itertools.product(shapes, batches):
    LU, pivots = torch.linalg.lu_factor(make_arg(batch + (n, n)))
    B = make_arg(batch + (n, 1))
    print(LU.shape)
    stmt = "torch.linalg.lu_solve(LU, pivots, B, adjoint=adjoint)"
    #stmt = "f(LU, pivots, B, adjoint=adjoint)"
    for adjoint in (True, False):
        timer = Timer(stmt,
                      globals=globals(),
                      label="linalg.lu_solve CUDA{}".format(" Adjoint" if adjoint else ""),
                      description=label,
                      sub_label=f"shape {LU.shape}",
                      num_threads=1)
        results.append(timer.blocked_autorange())


compare = Compare(results)
compare.trim_significant_figures()
compare.print()

with open("{}_lu_solve.pickle".format(name), 'wb') as f:
    pickle.dump(results, f)
```
</details>

Finally, I joined all the results with the following script:
<details>
<summary>
Script to join the results
</summary>

```python
import pickle
from torch.utils.benchmark import Timer, Compare

files = [
         "looped_magma",
         "looped cusolver",
         "batched cublas",
         "batched magma",
         "unpack+solve_triangular",
         "heuristic",
        ]

timers = []
for name in files:
    with open("{}_lu_solve.pickle".format(name), 'rb') as f:
        timers += pickle.load(f)

compare = Compare(timers)
compare.trim_significant_figures()
compare.print()
```
</details>

### Fix for Magma's batched lu_solve when `adjoint=True`
I also developed the following fix around MAGMA's bug, but I ended up not using it, and preferring the triangular solves over it, as they were faster. I'm leaving it here in case it's useful in the future.

<details>
<summary>
Fix for MAGMA's issue with `adjoint=True`
</summary>

```cpp
auto lu_solve_batched_magma_fn = [m](const Tensor& LU, const Tensor& pivots, const Tensor& B, TransposeType trans) {
    if (trans == TransposeType::NoTranspose) {
      lu_solve_batched_magma(LU, pivots, B, trans);
      return;
    }
    // There's a bug in magma for the other cases, so we need to properly perform mT or mH on LU
    // The LU of the transpose is not the transpose of the LU
    // We need to do LU = LDU' = L'U' where L' = LD, U' = D^{-1}U and D = diag(U)
    auto diag = LU.diagonal(0, -2, -1);
    auto LU_f = LU.tril(-1).mul_(diag.unsqueeze(-2)) +
                LU.triu(1).div_(diag.unsqueeze(-1));
    LU_f.diagonal(0, -2, -1).copy_(diag);

    if (trans == TransposeType::ConjTranspose) {
      LU_f = LU_f.conj_physical();
    }
    LU_f.transpose(-2, -1);
    // At this point LU_f is F-contiguous, because triu / tril / conj_phisical return contiguous tensors
    TORCH_INTERNAL_ASSERT_DEBUG_ONLY(LU_f.mT().is_contiguous());

    // Trivial permutation
    auto pivots_aux = at::arange(1, m + 1, pivots.options()).expand_as(pivots).contiguous();

    lu_solve_batched_magma(LU_f, pivots_aux, B, TransposeType::NoTranspose);

    // We then need to multiply B by P on the right as (PLU)^T = B iff U^TL^T = BP
    // Fill `perm` with the identity permutation (perhaps batched)
    // This is faster than torch.lu_unpack + matmul, as this logic is borrowed from lu_unpack
    const auto perm = at::arange(m, pivots.options().dtype(kLong)).expand(pivots.sizes()).contiguous();
    auto iter = TensorIteratorConfig()
      .set_check_mem_overlap(false)
      .check_all_same_dtype(false)
      .resize_outputs(false)
      .declare_static_shape(pivots.sizes(), /*squash_dim=*/pivots.dim() - 1)
      .add_output(perm)
      .add_input(pivots)
      .build();
    unpack_pivots_stub(pivots.device().type(), iter, m);
    B.scatter_(-2, perm.unsqueeze(-1).expand_as(B), B.clone());
};
```

</details>



Fixes #61657

[ghstack-poisoned]
lezcano added a commit that referenced this pull request Mar 4, 2022
This PR adds `linalg.lu_solve`. While doing so, I found a bug in MAGMA
when calling the batched MAGMA backend with trans=True. We work around
that by solving the system solving two triangular systems.

We also update the heuristics for this function, as they were fairly
updated. We found that cuSolver is king, so luckily we do not need to
rely on the buggy backend from magma for this function.

We added tests testing this function left and right. We also added tests
for the different backends. We also activated the tests for AMD, as
those should work as well.

Fixes #61657

ghstack-source-id: 7076c75
Pull Request resolved: #72935
@lezcano lezcano marked this pull request as ready for review March 4, 2022 19:32
@lezcano lezcano requested review from albanD and soulitzer as code owners March 4, 2022 19:32
lezcano added 5 commits March 4, 2022 19:37
This PR adds `linalg.lu_solve`. While doing so, I found a bug in MAGMA
when calling the batched MAGMA backend with trans=True. We work around
that by solving the system solving two triangular systems.

We also update the heuristics for this function, as they were fairly
outdated. We found that cuSolver is king, so luckily we do not need to
rely on the buggy backend from magma for this function.

We added tests testing this function left and right. We also added tests
for the different backends. We also activated the tests for AMD, as
those should work as well.

### Benchmarking

<details>
<summary>
Benchmark Results (adjoint=False)
</summary>

```
--------------------------------------------------------------------------------------------- linalg.lu_solve CUDA ---------------------------------------------------------------------------------------------]   
                                          |  lu_solve looped_magma  |  lu_solve looped cusolver  |  lu_solve batched cublas  |  lu_solve batched magma  |  lu_solve unpack+solve_triangular  |  lu_solve heuristic   
1 threads: -------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------   
      shape torch.Size([1, 1, 1])         |             750         |              34            |              28           |            252           |                  86                |           27          
      shape torch.Size([2, 1, 1])         |            1500         |              50            |              28           |            239           |                  94                |           27          
      shape torch.Size([4, 1, 1])         |            2995         |              83            |              28           |            239           |                 100                |           27          
      shape torch.Size([8, 1, 1])         |            6000         |             146            |              28           |            239           |                  94                |           27          
      shape torch.Size([16, 1, 1])        |           11900         |             272            |              28           |            241           |                  95                |           27          
      shape torch.Size([32, 1, 1])        |           23880         |             524            |              28           |            244           |                  94                |           27          
      shape torch.Size([64, 1, 1])        |           48000         |            1000            |              28           |            245           |                  99                |           27          
      shape torch.Size([128, 1, 1])       |           96000         |            2054            |              28           |            242           |                  96                |           27          
      shape torch.Size([512, 1, 1])       |          381900         |            8100            |              28           |            250           |                  94                |           27          
      shape torch.Size([1024, 1, 1])      |          763800         |           16200            |              28           |            257           |                  95                |           27          
      shape torch.Size([1, 2, 2])         |             750         |              33            |              28           |            240           |                  88                |           27          
      shape torch.Size([2, 2, 2])         |            1500         |              51            |              28           |            240           |                  96                |           27          
      shape torch.Size([4, 2, 2])         |            2991         |              82            |              28           |            241           |                  96                |           28          
      shape torch.Size([8, 2, 2])         |            6000         |             150            |              28           |            241           |                  96                |           27          
      shape torch.Size([16, 2, 2])        |           12000         |             275            |              28           |            242           |                  96                |           27          
      shape torch.Size([32, 2, 2])        |           23980         |             530            |              28           |            246           |                  97                |           28          
      shape torch.Size([64, 2, 2])        |           48000         |            1000            |              28           |            244           |                  96                |           27          
      shape torch.Size([128, 2, 2])       |           96000         |            2063            |              28           |            245           |                  96                |           28          
      shape torch.Size([512, 2, 2])       |          382000         |            8300            |              28           |            257           |                  97                |           28          
      shape torch.Size([1024, 2, 2])      |          764000         |           20000            |              28           |            271           |                  97                |           28          
      shape torch.Size([1, 8, 8])         |             749         |              34            |              28           |            243           |                  88                |           28          
      shape torch.Size([2, 8, 8])         |            1500         |              50            |              28           |            244           |                  97                |           28          
      shape torch.Size([4, 8, 8])         |            2988         |              83            |              28           |            244           |                 100                |           28          
      shape torch.Size([8, 8, 8])         |            5980         |             150            |              28           |            245           |                  97                |           28          
      shape torch.Size([16, 8, 8])        |           12000         |             278            |              28           |            246           |                  96                |           28          
      shape torch.Size([32, 8, 8])        |           23910         |             536            |              28           |            249           |                  98                |           28          
      shape torch.Size([64, 8, 8])        |           47800         |            1100            |              28           |            247           |                  96                |           28          
      shape torch.Size([128, 8, 8])       |           96000         |            2075            |              28           |            248           |                  96                |           28          
      shape torch.Size([512, 8, 8])       |          382100         |            8300            |              28           |            270           |                  97                |           28          
      shape torch.Size([1024, 8, 8])      |          764100         |           16400            |              28           |            291           |                 100                |           28          
      shape torch.Size([1, 16, 16])       |             750         |              33            |              28           |            248           |                  88                |           28
      shape torch.Size([2, 16, 16])       |            1500         |              50            |              28           |            250           |                  97                |           28       
      shape torch.Size([4, 16, 16])       |            2996         |              83            |              28           |            250           |                 100                |           28       
      shape torch.Size([8, 16, 16])       |            5980         |             147            |              28           |            251           |                  97                |           28       
      shape torch.Size([16, 16, 16])      |           11900         |             274            |              28           |            251           |                  97                |           28       
      shape torch.Size([32, 16, 16])      |           24040         |             527            |              28           |            252           |                  97                |           28       
      shape torch.Size([64, 16, 16])      |           47800         |            1037            |              28           |            251           |                 100                |           28       
      shape torch.Size([128, 16, 16])     |           95600         |            2044            |              28           |            252           |                  98                |           28       
      shape torch.Size([512, 16, 16])     |          388200         |            8100            |              28           |            280           |                 100                |           28       
      shape torch.Size([1024, 16, 16])    |          769700         |           16000            |              28           |            322           |                 117                |           28       
      shape torch.Size([1, 32, 32])       |             760         |              33            |              28           |            255           |                  88                |           28       
      shape torch.Size([2, 32, 32])       |            1510         |              50            |              28           |            256           |                  97                |           28       
      shape torch.Size([4, 32, 32])       |            3022         |              82            |              31           |            256           |                  97                |           30       
      shape torch.Size([8, 32, 32])       |            6000         |             140            |              31           |            257           |                 100                |           31       
      shape torch.Size([16, 32, 32])      |           12000         |             281            |              31           |            258           |                  96                |           31       
      shape torch.Size([32, 32, 32])      |           24150         |             563            |              35           |            258           |                  96                |           35       
      shape torch.Size([64, 32, 32])      |           48300         |            1119            |              36           |            258           |                  97                |           36       
      shape torch.Size([128, 32, 32])     |           96500         |            2235            |              43           |            261           |                  96                |           43       
      shape torch.Size([512, 32, 32])     |          383100         |            8930            |              82           |            317           |                 191                |           82       
      shape torch.Size([1024, 32, 32])    |          766300         |           19200            |             122           |            400           |                 312                |          122       
      shape torch.Size([1, 64, 64])       |             760         |              33            |              55           |            272           |                  68                |           34       
      shape torch.Size([2, 64, 64])       |            1500         |              52            |              58           |            273           |                  85                |           52       
      shape torch.Size([4, 64, 64])       |            3127         |             102            |              65           |            273           |                 150                |           65       
      shape torch.Size([8, 64, 64])       |            6070         |             201            |              65           |            275           |                 278                |           65       
      shape torch.Size([16, 64, 64])      |           12000         |             399            |              66           |            274           |                  95                |           67       
      shape torch.Size([32, 64, 64])      |           23900         |             796            |              73           |            275           |                  97                |           73       
      shape torch.Size([64, 64, 64])      |           48000         |            1594            |              75           |            283           |                 123                |           76       
      shape torch.Size([128, 64, 64])     |           95000         |            3177            |              96           |            292           |                 176                |           96       
      shape torch.Size([512, 64, 64])     |          379300         |           13520            |             208           |            426           |                 551                |          208       
      shape torch.Size([1024, 64, 64])    |          758700         |           27100            |             306           |            570           |                 919                |          306       
      shape torch.Size([1, 128, 128])     |             750         |              42            |             115           |            306           |                  90                |           42       
      shape torch.Size([2, 128, 128])     |            1500         |              82            |             122           |            307           |                 164                |           83       
      shape torch.Size([4, 128, 128])     |            2966         |             162            |             136           |            307           |                 301                |          136       
      shape torch.Size([8, 128, 128])     |            5930         |             317            |             137           |            308           |                 578                |          138       
      shape torch.Size([16, 128, 128])    |           12000         |             635            |             143           |            316           |                 199                |          143       
      shape torch.Size([32, 128, 128])    |           23700         |            1266            |             152           |            322           |                 241                |          152       
      shape torch.Size([64, 128, 128])    |           48000         |            2668            |             177           |            337           |                 322                |          177       
      shape torch.Size([128, 128, 128])   |           96000         |            5366            |             228           |            365           |                 514                |          228       
      shape torch.Size([512, 128, 128])   |          379400         |           21490            |             502           |            620           |                1697                |          502       
      shape torch.Size([1024, 128, 128])  |          755700         |           43040            |             764           |            903           |                3040                |          770       
      shape torch.Size([1, 256, 256])     |             750         |              70            |             235           |            383           |                 178                |           72       
      shape torch.Size([2, 256, 256])     |            2000         |             138            |             250           |            384           |                 329                |          139       
      shape torch.Size([4, 256, 256])     |            2988         |             277            |             279           |            404           |                 655                |          278       
      shape torch.Size([8, 256, 256])     |            6100         |             546            |             283           |            420           |                1321                |          286       
      shape torch.Size([16, 256, 256])    |           12100         |            1149            |             330           |            441           |                 472                |          330       
      shape torch.Size([32, 256, 256])    |           24040         |            2303            |             359           |            453           |                 634                |          360       
      shape torch.Size([64, 256, 256])    |           48000         |            4626            |             408           |            472           |                 925                |          408       
      shape torch.Size([128, 256, 256])   |           94700         |            9247            |             543           |            543           |                1582                |          543       
      shape torch.Size([512, 256, 256])   |          372000         |           37030            |            1310           |           1185           |                5711                |         1310       
      shape torch.Size([1024, 256, 256])  |          747200         |           74100            |            2116           |           1910           |               10660                |         2122       
```
</details>

<details>
<summary>
Benchmark Results (adjoint=True)
</summary>

```
[----------------------------------------------------------------------------------------- linalg.lu_solve CUDA Adjoint -----------------------------------------------------------------------------------------]   
                                          |  lu_solve looped_magma  |  lu_solve looped cusolver  |  lu_solve batched cublas  |  lu_solve batched magma  |  lu_solve unpack+solve_triangular  |  lu_solve heuristic   
1 threads: -------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------   
      shape torch.Size([1, 1, 1])         |             749         |              34            |              28           |             33           |                  98                |           27          
      shape torch.Size([2, 1, 1])         |            1500         |              50            |              28           |             50           |                 110                |           27          
      shape torch.Size([4, 1, 1])         |            3005         |              82            |              28           |             81           |                 110                |           27          
      shape torch.Size([8, 1, 1])         |            5999         |             145            |              28           |            140           |                 110                |           27          
      shape torch.Size([16, 1, 1])        |           12000         |             273            |              28           |             77           |                 110                |           27          
      shape torch.Size([32, 1, 1])        |           24000         |             522            |              28           |             78           |                 110                |           27          
      shape torch.Size([64, 1, 1])        |           48000         |            1000            |              28           |             77           |                 100                |           27          
      shape torch.Size([128, 1, 1])       |           96000         |            2029            |              28           |             78           |                 110                |           27          
      shape torch.Size([512, 1, 1])       |          383300         |            8100            |              28           |             78           |                 110                |           28          
      shape torch.Size([1024, 1, 1])      |          767500         |           16100            |              28           |             77           |                 100                |           27          
      shape torch.Size([1, 2, 2])         |             753         |              33            |              28           |             33           |                  99                |           28          
      shape torch.Size([2, 2, 2])         |            1500         |              50            |              28           |             50           |                 110                |           28          
      shape torch.Size([4, 2, 2])         |            3002         |              82            |              28           |             80           |                 100                |           27          
      shape torch.Size([8, 2, 2])         |            6000         |             145            |              28           |            144           |                 107                |           27          
      shape torch.Size([16, 2, 2])        |           12000         |             271            |              28           |             78           |                 110                |           27          
      shape torch.Size([32, 2, 2])        |           24120         |             524            |              28           |             78           |                 110                |           28          
      shape torch.Size([64, 2, 2])        |           48300         |            1030            |              28           |             78           |                 111                |           27          
      shape torch.Size([128, 2, 2])       |           96100         |            2041            |              28           |             78           |                 107                |           28          
      shape torch.Size([512, 2, 2])       |          383000         |            8100            |              28           |             79           |                 108                |           28          
      shape torch.Size([1024, 2, 2])      |          766100         |           16000            |              28           |             78           |                 110                |           28          
      shape torch.Size([1, 8, 8])         |             750         |              34            |              28           |             34           |                  99                |           28          
      shape torch.Size([2, 8, 8])         |            1500         |              50            |              28           |             50           |                 107                |           28          
      shape torch.Size([4, 8, 8])         |            2998         |              82            |              28           |             82           |                 110                |           28          
      shape torch.Size([8, 8, 8])         |            5990         |             146            |              28           |            150           |                 107                |           28          
      shape torch.Size([16, 8, 8])        |           11980         |             272            |              28           |             79           |                 107                |           28          
      shape torch.Size([32, 8, 8])        |           23970         |             530            |              28           |             79           |                 110                |           28          
      shape torch.Size([64, 8, 8])        |           47900         |            1040            |              28           |             79           |                 108                |           28          
      shape torch.Size([128, 8, 8])       |           96000         |            2048            |              28           |             78           |                 108                |           28          
      shape torch.Size([512, 8, 8])       |          383700         |            8100            |              28           |             80           |                 108                |           28          
      shape torch.Size([1024, 8, 8])      |          766200         |           16300            |              28           |             80           |                 108                |           28          
      shape torch.Size([1, 16, 16])       |             760         |              33            |              28           |             34           |                  99                |           28          
      shape torch.Size([2, 16, 16])       |            1500         |              50            |              28           |             50           |                 110                |           28  [85/469]
      shape torch.Size([4, 16, 16])       |            3001         |              81            |              28           |             82           |                 108                |           28          
      shape torch.Size([8, 16, 16])       |            6000         |             145            |              28           |            140           |                 110                |           28          
      shape torch.Size([16, 16, 16])      |           12000         |             276            |              28           |             79           |                 110                |           28          
      shape torch.Size([32, 16, 16])      |           23870         |             549            |              28           |             79           |                 110                |           28          
      shape torch.Size([64, 16, 16])      |           47900         |            1098            |              29           |             80           |                 100                |           28          
      shape torch.Size([128, 16, 16])     |           95800         |            2184            |              28           |             79           |                 108                |           28          
      shape torch.Size([512, 16, 16])     |          386900         |            8769            |              28           |             80           |                 108                |           28          
      shape torch.Size([1024, 16, 16])    |          769800         |           17460            |              37           |             80           |                 107                |           37          
      shape torch.Size([1, 32, 32])       |             760         |              33            |              28           |             34           |                  99                |           28          
      shape torch.Size([2, 32, 32])       |            1500         |              50            |              28           |             50           |                 110                |           29          
      shape torch.Size([4, 32, 32])       |            3021         |              86            |              31           |             84           |                 110                |           32          
      shape torch.Size([8, 32, 32])       |            6040         |             167            |              32           |            167           |                 108                |           32          
      shape torch.Size([16, 32, 32])      |           12100         |             330            |              33           |             78           |                 107                |           33          
      shape torch.Size([32, 32, 32])      |           24150         |             662            |              35           |             78           |                 110                |           35          
      shape torch.Size([64, 32, 32])      |           48200         |            1323            |              36           |             79           |                 110                |           36          
      shape torch.Size([128, 32, 32])     |           97000         |            2637            |              44           |             79           |                 110                |           43          
      shape torch.Size([512, 32, 32])     |          382500         |           10580            |              83           |            180           |                 198                |           83          
      shape torch.Size([1024, 32, 32])    |          766600         |           22670            |             123           |            260           |                 318                |          120          
      shape torch.Size([1, 64, 64])       |             760         |              33            |              58           |             33           |                  88                |           34          
      shape torch.Size([2, 64, 64])       |            1520         |              60            |              60           |             60           |                 109                |           59          
      shape torch.Size([4, 64, 64])       |            3016         |             115            |              67           |            119           |                 132                |           66          
      shape torch.Size([8, 64, 64])       |            6120         |             230            |              67           |            233           |                 180                |           68          
      shape torch.Size([16, 64, 64])      |           12100         |             457            |              69           |             86           |                 107                |           69          
      shape torch.Size([32, 64, 64])      |           24000         |             912            |              74           |             95           |                 100                |           74          
      shape torch.Size([64, 64, 64])      |           48000         |            1833            |              76           |            106           |                 123                |           76          
      shape torch.Size([128, 64, 64])     |           95000         |            3636            |              97           |            163           |                 183                |           97          
      shape torch.Size([512, 64, 64])     |          380600         |           15600            |             210           |            464           |                 549                |          210          
      shape torch.Size([1024, 64, 64])    |          761200         |           31140            |             308           |            741           |                 918                |          308          
      shape torch.Size([1, 128, 128])     |             756         |              46            |             120           |             47           |                  89                |           46          
      shape torch.Size([2, 128, 128])     |            1500         |              91            |             123           |             89           |                 110                |           92          
      shape torch.Size([4, 128, 128])     |            2994         |             178            |             139           |            180           |                 131                |          139          
      shape torch.Size([8, 128, 128])     |            5960         |             350            |             140           |            354           |                 208                |          142          
      shape torch.Size([16, 128, 128])    |           12000         |             701            |             144           |            177           |                 198                |          143          
      shape torch.Size([32, 128, 128])    |           23870         |            1401            |             155           |            225           |                 246                |          155          
      shape torch.Size([64, 128, 128])    |           47600         |            2948            |             179           |            288           |                 323                |          180          
      shape torch.Size([128, 128, 128])   |           96000         |            5910            |             231           |            442           |                 512                |          231          
      shape torch.Size([512, 128, 128])   |          381200         |           23640            |             519           |           1400           |                1700                |          519          
      shape torch.Size([1024, 128, 128])  |          755800         |           47340            |             794           |           2436           |                3018                |          794          
      shape torch.Size([1, 256, 256])     |             760         |              74            |             246           |             77           |                  88                |           78          
      shape torch.Size([2, 256, 256])     |            1510         |             150            |             256           |            150           |                 117                |          150          
      shape torch.Size([4, 256, 256])     |            3030         |             296            |             284           |            296           |                 209                |          284          
      shape torch.Size([8, 256, 256])     |            6100         |             588            |             286           |            592           |                 394                |          288          
      shape torch.Size([16, 256, 256])    |           12200         |            1238            |             330           |            445           |                 480                |          330          
      shape torch.Size([32, 256, 256])    |           24430         |            2476            |             368           |            568           |                 629                |          367          
      shape torch.Size([64, 256, 256])    |           49000         |            4950            |             415           |            800           |                 921                |          414          
      shape torch.Size([128, 256, 256])   |           96000         |            9900            |             552           |           1330           |                1579                |          553          
      shape torch.Size([512, 256, 256])   |          369400         |           39580            |            1410           |           4614           |                5616                |         1410          
      shape torch.Size([1024, 256, 256])  |          716200         |           79200            |            2270           |           8472           |               10500                |         2277          

Times are in microseconds (us).
```

</details>

To generate the results below, I put the backend I wanted to test at the beginning of the function `lu_solve_kernel`, followed by a `break;`. Then I run the following script, changing the variable `name`. For the `lu_solve unpack+solve_triangular`, I also changed the `stmt` variable (uncomenting the commented one)
<details>
<summary>
Benchmarking script
</summary>

```python
import torch
import pickle
import itertools
from functools import partial
from torch.utils.benchmark import Timer, Compare

name = "heuristic"
label = "lu_solve {}".format(name)
shapes = [1, 2, 8, 16, 32, 64, 128, 256]
batches = [(1,), (2,), (4,), (8,), (16,), (32,), (64,), (128,), (512,), (1024,)]
results = []
make_arg = partial(torch.randn, dtype=torch.float32, device="cuda")


def f(LU, pivots, B, adjoint):
    P, L, U = torch.lu_unpack(LU, pivots)
    if adjoint:
        X = torch.linalg.solve_triangular(U.mH, B, upper=False)
        return P @ torch.linalg.solve_triangular(L.mH, X, upper=True, unitriangular=True, out=X)
    else:
        X = P.mT @ B
        X = torch.linalg.solve_triangular(L, X, upper=False, unitriangular=True, out=X)
        return torch.linalg.solve_triangular(U, X, upper=True, out=X)


for n, batch in itertools.product(shapes, batches):
    LU, pivots = torch.linalg.lu_factor(make_arg(batch + (n, n)))
    B = make_arg(batch + (n, 1))
    print(LU.shape)
    stmt = "torch.linalg.lu_solve(LU, pivots, B, adjoint=adjoint)"
    #stmt = "f(LU, pivots, B, adjoint=adjoint)"
    for adjoint in (True, False):
        timer = Timer(stmt,
                      globals=globals(),
                      label="linalg.lu_solve CUDA{}".format(" Adjoint" if adjoint else ""),
                      description=label,
                      sub_label=f"shape {LU.shape}",
                      num_threads=1)
        results.append(timer.blocked_autorange())


compare = Compare(results)
compare.trim_significant_figures()
compare.print()

with open("{}_lu_solve.pickle".format(name), 'wb') as f:
    pickle.dump(results, f)
```
</details>

Finally, I joined all the results with the following script:
<details>
<summary>
Script to join the results
</summary>

```python
import pickle
from torch.utils.benchmark import Timer, Compare

files = [
         "looped_magma",
         "looped cusolver",
         "batched cublas",
         "batched magma",
         "unpack+solve_triangular",
         "heuristic",
        ]

timers = []
for name in files:
    with open("{}_lu_solve.pickle".format(name), 'rb') as f:
        timers += pickle.load(f)

compare = Compare(timers)
compare.trim_significant_figures()
compare.print()
```
</details>

### Fix for Magma's batched lu_solve when `adjoint=True`
I also developed the following fix around MAGMA's bug, but I ended up not using it, and preferring the triangular solves over it, as they were faster. I'm leaving it here in case it's useful in the future.

<details>
<summary>
Fix for MAGMA's issue with `adjoint=True`
</summary>

```cpp
auto lu_solve_batched_magma_fn = [m](const Tensor& LU, const Tensor& pivots, const Tensor& B, TransposeType trans) {
    if (trans == TransposeType::NoTranspose) {
      lu_solve_batched_magma(LU, pivots, B, trans);
      return;
    }
    // There's a bug in magma for the other cases, so we need to properly perform mT or mH on LU
    // The LU of the transpose is not the transpose of the LU
    // We need to do LU = LDU' = L'U' where L' = LD, U' = D^{-1}U and D = diag(U)
    auto diag = LU.diagonal(0, -2, -1);
    auto LU_f = LU.tril(-1).mul_(diag.unsqueeze(-2)) +
                LU.triu(1).div_(diag.unsqueeze(-1));
    LU_f.diagonal(0, -2, -1).copy_(diag);

    if (trans == TransposeType::ConjTranspose) {
      LU_f = LU_f.conj_physical();
    }
    LU_f.transpose(-2, -1);
    // At this point LU_f is F-contiguous, because triu / tril / conj_phisical return contiguous tensors
    TORCH_INTERNAL_ASSERT_DEBUG_ONLY(LU_f.mT().is_contiguous());

    // Trivial permutation
    auto pivots_aux = at::arange(1, m + 1, pivots.options()).expand_as(pivots).contiguous();

    lu_solve_batched_magma(LU_f, pivots_aux, B, TransposeType::NoTranspose);

    // We then need to multiply B by P on the right as (PLU)^T = B iff U^TL^T = BP
    // Fill `perm` with the identity permutation (perhaps batched)
    // This is faster than torch.lu_unpack + matmul, as this logic is borrowed from lu_unpack
    const auto perm = at::arange(m, pivots.options().dtype(kLong)).expand(pivots.sizes()).contiguous();
    auto iter = TensorIteratorConfig()
      .set_check_mem_overlap(false)
      .check_all_same_dtype(false)
      .resize_outputs(false)
      .declare_static_shape(pivots.sizes(), /*squash_dim=*/pivots.dim() - 1)
      .add_output(perm)
      .add_input(pivots)
      .build();
    unpack_pivots_stub(pivots.device().type(), iter, m);
    B.scatter_(-2, perm.unsqueeze(-1).expand_as(B), B.clone());
};
```

</details>



Fixes #61657

[ghstack-poisoned]
This PR adds `linalg.lu_solve`. While doing so, I found a bug in MAGMA
when calling the batched MAGMA backend with trans=True. We work around
that by solving the system solving two triangular systems.

We also update the heuristics for this function, as they were fairly
outdated. We found that cuSolver is king, so luckily we do not need to
rely on the buggy backend from magma for this function.

We added tests testing this function left and right. We also added tests
for the different backends. We also activated the tests for AMD, as
those should work as well.

### Benchmarking

<details>
<summary>
Benchmark Results (adjoint=False)
</summary>

```
--------------------------------------------------------------------------------------------- linalg.lu_solve CUDA ---------------------------------------------------------------------------------------------]   
                                          |  lu_solve looped_magma  |  lu_solve looped cusolver  |  lu_solve batched cublas  |  lu_solve batched magma  |  lu_solve unpack+solve_triangular  |  lu_solve heuristic   
1 threads: -------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------   
      shape torch.Size([1, 1, 1])         |             750         |              34            |              28           |            252           |                  86                |           27          
      shape torch.Size([2, 1, 1])         |            1500         |              50            |              28           |            239           |                  94                |           27          
      shape torch.Size([4, 1, 1])         |            2995         |              83            |              28           |            239           |                 100                |           27          
      shape torch.Size([8, 1, 1])         |            6000         |             146            |              28           |            239           |                  94                |           27          
      shape torch.Size([16, 1, 1])        |           11900         |             272            |              28           |            241           |                  95                |           27          
      shape torch.Size([32, 1, 1])        |           23880         |             524            |              28           |            244           |                  94                |           27          
      shape torch.Size([64, 1, 1])        |           48000         |            1000            |              28           |            245           |                  99                |           27          
      shape torch.Size([128, 1, 1])       |           96000         |            2054            |              28           |            242           |                  96                |           27          
      shape torch.Size([512, 1, 1])       |          381900         |            8100            |              28           |            250           |                  94                |           27          
      shape torch.Size([1024, 1, 1])      |          763800         |           16200            |              28           |            257           |                  95                |           27          
      shape torch.Size([1, 2, 2])         |             750         |              33            |              28           |            240           |                  88                |           27          
      shape torch.Size([2, 2, 2])         |            1500         |              51            |              28           |            240           |                  96                |           27          
      shape torch.Size([4, 2, 2])         |            2991         |              82            |              28           |            241           |                  96                |           28          
      shape torch.Size([8, 2, 2])         |            6000         |             150            |              28           |            241           |                  96                |           27          
      shape torch.Size([16, 2, 2])        |           12000         |             275            |              28           |            242           |                  96                |           27          
      shape torch.Size([32, 2, 2])        |           23980         |             530            |              28           |            246           |                  97                |           28          
      shape torch.Size([64, 2, 2])        |           48000         |            1000            |              28           |            244           |                  96                |           27          
      shape torch.Size([128, 2, 2])       |           96000         |            2063            |              28           |            245           |                  96                |           28          
      shape torch.Size([512, 2, 2])       |          382000         |            8300            |              28           |            257           |                  97                |           28          
      shape torch.Size([1024, 2, 2])      |          764000         |           20000            |              28           |            271           |                  97                |           28          
      shape torch.Size([1, 8, 8])         |             749         |              34            |              28           |            243           |                  88                |           28          
      shape torch.Size([2, 8, 8])         |            1500         |              50            |              28           |            244           |                  97                |           28          
      shape torch.Size([4, 8, 8])         |            2988         |              83            |              28           |            244           |                 100                |           28          
      shape torch.Size([8, 8, 8])         |            5980         |             150            |              28           |            245           |                  97                |           28          
      shape torch.Size([16, 8, 8])        |           12000         |             278            |              28           |            246           |                  96                |           28          
      shape torch.Size([32, 8, 8])        |           23910         |             536            |              28           |            249           |                  98                |           28          
      shape torch.Size([64, 8, 8])        |           47800         |            1100            |              28           |            247           |                  96                |           28          
      shape torch.Size([128, 8, 8])       |           96000         |            2075            |              28           |            248           |                  96                |           28          
      shape torch.Size([512, 8, 8])       |          382100         |            8300            |              28           |            270           |                  97                |           28          
      shape torch.Size([1024, 8, 8])      |          764100         |           16400            |              28           |            291           |                 100                |           28          
      shape torch.Size([1, 16, 16])       |             750         |              33            |              28           |            248           |                  88                |           28
      shape torch.Size([2, 16, 16])       |            1500         |              50            |              28           |            250           |                  97                |           28       
      shape torch.Size([4, 16, 16])       |            2996         |              83            |              28           |            250           |                 100                |           28       
      shape torch.Size([8, 16, 16])       |            5980         |             147            |              28           |            251           |                  97                |           28       
      shape torch.Size([16, 16, 16])      |           11900         |             274            |              28           |            251           |                  97                |           28       
      shape torch.Size([32, 16, 16])      |           24040         |             527            |              28           |            252           |                  97                |           28       
      shape torch.Size([64, 16, 16])      |           47800         |            1037            |              28           |            251           |                 100                |           28       
      shape torch.Size([128, 16, 16])     |           95600         |            2044            |              28           |            252           |                  98                |           28       
      shape torch.Size([512, 16, 16])     |          388200         |            8100            |              28           |            280           |                 100                |           28       
      shape torch.Size([1024, 16, 16])    |          769700         |           16000            |              28           |            322           |                 117                |           28       
      shape torch.Size([1, 32, 32])       |             760         |              33            |              28           |            255           |                  88                |           28       
      shape torch.Size([2, 32, 32])       |            1510         |              50            |              28           |            256           |                  97                |           28       
      shape torch.Size([4, 32, 32])       |            3022         |              82            |              31           |            256           |                  97                |           30       
      shape torch.Size([8, 32, 32])       |            6000         |             140            |              31           |            257           |                 100                |           31       
      shape torch.Size([16, 32, 32])      |           12000         |             281            |              31           |            258           |                  96                |           31       
      shape torch.Size([32, 32, 32])      |           24150         |             563            |              35           |            258           |                  96                |           35       
      shape torch.Size([64, 32, 32])      |           48300         |            1119            |              36           |            258           |                  97                |           36       
      shape torch.Size([128, 32, 32])     |           96500         |            2235            |              43           |            261           |                  96                |           43       
      shape torch.Size([512, 32, 32])     |          383100         |            8930            |              82           |            317           |                 191                |           82       
      shape torch.Size([1024, 32, 32])    |          766300         |           19200            |             122           |            400           |                 312                |          122       
      shape torch.Size([1, 64, 64])       |             760         |              33            |              55           |            272           |                  68                |           34       
      shape torch.Size([2, 64, 64])       |            1500         |              52            |              58           |            273           |                  85                |           52       
      shape torch.Size([4, 64, 64])       |            3127         |             102            |              65           |            273           |                 150                |           65       
      shape torch.Size([8, 64, 64])       |            6070         |             201            |              65           |            275           |                 278                |           65       
      shape torch.Size([16, 64, 64])      |           12000         |             399            |              66           |            274           |                  95                |           67       
      shape torch.Size([32, 64, 64])      |           23900         |             796            |              73           |            275           |                  97                |           73       
      shape torch.Size([64, 64, 64])      |           48000         |            1594            |              75           |            283           |                 123                |           76       
      shape torch.Size([128, 64, 64])     |           95000         |            3177            |              96           |            292           |                 176                |           96       
      shape torch.Size([512, 64, 64])     |          379300         |           13520            |             208           |            426           |                 551                |          208       
      shape torch.Size([1024, 64, 64])    |          758700         |           27100            |             306           |            570           |                 919                |          306       
      shape torch.Size([1, 128, 128])     |             750         |              42            |             115           |            306           |                  90                |           42       
      shape torch.Size([2, 128, 128])     |            1500         |              82            |             122           |            307           |                 164                |           83       
      shape torch.Size([4, 128, 128])     |            2966         |             162            |             136           |            307           |                 301                |          136       
      shape torch.Size([8, 128, 128])     |            5930         |             317            |             137           |            308           |                 578                |          138       
      shape torch.Size([16, 128, 128])    |           12000         |             635            |             143           |            316           |                 199                |          143       
      shape torch.Size([32, 128, 128])    |           23700         |            1266            |             152           |            322           |                 241                |          152       
      shape torch.Size([64, 128, 128])    |           48000         |            2668            |             177           |            337           |                 322                |          177       
      shape torch.Size([128, 128, 128])   |           96000         |            5366            |             228           |            365           |                 514                |          228       
      shape torch.Size([512, 128, 128])   |          379400         |           21490            |             502           |            620           |                1697                |          502       
      shape torch.Size([1024, 128, 128])  |          755700         |           43040            |             764           |            903           |                3040                |          770       
      shape torch.Size([1, 256, 256])     |             750         |              70            |             235           |            383           |                 178                |           72       
      shape torch.Size([2, 256, 256])     |            2000         |             138            |             250           |            384           |                 329                |          139       
      shape torch.Size([4, 256, 256])     |            2988         |             277            |             279           |            404           |                 655                |          278       
      shape torch.Size([8, 256, 256])     |            6100         |             546            |             283           |            420           |                1321                |          286       
      shape torch.Size([16, 256, 256])    |           12100         |            1149            |             330           |            441           |                 472                |          330       
      shape torch.Size([32, 256, 256])    |           24040         |            2303            |             359           |            453           |                 634                |          360       
      shape torch.Size([64, 256, 256])    |           48000         |            4626            |             408           |            472           |                 925                |          408       
      shape torch.Size([128, 256, 256])   |           94700         |            9247            |             543           |            543           |                1582                |          543       
      shape torch.Size([512, 256, 256])   |          372000         |           37030            |            1310           |           1185           |                5711                |         1310       
      shape torch.Size([1024, 256, 256])  |          747200         |           74100            |            2116           |           1910           |               10660                |         2122       
```
</details>

<details>
<summary>
Benchmark Results (adjoint=True)
</summary>

```
[----------------------------------------------------------------------------------------- linalg.lu_solve CUDA Adjoint -----------------------------------------------------------------------------------------]   
                                          |  lu_solve looped_magma  |  lu_solve looped cusolver  |  lu_solve batched cublas  |  lu_solve batched magma  |  lu_solve unpack+solve_triangular  |  lu_solve heuristic   
1 threads: -------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------   
      shape torch.Size([1, 1, 1])         |             749         |              34            |              28           |             33           |                  98                |           27          
      shape torch.Size([2, 1, 1])         |            1500         |              50            |              28           |             50           |                 110                |           27          
      shape torch.Size([4, 1, 1])         |            3005         |              82            |              28           |             81           |                 110                |           27          
      shape torch.Size([8, 1, 1])         |            5999         |             145            |              28           |            140           |                 110                |           27          
      shape torch.Size([16, 1, 1])        |           12000         |             273            |              28           |             77           |                 110                |           27          
      shape torch.Size([32, 1, 1])        |           24000         |             522            |              28           |             78           |                 110                |           27          
      shape torch.Size([64, 1, 1])        |           48000         |            1000            |              28           |             77           |                 100                |           27          
      shape torch.Size([128, 1, 1])       |           96000         |            2029            |              28           |             78           |                 110                |           27          
      shape torch.Size([512, 1, 1])       |          383300         |            8100            |              28           |             78           |                 110                |           28          
      shape torch.Size([1024, 1, 1])      |          767500         |           16100            |              28           |             77           |                 100                |           27          
      shape torch.Size([1, 2, 2])         |             753         |              33            |              28           |             33           |                  99                |           28          
      shape torch.Size([2, 2, 2])         |            1500         |              50            |              28           |             50           |                 110                |           28          
      shape torch.Size([4, 2, 2])         |            3002         |              82            |              28           |             80           |                 100                |           27          
      shape torch.Size([8, 2, 2])         |            6000         |             145            |              28           |            144           |                 107                |           27          
      shape torch.Size([16, 2, 2])        |           12000         |             271            |              28           |             78           |                 110                |           27          
      shape torch.Size([32, 2, 2])        |           24120         |             524            |              28           |             78           |                 110                |           28          
      shape torch.Size([64, 2, 2])        |           48300         |            1030            |              28           |             78           |                 111                |           27          
      shape torch.Size([128, 2, 2])       |           96100         |            2041            |              28           |             78           |                 107                |           28          
      shape torch.Size([512, 2, 2])       |          383000         |            8100            |              28           |             79           |                 108                |           28          
      shape torch.Size([1024, 2, 2])      |          766100         |           16000            |              28           |             78           |                 110                |           28          
      shape torch.Size([1, 8, 8])         |             750         |              34            |              28           |             34           |                  99                |           28          
      shape torch.Size([2, 8, 8])         |            1500         |              50            |              28           |             50           |                 107                |           28          
      shape torch.Size([4, 8, 8])         |            2998         |              82            |              28           |             82           |                 110                |           28          
      shape torch.Size([8, 8, 8])         |            5990         |             146            |              28           |            150           |                 107                |           28          
      shape torch.Size([16, 8, 8])        |           11980         |             272            |              28           |             79           |                 107                |           28          
      shape torch.Size([32, 8, 8])        |           23970         |             530            |              28           |             79           |                 110                |           28          
      shape torch.Size([64, 8, 8])        |           47900         |            1040            |              28           |             79           |                 108                |           28          
      shape torch.Size([128, 8, 8])       |           96000         |            2048            |              28           |             78           |                 108                |           28          
      shape torch.Size([512, 8, 8])       |          383700         |            8100            |              28           |             80           |                 108                |           28          
      shape torch.Size([1024, 8, 8])      |          766200         |           16300            |              28           |             80           |                 108                |           28          
      shape torch.Size([1, 16, 16])       |             760         |              33            |              28           |             34           |                  99                |           28          
      shape torch.Size([2, 16, 16])       |            1500         |              50            |              28           |             50           |                 110                |           28  [85/469]
      shape torch.Size([4, 16, 16])       |            3001         |              81            |              28           |             82           |                 108                |           28          
      shape torch.Size([8, 16, 16])       |            6000         |             145            |              28           |            140           |                 110                |           28          
      shape torch.Size([16, 16, 16])      |           12000         |             276            |              28           |             79           |                 110                |           28          
      shape torch.Size([32, 16, 16])      |           23870         |             549            |              28           |             79           |                 110                |           28          
      shape torch.Size([64, 16, 16])      |           47900         |            1098            |              29           |             80           |                 100                |           28          
      shape torch.Size([128, 16, 16])     |           95800         |            2184            |              28           |             79           |                 108                |           28          
      shape torch.Size([512, 16, 16])     |          386900         |            8769            |              28           |             80           |                 108                |           28          
      shape torch.Size([1024, 16, 16])    |          769800         |           17460            |              37           |             80           |                 107                |           37          
      shape torch.Size([1, 32, 32])       |             760         |              33            |              28           |             34           |                  99                |           28          
      shape torch.Size([2, 32, 32])       |            1500         |              50            |              28           |             50           |                 110                |           29          
      shape torch.Size([4, 32, 32])       |            3021         |              86            |              31           |             84           |                 110                |           32          
      shape torch.Size([8, 32, 32])       |            6040         |             167            |              32           |            167           |                 108                |           32          
      shape torch.Size([16, 32, 32])      |           12100         |             330            |              33           |             78           |                 107                |           33          
      shape torch.Size([32, 32, 32])      |           24150         |             662            |              35           |             78           |                 110                |           35          
      shape torch.Size([64, 32, 32])      |           48200         |            1323            |              36           |             79           |                 110                |           36          
      shape torch.Size([128, 32, 32])     |           97000         |            2637            |              44           |             79           |                 110                |           43          
      shape torch.Size([512, 32, 32])     |          382500         |           10580            |              83           |            180           |                 198                |           83          
      shape torch.Size([1024, 32, 32])    |          766600         |           22670            |             123           |            260           |                 318                |          120          
      shape torch.Size([1, 64, 64])       |             760         |              33            |              58           |             33           |                  88                |           34          
      shape torch.Size([2, 64, 64])       |            1520         |              60            |              60           |             60           |                 109                |           59          
      shape torch.Size([4, 64, 64])       |            3016         |             115            |              67           |            119           |                 132                |           66          
      shape torch.Size([8, 64, 64])       |            6120         |             230            |              67           |            233           |                 180                |           68          
      shape torch.Size([16, 64, 64])      |           12100         |             457            |              69           |             86           |                 107                |           69          
      shape torch.Size([32, 64, 64])      |           24000         |             912            |              74           |             95           |                 100                |           74          
      shape torch.Size([64, 64, 64])      |           48000         |            1833            |              76           |            106           |                 123                |           76          
      shape torch.Size([128, 64, 64])     |           95000         |            3636            |              97           |            163           |                 183                |           97          
      shape torch.Size([512, 64, 64])     |          380600         |           15600            |             210           |            464           |                 549                |          210          
      shape torch.Size([1024, 64, 64])    |          761200         |           31140            |             308           |            741           |                 918                |          308          
      shape torch.Size([1, 128, 128])     |             756         |              46            |             120           |             47           |                  89                |           46          
      shape torch.Size([2, 128, 128])     |            1500         |              91            |             123           |             89           |                 110                |           92          
      shape torch.Size([4, 128, 128])     |            2994         |             178            |             139           |            180           |                 131                |          139          
      shape torch.Size([8, 128, 128])     |            5960         |             350            |             140           |            354           |                 208                |          142          
      shape torch.Size([16, 128, 128])    |           12000         |             701            |             144           |            177           |                 198                |          143          
      shape torch.Size([32, 128, 128])    |           23870         |            1401            |             155           |            225           |                 246                |          155          
      shape torch.Size([64, 128, 128])    |           47600         |            2948            |             179           |            288           |                 323                |          180          
      shape torch.Size([128, 128, 128])   |           96000         |            5910            |             231           |            442           |                 512                |          231          
      shape torch.Size([512, 128, 128])   |          381200         |           23640            |             519           |           1400           |                1700                |          519          
      shape torch.Size([1024, 128, 128])  |          755800         |           47340            |             794           |           2436           |                3018                |          794          
      shape torch.Size([1, 256, 256])     |             760         |              74            |             246           |             77           |                  88                |           78          
      shape torch.Size([2, 256, 256])     |            1510         |             150            |             256           |            150           |                 117                |          150          
      shape torch.Size([4, 256, 256])     |            3030         |             296            |             284           |            296           |                 209                |          284          
      shape torch.Size([8, 256, 256])     |            6100         |             588            |             286           |            592           |                 394                |          288          
      shape torch.Size([16, 256, 256])    |           12200         |            1238            |             330           |            445           |                 480                |          330          
      shape torch.Size([32, 256, 256])    |           24430         |            2476            |             368           |            568           |                 629                |          367          
      shape torch.Size([64, 256, 256])    |           49000         |            4950            |             415           |            800           |                 921                |          414          
      shape torch.Size([128, 256, 256])   |           96000         |            9900            |             552           |           1330           |                1579                |          553          
      shape torch.Size([512, 256, 256])   |          369400         |           39580            |            1410           |           4614           |                5616                |         1410          
      shape torch.Size([1024, 256, 256])  |          716200         |           79200            |            2270           |           8472           |               10500                |         2277          

Times are in microseconds (us).
```

</details>

To generate the results below, I put the backend I wanted to test at the beginning of the function `lu_solve_kernel`, followed by a `break;`. Then I run the following script, changing the variable `name`. For the `lu_solve unpack+solve_triangular`, I also changed the `stmt` variable (uncomenting the commented one)
<details>
<summary>
Benchmarking script
</summary>

```python
import torch
import pickle
import itertools
from functools import partial
from torch.utils.benchmark import Timer, Compare

name = "heuristic"
label = "lu_solve {}".format(name)
shapes = [1, 2, 8, 16, 32, 64, 128, 256]
batches = [(1,), (2,), (4,), (8,), (16,), (32,), (64,), (128,), (512,), (1024,)]
results = []
make_arg = partial(torch.randn, dtype=torch.float32, device="cuda")


def f(LU, pivots, B, adjoint):
    P, L, U = torch.lu_unpack(LU, pivots)
    if adjoint:
        X = torch.linalg.solve_triangular(U.mH, B, upper=False)
        return P @ torch.linalg.solve_triangular(L.mH, X, upper=True, unitriangular=True, out=X)
    else:
        X = P.mT @ B
        X = torch.linalg.solve_triangular(L, X, upper=False, unitriangular=True, out=X)
        return torch.linalg.solve_triangular(U, X, upper=True, out=X)


for n, batch in itertools.product(shapes, batches):
    LU, pivots = torch.linalg.lu_factor(make_arg(batch + (n, n)))
    B = make_arg(batch + (n, 1))
    print(LU.shape)
    stmt = "torch.linalg.lu_solve(LU, pivots, B, adjoint=adjoint)"
    #stmt = "f(LU, pivots, B, adjoint=adjoint)"
    for adjoint in (True, False):
        timer = Timer(stmt,
                      globals=globals(),
                      label="linalg.lu_solve CUDA{}".format(" Adjoint" if adjoint else ""),
                      description=label,
                      sub_label=f"shape {LU.shape}",
                      num_threads=1)
        results.append(timer.blocked_autorange())


compare = Compare(results)
compare.trim_significant_figures()
compare.print()

with open("{}_lu_solve.pickle".format(name), 'wb') as f:
    pickle.dump(results, f)
```
</details>

Finally, I joined all the results with the following script:
<details>
<summary>
Script to join the results
</summary>

```python
import pickle
from torch.utils.benchmark import Timer, Compare

files = [
         "looped_magma",
         "looped cusolver",
         "batched cublas",
         "batched magma",
         "unpack+solve_triangular",
         "heuristic",
        ]

timers = []
for name in files:
    with open("{}_lu_solve.pickle".format(name), 'rb') as f:
        timers += pickle.load(f)

compare = Compare(timers)
compare.trim_significant_figures()
compare.print()
```
</details>

### Fix for Magma's batched lu_solve when `adjoint=True`
I also developed the following fix around MAGMA's bug, but I ended up not using it, and preferring the triangular solves over it, as they were faster. I'm leaving it here in case it's useful in the future.

<details>
<summary>
Fix for MAGMA's issue with `adjoint=True`
</summary>

```cpp
auto lu_solve_batched_magma_fn = [m](const Tensor& LU, const Tensor& pivots, const Tensor& B, TransposeType trans) {
    if (trans == TransposeType::NoTranspose) {
      lu_solve_batched_magma(LU, pivots, B, trans);
      return;
    }
    // There's a bug in magma for the other cases, so we need to properly perform mT or mH on LU
    // The LU of the transpose is not the transpose of the LU
    // We need to do LU = LDU' = L'U' where L' = LD, U' = D^{-1}U and D = diag(U)
    auto diag = LU.diagonal(0, -2, -1);
    auto LU_f = LU.tril(-1).mul_(diag.unsqueeze(-2)) +
                LU.triu(1).div_(diag.unsqueeze(-1));
    LU_f.diagonal(0, -2, -1).copy_(diag);

    if (trans == TransposeType::ConjTranspose) {
      LU_f = LU_f.conj_physical();
    }
    LU_f.transpose(-2, -1);
    // At this point LU_f is F-contiguous, because triu / tril / conj_phisical return contiguous tensors
    TORCH_INTERNAL_ASSERT_DEBUG_ONLY(LU_f.mT().is_contiguous());

    // Trivial permutation
    auto pivots_aux = at::arange(1, m + 1, pivots.options()).expand_as(pivots).contiguous();

    lu_solve_batched_magma(LU_f, pivots_aux, B, TransposeType::NoTranspose);

    // We then need to multiply B by P on the right as (PLU)^T = B iff U^TL^T = BP
    // Fill `perm` with the identity permutation (perhaps batched)
    // This is faster than torch.lu_unpack + matmul, as this logic is borrowed from lu_unpack
    const auto perm = at::arange(m, pivots.options().dtype(kLong)).expand(pivots.sizes()).contiguous();
    auto iter = TensorIteratorConfig()
      .set_check_mem_overlap(false)
      .check_all_same_dtype(false)
      .resize_outputs(false)
      .declare_static_shape(pivots.sizes(), /*squash_dim=*/pivots.dim() - 1)
      .add_output(perm)
      .add_input(pivots)
      .build();
    unpack_pivots_stub(pivots.device().type(), iter, m);
    B.scatter_(-2, perm.unsqueeze(-1).expand_as(B), B.clone());
};
```

</details>



Fixes #61657

[ghstack-poisoned]
This PR adds `linalg.lu_solve`. While doing so, I found a bug in MAGMA
when calling the batched MAGMA backend with trans=True. We work around
that by solving the system solving two triangular systems.

We also update the heuristics for this function, as they were fairly
outdated. We found that cuSolver is king, so luckily we do not need to
rely on the buggy backend from magma for this function.

We added tests testing this function left and right. We also added tests
for the different backends. We also activated the tests for AMD, as
those should work as well.

### Benchmarking

<details>
<summary>
Benchmark Results (adjoint=False)
</summary>

```
--------------------------------------------------------------------------------------------- linalg.lu_solve CUDA ---------------------------------------------------------------------------------------------]   
                                          |  lu_solve looped_magma  |  lu_solve looped cusolver  |  lu_solve batched cublas  |  lu_solve batched magma  |  lu_solve unpack+solve_triangular  |  lu_solve heuristic   
1 threads: -------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------   
      shape torch.Size([1, 1, 1])         |             750         |              34            |              28           |            252           |                  86                |           27          
      shape torch.Size([2, 1, 1])         |            1500         |              50            |              28           |            239           |                  94                |           27          
      shape torch.Size([4, 1, 1])         |            2995         |              83            |              28           |            239           |                 100                |           27          
      shape torch.Size([8, 1, 1])         |            6000         |             146            |              28           |            239           |                  94                |           27          
      shape torch.Size([16, 1, 1])        |           11900         |             272            |              28           |            241           |                  95                |           27          
      shape torch.Size([32, 1, 1])        |           23880         |             524            |              28           |            244           |                  94                |           27          
      shape torch.Size([64, 1, 1])        |           48000         |            1000            |              28           |            245           |                  99                |           27          
      shape torch.Size([128, 1, 1])       |           96000         |            2054            |              28           |            242           |                  96                |           27          
      shape torch.Size([512, 1, 1])       |          381900         |            8100            |              28           |            250           |                  94                |           27          
      shape torch.Size([1024, 1, 1])      |          763800         |           16200            |              28           |            257           |                  95                |           27          
      shape torch.Size([1, 2, 2])         |             750         |              33            |              28           |            240           |                  88                |           27          
      shape torch.Size([2, 2, 2])         |            1500         |              51            |              28           |            240           |                  96                |           27          
      shape torch.Size([4, 2, 2])         |            2991         |              82            |              28           |            241           |                  96                |           28          
      shape torch.Size([8, 2, 2])         |            6000         |             150            |              28           |            241           |                  96                |           27          
      shape torch.Size([16, 2, 2])        |           12000         |             275            |              28           |            242           |                  96                |           27          
      shape torch.Size([32, 2, 2])        |           23980         |             530            |              28           |            246           |                  97                |           28          
      shape torch.Size([64, 2, 2])        |           48000         |            1000            |              28           |            244           |                  96                |           27          
      shape torch.Size([128, 2, 2])       |           96000         |            2063            |              28           |            245           |                  96                |           28          
      shape torch.Size([512, 2, 2])       |          382000         |            8300            |              28           |            257           |                  97                |           28          
      shape torch.Size([1024, 2, 2])      |          764000         |           20000            |              28           |            271           |                  97                |           28          
      shape torch.Size([1, 8, 8])         |             749         |              34            |              28           |            243           |                  88                |           28          
      shape torch.Size([2, 8, 8])         |            1500         |              50            |              28           |            244           |                  97                |           28          
      shape torch.Size([4, 8, 8])         |            2988         |              83            |              28           |            244           |                 100                |           28          
      shape torch.Size([8, 8, 8])         |            5980         |             150            |              28           |            245           |                  97                |           28          
      shape torch.Size([16, 8, 8])        |           12000         |             278            |              28           |            246           |                  96                |           28          
      shape torch.Size([32, 8, 8])        |           23910         |             536            |              28           |            249           |                  98                |           28          
      shape torch.Size([64, 8, 8])        |           47800         |            1100            |              28           |            247           |                  96                |           28          
      shape torch.Size([128, 8, 8])       |           96000         |            2075            |              28           |            248           |                  96                |           28          
      shape torch.Size([512, 8, 8])       |          382100         |            8300            |              28           |            270           |                  97                |           28          
      shape torch.Size([1024, 8, 8])      |          764100         |           16400            |              28           |            291           |                 100                |           28          
      shape torch.Size([1, 16, 16])       |             750         |              33            |              28           |            248           |                  88                |           28
      shape torch.Size([2, 16, 16])       |            1500         |              50            |              28           |            250           |                  97                |           28       
      shape torch.Size([4, 16, 16])       |            2996         |              83            |              28           |            250           |                 100                |           28       
      shape torch.Size([8, 16, 16])       |            5980         |             147            |              28           |            251           |                  97                |           28       
      shape torch.Size([16, 16, 16])      |           11900         |             274            |              28           |            251           |                  97                |           28       
      shape torch.Size([32, 16, 16])      |           24040         |             527            |              28           |            252           |                  97                |           28       
      shape torch.Size([64, 16, 16])      |           47800         |            1037            |              28           |            251           |                 100                |           28       
      shape torch.Size([128, 16, 16])     |           95600         |            2044            |              28           |            252           |                  98                |           28       
      shape torch.Size([512, 16, 16])     |          388200         |            8100            |              28           |            280           |                 100                |           28       
      shape torch.Size([1024, 16, 16])    |          769700         |           16000            |              28           |            322           |                 117                |           28       
      shape torch.Size([1, 32, 32])       |             760         |              33            |              28           |            255           |                  88                |           28       
      shape torch.Size([2, 32, 32])       |            1510         |              50            |              28           |            256           |                  97                |           28       
      shape torch.Size([4, 32, 32])       |            3022         |              82            |              31           |            256           |                  97                |           30       
      shape torch.Size([8, 32, 32])       |            6000         |             140            |              31           |            257           |                 100                |           31       
      shape torch.Size([16, 32, 32])      |           12000         |             281            |              31           |            258           |                  96                |           31       
      shape torch.Size([32, 32, 32])      |           24150         |             563            |              35           |            258           |                  96                |           35       
      shape torch.Size([64, 32, 32])      |           48300         |            1119            |              36           |            258           |                  97                |           36       
      shape torch.Size([128, 32, 32])     |           96500         |            2235            |              43           |            261           |                  96                |           43       
      shape torch.Size([512, 32, 32])     |          383100         |            8930            |              82           |            317           |                 191                |           82       
      shape torch.Size([1024, 32, 32])    |          766300         |           19200            |             122           |            400           |                 312                |          122       
      shape torch.Size([1, 64, 64])       |             760         |              33            |              55           |            272           |                  68                |           34       
      shape torch.Size([2, 64, 64])       |            1500         |              52            |              58           |            273           |                  85                |           52       
      shape torch.Size([4, 64, 64])       |            3127         |             102            |              65           |            273           |                 150                |           65       
      shape torch.Size([8, 64, 64])       |            6070         |             201            |              65           |            275           |                 278                |           65       
      shape torch.Size([16, 64, 64])      |           12000         |             399            |              66           |            274           |                  95                |           67       
      shape torch.Size([32, 64, 64])      |           23900         |             796            |              73           |            275           |                  97                |           73       
      shape torch.Size([64, 64, 64])      |           48000         |            1594            |              75           |            283           |                 123                |           76       
      shape torch.Size([128, 64, 64])     |           95000         |            3177            |              96           |            292           |                 176                |           96       
      shape torch.Size([512, 64, 64])     |          379300         |           13520            |             208           |            426           |                 551                |          208       
      shape torch.Size([1024, 64, 64])    |          758700         |           27100            |             306           |            570           |                 919                |          306       
      shape torch.Size([1, 128, 128])     |             750         |              42            |             115           |            306           |                  90                |           42       
      shape torch.Size([2, 128, 128])     |            1500         |              82            |             122           |            307           |                 164                |           83       
      shape torch.Size([4, 128, 128])     |            2966         |             162            |             136           |            307           |                 301                |          136       
      shape torch.Size([8, 128, 128])     |            5930         |             317            |             137           |            308           |                 578                |          138       
      shape torch.Size([16, 128, 128])    |           12000         |             635            |             143           |            316           |                 199                |          143       
      shape torch.Size([32, 128, 128])    |           23700         |            1266            |             152           |            322           |                 241                |          152       
      shape torch.Size([64, 128, 128])    |           48000         |            2668            |             177           |            337           |                 322                |          177       
      shape torch.Size([128, 128, 128])   |           96000         |            5366            |             228           |            365           |                 514                |          228       
      shape torch.Size([512, 128, 128])   |          379400         |           21490            |             502           |            620           |                1697                |          502       
      shape torch.Size([1024, 128, 128])  |          755700         |           43040            |             764           |            903           |                3040                |          770       
      shape torch.Size([1, 256, 256])     |             750         |              70            |             235           |            383           |                 178                |           72       
      shape torch.Size([2, 256, 256])     |            2000         |             138            |             250           |            384           |                 329                |          139       
      shape torch.Size([4, 256, 256])     |            2988         |             277            |             279           |            404           |                 655                |          278       
      shape torch.Size([8, 256, 256])     |            6100         |             546            |             283           |            420           |                1321                |          286       
      shape torch.Size([16, 256, 256])    |           12100         |            1149            |             330           |            441           |                 472                |          330       
      shape torch.Size([32, 256, 256])    |           24040         |            2303            |             359           |            453           |                 634                |          360       
      shape torch.Size([64, 256, 256])    |           48000         |            4626            |             408           |            472           |                 925                |          408       
      shape torch.Size([128, 256, 256])   |           94700         |            9247            |             543           |            543           |                1582                |          543       
      shape torch.Size([512, 256, 256])   |          372000         |           37030            |            1310           |           1185           |                5711                |         1310       
      shape torch.Size([1024, 256, 256])  |          747200         |           74100            |            2116           |           1910           |               10660                |         2122       
```
</details>

<details>
<summary>
Benchmark Results (adjoint=True)
</summary>

```
[----------------------------------------------------------------------------------------- linalg.lu_solve CUDA Adjoint -----------------------------------------------------------------------------------------]   
                                          |  lu_solve looped_magma  |  lu_solve looped cusolver  |  lu_solve batched cublas  |  lu_solve batched magma  |  lu_solve unpack+solve_triangular  |  lu_solve heuristic   
1 threads: -------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------   
      shape torch.Size([1, 1, 1])         |             749         |              34            |              28           |             33           |                  98                |           27          
      shape torch.Size([2, 1, 1])         |            1500         |              50            |              28           |             50           |                 110                |           27          
      shape torch.Size([4, 1, 1])         |            3005         |              82            |              28           |             81           |                 110                |           27          
      shape torch.Size([8, 1, 1])         |            5999         |             145            |              28           |            140           |                 110                |           27          
      shape torch.Size([16, 1, 1])        |           12000         |             273            |              28           |             77           |                 110                |           27          
      shape torch.Size([32, 1, 1])        |           24000         |             522            |              28           |             78           |                 110                |           27          
      shape torch.Size([64, 1, 1])        |           48000         |            1000            |              28           |             77           |                 100                |           27          
      shape torch.Size([128, 1, 1])       |           96000         |            2029            |              28           |             78           |                 110                |           27          
      shape torch.Size([512, 1, 1])       |          383300         |            8100            |              28           |             78           |                 110                |           28          
      shape torch.Size([1024, 1, 1])      |          767500         |           16100            |              28           |             77           |                 100                |           27          
      shape torch.Size([1, 2, 2])         |             753         |              33            |              28           |             33           |                  99                |           28          
      shape torch.Size([2, 2, 2])         |            1500         |              50            |              28           |             50           |                 110                |           28          
      shape torch.Size([4, 2, 2])         |            3002         |              82            |              28           |             80           |                 100                |           27          
      shape torch.Size([8, 2, 2])         |            6000         |             145            |              28           |            144           |                 107                |           27          
      shape torch.Size([16, 2, 2])        |           12000         |             271            |              28           |             78           |                 110                |           27          
      shape torch.Size([32, 2, 2])        |           24120         |             524            |              28           |             78           |                 110                |           28          
      shape torch.Size([64, 2, 2])        |           48300         |            1030            |              28           |             78           |                 111                |           27          
      shape torch.Size([128, 2, 2])       |           96100         |            2041            |              28           |             78           |                 107                |           28          
      shape torch.Size([512, 2, 2])       |          383000         |            8100            |              28           |             79           |                 108                |           28          
      shape torch.Size([1024, 2, 2])      |          766100         |           16000            |              28           |             78           |                 110                |           28          
      shape torch.Size([1, 8, 8])         |             750         |              34            |              28           |             34           |                  99                |           28          
      shape torch.Size([2, 8, 8])         |            1500         |              50            |              28           |             50           |                 107                |           28          
      shape torch.Size([4, 8, 8])         |            2998         |              82            |              28           |             82           |                 110                |           28          
      shape torch.Size([8, 8, 8])         |            5990         |             146            |              28           |            150           |                 107                |           28          
      shape torch.Size([16, 8, 8])        |           11980         |             272            |              28           |             79           |                 107                |           28          
      shape torch.Size([32, 8, 8])        |           23970         |             530            |              28           |             79           |                 110                |           28          
      shape torch.Size([64, 8, 8])        |           47900         |            1040            |              28           |             79           |                 108                |           28          
      shape torch.Size([128, 8, 8])       |           96000         |            2048            |              28           |             78           |                 108                |           28          
      shape torch.Size([512, 8, 8])       |          383700         |            8100            |              28           |             80           |                 108                |           28          
      shape torch.Size([1024, 8, 8])      |          766200         |           16300            |              28           |             80           |                 108                |           28          
      shape torch.Size([1, 16, 16])       |             760         |              33            |              28           |             34           |                  99                |           28          
      shape torch.Size([2, 16, 16])       |            1500         |              50            |              28           |             50           |                 110                |           28  [85/469]
      shape torch.Size([4, 16, 16])       |            3001         |              81            |              28           |             82           |                 108                |           28          
      shape torch.Size([8, 16, 16])       |            6000         |             145            |              28           |            140           |                 110                |           28          
      shape torch.Size([16, 16, 16])      |           12000         |             276            |              28           |             79           |                 110                |           28          
      shape torch.Size([32, 16, 16])      |           23870         |             549            |              28           |             79           |                 110                |           28          
      shape torch.Size([64, 16, 16])      |           47900         |            1098            |              29           |             80           |                 100                |           28          
      shape torch.Size([128, 16, 16])     |           95800         |            2184            |              28           |             79           |                 108                |           28          
      shape torch.Size([512, 16, 16])     |          386900         |            8769            |              28           |             80           |                 108                |           28          
      shape torch.Size([1024, 16, 16])    |          769800         |           17460            |              37           |             80           |                 107                |           37          
      shape torch.Size([1, 32, 32])       |             760         |              33            |              28           |             34           |                  99                |           28          
      shape torch.Size([2, 32, 32])       |            1500         |              50            |              28           |             50           |                 110                |           29          
      shape torch.Size([4, 32, 32])       |            3021         |              86            |              31           |             84           |                 110                |           32          
      shape torch.Size([8, 32, 32])       |            6040         |             167            |              32           |            167           |                 108                |           32          
      shape torch.Size([16, 32, 32])      |           12100         |             330            |              33           |             78           |                 107                |           33          
      shape torch.Size([32, 32, 32])      |           24150         |             662            |              35           |             78           |                 110                |           35          
      shape torch.Size([64, 32, 32])      |           48200         |            1323            |              36           |             79           |                 110                |           36          
      shape torch.Size([128, 32, 32])     |           97000         |            2637            |              44           |             79           |                 110                |           43          
      shape torch.Size([512, 32, 32])     |          382500         |           10580            |              83           |            180           |                 198                |           83          
      shape torch.Size([1024, 32, 32])    |          766600         |           22670            |             123           |            260           |                 318                |          120          
      shape torch.Size([1, 64, 64])       |             760         |              33            |              58           |             33           |                  88                |           34          
      shape torch.Size([2, 64, 64])       |            1520         |              60            |              60           |             60           |                 109                |           59          
      shape torch.Size([4, 64, 64])       |            3016         |             115            |              67           |            119           |                 132                |           66          
      shape torch.Size([8, 64, 64])       |            6120         |             230            |              67           |            233           |                 180                |           68          
      shape torch.Size([16, 64, 64])      |           12100         |             457            |              69           |             86           |                 107                |           69          
      shape torch.Size([32, 64, 64])      |           24000         |             912            |              74           |             95           |                 100                |           74          
      shape torch.Size([64, 64, 64])      |           48000         |            1833            |              76           |            106           |                 123                |           76          
      shape torch.Size([128, 64, 64])     |           95000         |            3636            |              97           |            163           |                 183                |           97          
      shape torch.Size([512, 64, 64])     |          380600         |           15600            |             210           |            464           |                 549                |          210          
      shape torch.Size([1024, 64, 64])    |          761200         |           31140            |             308           |            741           |                 918                |          308          
      shape torch.Size([1, 128, 128])     |             756         |              46            |             120           |             47           |                  89                |           46          
      shape torch.Size([2, 128, 128])     |            1500         |              91            |             123           |             89           |                 110                |           92          
      shape torch.Size([4, 128, 128])     |            2994         |             178            |             139           |            180           |                 131                |          139          
      shape torch.Size([8, 128, 128])     |            5960         |             350            |             140           |            354           |                 208                |          142          
      shape torch.Size([16, 128, 128])    |           12000         |             701            |             144           |            177           |                 198                |          143          
      shape torch.Size([32, 128, 128])    |           23870         |            1401            |             155           |            225           |                 246                |          155          
      shape torch.Size([64, 128, 128])    |           47600         |            2948            |             179           |            288           |                 323                |          180          
      shape torch.Size([128, 128, 128])   |           96000         |            5910            |             231           |            442           |                 512                |          231          
      shape torch.Size([512, 128, 128])   |          381200         |           23640            |             519           |           1400           |                1700                |          519          
      shape torch.Size([1024, 128, 128])  |          755800         |           47340            |             794           |           2436           |                3018                |          794          
      shape torch.Size([1, 256, 256])     |             760         |              74            |             246           |             77           |                  88                |           78          
      shape torch.Size([2, 256, 256])     |            1510         |             150            |             256           |            150           |                 117                |          150          
      shape torch.Size([4, 256, 256])     |            3030         |             296            |             284           |            296           |                 209                |          284          
      shape torch.Size([8, 256, 256])     |            6100         |             588            |             286           |            592           |                 394                |          288          
      shape torch.Size([16, 256, 256])    |           12200         |            1238            |             330           |            445           |                 480                |          330          
      shape torch.Size([32, 256, 256])    |           24430         |            2476            |             368           |            568           |                 629                |          367          
      shape torch.Size([64, 256, 256])    |           49000         |            4950            |             415           |            800           |                 921                |          414          
      shape torch.Size([128, 256, 256])   |           96000         |            9900            |             552           |           1330           |                1579                |          553          
      shape torch.Size([512, 256, 256])   |          369400         |           39580            |            1410           |           4614           |                5616                |         1410          
      shape torch.Size([1024, 256, 256])  |          716200         |           79200            |            2270           |           8472           |               10500                |         2277          

Times are in microseconds (us).
```

</details>

To generate the results below, I put the backend I wanted to test at the beginning of the function `lu_solve_kernel`, followed by a `break;`. Then I run the following script, changing the variable `name`. For the `lu_solve unpack+solve_triangular`, I also changed the `stmt` variable (uncomenting the commented one)
<details>
<summary>
Benchmarking script
</summary>

```python
import torch
import pickle
import itertools
from functools import partial
from torch.utils.benchmark import Timer, Compare

name = "heuristic"
label = "lu_solve {}".format(name)
shapes = [1, 2, 8, 16, 32, 64, 128, 256]
batches = [(1,), (2,), (4,), (8,), (16,), (32,), (64,), (128,), (512,), (1024,)]
results = []
make_arg = partial(torch.randn, dtype=torch.float32, device="cuda")


def f(LU, pivots, B, adjoint):
    P, L, U = torch.lu_unpack(LU, pivots)
    if adjoint:
        X = torch.linalg.solve_triangular(U.mH, B, upper=False)
        return P @ torch.linalg.solve_triangular(L.mH, X, upper=True, unitriangular=True, out=X)
    else:
        X = P.mT @ B
        X = torch.linalg.solve_triangular(L, X, upper=False, unitriangular=True, out=X)
        return torch.linalg.solve_triangular(U, X, upper=True, out=X)


for n, batch in itertools.product(shapes, batches):
    LU, pivots = torch.linalg.lu_factor(make_arg(batch + (n, n)))
    B = make_arg(batch + (n, 1))
    print(LU.shape)
    stmt = "torch.linalg.lu_solve(LU, pivots, B, adjoint=adjoint)"
    #stmt = "f(LU, pivots, B, adjoint=adjoint)"
    for adjoint in (True, False):
        timer = Timer(stmt,
                      globals=globals(),
                      label="linalg.lu_solve CUDA{}".format(" Adjoint" if adjoint else ""),
                      description=label,
                      sub_label=f"shape {LU.shape}",
                      num_threads=1)
        results.append(timer.blocked_autorange())


compare = Compare(results)
compare.trim_significant_figures()
compare.print()

with open("{}_lu_solve.pickle".format(name), 'wb') as f:
    pickle.dump(results, f)
```
</details>

Finally, I joined all the results with the following script:
<details>
<summary>
Script to join the results
</summary>

```python
import pickle
from torch.utils.benchmark import Timer, Compare

files = [
         "looped_magma",
         "looped cusolver",
         "batched cublas",
         "batched magma",
         "unpack+solve_triangular",
         "heuristic",
        ]

timers = []
for name in files:
    with open("{}_lu_solve.pickle".format(name), 'rb') as f:
        timers += pickle.load(f)

compare = Compare(timers)
compare.trim_significant_figures()
compare.print()
```
</details>

### Fix for Magma's batched lu_solve when `adjoint=True`
I also developed the following fix around MAGMA's bug, but I ended up not using it, and preferring the triangular solves over it, as they were faster. I'm leaving it here in case it's useful in the future.

<details>
<summary>
Fix for MAGMA's issue with `adjoint=True`
</summary>

```cpp
auto lu_solve_batched_magma_fn = [m](const Tensor& LU, const Tensor& pivots, const Tensor& B, TransposeType trans) {
    if (trans == TransposeType::NoTranspose) {
      lu_solve_batched_magma(LU, pivots, B, trans);
      return;
    }
    // There's a bug in magma for the other cases, so we need to properly perform mT or mH on LU
    // The LU of the transpose is not the transpose of the LU
    // We need to do LU = LDU' = L'U' where L' = LD, U' = D^{-1}U and D = diag(U)
    auto diag = LU.diagonal(0, -2, -1);
    auto LU_f = LU.tril(-1).mul_(diag.unsqueeze(-2)) +
                LU.triu(1).div_(diag.unsqueeze(-1));
    LU_f.diagonal(0, -2, -1).copy_(diag);

    if (trans == TransposeType::ConjTranspose) {
      LU_f = LU_f.conj_physical();
    }
    LU_f.transpose(-2, -1);
    // At this point LU_f is F-contiguous, because triu / tril / conj_phisical return contiguous tensors
    TORCH_INTERNAL_ASSERT_DEBUG_ONLY(LU_f.mT().is_contiguous());

    // Trivial permutation
    auto pivots_aux = at::arange(1, m + 1, pivots.options()).expand_as(pivots).contiguous();

    lu_solve_batched_magma(LU_f, pivots_aux, B, TransposeType::NoTranspose);

    // We then need to multiply B by P on the right as (PLU)^T = B iff U^TL^T = BP
    // Fill `perm` with the identity permutation (perhaps batched)
    // This is faster than torch.lu_unpack + matmul, as this logic is borrowed from lu_unpack
    const auto perm = at::arange(m, pivots.options().dtype(kLong)).expand(pivots.sizes()).contiguous();
    auto iter = TensorIteratorConfig()
      .set_check_mem_overlap(false)
      .check_all_same_dtype(false)
      .resize_outputs(false)
      .declare_static_shape(pivots.sizes(), /*squash_dim=*/pivots.dim() - 1)
      .add_output(perm)
      .add_input(pivots)
      .build();
    unpack_pivots_stub(pivots.device().type(), iter, m);
    B.scatter_(-2, perm.unsqueeze(-1).expand_as(B), B.clone());
};
```

</details>



Fixes #61657

[ghstack-poisoned]
This PR adds `linalg.lu_solve`. While doing so, I found a bug in MAGMA
when calling the batched MAGMA backend with trans=True. We work around
that by solving the system solving two triangular systems.

We also update the heuristics for this function, as they were fairly
outdated. We found that cuSolver is king, so luckily we do not need to
rely on the buggy backend from magma for this function.

We added tests testing this function left and right. We also added tests
for the different backends. We also activated the tests for AMD, as
those should work as well.

### Benchmarking

<details>
<summary>
Benchmark Results (adjoint=False)
</summary>

```
--------------------------------------------------------------------------------------------- linalg.lu_solve CUDA ---------------------------------------------------------------------------------------------]   
                                          |  lu_solve looped_magma  |  lu_solve looped cusolver  |  lu_solve batched cublas  |  lu_solve batched magma  |  lu_solve unpack+solve_triangular  |  lu_solve heuristic   
1 threads: -------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------   
      shape torch.Size([1, 1, 1])         |             750         |              34            |              28           |            252           |                  86                |           27          
      shape torch.Size([2, 1, 1])         |            1500         |              50            |              28           |            239           |                  94                |           27          
      shape torch.Size([4, 1, 1])         |            2995         |              83            |              28           |            239           |                 100                |           27          
      shape torch.Size([8, 1, 1])         |            6000         |             146            |              28           |            239           |                  94                |           27          
      shape torch.Size([16, 1, 1])        |           11900         |             272            |              28           |            241           |                  95                |           27          
      shape torch.Size([32, 1, 1])        |           23880         |             524            |              28           |            244           |                  94                |           27          
      shape torch.Size([64, 1, 1])        |           48000         |            1000            |              28           |            245           |                  99                |           27          
      shape torch.Size([128, 1, 1])       |           96000         |            2054            |              28           |            242           |                  96                |           27          
      shape torch.Size([512, 1, 1])       |          381900         |            8100            |              28           |            250           |                  94                |           27          
      shape torch.Size([1024, 1, 1])      |          763800         |           16200            |              28           |            257           |                  95                |           27          
      shape torch.Size([1, 2, 2])         |             750         |              33            |              28           |            240           |                  88                |           27          
      shape torch.Size([2, 2, 2])         |            1500         |              51            |              28           |            240           |                  96                |           27          
      shape torch.Size([4, 2, 2])         |            2991         |              82            |              28           |            241           |                  96                |           28          
      shape torch.Size([8, 2, 2])         |            6000         |             150            |              28           |            241           |                  96                |           27          
      shape torch.Size([16, 2, 2])        |           12000         |             275            |              28           |            242           |                  96                |           27          
      shape torch.Size([32, 2, 2])        |           23980         |             530            |              28           |            246           |                  97                |           28          
      shape torch.Size([64, 2, 2])        |           48000         |            1000            |              28           |            244           |                  96                |           27          
      shape torch.Size([128, 2, 2])       |           96000         |            2063            |              28           |            245           |                  96                |           28          
      shape torch.Size([512, 2, 2])       |          382000         |            8300            |              28           |            257           |                  97                |           28          
      shape torch.Size([1024, 2, 2])      |          764000         |           20000            |              28           |            271           |                  97                |           28          
      shape torch.Size([1, 8, 8])         |             749         |              34            |              28           |            243           |                  88                |           28          
      shape torch.Size([2, 8, 8])         |            1500         |              50            |              28           |            244           |                  97                |           28          
      shape torch.Size([4, 8, 8])         |            2988         |              83            |              28           |            244           |                 100                |           28          
      shape torch.Size([8, 8, 8])         |            5980         |             150            |              28           |            245           |                  97                |           28          
      shape torch.Size([16, 8, 8])        |           12000         |             278            |              28           |            246           |                  96                |           28          
      shape torch.Size([32, 8, 8])        |           23910         |             536            |              28           |            249           |                  98                |           28          
      shape torch.Size([64, 8, 8])        |           47800         |            1100            |              28           |            247           |                  96                |           28          
      shape torch.Size([128, 8, 8])       |           96000         |            2075            |              28           |            248           |                  96                |           28          
      shape torch.Size([512, 8, 8])       |          382100         |            8300            |              28           |            270           |                  97                |           28          
      shape torch.Size([1024, 8, 8])      |          764100         |           16400            |              28           |            291           |                 100                |           28          
      shape torch.Size([1, 16, 16])       |             750         |              33            |              28           |            248           |                  88                |           28
      shape torch.Size([2, 16, 16])       |            1500         |              50            |              28           |            250           |                  97                |           28       
      shape torch.Size([4, 16, 16])       |            2996         |              83            |              28           |            250           |                 100                |           28       
      shape torch.Size([8, 16, 16])       |            5980         |             147            |              28           |            251           |                  97                |           28       
      shape torch.Size([16, 16, 16])      |           11900         |             274            |              28           |            251           |                  97                |           28       
      shape torch.Size([32, 16, 16])      |           24040         |             527            |              28           |            252           |                  97                |           28       
      shape torch.Size([64, 16, 16])      |           47800         |            1037            |              28           |            251           |                 100                |           28       
      shape torch.Size([128, 16, 16])     |           95600         |            2044            |              28           |            252           |                  98                |           28       
      shape torch.Size([512, 16, 16])     |          388200         |            8100            |              28           |            280           |                 100                |           28       
      shape torch.Size([1024, 16, 16])    |          769700         |           16000            |              28           |            322           |                 117                |           28       
      shape torch.Size([1, 32, 32])       |             760         |              33            |              28           |            255           |                  88                |           28       
      shape torch.Size([2, 32, 32])       |            1510         |              50            |              28           |            256           |                  97                |           28       
      shape torch.Size([4, 32, 32])       |            3022         |              82            |              31           |            256           |                  97                |           30       
      shape torch.Size([8, 32, 32])       |            6000         |             140            |              31           |            257           |                 100                |           31       
      shape torch.Size([16, 32, 32])      |           12000         |             281            |              31           |            258           |                  96                |           31       
      shape torch.Size([32, 32, 32])      |           24150         |             563            |              35           |            258           |                  96                |           35       
      shape torch.Size([64, 32, 32])      |           48300         |            1119            |              36           |            258           |                  97                |           36       
      shape torch.Size([128, 32, 32])     |           96500         |            2235            |              43           |            261           |                  96                |           43       
      shape torch.Size([512, 32, 32])     |          383100         |            8930            |              82           |            317           |                 191                |           82       
      shape torch.Size([1024, 32, 32])    |          766300         |           19200            |             122           |            400           |                 312                |          122       
      shape torch.Size([1, 64, 64])       |             760         |              33            |              55           |            272           |                  68                |           34       
      shape torch.Size([2, 64, 64])       |            1500         |              52            |              58           |            273           |                  85                |           52       
      shape torch.Size([4, 64, 64])       |            3127         |             102            |              65           |            273           |                 150                |           65       
      shape torch.Size([8, 64, 64])       |            6070         |             201            |              65           |            275           |                 278                |           65       
      shape torch.Size([16, 64, 64])      |           12000         |             399            |              66           |            274           |                  95                |           67       
      shape torch.Size([32, 64, 64])      |           23900         |             796            |              73           |            275           |                  97                |           73       
      shape torch.Size([64, 64, 64])      |           48000         |            1594            |              75           |            283           |                 123                |           76       
      shape torch.Size([128, 64, 64])     |           95000         |            3177            |              96           |            292           |                 176                |           96       
      shape torch.Size([512, 64, 64])     |          379300         |           13520            |             208           |            426           |                 551                |          208       
      shape torch.Size([1024, 64, 64])    |          758700         |           27100            |             306           |            570           |                 919                |          306       
      shape torch.Size([1, 128, 128])     |             750         |              42            |             115           |            306           |                  90                |           42       
      shape torch.Size([2, 128, 128])     |            1500         |              82            |             122           |            307           |                 164                |           83       
      shape torch.Size([4, 128, 128])     |            2966         |             162            |             136           |            307           |                 301                |          136       
      shape torch.Size([8, 128, 128])     |            5930         |             317            |             137           |            308           |                 578                |          138       
      shape torch.Size([16, 128, 128])    |           12000         |             635            |             143           |            316           |                 199                |          143       
      shape torch.Size([32, 128, 128])    |           23700         |            1266            |             152           |            322           |                 241                |          152       
      shape torch.Size([64, 128, 128])    |           48000         |            2668            |             177           |            337           |                 322                |          177       
      shape torch.Size([128, 128, 128])   |           96000         |            5366            |             228           |            365           |                 514                |          228       
      shape torch.Size([512, 128, 128])   |          379400         |           21490            |             502           |            620           |                1697                |          502       
      shape torch.Size([1024, 128, 128])  |          755700         |           43040            |             764           |            903           |                3040                |          770       
      shape torch.Size([1, 256, 256])     |             750         |              70            |             235           |            383           |                 178                |           72       
      shape torch.Size([2, 256, 256])     |            2000         |             138            |             250           |            384           |                 329                |          139       
      shape torch.Size([4, 256, 256])     |            2988         |             277            |             279           |            404           |                 655                |          278       
      shape torch.Size([8, 256, 256])     |            6100         |             546            |             283           |            420           |                1321                |          286       
      shape torch.Size([16, 256, 256])    |           12100         |            1149            |             330           |            441           |                 472                |          330       
      shape torch.Size([32, 256, 256])    |           24040         |            2303            |             359           |            453           |                 634                |          360       
      shape torch.Size([64, 256, 256])    |           48000         |            4626            |             408           |            472           |                 925                |          408       
      shape torch.Size([128, 256, 256])   |           94700         |            9247            |             543           |            543           |                1582                |          543       
      shape torch.Size([512, 256, 256])   |          372000         |           37030            |            1310           |           1185           |                5711                |         1310       
      shape torch.Size([1024, 256, 256])  |          747200         |           74100            |            2116           |           1910           |               10660                |         2122       
```
</details>

<details>
<summary>
Benchmark Results (adjoint=True)
</summary>

```
[----------------------------------------------------------------------------------------- linalg.lu_solve CUDA Adjoint -----------------------------------------------------------------------------------------]   
                                          |  lu_solve looped_magma  |  lu_solve looped cusolver  |  lu_solve batched cublas  |  lu_solve batched magma  |  lu_solve unpack+solve_triangular  |  lu_solve heuristic   
1 threads: -------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------   
      shape torch.Size([1, 1, 1])         |             749         |              34            |              28           |             33           |                  98                |           27          
      shape torch.Size([2, 1, 1])         |            1500         |              50            |              28           |             50           |                 110                |           27          
      shape torch.Size([4, 1, 1])         |            3005         |              82            |              28           |             81           |                 110                |           27          
      shape torch.Size([8, 1, 1])         |            5999         |             145            |              28           |            140           |                 110                |           27          
      shape torch.Size([16, 1, 1])        |           12000         |             273            |              28           |             77           |                 110                |           27          
      shape torch.Size([32, 1, 1])        |           24000         |             522            |              28           |             78           |                 110                |           27          
      shape torch.Size([64, 1, 1])        |           48000         |            1000            |              28           |             77           |                 100                |           27          
      shape torch.Size([128, 1, 1])       |           96000         |            2029            |              28           |             78           |                 110                |           27          
      shape torch.Size([512, 1, 1])       |          383300         |            8100            |              28           |             78           |                 110                |           28          
      shape torch.Size([1024, 1, 1])      |          767500         |           16100            |              28           |             77           |                 100                |           27          
      shape torch.Size([1, 2, 2])         |             753         |              33            |              28           |             33           |                  99                |           28          
      shape torch.Size([2, 2, 2])         |            1500         |              50            |              28           |             50           |                 110                |           28          
      shape torch.Size([4, 2, 2])         |            3002         |              82            |              28           |             80           |                 100                |           27          
      shape torch.Size([8, 2, 2])         |            6000         |             145            |              28           |            144           |                 107                |           27          
      shape torch.Size([16, 2, 2])        |           12000         |             271            |              28           |             78           |                 110                |           27          
      shape torch.Size([32, 2, 2])        |           24120         |             524            |              28           |             78           |                 110                |           28          
      shape torch.Size([64, 2, 2])        |           48300         |            1030            |              28           |             78           |                 111                |           27          
      shape torch.Size([128, 2, 2])       |           96100         |            2041            |              28           |             78           |                 107                |           28          
      shape torch.Size([512, 2, 2])       |          383000         |            8100            |              28           |             79           |                 108                |           28          
      shape torch.Size([1024, 2, 2])      |          766100         |           16000            |              28           |             78           |                 110                |           28          
      shape torch.Size([1, 8, 8])         |             750         |              34            |              28           |             34           |                  99                |           28          
      shape torch.Size([2, 8, 8])         |            1500         |              50            |              28           |             50           |                 107                |           28          
      shape torch.Size([4, 8, 8])         |            2998         |              82            |              28           |             82           |                 110                |           28          
      shape torch.Size([8, 8, 8])         |            5990         |             146            |              28           |            150           |                 107                |           28          
      shape torch.Size([16, 8, 8])        |           11980         |             272            |              28           |             79           |                 107                |           28          
      shape torch.Size([32, 8, 8])        |           23970         |             530            |              28           |             79           |                 110                |           28          
      shape torch.Size([64, 8, 8])        |           47900         |            1040            |              28           |             79           |                 108                |           28          
      shape torch.Size([128, 8, 8])       |           96000         |            2048            |              28           |             78           |                 108                |           28          
      shape torch.Size([512, 8, 8])       |          383700         |            8100            |              28           |             80           |                 108                |           28          
      shape torch.Size([1024, 8, 8])      |          766200         |           16300            |              28           |             80           |                 108                |           28          
      shape torch.Size([1, 16, 16])       |             760         |              33            |              28           |             34           |                  99                |           28          
      shape torch.Size([2, 16, 16])       |            1500         |              50            |              28           |             50           |                 110                |           28  [85/469]
      shape torch.Size([4, 16, 16])       |            3001         |              81            |              28           |             82           |                 108                |           28          
      shape torch.Size([8, 16, 16])       |            6000         |             145            |              28           |            140           |                 110                |           28          
      shape torch.Size([16, 16, 16])      |           12000         |             276            |              28           |             79           |                 110                |           28          
      shape torch.Size([32, 16, 16])      |           23870         |             549            |              28           |             79           |                 110                |           28          
      shape torch.Size([64, 16, 16])      |           47900         |            1098            |              29           |             80           |                 100                |           28          
      shape torch.Size([128, 16, 16])     |           95800         |            2184            |              28           |             79           |                 108                |           28          
      shape torch.Size([512, 16, 16])     |          386900         |            8769            |              28           |             80           |                 108                |           28          
      shape torch.Size([1024, 16, 16])    |          769800         |           17460            |              37           |             80           |                 107                |           37          
      shape torch.Size([1, 32, 32])       |             760         |              33            |              28           |             34           |                  99                |           28          
      shape torch.Size([2, 32, 32])       |            1500         |              50            |              28           |             50           |                 110                |           29          
      shape torch.Size([4, 32, 32])       |            3021         |              86            |              31           |             84           |                 110                |           32          
      shape torch.Size([8, 32, 32])       |            6040         |             167            |              32           |            167           |                 108                |           32          
      shape torch.Size([16, 32, 32])      |           12100         |             330            |              33           |             78           |                 107                |           33          
      shape torch.Size([32, 32, 32])      |           24150         |             662            |              35           |             78           |                 110                |           35          
      shape torch.Size([64, 32, 32])      |           48200         |            1323            |              36           |             79           |                 110                |           36          
      shape torch.Size([128, 32, 32])     |           97000         |            2637            |              44           |             79           |                 110                |           43          
      shape torch.Size([512, 32, 32])     |          382500         |           10580            |              83           |            180           |                 198                |           83          
      shape torch.Size([1024, 32, 32])    |          766600         |           22670            |             123           |            260           |                 318                |          120          
      shape torch.Size([1, 64, 64])       |             760         |              33            |              58           |             33           |                  88                |           34          
      shape torch.Size([2, 64, 64])       |            1520         |              60            |              60           |             60           |                 109                |           59          
      shape torch.Size([4, 64, 64])       |            3016         |             115            |              67           |            119           |                 132                |           66          
      shape torch.Size([8, 64, 64])       |            6120         |             230            |              67           |            233           |                 180                |           68          
      shape torch.Size([16, 64, 64])      |           12100         |             457            |              69           |             86           |                 107                |           69          
      shape torch.Size([32, 64, 64])      |           24000         |             912            |              74           |             95           |                 100                |           74          
      shape torch.Size([64, 64, 64])      |           48000         |            1833            |              76           |            106           |                 123                |           76          
      shape torch.Size([128, 64, 64])     |           95000         |            3636            |              97           |            163           |                 183                |           97          
      shape torch.Size([512, 64, 64])     |          380600         |           15600            |             210           |            464           |                 549                |          210          
      shape torch.Size([1024, 64, 64])    |          761200         |           31140            |             308           |            741           |                 918                |          308          
      shape torch.Size([1, 128, 128])     |             756         |              46            |             120           |             47           |                  89                |           46          
      shape torch.Size([2, 128, 128])     |            1500         |              91            |             123           |             89           |                 110                |           92          
      shape torch.Size([4, 128, 128])     |            2994         |             178            |             139           |            180           |                 131                |          139          
      shape torch.Size([8, 128, 128])     |            5960         |             350            |             140           |            354           |                 208                |          142          
      shape torch.Size([16, 128, 128])    |           12000         |             701            |             144           |            177           |                 198                |          143          
      shape torch.Size([32, 128, 128])    |           23870         |            1401            |             155           |            225           |                 246                |          155          
      shape torch.Size([64, 128, 128])    |           47600         |            2948            |             179           |            288           |                 323                |          180          
      shape torch.Size([128, 128, 128])   |           96000         |            5910            |             231           |            442           |                 512                |          231          
      shape torch.Size([512, 128, 128])   |          381200         |           23640            |             519           |           1400           |                1700                |          519          
      shape torch.Size([1024, 128, 128])  |          755800         |           47340            |             794           |           2436           |                3018                |          794          
      shape torch.Size([1, 256, 256])     |             760         |              74            |             246           |             77           |                  88                |           78          
      shape torch.Size([2, 256, 256])     |            1510         |             150            |             256           |            150           |                 117                |          150          
      shape torch.Size([4, 256, 256])     |            3030         |             296            |             284           |            296           |                 209                |          284          
      shape torch.Size([8, 256, 256])     |            6100         |             588            |             286           |            592           |                 394                |          288          
      shape torch.Size([16, 256, 256])    |           12200         |            1238            |             330           |            445           |                 480                |          330          
      shape torch.Size([32, 256, 256])    |           24430         |            2476            |             368           |            568           |                 629                |          367          
      shape torch.Size([64, 256, 256])    |           49000         |            4950            |             415           |            800           |                 921                |          414          
      shape torch.Size([128, 256, 256])   |           96000         |            9900            |             552           |           1330           |                1579                |          553          
      shape torch.Size([512, 256, 256])   |          369400         |           39580            |            1410           |           4614           |                5616                |         1410          
      shape torch.Size([1024, 256, 256])  |          716200         |           79200            |            2270           |           8472           |               10500                |         2277          

Times are in microseconds (us).
```

</details>

To generate the results below, I put the backend I wanted to test at the beginning of the function `lu_solve_kernel`, followed by a `break;`. Then I run the following script, changing the variable `name`. For the `lu_solve unpack+solve_triangular`, I also changed the `stmt` variable (uncomenting the commented one)
<details>
<summary>
Benchmarking script
</summary>

```python
import torch
import pickle
import itertools
from functools import partial
from torch.utils.benchmark import Timer, Compare

name = "heuristic"
label = "lu_solve {}".format(name)
shapes = [1, 2, 8, 16, 32, 64, 128, 256]
batches = [(1,), (2,), (4,), (8,), (16,), (32,), (64,), (128,), (512,), (1024,)]
results = []
make_arg = partial(torch.randn, dtype=torch.float32, device="cuda")


def f(LU, pivots, B, adjoint):
    P, L, U = torch.lu_unpack(LU, pivots)
    if adjoint:
        X = torch.linalg.solve_triangular(U.mH, B, upper=False)
        return P @ torch.linalg.solve_triangular(L.mH, X, upper=True, unitriangular=True, out=X)
    else:
        X = P.mT @ B
        X = torch.linalg.solve_triangular(L, X, upper=False, unitriangular=True, out=X)
        return torch.linalg.solve_triangular(U, X, upper=True, out=X)


for n, batch in itertools.product(shapes, batches):
    LU, pivots = torch.linalg.lu_factor(make_arg(batch + (n, n)))
    B = make_arg(batch + (n, 1))
    print(LU.shape)
    stmt = "torch.linalg.lu_solve(LU, pivots, B, adjoint=adjoint)"
    #stmt = "f(LU, pivots, B, adjoint=adjoint)"
    for adjoint in (True, False):
        timer = Timer(stmt,
                      globals=globals(),
                      label="linalg.lu_solve CUDA{}".format(" Adjoint" if adjoint else ""),
                      description=label,
                      sub_label=f"shape {LU.shape}",
                      num_threads=1)
        results.append(timer.blocked_autorange())


compare = Compare(results)
compare.trim_significant_figures()
compare.print()

with open("{}_lu_solve.pickle".format(name), 'wb') as f:
    pickle.dump(results, f)
```
</details>

Finally, I joined all the results with the following script:
<details>
<summary>
Script to join the results
</summary>

```python
import pickle
from torch.utils.benchmark import Timer, Compare

files = [
         "looped_magma",
         "looped cusolver",
         "batched cublas",
         "batched magma",
         "unpack+solve_triangular",
         "heuristic",
        ]

timers = []
for name in files:
    with open("{}_lu_solve.pickle".format(name), 'rb') as f:
        timers += pickle.load(f)

compare = Compare(timers)
compare.trim_significant_figures()
compare.print()
```
</details>

### Fix for Magma's batched lu_solve when `adjoint=True`
I also developed the following fix around MAGMA's bug, but I ended up not using it, and preferring the triangular solves over it, as they were faster. I'm leaving it here in case it's useful in the future.

<details>
<summary>
Fix for MAGMA's issue with `adjoint=True`
</summary>

```cpp
auto lu_solve_batched_magma_fn = [m](const Tensor& LU, const Tensor& pivots, const Tensor& B, TransposeType trans) {
    if (trans == TransposeType::NoTranspose) {
      lu_solve_batched_magma(LU, pivots, B, trans);
      return;
    }
    // There's a bug in magma for the other cases, so we need to properly perform mT or mH on LU
    // The LU of the transpose is not the transpose of the LU
    // We need to do LU = LDU' = L'U' where L' = LD, U' = D^{-1}U and D = diag(U)
    auto diag = LU.diagonal(0, -2, -1);
    auto LU_f = LU.tril(-1).mul_(diag.unsqueeze(-2)) +
                LU.triu(1).div_(diag.unsqueeze(-1));
    LU_f.diagonal(0, -2, -1).copy_(diag);

    if (trans == TransposeType::ConjTranspose) {
      LU_f = LU_f.conj_physical();
    }
    LU_f.transpose(-2, -1);
    // At this point LU_f is F-contiguous, because triu / tril / conj_phisical return contiguous tensors
    TORCH_INTERNAL_ASSERT_DEBUG_ONLY(LU_f.mT().is_contiguous());

    // Trivial permutation
    auto pivots_aux = at::arange(1, m + 1, pivots.options()).expand_as(pivots).contiguous();

    lu_solve_batched_magma(LU_f, pivots_aux, B, TransposeType::NoTranspose);

    // We then need to multiply B by P on the right as (PLU)^T = B iff U^TL^T = BP
    // Fill `perm` with the identity permutation (perhaps batched)
    // This is faster than torch.lu_unpack + matmul, as this logic is borrowed from lu_unpack
    const auto perm = at::arange(m, pivots.options().dtype(kLong)).expand(pivots.sizes()).contiguous();
    auto iter = TensorIteratorConfig()
      .set_check_mem_overlap(false)
      .check_all_same_dtype(false)
      .resize_outputs(false)
      .declare_static_shape(pivots.sizes(), /*squash_dim=*/pivots.dim() - 1)
      .add_output(perm)
      .add_input(pivots)
      .build();
    unpack_pivots_stub(pivots.device().type(), iter, m);
    B.scatter_(-2, perm.unsqueeze(-1).expand_as(B), B.clone());
};
```

</details>



Fixes #61657

[ghstack-poisoned]
This PR adds `linalg.lu_solve`. While doing so, I found a bug in MAGMA
when calling the batched MAGMA backend with trans=True. We work around
that by solving the system solving two triangular systems.

We also update the heuristics for this function, as they were fairly
outdated. We found that cuSolver is king, so luckily we do not need to
rely on the buggy backend from magma for this function.

We added tests testing this function left and right. We also added tests
for the different backends. We also activated the tests for AMD, as
those should work as well.

### Benchmarking

<details>
<summary>
Benchmark Results (adjoint=False)
</summary>

```
--------------------------------------------------------------------------------------------- linalg.lu_solve CUDA ---------------------------------------------------------------------------------------------]   
                                          |  lu_solve looped_magma  |  lu_solve looped cusolver  |  lu_solve batched cublas  |  lu_solve batched magma  |  lu_solve unpack+solve_triangular  |  lu_solve heuristic   
1 threads: -------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------   
      shape torch.Size([1, 1, 1])         |             750         |              34            |              28           |            252           |                  86                |           27          
      shape torch.Size([2, 1, 1])         |            1500         |              50            |              28           |            239           |                  94                |           27          
      shape torch.Size([4, 1, 1])         |            2995         |              83            |              28           |            239           |                 100                |           27          
      shape torch.Size([8, 1, 1])         |            6000         |             146            |              28           |            239           |                  94                |           27          
      shape torch.Size([16, 1, 1])        |           11900         |             272            |              28           |            241           |                  95                |           27          
      shape torch.Size([32, 1, 1])        |           23880         |             524            |              28           |            244           |                  94                |           27          
      shape torch.Size([64, 1, 1])        |           48000         |            1000            |              28           |            245           |                  99                |           27          
      shape torch.Size([128, 1, 1])       |           96000         |            2054            |              28           |            242           |                  96                |           27          
      shape torch.Size([512, 1, 1])       |          381900         |            8100            |              28           |            250           |                  94                |           27          
      shape torch.Size([1024, 1, 1])      |          763800         |           16200            |              28           |            257           |                  95                |           27          
      shape torch.Size([1, 2, 2])         |             750         |              33            |              28           |            240           |                  88                |           27          
      shape torch.Size([2, 2, 2])         |            1500         |              51            |              28           |            240           |                  96                |           27          
      shape torch.Size([4, 2, 2])         |            2991         |              82            |              28           |            241           |                  96                |           28          
      shape torch.Size([8, 2, 2])         |            6000         |             150            |              28           |            241           |                  96                |           27          
      shape torch.Size([16, 2, 2])        |           12000         |             275            |              28           |            242           |                  96                |           27          
      shape torch.Size([32, 2, 2])        |           23980         |             530            |              28           |            246           |                  97                |           28          
      shape torch.Size([64, 2, 2])        |           48000         |            1000            |              28           |            244           |                  96                |           27          
      shape torch.Size([128, 2, 2])       |           96000         |            2063            |              28           |            245           |                  96                |           28          
      shape torch.Size([512, 2, 2])       |          382000         |            8300            |              28           |            257           |                  97                |           28          
      shape torch.Size([1024, 2, 2])      |          764000         |           20000            |              28           |            271           |                  97                |           28          
      shape torch.Size([1, 8, 8])         |             749         |              34            |              28           |            243           |                  88                |           28          
      shape torch.Size([2, 8, 8])         |            1500         |              50            |              28           |            244           |                  97                |           28          
      shape torch.Size([4, 8, 8])         |            2988         |              83            |              28           |            244           |                 100                |           28          
      shape torch.Size([8, 8, 8])         |            5980         |             150            |              28           |            245           |                  97                |           28          
      shape torch.Size([16, 8, 8])        |           12000         |             278            |              28           |            246           |                  96                |           28          
      shape torch.Size([32, 8, 8])        |           23910         |             536            |              28           |            249           |                  98                |           28          
      shape torch.Size([64, 8, 8])        |           47800         |            1100            |              28           |            247           |                  96                |           28          
      shape torch.Size([128, 8, 8])       |           96000         |            2075            |              28           |            248           |                  96                |           28          
      shape torch.Size([512, 8, 8])       |          382100         |            8300            |              28           |            270           |                  97                |           28          
      shape torch.Size([1024, 8, 8])      |          764100         |           16400            |              28           |            291           |                 100                |           28          
      shape torch.Size([1, 16, 16])       |             750         |              33            |              28           |            248           |                  88                |           28
      shape torch.Size([2, 16, 16])       |            1500         |              50            |              28           |            250           |                  97                |           28       
      shape torch.Size([4, 16, 16])       |            2996         |              83            |              28           |            250           |                 100                |           28       
      shape torch.Size([8, 16, 16])       |            5980         |             147            |              28           |            251           |                  97                |           28       
      shape torch.Size([16, 16, 16])      |           11900         |             274            |              28           |            251           |                  97                |           28       
      shape torch.Size([32, 16, 16])      |           24040         |             527            |              28           |            252           |                  97                |           28       
      shape torch.Size([64, 16, 16])      |           47800         |            1037            |              28           |            251           |                 100                |           28       
      shape torch.Size([128, 16, 16])     |           95600         |            2044            |              28           |            252           |                  98                |           28       
      shape torch.Size([512, 16, 16])     |          388200         |            8100            |              28           |            280           |                 100                |           28       
      shape torch.Size([1024, 16, 16])    |          769700         |           16000            |              28           |            322           |                 117                |           28       
      shape torch.Size([1, 32, 32])       |             760         |              33            |              28           |            255           |                  88                |           28       
      shape torch.Size([2, 32, 32])       |            1510         |              50            |              28           |            256           |                  97                |           28       
      shape torch.Size([4, 32, 32])       |            3022         |              82            |              31           |            256           |                  97                |           30       
      shape torch.Size([8, 32, 32])       |            6000         |             140            |              31           |            257           |                 100                |           31       
      shape torch.Size([16, 32, 32])      |           12000         |             281            |              31           |            258           |                  96                |           31       
      shape torch.Size([32, 32, 32])      |           24150         |             563            |              35           |            258           |                  96                |           35       
      shape torch.Size([64, 32, 32])      |           48300         |            1119            |              36           |            258           |                  97                |           36       
      shape torch.Size([128, 32, 32])     |           96500         |            2235            |              43           |            261           |                  96                |           43       
      shape torch.Size([512, 32, 32])     |          383100         |            8930            |              82           |            317           |                 191                |           82       
      shape torch.Size([1024, 32, 32])    |          766300         |           19200            |             122           |            400           |                 312                |          122       
      shape torch.Size([1, 64, 64])       |             760         |              33            |              55           |            272           |                  68                |           34       
      shape torch.Size([2, 64, 64])       |            1500         |              52            |              58           |            273           |                  85                |           52       
      shape torch.Size([4, 64, 64])       |            3127         |             102            |              65           |            273           |                 150                |           65       
      shape torch.Size([8, 64, 64])       |            6070         |             201            |              65           |            275           |                 278                |           65       
      shape torch.Size([16, 64, 64])      |           12000         |             399            |              66           |            274           |                  95                |           67       
      shape torch.Size([32, 64, 64])      |           23900         |             796            |              73           |            275           |                  97                |           73       
      shape torch.Size([64, 64, 64])      |           48000         |            1594            |              75           |            283           |                 123                |           76       
      shape torch.Size([128, 64, 64])     |           95000         |            3177            |              96           |            292           |                 176                |           96       
      shape torch.Size([512, 64, 64])     |          379300         |           13520            |             208           |            426           |                 551                |          208       
      shape torch.Size([1024, 64, 64])    |          758700         |           27100            |             306           |            570           |                 919                |          306       
      shape torch.Size([1, 128, 128])     |             750         |              42            |             115           |            306           |                  90                |           42       
      shape torch.Size([2, 128, 128])     |            1500         |              82            |             122           |            307           |                 164                |           83       
      shape torch.Size([4, 128, 128])     |            2966         |             162            |             136           |            307           |                 301                |          136       
      shape torch.Size([8, 128, 128])     |            5930         |             317            |             137           |            308           |                 578                |          138       
      shape torch.Size([16, 128, 128])    |           12000         |             635            |             143           |            316           |                 199                |          143       
      shape torch.Size([32, 128, 128])    |           23700         |            1266            |             152           |            322           |                 241                |          152       
      shape torch.Size([64, 128, 128])    |           48000         |            2668            |             177           |            337           |                 322                |          177       
      shape torch.Size([128, 128, 128])   |           96000         |            5366            |             228           |            365           |                 514                |          228       
      shape torch.Size([512, 128, 128])   |          379400         |           21490            |             502           |            620           |                1697                |          502       
      shape torch.Size([1024, 128, 128])  |          755700         |           43040            |             764           |            903           |                3040                |          770       
      shape torch.Size([1, 256, 256])     |             750         |              70            |             235           |            383           |                 178                |           72       
      shape torch.Size([2, 256, 256])     |            2000         |             138            |             250           |            384           |                 329                |          139       
      shape torch.Size([4, 256, 256])     |            2988         |             277            |             279           |            404           |                 655                |          278       
      shape torch.Size([8, 256, 256])     |            6100         |             546            |             283           |            420           |                1321                |          286       
      shape torch.Size([16, 256, 256])    |           12100         |            1149            |             330           |            441           |                 472                |          330       
      shape torch.Size([32, 256, 256])    |           24040         |            2303            |             359           |            453           |                 634                |          360       
      shape torch.Size([64, 256, 256])    |           48000         |            4626            |             408           |            472           |                 925                |          408       
      shape torch.Size([128, 256, 256])   |           94700         |            9247            |             543           |            543           |                1582                |          543       
      shape torch.Size([512, 256, 256])   |          372000         |           37030            |            1310           |           1185           |                5711                |         1310       
      shape torch.Size([1024, 256, 256])  |          747200         |           74100            |            2116           |           1910           |               10660                |         2122       
```
</details>

<details>
<summary>
Benchmark Results (adjoint=True)
</summary>

```
[----------------------------------------------------------------------------------------- linalg.lu_solve CUDA Adjoint -----------------------------------------------------------------------------------------]   
                                          |  lu_solve looped_magma  |  lu_solve looped cusolver  |  lu_solve batched cublas  |  lu_solve batched magma  |  lu_solve unpack+solve_triangular  |  lu_solve heuristic   
1 threads: -------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------   
      shape torch.Size([1, 1, 1])         |             749         |              34            |              28           |             33           |                  98                |           27          
      shape torch.Size([2, 1, 1])         |            1500         |              50            |              28           |             50           |                 110                |           27          
      shape torch.Size([4, 1, 1])         |            3005         |              82            |              28           |             81           |                 110                |           27          
      shape torch.Size([8, 1, 1])         |            5999         |             145            |              28           |            140           |                 110                |           27          
      shape torch.Size([16, 1, 1])        |           12000         |             273            |              28           |             77           |                 110                |           27          
      shape torch.Size([32, 1, 1])        |           24000         |             522            |              28           |             78           |                 110                |           27          
      shape torch.Size([64, 1, 1])        |           48000         |            1000            |              28           |             77           |                 100                |           27          
      shape torch.Size([128, 1, 1])       |           96000         |            2029            |              28           |             78           |                 110                |           27          
      shape torch.Size([512, 1, 1])       |          383300         |            8100            |              28           |             78           |                 110                |           28          
      shape torch.Size([1024, 1, 1])      |          767500         |           16100            |              28           |             77           |                 100                |           27          
      shape torch.Size([1, 2, 2])         |             753         |              33            |              28           |             33           |                  99                |           28          
      shape torch.Size([2, 2, 2])         |            1500         |              50            |              28           |             50           |                 110                |           28          
      shape torch.Size([4, 2, 2])         |            3002         |              82            |              28           |             80           |                 100                |           27          
      shape torch.Size([8, 2, 2])         |            6000         |             145            |              28           |            144           |                 107                |           27          
      shape torch.Size([16, 2, 2])        |           12000         |             271            |              28           |             78           |                 110                |           27          
      shape torch.Size([32, 2, 2])        |           24120         |             524            |              28           |             78           |                 110                |           28          
      shape torch.Size([64, 2, 2])        |           48300         |            1030            |              28           |             78           |                 111                |           27          
      shape torch.Size([128, 2, 2])       |           96100         |            2041            |              28           |             78           |                 107                |           28          
      shape torch.Size([512, 2, 2])       |          383000         |            8100            |              28           |             79           |                 108                |           28          
      shape torch.Size([1024, 2, 2])      |          766100         |           16000            |              28           |             78           |                 110                |           28          
      shape torch.Size([1, 8, 8])         |             750         |              34            |              28           |             34           |                  99                |           28          
      shape torch.Size([2, 8, 8])         |            1500         |              50            |              28           |             50           |                 107                |           28          
      shape torch.Size([4, 8, 8])         |            2998         |              82            |              28           |             82           |                 110                |           28          
      shape torch.Size([8, 8, 8])         |            5990         |             146            |              28           |            150           |                 107                |           28          
      shape torch.Size([16, 8, 8])        |           11980         |             272            |              28           |             79           |                 107                |           28          
      shape torch.Size([32, 8, 8])        |           23970         |             530            |              28           |             79           |                 110                |           28          
      shape torch.Size([64, 8, 8])        |           47900         |            1040            |              28           |             79           |                 108                |           28          
      shape torch.Size([128, 8, 8])       |           96000         |            2048            |              28           |             78           |                 108                |           28          
      shape torch.Size([512, 8, 8])       |          383700         |            8100            |              28           |             80           |                 108                |           28          
      shape torch.Size([1024, 8, 8])      |          766200         |           16300            |              28           |             80           |                 108                |           28          
      shape torch.Size([1, 16, 16])       |             760         |              33            |              28           |             34           |                  99                |           28          
      shape torch.Size([2, 16, 16])       |            1500         |              50            |              28           |             50           |                 110                |           28  [85/469]
      shape torch.Size([4, 16, 16])       |            3001         |              81            |              28           |             82           |                 108                |           28          
      shape torch.Size([8, 16, 16])       |            6000         |             145            |              28           |            140           |                 110                |           28          
      shape torch.Size([16, 16, 16])      |           12000         |             276            |              28           |             79           |                 110                |           28          
      shape torch.Size([32, 16, 16])      |           23870         |             549            |              28           |             79           |                 110                |           28          
      shape torch.Size([64, 16, 16])      |           47900         |            1098            |              29           |             80           |                 100                |           28          
      shape torch.Size([128, 16, 16])     |           95800         |            2184            |              28           |             79           |                 108                |           28          
      shape torch.Size([512, 16, 16])     |          386900         |            8769            |              28           |             80           |                 108                |           28          
      shape torch.Size([1024, 16, 16])    |          769800         |           17460            |              37           |             80           |                 107                |           37          
      shape torch.Size([1, 32, 32])       |             760         |              33            |              28           |             34           |                  99                |           28          
      shape torch.Size([2, 32, 32])       |            1500         |              50            |              28           |             50           |                 110                |           29          
      shape torch.Size([4, 32, 32])       |            3021         |              86            |              31           |             84           |                 110                |           32          
      shape torch.Size([8, 32, 32])       |            6040         |             167            |              32           |            167           |                 108                |           32          
      shape torch.Size([16, 32, 32])      |           12100         |             330            |              33           |             78           |                 107                |           33          
      shape torch.Size([32, 32, 32])      |           24150         |             662            |              35           |             78           |                 110                |           35          
      shape torch.Size([64, 32, 32])      |           48200         |            1323            |              36           |             79           |                 110                |           36          
      shape torch.Size([128, 32, 32])     |           97000         |            2637            |              44           |             79           |                 110                |           43          
      shape torch.Size([512, 32, 32])     |          382500         |           10580            |              83           |            180           |                 198                |           83          
      shape torch.Size([1024, 32, 32])    |          766600         |           22670            |             123           |            260           |                 318                |          120          
      shape torch.Size([1, 64, 64])       |             760         |              33            |              58           |             33           |                  88                |           34          
      shape torch.Size([2, 64, 64])       |            1520         |              60            |              60           |             60           |                 109                |           59          
      shape torch.Size([4, 64, 64])       |            3016         |             115            |              67           |            119           |                 132                |           66          
      shape torch.Size([8, 64, 64])       |            6120         |             230            |              67           |            233           |                 180                |           68          
      shape torch.Size([16, 64, 64])      |           12100         |             457            |              69           |             86           |                 107                |           69          
      shape torch.Size([32, 64, 64])      |           24000         |             912            |              74           |             95           |                 100                |           74          
      shape torch.Size([64, 64, 64])      |           48000         |            1833            |              76           |            106           |                 123                |           76          
      shape torch.Size([128, 64, 64])     |           95000         |            3636            |              97           |            163           |                 183                |           97          
      shape torch.Size([512, 64, 64])     |          380600         |           15600            |             210           |            464           |                 549                |          210          
      shape torch.Size([1024, 64, 64])    |          761200         |           31140            |             308           |            741           |                 918                |          308          
      shape torch.Size([1, 128, 128])     |             756         |              46            |             120           |             47           |                  89                |           46          
      shape torch.Size([2, 128, 128])     |            1500         |              91            |             123           |             89           |                 110                |           92          
      shape torch.Size([4, 128, 128])     |            2994         |             178            |             139           |            180           |                 131                |          139          
      shape torch.Size([8, 128, 128])     |            5960         |             350            |             140           |            354           |                 208                |          142          
      shape torch.Size([16, 128, 128])    |           12000         |             701            |             144           |            177           |                 198                |          143          
      shape torch.Size([32, 128, 128])    |           23870         |            1401            |             155           |            225           |                 246                |          155          
      shape torch.Size([64, 128, 128])    |           47600         |            2948            |             179           |            288           |                 323                |          180          
      shape torch.Size([128, 128, 128])   |           96000         |            5910            |             231           |            442           |                 512                |          231          
      shape torch.Size([512, 128, 128])   |          381200         |           23640            |             519           |           1400           |                1700                |          519          
      shape torch.Size([1024, 128, 128])  |          755800         |           47340            |             794           |           2436           |                3018                |          794          
      shape torch.Size([1, 256, 256])     |             760         |              74            |             246           |             77           |                  88                |           78          
      shape torch.Size([2, 256, 256])     |            1510         |             150            |             256           |            150           |                 117                |          150          
      shape torch.Size([4, 256, 256])     |            3030         |             296            |             284           |            296           |                 209                |          284          
      shape torch.Size([8, 256, 256])     |            6100         |             588            |             286           |            592           |                 394                |          288          
      shape torch.Size([16, 256, 256])    |           12200         |            1238            |             330           |            445           |                 480                |          330          
      shape torch.Size([32, 256, 256])    |           24430         |            2476            |             368           |            568           |                 629                |          367          
      shape torch.Size([64, 256, 256])    |           49000         |            4950            |             415           |            800           |                 921                |          414          
      shape torch.Size([128, 256, 256])   |           96000         |            9900            |             552           |           1330           |                1579                |          553          
      shape torch.Size([512, 256, 256])   |          369400         |           39580            |            1410           |           4614           |                5616                |         1410          
      shape torch.Size([1024, 256, 256])  |          716200         |           79200            |            2270           |           8472           |               10500                |         2277          

Times are in microseconds (us).
```

</details>

To generate the results below, I put the backend I wanted to test at the beginning of the function `lu_solve_kernel`, followed by a `break;`. Then I run the following script, changing the variable `name`. For the `lu_solve unpack+solve_triangular`, I also changed the `stmt` variable (uncomenting the commented one)
<details>
<summary>
Benchmarking script
</summary>

```python
import torch
import pickle
import itertools
from functools import partial
from torch.utils.benchmark import Timer, Compare

name = "heuristic"
label = "lu_solve {}".format(name)
shapes = [1, 2, 8, 16, 32, 64, 128, 256]
batches = [(1,), (2,), (4,), (8,), (16,), (32,), (64,), (128,), (512,), (1024,)]
results = []
make_arg = partial(torch.randn, dtype=torch.float32, device="cuda")


def f(LU, pivots, B, adjoint):
    P, L, U = torch.lu_unpack(LU, pivots)
    if adjoint:
        X = torch.linalg.solve_triangular(U.mH, B, upper=False)
        return P @ torch.linalg.solve_triangular(L.mH, X, upper=True, unitriangular=True, out=X)
    else:
        X = P.mT @ B
        X = torch.linalg.solve_triangular(L, X, upper=False, unitriangular=True, out=X)
        return torch.linalg.solve_triangular(U, X, upper=True, out=X)


for n, batch in itertools.product(shapes, batches):
    LU, pivots = torch.linalg.lu_factor(make_arg(batch + (n, n)))
    B = make_arg(batch + (n, 1))
    print(LU.shape)
    stmt = "torch.linalg.lu_solve(LU, pivots, B, adjoint=adjoint)"
    #stmt = "f(LU, pivots, B, adjoint=adjoint)"
    for adjoint in (True, False):
        timer = Timer(stmt,
                      globals=globals(),
                      label="linalg.lu_solve CUDA{}".format(" Adjoint" if adjoint else ""),
                      description=label,
                      sub_label=f"shape {LU.shape}",
                      num_threads=1)
        results.append(timer.blocked_autorange())


compare = Compare(results)
compare.trim_significant_figures()
compare.print()

with open("{}_lu_solve.pickle".format(name), 'wb') as f:
    pickle.dump(results, f)
```
</details>

Finally, I joined all the results with the following script:
<details>
<summary>
Script to join the results
</summary>

```python
import pickle
from torch.utils.benchmark import Timer, Compare

files = [
         "looped_magma",
         "looped cusolver",
         "batched cublas",
         "batched magma",
         "unpack+solve_triangular",
         "heuristic",
        ]

timers = []
for name in files:
    with open("{}_lu_solve.pickle".format(name), 'rb') as f:
        timers += pickle.load(f)

compare = Compare(timers)
compare.trim_significant_figures()
compare.print()
```
</details>

### Fix for Magma's batched lu_solve when `adjoint=True`
I also developed the following fix around MAGMA's bug, but I ended up not using it, and preferring the triangular solves over it, as they were faster. I'm leaving it here in case it's useful in the future.

<details>
<summary>
Fix for MAGMA's issue with `adjoint=True`
</summary>

```cpp
auto lu_solve_batched_magma_fn = [m](const Tensor& LU, const Tensor& pivots, const Tensor& B, TransposeType trans) {
    if (trans == TransposeType::NoTranspose) {
      lu_solve_batched_magma(LU, pivots, B, trans);
      return;
    }
    // There's a bug in magma for the other cases, so we need to properly perform mT or mH on LU
    // The LU of the transpose is not the transpose of the LU
    // We need to do LU = LDU' = L'U' where L' = LD, U' = D^{-1}U and D = diag(U)
    auto diag = LU.diagonal(0, -2, -1);
    auto LU_f = LU.tril(-1).mul_(diag.unsqueeze(-2)) +
                LU.triu(1).div_(diag.unsqueeze(-1));
    LU_f.diagonal(0, -2, -1).copy_(diag);

    if (trans == TransposeType::ConjTranspose) {
      LU_f = LU_f.conj_physical();
    }
    LU_f.transpose(-2, -1);
    // At this point LU_f is F-contiguous, because triu / tril / conj_phisical return contiguous tensors
    TORCH_INTERNAL_ASSERT_DEBUG_ONLY(LU_f.mT().is_contiguous());

    // Trivial permutation
    auto pivots_aux = at::arange(1, m + 1, pivots.options()).expand_as(pivots).contiguous();

    lu_solve_batched_magma(LU_f, pivots_aux, B, TransposeType::NoTranspose);

    // We then need to multiply B by P on the right as (PLU)^T = B iff U^TL^T = BP
    // Fill `perm` with the identity permutation (perhaps batched)
    // This is faster than torch.lu_unpack + matmul, as this logic is borrowed from lu_unpack
    const auto perm = at::arange(m, pivots.options().dtype(kLong)).expand(pivots.sizes()).contiguous();
    auto iter = TensorIteratorConfig()
      .set_check_mem_overlap(false)
      .check_all_same_dtype(false)
      .resize_outputs(false)
      .declare_static_shape(pivots.sizes(), /*squash_dim=*/pivots.dim() - 1)
      .add_output(perm)
      .add_input(pivots)
      .build();
    unpack_pivots_stub(pivots.device().type(), iter, m);
    B.scatter_(-2, perm.unsqueeze(-1).expand_as(B), B.clone());
};
```

</details>



Fixes #61657

[ghstack-poisoned]
lezcano added 2 commits May 4, 2022 13:19
This PR adds `linalg.lu_solve`. While doing so, I found a bug in MAGMA
when calling the batched MAGMA backend with trans=True. We work around
that by solving the system solving two triangular systems.

We also update the heuristics for this function, as they were fairly
outdated. We found that cuSolver is king, so luckily we do not need to
rely on the buggy backend from magma for this function.

We added tests testing this function left and right. We also added tests
for the different backends. We also activated the tests for AMD, as
those should work as well.

### Benchmarking

<details>
<summary>
Benchmark Results (adjoint=False)
</summary>

```
--------------------------------------------------------------------------------------------- linalg.lu_solve CUDA ---------------------------------------------------------------------------------------------]   
                                          |  lu_solve looped_magma  |  lu_solve looped cusolver  |  lu_solve batched cublas  |  lu_solve batched magma  |  lu_solve unpack+solve_triangular  |  lu_solve heuristic   
1 threads: -------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------   
      shape torch.Size([1, 1, 1])         |             750         |              34            |              28           |            252           |                  78                |           27          
      shape torch.Size([2, 1, 1])         |            1500         |              50            |              28           |            239           |                  85                |           27          
      shape torch.Size([4, 1, 1])         |            2995         |              83            |              28           |            239           |                  85                |           27          
      shape torch.Size([8, 1, 1])         |            6000         |             146            |              28           |            239           |                  85                |           27          
      shape torch.Size([16, 1, 1])        |           11900         |             272            |              28           |            241           |                  85                |           27          
      shape torch.Size([32, 1, 1])        |           23880         |             524            |              28           |            244           |                  85                |           27          
      shape torch.Size([64, 1, 1])        |           48000         |            1000            |              28           |            245           |                  85                |           27          
      shape torch.Size([128, 1, 1])       |           96000         |            2054            |              28           |            242           |                  86                |           27          
      shape torch.Size([512, 1, 1])       |          381900         |            8100            |              28           |            250           |                  85                |           27          
      shape torch.Size([1024, 1, 1])      |          763800         |           16200            |              28           |            257           |                  86                |           27          
      shape torch.Size([1, 2, 2])         |             750         |              33            |              28           |            240           |                  82                |           27          
      shape torch.Size([2, 2, 2])         |            1500         |              51            |              28           |            240           |                  88                |           27          
      shape torch.Size([4, 2, 2])         |            2991         |              82            |              28           |            241           |                  88                |           28          
      shape torch.Size([8, 2, 2])         |            6000         |             150            |              28           |            241           |                  88                |           27          
      shape torch.Size([16, 2, 2])        |           12000         |             275            |              28           |            242           |                  88                |           27          
      shape torch.Size([32, 2, 2])        |           23980         |             530            |              28           |            246           |                  90                |           28          
      shape torch.Size([64, 2, 2])        |           48000         |            1000            |              28           |            244           |                  89                |           27          
      shape torch.Size([128, 2, 2])       |           96000         |            2063            |              28           |            245           |                  89                |           28          
      shape torch.Size([512, 2, 2])       |          382000         |            8300            |              28           |            257           |                  89                |           28          
      shape torch.Size([1024, 2, 2])      |          764000         |           20000            |              28           |            271           |                  88                |           28          
      shape torch.Size([1, 8, 8])         |             749         |              34            |              28           |            243           |                  82                |           28          
      shape torch.Size([2, 8, 8])         |            1500         |              50            |              28           |            244           |                  89                |           28          
      shape torch.Size([4, 8, 8])         |            2988         |              83            |              28           |            244           |                  89                |           28          
      shape torch.Size([8, 8, 8])         |            5980         |             150            |              28           |            245           |                  88                |           28          
      shape torch.Size([16, 8, 8])        |           12000         |             278            |              28           |            246           |                  89                |           28          
      shape torch.Size([32, 8, 8])        |           23910         |             536            |              28           |            249           |                  89                |           28          
      shape torch.Size([64, 8, 8])        |           47800         |            1100            |              28           |            247           |                  89                |           28          
      shape torch.Size([128, 8, 8])       |           96000         |            2075            |              28           |            248           |                  89                |           28          
      shape torch.Size([512, 8, 8])       |          382100         |            8300            |              28           |            270           |                  89                |           28          
      shape torch.Size([1024, 8, 8])      |          764100         |           16400            |              28           |            291           |                  89                |           28          
      shape torch.Size([1, 16, 16])       |             750         |              33            |              28           |            248           |                  82                |           28       
      shape torch.Size([2, 16, 16])       |            1500         |              50            |              28           |            250           |                  89                |           28       
      shape torch.Size([4, 16, 16])       |            2996         |              83            |              28           |            250           |                  89                |           28       
      shape torch.Size([8, 16, 16])       |            5980         |             147            |              28           |            251           |                  90                |           28       
      shape torch.Size([16, 16, 16])      |           11900         |             274            |              28           |            251           |                  89                |           28       
      shape torch.Size([32, 16, 16])      |           24040         |             527            |              28           |            252           |                  89                |           28       
      shape torch.Size([64, 16, 16])      |           47800         |            1037            |              28           |            251           |                  89                |           28       
      shape torch.Size([128, 16, 16])     |           95600         |            2044            |              28           |            252           |                  89                |           28       
      shape torch.Size([512, 16, 16])     |          388200         |            8100            |              28           |            280           |                  88                |           28       
      shape torch.Size([1024, 16, 16])    |          769700         |           16000            |              28           |            322           |                  89                |           28       
      shape torch.Size([1, 32, 32])       |             760         |              33            |              28           |            255           |                  82                |           28       
      shape torch.Size([2, 32, 32])       |            1510         |              50            |              28           |            256           |                  89                |           28       
      shape torch.Size([4, 32, 32])       |            3022         |              82            |              31           |            256           |                  89                |           30       
      shape torch.Size([8, 32, 32])       |            6000         |             140            |              31           |            257           |                  89                |           31       
      shape torch.Size([16, 32, 32])      |           12000         |             281            |              31           |            258           |                  89                |           31       
      shape torch.Size([32, 32, 32])      |           24150         |             563            |              35           |            258           |                  89                |           35       
      shape torch.Size([64, 32, 32])      |           48300         |            1119            |              36           |            258           |                  90                |           36       
      shape torch.Size([128, 32, 32])     |           96500         |            2235            |              43           |            261           |                  88                |           43       
      shape torch.Size([512, 32, 32])     |          383100         |            8930            |              82           |            317           |                 137                |           82       
      shape torch.Size([1024, 32, 32])    |          766300         |           19200            |             122           |            400           |                 187                |          122       
      shape torch.Size([1, 64, 64])       |             760         |              33            |              55           |            272           |                  71                |           34       
      shape torch.Size([2, 64, 64])       |            1500         |              52            |              58           |            273           |                  90                |           52       
      shape torch.Size([4, 64, 64])       |            3127         |             102            |              65           |            273           |                 110                |           65       
      shape torch.Size([8, 64, 64])       |            6070         |             201            |              65           |            275           |                 160                |           65       
      shape torch.Size([16, 64, 64])      |           12000         |             399            |              66           |            274           |                  88                |           67       
      shape torch.Size([32, 64, 64])      |           23900         |             796            |              73           |            275           |                  91                |           73       
      shape torch.Size([64, 64, 64])      |           48000         |            1594            |              75           |            283           |                  94                |           76       
      shape torch.Size([128, 64, 64])     |           95000         |            3177            |              96           |            292           |                 122                |           96       
      shape torch.Size([512, 64, 64])     |          379300         |           13520            |             208           |            426           |                 313                |          208       
      shape torch.Size([1024, 64, 64])    |          758700         |           27100            |             306           |            570           |                 437                |          306       
      shape torch.Size([1, 128, 128])     |             750         |              42            |             115           |            306           |                  71                |           42       
      shape torch.Size([2, 128, 128])     |            1500         |              82            |             122           |            307           |                  90                |           83       
      shape torch.Size([4, 128, 128])     |            2966         |             162            |             136           |            307           |                 117                |          136       
      shape torch.Size([8, 128, 128])     |            5930         |             317            |             137           |            308           |                 202                |          138       
      shape torch.Size([16, 128, 128])    |           12000         |             635            |             143           |            316           |                 162                |          143       
      shape torch.Size([32, 128, 128])    |           23700         |            1266            |             152           |            322           |                 168                |          152       
      shape torch.Size([64, 128, 128])    |           48000         |            2668            |             177           |            337           |                 196                |          177       
      shape torch.Size([128, 128, 128])   |           96000         |            5366            |             228           |            365           |                 260                |          228       
      shape torch.Size([512, 128, 128])   |          379400         |           21490            |             502           |            620           |                 669                |          502       
      shape torch.Size([1024, 128, 128])  |          755700         |           43040            |             764           |            903           |                 978                |          770       
      shape torch.Size([1, 256, 256])     |             750         |              70            |             235           |            383           |                  78                |           72       
      shape torch.Size([2, 256, 256])     |            2000         |             138            |             250           |            384           |                 120                |          139       
      shape torch.Size([4, 256, 256])     |            2988         |             277            |             279           |            404           |                 191                |          278       
      shape torch.Size([8, 256, 256])     |            6100         |             546            |             283           |            420           |                 332                |          286       
      shape torch.Size([16, 256, 256])    |           12100         |            1149            |             330           |            441           |                 364                |          330       
      shape torch.Size([32, 256, 256])    |           24040         |            2303            |             359           |            453           |                 401                |          360       
      shape torch.Size([64, 256, 256])    |           48000         |            4626            |             408           |            472           |                 466                |          408       
      shape torch.Size([128, 256, 256])   |           94700         |            9247            |             543           |            543           |                 658                |          543       
      shape torch.Size([512, 256, 256])   |          372000         |           37030            |            1310           |           1185           |                1896                |         1310       
      shape torch.Size([1024, 256, 256])  |          747200         |           74100            |            2116           |           1910           |                3051                |         2122 
```
</details>

<details>
<summary>
Benchmark Results (adjoint=True)
</summary>

```
[----------------------------------------------------------------------------------------- linalg.lu_solve CUDA Adjoint -----------------------------------------------------------------------------------------]   
                                          |  lu_solve looped_magma  |  lu_solve looped cusolver  |  lu_solve batched cublas  |  lu_solve batched magma  |  lu_solve unpack+solve_triangular  |  lu_solve heuristic   
1 threads: -------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------   
      shape torch.Size([1, 1, 1])         |             749         |              34            |              28           |             33           |                  60                |           27          
      shape torch.Size([2, 1, 1])         |            1500         |              50            |              28           |             50           |                  67                |           27          
      shape torch.Size([4, 1, 1])         |            3005         |              82            |              28           |             81           |                  67                |           27          
      shape torch.Size([8, 1, 1])         |            5999         |             145            |              28           |            140           |                  71                |           27          
      shape torch.Size([16, 1, 1])        |           12000         |             273            |              28           |             77           |                  67                |           27          
      shape torch.Size([32, 1, 1])        |           24000         |             522            |              28           |             78           |                  67                |           27          
      shape torch.Size([64, 1, 1])        |           48000         |            1000            |              28           |             77           |                  67                |           27          
      shape torch.Size([128, 1, 1])       |           96000         |            2029            |              28           |             78           |                  67                |           27          
      shape torch.Size([512, 1, 1])       |          383300         |            8100            |              28           |             78           |                  67                |           28          
      shape torch.Size([1024, 1, 1])      |          767500         |           16100            |              28           |             77           |                  68                |           27          
      shape torch.Size([1, 2, 2])         |             753         |              33            |              28           |             33           |                  60                |           28          
      shape torch.Size([2, 2, 2])         |            1500         |              50            |              28           |             50           |                  67                |           28          
      shape torch.Size([4, 2, 2])         |            3002         |              82            |              28           |             80           |                  67                |           27          
      shape torch.Size([8, 2, 2])         |            6000         |             145            |              28           |            144           |                  67                |           27          
      shape torch.Size([16, 2, 2])        |           12000         |             271            |              28           |             78           |                  66                |           27          
      shape torch.Size([32, 2, 2])        |           24120         |             524            |              28           |             78           |                  69                |           28          
      shape torch.Size([64, 2, 2])        |           48300         |            1030            |              28           |             78           |                  66                |           27          
      shape torch.Size([128, 2, 2])       |           96100         |            2041            |              28           |             78           |                  67                |           28          
      shape torch.Size([512, 2, 2])       |          383000         |            8100            |              28           |             79           |                  67                |           28          
      shape torch.Size([1024, 2, 2])      |          766100         |           16000            |              28           |             78           |                  67                |           28          
      shape torch.Size([1, 8, 8])         |             750         |              34            |              28           |             34           |                  60                |           28          
      shape torch.Size([2, 8, 8])         |            1500         |              50            |              28           |             50           |                  67                |           28          
      shape torch.Size([4, 8, 8])         |            2998         |              82            |              28           |             82           |                  67                |           28          
      shape torch.Size([8, 8, 8])         |            5990         |             146            |              28           |            150           |                  66                |           28          
      shape torch.Size([16, 8, 8])        |           11980         |             272            |              28           |             79           |                  67                |           28          
      shape torch.Size([32, 8, 8])        |           23970         |             530            |              28           |             79           |                  67                |           28          
      shape torch.Size([64, 8, 8])        |           47900         |            1040            |              28           |             79           |                  67                |           28          
      shape torch.Size([128, 8, 8])       |           96000         |            2048            |              28           |             78           |                  67                |           28          
      shape torch.Size([512, 8, 8])       |          383700         |            8100            |              28           |             80           |                  67                |           28          
      shape torch.Size([1024, 8, 8])      |          766200         |           16300            |              28           |             80           |                  68                |           28          
      shape torch.Size([1, 16, 16])       |             760         |              33            |              28           |             34           |                  60                |           28          
      shape torch.Size([2, 16, 16])       |            1500         |              50            |              28           |             50           |                  67                |           28          
      shape torch.Size([4, 16, 16])       |            3001         |              81            |              28           |             82           |                  67                |           28          
      shape torch.Size([8, 16, 16])       |            6000         |             145            |              28           |            140           |                  67                |           28          
      shape torch.Size([16, 16, 16])      |           12000         |             276            |              28           |             79           |                  67                |           28          
      shape torch.Size([32, 16, 16])      |           23870         |             549            |              28           |             79           |                  67                |           28          
      shape torch.Size([64, 16, 16])      |           47900         |            1098            |              29           |             80           |                  68                |           28          
      shape torch.Size([128, 16, 16])     |           95800         |            2184            |              28           |             79           |                  68                |           28          
      shape torch.Size([512, 16, 16])     |          386900         |            8769            |              28           |             80           |                  67                |           28          
      shape torch.Size([1024, 16, 16])    |          769800         |           17460            |              37           |             80           |                  67                |           37          
      shape torch.Size([1, 32, 32])       |             760         |              33            |              28           |             34           |                  60                |           28          
      shape torch.Size([2, 32, 32])       |            1500         |              50            |              28           |             50           |                  67                |           29          
      shape torch.Size([4, 32, 32])       |            3021         |              86            |              31           |             84           |                  67                |           32          
      shape torch.Size([8, 32, 32])       |            6040         |             167            |              32           |            167           |                  67                |           32          
      shape torch.Size([16, 32, 32])      |           12100         |             330            |              33           |             78           |                  67                |           33          
      shape torch.Size([32, 32, 32])      |           24150         |             662            |              35           |             78           |                  66                |           35          
      shape torch.Size([64, 32, 32])      |           48200         |            1323            |              36           |             79           |                  67                |           36          
      shape torch.Size([128, 32, 32])     |           97000         |            2637            |              44           |             79           |                  67                |           43          
      shape torch.Size([512, 32, 32])     |          382500         |           10580            |              83           |            180           |                 121                |           83          
      shape torch.Size([1024, 32, 32])    |          766600         |           22670            |             123           |            260           |                 165                |          120        
      shape torch.Size([1, 64, 64])       |             760         |              33            |              58           |             33           |                  49                |           34       
      shape torch.Size([2, 64, 64])       |            1520         |              60            |              60           |             60           |                  67                |           59       
      shape torch.Size([4, 64, 64])       |            3016         |             115            |              67           |            119           |                  90                |           66       
      shape torch.Size([8, 64, 64])       |            6120         |             230            |              67           |            233           |                 153                |           68       
      shape torch.Size([16, 64, 64])      |           12100         |             457            |              69           |             86           |                  77                |           69       
      shape torch.Size([32, 64, 64])      |           24000         |             912            |              74           |             95           |                  80                |           74       
      shape torch.Size([64, 64, 64])      |           48000         |            1833            |              76           |            106           |                  82                |           76       
      shape torch.Size([128, 64, 64])     |           95000         |            3636            |              97           |            163           |                 108                |           97       
      shape torch.Size([512, 64, 64])     |          380600         |           15600            |             210           |            464           |                 278                |          210       
      shape torch.Size([1024, 64, 64])    |          761200         |           31140            |             308           |            741           |                 377                |          308       
      shape torch.Size([1, 128, 128])     |             756         |              46            |             120           |             47           |                  49                |           46       
      shape torch.Size([2, 128, 128])     |            1500         |              91            |             123           |             89           |                  67                |           92       
      shape torch.Size([4, 128, 128])     |            2994         |             178            |             139           |            180           |                 117                |          139       
      shape torch.Size([8, 128, 128])     |            5960         |             350            |             140           |            354           |                 214                |          142       
      shape torch.Size([16, 128, 128])    |           12000         |             701            |             144           |            177           |                 148                |          143       
      shape torch.Size([32, 128, 128])    |           23870         |            1401            |             155           |            225           |                 159                |          155       
      shape torch.Size([64, 128, 128])    |           47600         |            2948            |             179           |            288           |                 184                |          180       
      shape torch.Size([128, 128, 128])   |           96000         |            5910            |             231           |            442           |                 242                |          231       
      shape torch.Size([512, 128, 128])   |          381200         |           23640            |             519           |           1400           |                 643                |          519       
      shape torch.Size([1024, 128, 128])  |          755800         |           47340            |             794           |           2436           |                 930                |          794       
      shape torch.Size([1, 256, 256])     |             760         |              74            |             246           |             77           |                  62                |           78       
      shape torch.Size([2, 256, 256])     |            1510         |             150            |             256           |            150           |                 106                |          150       
      shape torch.Size([4, 256, 256])     |            3030         |             296            |             284           |            296           |                 183                |          284       
      shape torch.Size([8, 256, 256])     |            6100         |             588            |             286           |            592           |                 337                |          288       
      shape torch.Size([16, 256, 256])    |           12200         |            1238            |             330           |            445           |                 330                |          330       
      shape torch.Size([32, 256, 256])    |           24430         |            2476            |             368           |            568           |                 365                |          367       
      shape torch.Size([64, 256, 256])    |           49000         |            4950            |             415           |            800           |                 408                |          414       
      shape torch.Size([128, 256, 256])   |           96000         |            9900            |             552           |           1330           |                 561                |          553       
      shape torch.Size([512, 256, 256])   |          369400         |           39580            |            1410           |           4614           |                1649                |         1410       
      shape torch.Size([1024, 256, 256])  |          716200         |           79200            |            2270           |           8472           |                2528                |         2277      

Times are in microseconds (us).
```

</details>

To generate the results below, I put the backend I wanted to test at the beginning of the function `lu_solve_kernel`, followed by a `return;`. Then I run the following script, changing the variable `name`.
<details>
<summary>
Benchmarking script
</summary>

```python
import torch
import pickle
import itertools
from functools import partial
from torch.utils.benchmark import Timer, Compare

name = "heuristic"
label = "lu_solve {}".format(name)
shapes = [1, 2, 8, 16, 32, 64, 128, 256]
batches = [(1,), (2,), (4,), (8,), (16,), (32,), (64,), (128,), (512,), (1024,)]
results = []
make_arg = partial(torch.randn, dtype=torch.float32, device="cuda")

for n, batch in itertools.product(shapes, batches):
    LU, pivots = torch.linalg.lu_factor(make_arg(batch + (n, n)))
    B = make_arg(batch + (n, 1))
    print(LU.shape)
    stmt = "torch.linalg.lu_solve(LU, pivots, B, adjoint=adjoint)"
    for adjoint in (True, False):
        timer = Timer(stmt,
                      globals=globals(),
                      label="linalg.lu_solve CUDA{}".format(" Adjoint" if adjoint else ""),
                      description=label,
                      sub_label=f"shape {LU.shape}",
                      num_threads=1)
        results.append(timer.blocked_autorange())


compare = Compare(results)
compare.trim_significant_figures()
compare.print()

with open("{}_lu_solve.pickle".format(name), 'wb') as f:
    pickle.dump(results, f)
```
</details>

Finally, I joined all the results with the following script:
<details>
<summary>
Script to join the results
</summary>

```python
import pickle
from torch.utils.benchmark import Timer, Compare

files = [
         "looped_magma",
         "looped cusolver",
         "batched cublas",
         "batched magma",
         "unpack+solve_triangular",
         "heuristic",
        ]

timers = []
for name in files:
    with open("{}_lu_solve.pickle".format(name), 'rb') as f:
        timers += pickle.load(f)

compare = Compare(timers)
compare.trim_significant_figures()
compare.print()
```
</details>

### Fix for Magma's batched lu_solve when `adjoint=True`
I also developed the following fix around MAGMA's bug, but I ended up not using it, and preferring the triangular solves over it, as they were faster. I'm leaving it here in case it's useful in the future.

<details>
<summary>
Fix for MAGMA's issue with `adjoint=True`
</summary>

```cpp
auto lu_solve_batched_magma_fn = [m](const Tensor& LU, const Tensor& pivots, const Tensor& B, TransposeType trans) {
    if (trans == TransposeType::NoTranspose) {
      lu_solve_batched_magma(LU, pivots, B, trans);
      return;
    }
    // There's a bug in magma for the other cases, so we need to properly perform mT or mH on LU
    // The LU of the transpose is not the transpose of the LU
    // We need to do LU = LDU' = L'U' where L' = LD, U' = D^{-1}U and D = diag(U)
    auto diag = LU.diagonal(0, -2, -1);
    auto LU_f = LU.tril(-1).mul_(diag.unsqueeze(-2)) +
                LU.triu(1).div_(diag.unsqueeze(-1));
    LU_f.diagonal(0, -2, -1).copy_(diag);

    if (trans == TransposeType::ConjTranspose) {
      LU_f = LU_f.conj_physical();
    }
    LU_f.transpose(-2, -1);
    // At this point LU_f is F-contiguous, because triu / tril / conj_phisical return contiguous tensors
    TORCH_INTERNAL_ASSERT_DEBUG_ONLY(LU_f.mT().is_contiguous());

    // Trivial permutation
    auto pivots_aux = at::arange(1, m + 1, pivots.options()).expand_as(pivots).contiguous();

    lu_solve_batched_magma(LU_f, pivots_aux, B, TransposeType::NoTranspose);

    // We then need to multiply B by P on the right as (PLU)^T = B iff U^TL^T = BP
    // Fill `perm` with the identity permutation (perhaps batched)
    // This is faster than torch.lu_unpack + matmul, as this logic is borrowed from lu_unpack
    const auto perm = at::arange(m, pivots.options().dtype(kLong)).expand(pivots.sizes()).contiguous();
    auto iter = TensorIteratorConfig()
      .set_check_mem_overlap(false)
      .check_all_same_dtype(false)
      .resize_outputs(false)
      .declare_static_shape(pivots.sizes(), /*squash_dim=*/pivots.dim() - 1)
      .add_output(perm)
      .add_input(pivots)
      .build();
    unpack_pivots_stub(pivots.device().type(), iter, m);
    B.scatter_(-2, perm.unsqueeze(-1).expand_as(B), B.clone());
};
```

</details>



Fixes #61657

[ghstack-poisoned]
This PR adds `linalg.lu_solve`. While doing so, I found a bug in MAGMA
when calling the batched MAGMA backend with trans=True. We work around
that by solving the system solving two triangular systems.

We also update the heuristics for this function, as they were fairly
outdated. We found that cuSolver is king, so luckily we do not need to
rely on the buggy backend from magma for this function.

We added tests testing this function left and right. We also added tests
for the different backends. We also activated the tests for AMD, as
those should work as well.

### Benchmarking

<details>
<summary>
Benchmark Results (adjoint=False)
</summary>

```
--------------------------------------------------------------------------------------------- linalg.lu_solve CUDA ---------------------------------------------------------------------------------------------]   
                                          |  lu_solve looped_magma  |  lu_solve looped cusolver  |  lu_solve batched cublas  |  lu_solve batched magma  |  lu_solve unpack+solve_triangular  |  lu_solve heuristic   
1 threads: -------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------   
      shape torch.Size([1, 1, 1])         |             750         |              34            |              28           |            252           |                  78                |           27          
      shape torch.Size([2, 1, 1])         |            1500         |              50            |              28           |            239           |                  85                |           27          
      shape torch.Size([4, 1, 1])         |            2995         |              83            |              28           |            239           |                  85                |           27          
      shape torch.Size([8, 1, 1])         |            6000         |             146            |              28           |            239           |                  85                |           27          
      shape torch.Size([16, 1, 1])        |           11900         |             272            |              28           |            241           |                  85                |           27          
      shape torch.Size([32, 1, 1])        |           23880         |             524            |              28           |            244           |                  85                |           27          
      shape torch.Size([64, 1, 1])        |           48000         |            1000            |              28           |            245           |                  85                |           27          
      shape torch.Size([128, 1, 1])       |           96000         |            2054            |              28           |            242           |                  86                |           27          
      shape torch.Size([512, 1, 1])       |          381900         |            8100            |              28           |            250           |                  85                |           27          
      shape torch.Size([1024, 1, 1])      |          763800         |           16200            |              28           |            257           |                  86                |           27          
      shape torch.Size([1, 2, 2])         |             750         |              33            |              28           |            240           |                  82                |           27          
      shape torch.Size([2, 2, 2])         |            1500         |              51            |              28           |            240           |                  88                |           27          
      shape torch.Size([4, 2, 2])         |            2991         |              82            |              28           |            241           |                  88                |           28          
      shape torch.Size([8, 2, 2])         |            6000         |             150            |              28           |            241           |                  88                |           27          
      shape torch.Size([16, 2, 2])        |           12000         |             275            |              28           |            242           |                  88                |           27          
      shape torch.Size([32, 2, 2])        |           23980         |             530            |              28           |            246           |                  90                |           28          
      shape torch.Size([64, 2, 2])        |           48000         |            1000            |              28           |            244           |                  89                |           27          
      shape torch.Size([128, 2, 2])       |           96000         |            2063            |              28           |            245           |                  89                |           28          
      shape torch.Size([512, 2, 2])       |          382000         |            8300            |              28           |            257           |                  89                |           28          
      shape torch.Size([1024, 2, 2])      |          764000         |           20000            |              28           |            271           |                  88                |           28          
      shape torch.Size([1, 8, 8])         |             749         |              34            |              28           |            243           |                  82                |           28          
      shape torch.Size([2, 8, 8])         |            1500         |              50            |              28           |            244           |                  89                |           28          
      shape torch.Size([4, 8, 8])         |            2988         |              83            |              28           |            244           |                  89                |           28          
      shape torch.Size([8, 8, 8])         |            5980         |             150            |              28           |            245           |                  88                |           28          
      shape torch.Size([16, 8, 8])        |           12000         |             278            |              28           |            246           |                  89                |           28          
      shape torch.Size([32, 8, 8])        |           23910         |             536            |              28           |            249           |                  89                |           28          
      shape torch.Size([64, 8, 8])        |           47800         |            1100            |              28           |            247           |                  89                |           28          
      shape torch.Size([128, 8, 8])       |           96000         |            2075            |              28           |            248           |                  89                |           28          
      shape torch.Size([512, 8, 8])       |          382100         |            8300            |              28           |            270           |                  89                |           28          
      shape torch.Size([1024, 8, 8])      |          764100         |           16400            |              28           |            291           |                  89                |           28          
      shape torch.Size([1, 16, 16])       |             750         |              33            |              28           |            248           |                  82                |           28       
      shape torch.Size([2, 16, 16])       |            1500         |              50            |              28           |            250           |                  89                |           28       
      shape torch.Size([4, 16, 16])       |            2996         |              83            |              28           |            250           |                  89                |           28       
      shape torch.Size([8, 16, 16])       |            5980         |             147            |              28           |            251           |                  90                |           28       
      shape torch.Size([16, 16, 16])      |           11900         |             274            |              28           |            251           |                  89                |           28       
      shape torch.Size([32, 16, 16])      |           24040         |             527            |              28           |            252           |                  89                |           28       
      shape torch.Size([64, 16, 16])      |           47800         |            1037            |              28           |            251           |                  89                |           28       
      shape torch.Size([128, 16, 16])     |           95600         |            2044            |              28           |            252           |                  89                |           28       
      shape torch.Size([512, 16, 16])     |          388200         |            8100            |              28           |            280           |                  88                |           28       
      shape torch.Size([1024, 16, 16])    |          769700         |           16000            |              28           |            322           |                  89                |           28       
      shape torch.Size([1, 32, 32])       |             760         |              33            |              28           |            255           |                  82                |           28       
      shape torch.Size([2, 32, 32])       |            1510         |              50            |              28           |            256           |                  89                |           28       
      shape torch.Size([4, 32, 32])       |            3022         |              82            |              31           |            256           |                  89                |           30       
      shape torch.Size([8, 32, 32])       |            6000         |             140            |              31           |            257           |                  89                |           31       
      shape torch.Size([16, 32, 32])      |           12000         |             281            |              31           |            258           |                  89                |           31       
      shape torch.Size([32, 32, 32])      |           24150         |             563            |              35           |            258           |                  89                |           35       
      shape torch.Size([64, 32, 32])      |           48300         |            1119            |              36           |            258           |                  90                |           36       
      shape torch.Size([128, 32, 32])     |           96500         |            2235            |              43           |            261           |                  88                |           43       
      shape torch.Size([512, 32, 32])     |          383100         |            8930            |              82           |            317           |                 137                |           82       
      shape torch.Size([1024, 32, 32])    |          766300         |           19200            |             122           |            400           |                 187                |          122       
      shape torch.Size([1, 64, 64])       |             760         |              33            |              55           |            272           |                  71                |           34       
      shape torch.Size([2, 64, 64])       |            1500         |              52            |              58           |            273           |                  90                |           52       
      shape torch.Size([4, 64, 64])       |            3127         |             102            |              65           |            273           |                 110                |           65       
      shape torch.Size([8, 64, 64])       |            6070         |             201            |              65           |            275           |                 160                |           65       
      shape torch.Size([16, 64, 64])      |           12000         |             399            |              66           |            274           |                  88                |           67       
      shape torch.Size([32, 64, 64])      |           23900         |             796            |              73           |            275           |                  91                |           73       
      shape torch.Size([64, 64, 64])      |           48000         |            1594            |              75           |            283           |                  94                |           76       
      shape torch.Size([128, 64, 64])     |           95000         |            3177            |              96           |            292           |                 122                |           96       
      shape torch.Size([512, 64, 64])     |          379300         |           13520            |             208           |            426           |                 313                |          208       
      shape torch.Size([1024, 64, 64])    |          758700         |           27100            |             306           |            570           |                 437                |          306       
      shape torch.Size([1, 128, 128])     |             750         |              42            |             115           |            306           |                  71                |           42       
      shape torch.Size([2, 128, 128])     |            1500         |              82            |             122           |            307           |                  90                |           83       
      shape torch.Size([4, 128, 128])     |            2966         |             162            |             136           |            307           |                 117                |          136       
      shape torch.Size([8, 128, 128])     |            5930         |             317            |             137           |            308           |                 202                |          138       
      shape torch.Size([16, 128, 128])    |           12000         |             635            |             143           |            316           |                 162                |          143       
      shape torch.Size([32, 128, 128])    |           23700         |            1266            |             152           |            322           |                 168                |          152       
      shape torch.Size([64, 128, 128])    |           48000         |            2668            |             177           |            337           |                 196                |          177       
      shape torch.Size([128, 128, 128])   |           96000         |            5366            |             228           |            365           |                 260                |          228       
      shape torch.Size([512, 128, 128])   |          379400         |           21490            |             502           |            620           |                 669                |          502       
      shape torch.Size([1024, 128, 128])  |          755700         |           43040            |             764           |            903           |                 978                |          770       
      shape torch.Size([1, 256, 256])     |             750         |              70            |             235           |            383           |                  78                |           72       
      shape torch.Size([2, 256, 256])     |            2000         |             138            |             250           |            384           |                 120                |          139       
      shape torch.Size([4, 256, 256])     |            2988         |             277            |             279           |            404           |                 191                |          278       
      shape torch.Size([8, 256, 256])     |            6100         |             546            |             283           |            420           |                 332                |          286       
      shape torch.Size([16, 256, 256])    |           12100         |            1149            |             330           |            441           |                 364                |          330       
      shape torch.Size([32, 256, 256])    |           24040         |            2303            |             359           |            453           |                 401                |          360       
      shape torch.Size([64, 256, 256])    |           48000         |            4626            |             408           |            472           |                 466                |          408       
      shape torch.Size([128, 256, 256])   |           94700         |            9247            |             543           |            543           |                 658                |          543       
      shape torch.Size([512, 256, 256])   |          372000         |           37030            |            1310           |           1185           |                1896                |         1310       
      shape torch.Size([1024, 256, 256])  |          747200         |           74100            |            2116           |           1910           |                3051                |         2122 
```
</details>

<details>
<summary>
Benchmark Results (adjoint=True)
</summary>

```
[----------------------------------------------------------------------------------------- linalg.lu_solve CUDA Adjoint -----------------------------------------------------------------------------------------]   
                                          |  lu_solve looped_magma  |  lu_solve looped cusolver  |  lu_solve batched cublas  |  lu_solve batched magma  |  lu_solve unpack+solve_triangular  |  lu_solve heuristic   
1 threads: -------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------   
      shape torch.Size([1, 1, 1])         |             749         |              34            |              28           |             33           |                  60                |           27          
      shape torch.Size([2, 1, 1])         |            1500         |              50            |              28           |             50           |                  67                |           27          
      shape torch.Size([4, 1, 1])         |            3005         |              82            |              28           |             81           |                  67                |           27          
      shape torch.Size([8, 1, 1])         |            5999         |             145            |              28           |            140           |                  71                |           27          
      shape torch.Size([16, 1, 1])        |           12000         |             273            |              28           |             77           |                  67                |           27          
      shape torch.Size([32, 1, 1])        |           24000         |             522            |              28           |             78           |                  67                |           27          
      shape torch.Size([64, 1, 1])        |           48000         |            1000            |              28           |             77           |                  67                |           27          
      shape torch.Size([128, 1, 1])       |           96000         |            2029            |              28           |             78           |                  67                |           27          
      shape torch.Size([512, 1, 1])       |          383300         |            8100            |              28           |             78           |                  67                |           28          
      shape torch.Size([1024, 1, 1])      |          767500         |           16100            |              28           |             77           |                  68                |           27          
      shape torch.Size([1, 2, 2])         |             753         |              33            |              28           |             33           |                  60                |           28          
      shape torch.Size([2, 2, 2])         |            1500         |              50            |              28           |             50           |                  67                |           28          
      shape torch.Size([4, 2, 2])         |            3002         |              82            |              28           |             80           |                  67                |           27          
      shape torch.Size([8, 2, 2])         |            6000         |             145            |              28           |            144           |                  67                |           27          
      shape torch.Size([16, 2, 2])        |           12000         |             271            |              28           |             78           |                  66                |           27          
      shape torch.Size([32, 2, 2])        |           24120         |             524            |              28           |             78           |                  69                |           28          
      shape torch.Size([64, 2, 2])        |           48300         |            1030            |              28           |             78           |                  66                |           27          
      shape torch.Size([128, 2, 2])       |           96100         |            2041            |              28           |             78           |                  67                |           28          
      shape torch.Size([512, 2, 2])       |          383000         |            8100            |              28           |             79           |                  67                |           28          
      shape torch.Size([1024, 2, 2])      |          766100         |           16000            |              28           |             78           |                  67                |           28          
      shape torch.Size([1, 8, 8])         |             750         |              34            |              28           |             34           |                  60                |           28          
      shape torch.Size([2, 8, 8])         |            1500         |              50            |              28           |             50           |                  67                |           28          
      shape torch.Size([4, 8, 8])         |            2998         |              82            |              28           |             82           |                  67                |           28          
      shape torch.Size([8, 8, 8])         |            5990         |             146            |              28           |            150           |                  66                |           28          
      shape torch.Size([16, 8, 8])        |           11980         |             272            |              28           |             79           |                  67                |           28          
      shape torch.Size([32, 8, 8])        |           23970         |             530            |              28           |             79           |                  67                |           28          
      shape torch.Size([64, 8, 8])        |           47900         |            1040            |              28           |             79           |                  67                |           28          
      shape torch.Size([128, 8, 8])       |           96000         |            2048            |              28           |             78           |                  67                |           28          
      shape torch.Size([512, 8, 8])       |          383700         |            8100            |              28           |             80           |                  67                |           28          
      shape torch.Size([1024, 8, 8])      |          766200         |           16300            |              28           |             80           |                  68                |           28          
      shape torch.Size([1, 16, 16])       |             760         |              33            |              28           |             34           |                  60                |           28          
      shape torch.Size([2, 16, 16])       |            1500         |              50            |              28           |             50           |                  67                |           28          
      shape torch.Size([4, 16, 16])       |            3001         |              81            |              28           |             82           |                  67                |           28          
      shape torch.Size([8, 16, 16])       |            6000         |             145            |              28           |            140           |                  67                |           28          
      shape torch.Size([16, 16, 16])      |           12000         |             276            |              28           |             79           |                  67                |           28          
      shape torch.Size([32, 16, 16])      |           23870         |             549            |              28           |             79           |                  67                |           28          
      shape torch.Size([64, 16, 16])      |           47900         |            1098            |              29           |             80           |                  68                |           28          
      shape torch.Size([128, 16, 16])     |           95800         |            2184            |              28           |             79           |                  68                |           28          
      shape torch.Size([512, 16, 16])     |          386900         |            8769            |              28           |             80           |                  67                |           28          
      shape torch.Size([1024, 16, 16])    |          769800         |           17460            |              37           |             80           |                  67                |           37          
      shape torch.Size([1, 32, 32])       |             760         |              33            |              28           |             34           |                  60                |           28          
      shape torch.Size([2, 32, 32])       |            1500         |              50            |              28           |             50           |                  67                |           29          
      shape torch.Size([4, 32, 32])       |            3021         |              86            |              31           |             84           |                  67                |           32          
      shape torch.Size([8, 32, 32])       |            6040         |             167            |              32           |            167           |                  67                |           32          
      shape torch.Size([16, 32, 32])      |           12100         |             330            |              33           |             78           |                  67                |           33          
      shape torch.Size([32, 32, 32])      |           24150         |             662            |              35           |             78           |                  66                |           35          
      shape torch.Size([64, 32, 32])      |           48200         |            1323            |              36           |             79           |                  67                |           36          
      shape torch.Size([128, 32, 32])     |           97000         |            2637            |              44           |             79           |                  67                |           43          
      shape torch.Size([512, 32, 32])     |          382500         |           10580            |              83           |            180           |                 121                |           83          
      shape torch.Size([1024, 32, 32])    |          766600         |           22670            |             123           |            260           |                 165                |          120        
      shape torch.Size([1, 64, 64])       |             760         |              33            |              58           |             33           |                  49                |           34       
      shape torch.Size([2, 64, 64])       |            1520         |              60            |              60           |             60           |                  67                |           59       
      shape torch.Size([4, 64, 64])       |            3016         |             115            |              67           |            119           |                  90                |           66       
      shape torch.Size([8, 64, 64])       |            6120         |             230            |              67           |            233           |                 153                |           68       
      shape torch.Size([16, 64, 64])      |           12100         |             457            |              69           |             86           |                  77                |           69       
      shape torch.Size([32, 64, 64])      |           24000         |             912            |              74           |             95           |                  80                |           74       
      shape torch.Size([64, 64, 64])      |           48000         |            1833            |              76           |            106           |                  82                |           76       
      shape torch.Size([128, 64, 64])     |           95000         |            3636            |              97           |            163           |                 108                |           97       
      shape torch.Size([512, 64, 64])     |          380600         |           15600            |             210           |            464           |                 278                |          210       
      shape torch.Size([1024, 64, 64])    |          761200         |           31140            |             308           |            741           |                 377                |          308       
      shape torch.Size([1, 128, 128])     |             756         |              46            |             120           |             47           |                  49                |           46       
      shape torch.Size([2, 128, 128])     |            1500         |              91            |             123           |             89           |                  67                |           92       
      shape torch.Size([4, 128, 128])     |            2994         |             178            |             139           |            180           |                 117                |          139       
      shape torch.Size([8, 128, 128])     |            5960         |             350            |             140           |            354           |                 214                |          142       
      shape torch.Size([16, 128, 128])    |           12000         |             701            |             144           |            177           |                 148                |          143       
      shape torch.Size([32, 128, 128])    |           23870         |            1401            |             155           |            225           |                 159                |          155       
      shape torch.Size([64, 128, 128])    |           47600         |            2948            |             179           |            288           |                 184                |          180       
      shape torch.Size([128, 128, 128])   |           96000         |            5910            |             231           |            442           |                 242                |          231       
      shape torch.Size([512, 128, 128])   |          381200         |           23640            |             519           |           1400           |                 643                |          519       
      shape torch.Size([1024, 128, 128])  |          755800         |           47340            |             794           |           2436           |                 930                |          794       
      shape torch.Size([1, 256, 256])     |             760         |              74            |             246           |             77           |                  62                |           78       
      shape torch.Size([2, 256, 256])     |            1510         |             150            |             256           |            150           |                 106                |          150       
      shape torch.Size([4, 256, 256])     |            3030         |             296            |             284           |            296           |                 183                |          284       
      shape torch.Size([8, 256, 256])     |            6100         |             588            |             286           |            592           |                 337                |          288       
      shape torch.Size([16, 256, 256])    |           12200         |            1238            |             330           |            445           |                 330                |          330       
      shape torch.Size([32, 256, 256])    |           24430         |            2476            |             368           |            568           |                 365                |          367       
      shape torch.Size([64, 256, 256])    |           49000         |            4950            |             415           |            800           |                 408                |          414       
      shape torch.Size([128, 256, 256])   |           96000         |            9900            |             552           |           1330           |                 561                |          553       
      shape torch.Size([512, 256, 256])   |          369400         |           39580            |            1410           |           4614           |                1649                |         1410       
      shape torch.Size([1024, 256, 256])  |          716200         |           79200            |            2270           |           8472           |                2528                |         2277      

Times are in microseconds (us).
```

</details>

To generate the results below, I put the backend I wanted to test at the beginning of the function `lu_solve_kernel`, followed by a `return;`. Then I run the following script, changing the variable `name`.
<details>
<summary>
Benchmarking script
</summary>

```python
import torch
import pickle
import itertools
from functools import partial
from torch.utils.benchmark import Timer, Compare

name = "heuristic"
label = "lu_solve {}".format(name)
shapes = [1, 2, 8, 16, 32, 64, 128, 256]
batches = [(1,), (2,), (4,), (8,), (16,), (32,), (64,), (128,), (512,), (1024,)]
results = []
make_arg = partial(torch.randn, dtype=torch.float32, device="cuda")

for n, batch in itertools.product(shapes, batches):
    LU, pivots = torch.linalg.lu_factor(make_arg(batch + (n, n)))
    B = make_arg(batch + (n, 1))
    print(LU.shape)
    stmt = "torch.linalg.lu_solve(LU, pivots, B, adjoint=adjoint)"
    for adjoint in (True, False):
        timer = Timer(stmt,
                      globals=globals(),
                      label="linalg.lu_solve CUDA{}".format(" Adjoint" if adjoint else ""),
                      description=label,
                      sub_label=f"shape {LU.shape}",
                      num_threads=1)
        results.append(timer.blocked_autorange())


compare = Compare(results)
compare.trim_significant_figures()
compare.print()

with open("{}_lu_solve.pickle".format(name), 'wb') as f:
    pickle.dump(results, f)
```
</details>

Finally, I joined all the results with the following script:
<details>
<summary>
Script to join the results
</summary>

```python
import pickle
from torch.utils.benchmark import Timer, Compare

files = [
         "looped_magma",
         "looped cusolver",
         "batched cublas",
         "batched magma",
         "unpack+solve_triangular",
         "heuristic",
        ]

timers = []
for name in files:
    with open("{}_lu_solve.pickle".format(name), 'rb') as f:
        timers += pickle.load(f)

compare = Compare(timers)
compare.trim_significant_figures()
compare.print()
```
</details>

### Fix for Magma's batched lu_solve when `adjoint=True`
I also developed the following fix around MAGMA's bug, but I ended up not using it, and preferring the triangular solves over it, as they were faster. I'm leaving it here in case it's useful in the future.

<details>
<summary>
Fix for MAGMA's issue with `adjoint=True`
</summary>

```cpp
auto lu_solve_batched_magma_fn = [m](const Tensor& LU, const Tensor& pivots, const Tensor& B, TransposeType trans) {
    if (trans == TransposeType::NoTranspose) {
      lu_solve_batched_magma(LU, pivots, B, trans);
      return;
    }
    // There's a bug in magma for the other cases, so we need to properly perform mT or mH on LU
    // The LU of the transpose is not the transpose of the LU
    // We need to do LU = LDU' = L'U' where L' = LD, U' = D^{-1}U and D = diag(U)
    auto diag = LU.diagonal(0, -2, -1);
    auto LU_f = LU.tril(-1).mul_(diag.unsqueeze(-2)) +
                LU.triu(1).div_(diag.unsqueeze(-1));
    LU_f.diagonal(0, -2, -1).copy_(diag);

    if (trans == TransposeType::ConjTranspose) {
      LU_f = LU_f.conj_physical();
    }
    LU_f.transpose(-2, -1);
    // At this point LU_f is F-contiguous, because triu / tril / conj_phisical return contiguous tensors
    TORCH_INTERNAL_ASSERT_DEBUG_ONLY(LU_f.mT().is_contiguous());

    // Trivial permutation
    auto pivots_aux = at::arange(1, m + 1, pivots.options()).expand_as(pivots).contiguous();

    lu_solve_batched_magma(LU_f, pivots_aux, B, TransposeType::NoTranspose);

    // We then need to multiply B by P on the right as (PLU)^T = B iff U^TL^T = BP
    // Fill `perm` with the identity permutation (perhaps batched)
    // This is faster than torch.lu_unpack + matmul, as this logic is borrowed from lu_unpack
    const auto perm = at::arange(m, pivots.options().dtype(kLong)).expand(pivots.sizes()).contiguous();
    auto iter = TensorIteratorConfig()
      .set_check_mem_overlap(false)
      .check_all_same_dtype(false)
      .resize_outputs(false)
      .declare_static_shape(pivots.sizes(), /*squash_dim=*/pivots.dim() - 1)
      .add_output(perm)
      .add_input(pivots)
      .build();
    unpack_pivots_stub(pivots.device().type(), iter, m);
    B.scatter_(-2, perm.unsqueeze(-1).expand_as(B), B.clone());
};
```

</details>



Fixes #61657

[ghstack-poisoned]
@lezcano
Copy link
Copy Markdown
Collaborator Author

lezcano commented May 5, 2022

The CI is not green, but the issues seem unrelated. I'll try to merge this one after a small correction of the docs.

This PR adds `linalg.lu_solve`. While doing so, I found a bug in MAGMA
when calling the batched MAGMA backend with trans=True. We work around
that by solving the system solving two triangular systems.

We also update the heuristics for this function, as they were fairly
outdated. We found that cuSolver is king, so luckily we do not need to
rely on the buggy backend from magma for this function.

We added tests testing this function left and right. We also added tests
for the different backends. We also activated the tests for AMD, as
those should work as well.

### Benchmarking

<details>
<summary>
Benchmark Results (adjoint=False)
</summary>

```
--------------------------------------------------------------------------------------------- linalg.lu_solve CUDA ---------------------------------------------------------------------------------------------]   
                                          |  lu_solve looped_magma  |  lu_solve looped cusolver  |  lu_solve batched cublas  |  lu_solve batched magma  |  lu_solve unpack+solve_triangular  |  lu_solve heuristic   
1 threads: -------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------   
      shape torch.Size([1, 1, 1])         |             750         |              34            |              28           |            252           |                  78                |           27          
      shape torch.Size([2, 1, 1])         |            1500         |              50            |              28           |            239           |                  85                |           27          
      shape torch.Size([4, 1, 1])         |            2995         |              83            |              28           |            239           |                  85                |           27          
      shape torch.Size([8, 1, 1])         |            6000         |             146            |              28           |            239           |                  85                |           27          
      shape torch.Size([16, 1, 1])        |           11900         |             272            |              28           |            241           |                  85                |           27          
      shape torch.Size([32, 1, 1])        |           23880         |             524            |              28           |            244           |                  85                |           27          
      shape torch.Size([64, 1, 1])        |           48000         |            1000            |              28           |            245           |                  85                |           27          
      shape torch.Size([128, 1, 1])       |           96000         |            2054            |              28           |            242           |                  86                |           27          
      shape torch.Size([512, 1, 1])       |          381900         |            8100            |              28           |            250           |                  85                |           27          
      shape torch.Size([1024, 1, 1])      |          763800         |           16200            |              28           |            257           |                  86                |           27          
      shape torch.Size([1, 2, 2])         |             750         |              33            |              28           |            240           |                  82                |           27          
      shape torch.Size([2, 2, 2])         |            1500         |              51            |              28           |            240           |                  88                |           27          
      shape torch.Size([4, 2, 2])         |            2991         |              82            |              28           |            241           |                  88                |           28          
      shape torch.Size([8, 2, 2])         |            6000         |             150            |              28           |            241           |                  88                |           27          
      shape torch.Size([16, 2, 2])        |           12000         |             275            |              28           |            242           |                  88                |           27          
      shape torch.Size([32, 2, 2])        |           23980         |             530            |              28           |            246           |                  90                |           28          
      shape torch.Size([64, 2, 2])        |           48000         |            1000            |              28           |            244           |                  89                |           27          
      shape torch.Size([128, 2, 2])       |           96000         |            2063            |              28           |            245           |                  89                |           28          
      shape torch.Size([512, 2, 2])       |          382000         |            8300            |              28           |            257           |                  89                |           28          
      shape torch.Size([1024, 2, 2])      |          764000         |           20000            |              28           |            271           |                  88                |           28          
      shape torch.Size([1, 8, 8])         |             749         |              34            |              28           |            243           |                  82                |           28          
      shape torch.Size([2, 8, 8])         |            1500         |              50            |              28           |            244           |                  89                |           28          
      shape torch.Size([4, 8, 8])         |            2988         |              83            |              28           |            244           |                  89                |           28          
      shape torch.Size([8, 8, 8])         |            5980         |             150            |              28           |            245           |                  88                |           28          
      shape torch.Size([16, 8, 8])        |           12000         |             278            |              28           |            246           |                  89                |           28          
      shape torch.Size([32, 8, 8])        |           23910         |             536            |              28           |            249           |                  89                |           28          
      shape torch.Size([64, 8, 8])        |           47800         |            1100            |              28           |            247           |                  89                |           28          
      shape torch.Size([128, 8, 8])       |           96000         |            2075            |              28           |            248           |                  89                |           28          
      shape torch.Size([512, 8, 8])       |          382100         |            8300            |              28           |            270           |                  89                |           28          
      shape torch.Size([1024, 8, 8])      |          764100         |           16400            |              28           |            291           |                  89                |           28          
      shape torch.Size([1, 16, 16])       |             750         |              33            |              28           |            248           |                  82                |           28       
      shape torch.Size([2, 16, 16])       |            1500         |              50            |              28           |            250           |                  89                |           28       
      shape torch.Size([4, 16, 16])       |            2996         |              83            |              28           |            250           |                  89                |           28       
      shape torch.Size([8, 16, 16])       |            5980         |             147            |              28           |            251           |                  90                |           28       
      shape torch.Size([16, 16, 16])      |           11900         |             274            |              28           |            251           |                  89                |           28       
      shape torch.Size([32, 16, 16])      |           24040         |             527            |              28           |            252           |                  89                |           28       
      shape torch.Size([64, 16, 16])      |           47800         |            1037            |              28           |            251           |                  89                |           28       
      shape torch.Size([128, 16, 16])     |           95600         |            2044            |              28           |            252           |                  89                |           28       
      shape torch.Size([512, 16, 16])     |          388200         |            8100            |              28           |            280           |                  88                |           28       
      shape torch.Size([1024, 16, 16])    |          769700         |           16000            |              28           |            322           |                  89                |           28       
      shape torch.Size([1, 32, 32])       |             760         |              33            |              28           |            255           |                  82                |           28       
      shape torch.Size([2, 32, 32])       |            1510         |              50            |              28           |            256           |                  89                |           28       
      shape torch.Size([4, 32, 32])       |            3022         |              82            |              31           |            256           |                  89                |           30       
      shape torch.Size([8, 32, 32])       |            6000         |             140            |              31           |            257           |                  89                |           31       
      shape torch.Size([16, 32, 32])      |           12000         |             281            |              31           |            258           |                  89                |           31       
      shape torch.Size([32, 32, 32])      |           24150         |             563            |              35           |            258           |                  89                |           35       
      shape torch.Size([64, 32, 32])      |           48300         |            1119            |              36           |            258           |                  90                |           36       
      shape torch.Size([128, 32, 32])     |           96500         |            2235            |              43           |            261           |                  88                |           43       
      shape torch.Size([512, 32, 32])     |          383100         |            8930            |              82           |            317           |                 137                |           82       
      shape torch.Size([1024, 32, 32])    |          766300         |           19200            |             122           |            400           |                 187                |          122       
      shape torch.Size([1, 64, 64])       |             760         |              33            |              55           |            272           |                  71                |           34       
      shape torch.Size([2, 64, 64])       |            1500         |              52            |              58           |            273           |                  90                |           52       
      shape torch.Size([4, 64, 64])       |            3127         |             102            |              65           |            273           |                 110                |           65       
      shape torch.Size([8, 64, 64])       |            6070         |             201            |              65           |            275           |                 160                |           65       
      shape torch.Size([16, 64, 64])      |           12000         |             399            |              66           |            274           |                  88                |           67       
      shape torch.Size([32, 64, 64])      |           23900         |             796            |              73           |            275           |                  91                |           73       
      shape torch.Size([64, 64, 64])      |           48000         |            1594            |              75           |            283           |                  94                |           76       
      shape torch.Size([128, 64, 64])     |           95000         |            3177            |              96           |            292           |                 122                |           96       
      shape torch.Size([512, 64, 64])     |          379300         |           13520            |             208           |            426           |                 313                |          208       
      shape torch.Size([1024, 64, 64])    |          758700         |           27100            |             306           |            570           |                 437                |          306       
      shape torch.Size([1, 128, 128])     |             750         |              42            |             115           |            306           |                  71                |           42       
      shape torch.Size([2, 128, 128])     |            1500         |              82            |             122           |            307           |                  90                |           83       
      shape torch.Size([4, 128, 128])     |            2966         |             162            |             136           |            307           |                 117                |          136       
      shape torch.Size([8, 128, 128])     |            5930         |             317            |             137           |            308           |                 202                |          138       
      shape torch.Size([16, 128, 128])    |           12000         |             635            |             143           |            316           |                 162                |          143       
      shape torch.Size([32, 128, 128])    |           23700         |            1266            |             152           |            322           |                 168                |          152       
      shape torch.Size([64, 128, 128])    |           48000         |            2668            |             177           |            337           |                 196                |          177       
      shape torch.Size([128, 128, 128])   |           96000         |            5366            |             228           |            365           |                 260                |          228       
      shape torch.Size([512, 128, 128])   |          379400         |           21490            |             502           |            620           |                 669                |          502       
      shape torch.Size([1024, 128, 128])  |          755700         |           43040            |             764           |            903           |                 978                |          770       
      shape torch.Size([1, 256, 256])     |             750         |              70            |             235           |            383           |                  78                |           72       
      shape torch.Size([2, 256, 256])     |            2000         |             138            |             250           |            384           |                 120                |          139       
      shape torch.Size([4, 256, 256])     |            2988         |             277            |             279           |            404           |                 191                |          278       
      shape torch.Size([8, 256, 256])     |            6100         |             546            |             283           |            420           |                 332                |          286       
      shape torch.Size([16, 256, 256])    |           12100         |            1149            |             330           |            441           |                 364                |          330       
      shape torch.Size([32, 256, 256])    |           24040         |            2303            |             359           |            453           |                 401                |          360       
      shape torch.Size([64, 256, 256])    |           48000         |            4626            |             408           |            472           |                 466                |          408       
      shape torch.Size([128, 256, 256])   |           94700         |            9247            |             543           |            543           |                 658                |          543       
      shape torch.Size([512, 256, 256])   |          372000         |           37030            |            1310           |           1185           |                1896                |         1310       
      shape torch.Size([1024, 256, 256])  |          747200         |           74100            |            2116           |           1910           |                3051                |         2122 
```
</details>

<details>
<summary>
Benchmark Results (adjoint=True)
</summary>

```
[----------------------------------------------------------------------------------------- linalg.lu_solve CUDA Adjoint -----------------------------------------------------------------------------------------]   
                                          |  lu_solve looped_magma  |  lu_solve looped cusolver  |  lu_solve batched cublas  |  lu_solve batched magma  |  lu_solve unpack+solve_triangular  |  lu_solve heuristic   
1 threads: -------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------   
      shape torch.Size([1, 1, 1])         |             749         |              34            |              28           |             33           |                  60                |           27          
      shape torch.Size([2, 1, 1])         |            1500         |              50            |              28           |             50           |                  67                |           27          
      shape torch.Size([4, 1, 1])         |            3005         |              82            |              28           |             81           |                  67                |           27          
      shape torch.Size([8, 1, 1])         |            5999         |             145            |              28           |            140           |                  71                |           27          
      shape torch.Size([16, 1, 1])        |           12000         |             273            |              28           |             77           |                  67                |           27          
      shape torch.Size([32, 1, 1])        |           24000         |             522            |              28           |             78           |                  67                |           27          
      shape torch.Size([64, 1, 1])        |           48000         |            1000            |              28           |             77           |                  67                |           27          
      shape torch.Size([128, 1, 1])       |           96000         |            2029            |              28           |             78           |                  67                |           27          
      shape torch.Size([512, 1, 1])       |          383300         |            8100            |              28           |             78           |                  67                |           28          
      shape torch.Size([1024, 1, 1])      |          767500         |           16100            |              28           |             77           |                  68                |           27          
      shape torch.Size([1, 2, 2])         |             753         |              33            |              28           |             33           |                  60                |           28          
      shape torch.Size([2, 2, 2])         |            1500         |              50            |              28           |             50           |                  67                |           28          
      shape torch.Size([4, 2, 2])         |            3002         |              82            |              28           |             80           |                  67                |           27          
      shape torch.Size([8, 2, 2])         |            6000         |             145            |              28           |            144           |                  67                |           27          
      shape torch.Size([16, 2, 2])        |           12000         |             271            |              28           |             78           |                  66                |           27          
      shape torch.Size([32, 2, 2])        |           24120         |             524            |              28           |             78           |                  69                |           28          
      shape torch.Size([64, 2, 2])        |           48300         |            1030            |              28           |             78           |                  66                |           27          
      shape torch.Size([128, 2, 2])       |           96100         |            2041            |              28           |             78           |                  67                |           28          
      shape torch.Size([512, 2, 2])       |          383000         |            8100            |              28           |             79           |                  67                |           28          
      shape torch.Size([1024, 2, 2])      |          766100         |           16000            |              28           |             78           |                  67                |           28          
      shape torch.Size([1, 8, 8])         |             750         |              34            |              28           |             34           |                  60                |           28          
      shape torch.Size([2, 8, 8])         |            1500         |              50            |              28           |             50           |                  67                |           28          
      shape torch.Size([4, 8, 8])         |            2998         |              82            |              28           |             82           |                  67                |           28          
      shape torch.Size([8, 8, 8])         |            5990         |             146            |              28           |            150           |                  66                |           28          
      shape torch.Size([16, 8, 8])        |           11980         |             272            |              28           |             79           |                  67                |           28          
      shape torch.Size([32, 8, 8])        |           23970         |             530            |              28           |             79           |                  67                |           28          
      shape torch.Size([64, 8, 8])        |           47900         |            1040            |              28           |             79           |                  67                |           28          
      shape torch.Size([128, 8, 8])       |           96000         |            2048            |              28           |             78           |                  67                |           28          
      shape torch.Size([512, 8, 8])       |          383700         |            8100            |              28           |             80           |                  67                |           28          
      shape torch.Size([1024, 8, 8])      |          766200         |           16300            |              28           |             80           |                  68                |           28          
      shape torch.Size([1, 16, 16])       |             760         |              33            |              28           |             34           |                  60                |           28          
      shape torch.Size([2, 16, 16])       |            1500         |              50            |              28           |             50           |                  67                |           28          
      shape torch.Size([4, 16, 16])       |            3001         |              81            |              28           |             82           |                  67                |           28          
      shape torch.Size([8, 16, 16])       |            6000         |             145            |              28           |            140           |                  67                |           28          
      shape torch.Size([16, 16, 16])      |           12000         |             276            |              28           |             79           |                  67                |           28          
      shape torch.Size([32, 16, 16])      |           23870         |             549            |              28           |             79           |                  67                |           28          
      shape torch.Size([64, 16, 16])      |           47900         |            1098            |              29           |             80           |                  68                |           28          
      shape torch.Size([128, 16, 16])     |           95800         |            2184            |              28           |             79           |                  68                |           28          
      shape torch.Size([512, 16, 16])     |          386900         |            8769            |              28           |             80           |                  67                |           28          
      shape torch.Size([1024, 16, 16])    |          769800         |           17460            |              37           |             80           |                  67                |           37          
      shape torch.Size([1, 32, 32])       |             760         |              33            |              28           |             34           |                  60                |           28          
      shape torch.Size([2, 32, 32])       |            1500         |              50            |              28           |             50           |                  67                |           29          
      shape torch.Size([4, 32, 32])       |            3021         |              86            |              31           |             84           |                  67                |           32          
      shape torch.Size([8, 32, 32])       |            6040         |             167            |              32           |            167           |                  67                |           32          
      shape torch.Size([16, 32, 32])      |           12100         |             330            |              33           |             78           |                  67                |           33          
      shape torch.Size([32, 32, 32])      |           24150         |             662            |              35           |             78           |                  66                |           35          
      shape torch.Size([64, 32, 32])      |           48200         |            1323            |              36           |             79           |                  67                |           36          
      shape torch.Size([128, 32, 32])     |           97000         |            2637            |              44           |             79           |                  67                |           43          
      shape torch.Size([512, 32, 32])     |          382500         |           10580            |              83           |            180           |                 121                |           83          
      shape torch.Size([1024, 32, 32])    |          766600         |           22670            |             123           |            260           |                 165                |          120        
      shape torch.Size([1, 64, 64])       |             760         |              33            |              58           |             33           |                  49                |           34       
      shape torch.Size([2, 64, 64])       |            1520         |              60            |              60           |             60           |                  67                |           59       
      shape torch.Size([4, 64, 64])       |            3016         |             115            |              67           |            119           |                  90                |           66       
      shape torch.Size([8, 64, 64])       |            6120         |             230            |              67           |            233           |                 153                |           68       
      shape torch.Size([16, 64, 64])      |           12100         |             457            |              69           |             86           |                  77                |           69       
      shape torch.Size([32, 64, 64])      |           24000         |             912            |              74           |             95           |                  80                |           74       
      shape torch.Size([64, 64, 64])      |           48000         |            1833            |              76           |            106           |                  82                |           76       
      shape torch.Size([128, 64, 64])     |           95000         |            3636            |              97           |            163           |                 108                |           97       
      shape torch.Size([512, 64, 64])     |          380600         |           15600            |             210           |            464           |                 278                |          210       
      shape torch.Size([1024, 64, 64])    |          761200         |           31140            |             308           |            741           |                 377                |          308       
      shape torch.Size([1, 128, 128])     |             756         |              46            |             120           |             47           |                  49                |           46       
      shape torch.Size([2, 128, 128])     |            1500         |              91            |             123           |             89           |                  67                |           92       
      shape torch.Size([4, 128, 128])     |            2994         |             178            |             139           |            180           |                 117                |          139       
      shape torch.Size([8, 128, 128])     |            5960         |             350            |             140           |            354           |                 214                |          142       
      shape torch.Size([16, 128, 128])    |           12000         |             701            |             144           |            177           |                 148                |          143       
      shape torch.Size([32, 128, 128])    |           23870         |            1401            |             155           |            225           |                 159                |          155       
      shape torch.Size([64, 128, 128])    |           47600         |            2948            |             179           |            288           |                 184                |          180       
      shape torch.Size([128, 128, 128])   |           96000         |            5910            |             231           |            442           |                 242                |          231       
      shape torch.Size([512, 128, 128])   |          381200         |           23640            |             519           |           1400           |                 643                |          519       
      shape torch.Size([1024, 128, 128])  |          755800         |           47340            |             794           |           2436           |                 930                |          794       
      shape torch.Size([1, 256, 256])     |             760         |              74            |             246           |             77           |                  62                |           78       
      shape torch.Size([2, 256, 256])     |            1510         |             150            |             256           |            150           |                 106                |          150       
      shape torch.Size([4, 256, 256])     |            3030         |             296            |             284           |            296           |                 183                |          284       
      shape torch.Size([8, 256, 256])     |            6100         |             588            |             286           |            592           |                 337                |          288       
      shape torch.Size([16, 256, 256])    |           12200         |            1238            |             330           |            445           |                 330                |          330       
      shape torch.Size([32, 256, 256])    |           24430         |            2476            |             368           |            568           |                 365                |          367       
      shape torch.Size([64, 256, 256])    |           49000         |            4950            |             415           |            800           |                 408                |          414       
      shape torch.Size([128, 256, 256])   |           96000         |            9900            |             552           |           1330           |                 561                |          553       
      shape torch.Size([512, 256, 256])   |          369400         |           39580            |            1410           |           4614           |                1649                |         1410       
      shape torch.Size([1024, 256, 256])  |          716200         |           79200            |            2270           |           8472           |                2528                |         2277      

Times are in microseconds (us).
```

</details>

To generate the results below, I put the backend I wanted to test at the beginning of the function `lu_solve_kernel`, followed by a `return;`. Then I run the following script, changing the variable `name`.
<details>
<summary>
Benchmarking script
</summary>

```python
import torch
import pickle
import itertools
from functools import partial
from torch.utils.benchmark import Timer, Compare

name = "heuristic"
label = "lu_solve {}".format(name)
shapes = [1, 2, 8, 16, 32, 64, 128, 256]
batches = [(1,), (2,), (4,), (8,), (16,), (32,), (64,), (128,), (512,), (1024,)]
results = []
make_arg = partial(torch.randn, dtype=torch.float32, device="cuda")

for n, batch in itertools.product(shapes, batches):
    LU, pivots = torch.linalg.lu_factor(make_arg(batch + (n, n)))
    B = make_arg(batch + (n, 1))
    print(LU.shape)
    stmt = "torch.linalg.lu_solve(LU, pivots, B, adjoint=adjoint)"
    for adjoint in (True, False):
        timer = Timer(stmt,
                      globals=globals(),
                      label="linalg.lu_solve CUDA{}".format(" Adjoint" if adjoint else ""),
                      description=label,
                      sub_label=f"shape {LU.shape}",
                      num_threads=1)
        results.append(timer.blocked_autorange())


compare = Compare(results)
compare.trim_significant_figures()
compare.print()

with open("{}_lu_solve.pickle".format(name), 'wb') as f:
    pickle.dump(results, f)
```
</details>

Finally, I joined all the results with the following script:
<details>
<summary>
Script to join the results
</summary>

```python
import pickle
from torch.utils.benchmark import Timer, Compare

files = [
         "looped_magma",
         "looped cusolver",
         "batched cublas",
         "batched magma",
         "unpack+solve_triangular",
         "heuristic",
        ]

timers = []
for name in files:
    with open("{}_lu_solve.pickle".format(name), 'rb') as f:
        timers += pickle.load(f)

compare = Compare(timers)
compare.trim_significant_figures()
compare.print()
```
</details>

### Fix for Magma's batched lu_solve when `adjoint=True`
I also developed the following fix around MAGMA's bug, but I ended up not using it, and preferring the triangular solves over it, as they were faster. I'm leaving it here in case it's useful in the future.

<details>
<summary>
Fix for MAGMA's issue with `adjoint=True`
</summary>

```cpp
auto lu_solve_batched_magma_fn = [m](const Tensor& LU, const Tensor& pivots, const Tensor& B, TransposeType trans) {
    if (trans == TransposeType::NoTranspose) {
      lu_solve_batched_magma(LU, pivots, B, trans);
      return;
    }
    // There's a bug in magma for the other cases, so we need to properly perform mT or mH on LU
    // The LU of the transpose is not the transpose of the LU
    // We need to do LU = LDU' = L'U' where L' = LD, U' = D^{-1}U and D = diag(U)
    auto diag = LU.diagonal(0, -2, -1);
    auto LU_f = LU.tril(-1).mul_(diag.unsqueeze(-2)) +
                LU.triu(1).div_(diag.unsqueeze(-1));
    LU_f.diagonal(0, -2, -1).copy_(diag);

    if (trans == TransposeType::ConjTranspose) {
      LU_f = LU_f.conj_physical();
    }
    LU_f.transpose(-2, -1);
    // At this point LU_f is F-contiguous, because triu / tril / conj_phisical return contiguous tensors
    TORCH_INTERNAL_ASSERT_DEBUG_ONLY(LU_f.mT().is_contiguous());

    // Trivial permutation
    auto pivots_aux = at::arange(1, m + 1, pivots.options()).expand_as(pivots).contiguous();

    lu_solve_batched_magma(LU_f, pivots_aux, B, TransposeType::NoTranspose);

    // We then need to multiply B by P on the right as (PLU)^T = B iff U^TL^T = BP
    // Fill `perm` with the identity permutation (perhaps batched)
    // This is faster than torch.lu_unpack + matmul, as this logic is borrowed from lu_unpack
    const auto perm = at::arange(m, pivots.options().dtype(kLong)).expand(pivots.sizes()).contiguous();
    auto iter = TensorIteratorConfig()
      .set_check_mem_overlap(false)
      .check_all_same_dtype(false)
      .resize_outputs(false)
      .declare_static_shape(pivots.sizes(), /*squash_dim=*/pivots.dim() - 1)
      .add_output(perm)
      .add_input(pivots)
      .build();
    unpack_pivots_stub(pivots.device().type(), iter, m);
    B.scatter_(-2, perm.unsqueeze(-1).expand_as(B), B.clone());
};
```

</details>



Fixes #61657

[ghstack-poisoned]
@lezcano
Copy link
Copy Markdown
Collaborator Author

lezcano commented May 5, 2022

@pytorchmergebot merge this please

@pytorchmergebot
Copy link
Copy Markdown
Collaborator

Merge failed due to Matched rule superuser, but it was not reviewed yet by any of:dreiss,laurencer,Adolfo-Karim,cheetah2216,mvsampath, ...
Raised by https://github.com/pytorch/pytorch/actions/runs/2274857117

lezcano added 2 commits May 5, 2022 10:07
This PR adds `linalg.lu_solve`. While doing so, I found a bug in MAGMA
when calling the batched MAGMA backend with trans=True. We work around
that by solving the system solving two triangular systems.

We also update the heuristics for this function, as they were fairly
outdated. We found that cuSolver is king, so luckily we do not need to
rely on the buggy backend from magma for this function.

We added tests testing this function left and right. We also added tests
for the different backends. We also activated the tests for AMD, as
those should work as well.

### Benchmarking

<details>
<summary>
Benchmark Results (adjoint=False)
</summary>

```
--------------------------------------------------------------------------------------------- linalg.lu_solve CUDA ---------------------------------------------------------------------------------------------]   
                                          |  lu_solve looped_magma  |  lu_solve looped cusolver  |  lu_solve batched cublas  |  lu_solve batched magma  |  lu_solve unpack+solve_triangular  |  lu_solve heuristic   
1 threads: -------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------   
      shape torch.Size([1, 1, 1])         |             750         |              34            |              28           |            252           |                  78                |           27          
      shape torch.Size([2, 1, 1])         |            1500         |              50            |              28           |            239           |                  85                |           27          
      shape torch.Size([4, 1, 1])         |            2995         |              83            |              28           |            239           |                  85                |           27          
      shape torch.Size([8, 1, 1])         |            6000         |             146            |              28           |            239           |                  85                |           27          
      shape torch.Size([16, 1, 1])        |           11900         |             272            |              28           |            241           |                  85                |           27          
      shape torch.Size([32, 1, 1])        |           23880         |             524            |              28           |            244           |                  85                |           27          
      shape torch.Size([64, 1, 1])        |           48000         |            1000            |              28           |            245           |                  85                |           27          
      shape torch.Size([128, 1, 1])       |           96000         |            2054            |              28           |            242           |                  86                |           27          
      shape torch.Size([512, 1, 1])       |          381900         |            8100            |              28           |            250           |                  85                |           27          
      shape torch.Size([1024, 1, 1])      |          763800         |           16200            |              28           |            257           |                  86                |           27          
      shape torch.Size([1, 2, 2])         |             750         |              33            |              28           |            240           |                  82                |           27          
      shape torch.Size([2, 2, 2])         |            1500         |              51            |              28           |            240           |                  88                |           27          
      shape torch.Size([4, 2, 2])         |            2991         |              82            |              28           |            241           |                  88                |           28          
      shape torch.Size([8, 2, 2])         |            6000         |             150            |              28           |            241           |                  88                |           27          
      shape torch.Size([16, 2, 2])        |           12000         |             275            |              28           |            242           |                  88                |           27          
      shape torch.Size([32, 2, 2])        |           23980         |             530            |              28           |            246           |                  90                |           28          
      shape torch.Size([64, 2, 2])        |           48000         |            1000            |              28           |            244           |                  89                |           27          
      shape torch.Size([128, 2, 2])       |           96000         |            2063            |              28           |            245           |                  89                |           28          
      shape torch.Size([512, 2, 2])       |          382000         |            8300            |              28           |            257           |                  89                |           28          
      shape torch.Size([1024, 2, 2])      |          764000         |           20000            |              28           |            271           |                  88                |           28          
      shape torch.Size([1, 8, 8])         |             749         |              34            |              28           |            243           |                  82                |           28          
      shape torch.Size([2, 8, 8])         |            1500         |              50            |              28           |            244           |                  89                |           28          
      shape torch.Size([4, 8, 8])         |            2988         |              83            |              28           |            244           |                  89                |           28          
      shape torch.Size([8, 8, 8])         |            5980         |             150            |              28           |            245           |                  88                |           28          
      shape torch.Size([16, 8, 8])        |           12000         |             278            |              28           |            246           |                  89                |           28          
      shape torch.Size([32, 8, 8])        |           23910         |             536            |              28           |            249           |                  89                |           28          
      shape torch.Size([64, 8, 8])        |           47800         |            1100            |              28           |            247           |                  89                |           28          
      shape torch.Size([128, 8, 8])       |           96000         |            2075            |              28           |            248           |                  89                |           28          
      shape torch.Size([512, 8, 8])       |          382100         |            8300            |              28           |            270           |                  89                |           28          
      shape torch.Size([1024, 8, 8])      |          764100         |           16400            |              28           |            291           |                  89                |           28          
      shape torch.Size([1, 16, 16])       |             750         |              33            |              28           |            248           |                  82                |           28       
      shape torch.Size([2, 16, 16])       |            1500         |              50            |              28           |            250           |                  89                |           28       
      shape torch.Size([4, 16, 16])       |            2996         |              83            |              28           |            250           |                  89                |           28       
      shape torch.Size([8, 16, 16])       |            5980         |             147            |              28           |            251           |                  90                |           28       
      shape torch.Size([16, 16, 16])      |           11900         |             274            |              28           |            251           |                  89                |           28       
      shape torch.Size([32, 16, 16])      |           24040         |             527            |              28           |            252           |                  89                |           28       
      shape torch.Size([64, 16, 16])      |           47800         |            1037            |              28           |            251           |                  89                |           28       
      shape torch.Size([128, 16, 16])     |           95600         |            2044            |              28           |            252           |                  89                |           28       
      shape torch.Size([512, 16, 16])     |          388200         |            8100            |              28           |            280           |                  88                |           28       
      shape torch.Size([1024, 16, 16])    |          769700         |           16000            |              28           |            322           |                  89                |           28       
      shape torch.Size([1, 32, 32])       |             760         |              33            |              28           |            255           |                  82                |           28       
      shape torch.Size([2, 32, 32])       |            1510         |              50            |              28           |            256           |                  89                |           28       
      shape torch.Size([4, 32, 32])       |            3022         |              82            |              31           |            256           |                  89                |           30       
      shape torch.Size([8, 32, 32])       |            6000         |             140            |              31           |            257           |                  89                |           31       
      shape torch.Size([16, 32, 32])      |           12000         |             281            |              31           |            258           |                  89                |           31       
      shape torch.Size([32, 32, 32])      |           24150         |             563            |              35           |            258           |                  89                |           35       
      shape torch.Size([64, 32, 32])      |           48300         |            1119            |              36           |            258           |                  90                |           36       
      shape torch.Size([128, 32, 32])     |           96500         |            2235            |              43           |            261           |                  88                |           43       
      shape torch.Size([512, 32, 32])     |          383100         |            8930            |              82           |            317           |                 137                |           82       
      shape torch.Size([1024, 32, 32])    |          766300         |           19200            |             122           |            400           |                 187                |          122       
      shape torch.Size([1, 64, 64])       |             760         |              33            |              55           |            272           |                  71                |           34       
      shape torch.Size([2, 64, 64])       |            1500         |              52            |              58           |            273           |                  90                |           52       
      shape torch.Size([4, 64, 64])       |            3127         |             102            |              65           |            273           |                 110                |           65       
      shape torch.Size([8, 64, 64])       |            6070         |             201            |              65           |            275           |                 160                |           65       
      shape torch.Size([16, 64, 64])      |           12000         |             399            |              66           |            274           |                  88                |           67       
      shape torch.Size([32, 64, 64])      |           23900         |             796            |              73           |            275           |                  91                |           73       
      shape torch.Size([64, 64, 64])      |           48000         |            1594            |              75           |            283           |                  94                |           76       
      shape torch.Size([128, 64, 64])     |           95000         |            3177            |              96           |            292           |                 122                |           96       
      shape torch.Size([512, 64, 64])     |          379300         |           13520            |             208           |            426           |                 313                |          208       
      shape torch.Size([1024, 64, 64])    |          758700         |           27100            |             306           |            570           |                 437                |          306       
      shape torch.Size([1, 128, 128])     |             750         |              42            |             115           |            306           |                  71                |           42       
      shape torch.Size([2, 128, 128])     |            1500         |              82            |             122           |            307           |                  90                |           83       
      shape torch.Size([4, 128, 128])     |            2966         |             162            |             136           |            307           |                 117                |          136       
      shape torch.Size([8, 128, 128])     |            5930         |             317            |             137           |            308           |                 202                |          138       
      shape torch.Size([16, 128, 128])    |           12000         |             635            |             143           |            316           |                 162                |          143       
      shape torch.Size([32, 128, 128])    |           23700         |            1266            |             152           |            322           |                 168                |          152       
      shape torch.Size([64, 128, 128])    |           48000         |            2668            |             177           |            337           |                 196                |          177       
      shape torch.Size([128, 128, 128])   |           96000         |            5366            |             228           |            365           |                 260                |          228       
      shape torch.Size([512, 128, 128])   |          379400         |           21490            |             502           |            620           |                 669                |          502       
      shape torch.Size([1024, 128, 128])  |          755700         |           43040            |             764           |            903           |                 978                |          770       
      shape torch.Size([1, 256, 256])     |             750         |              70            |             235           |            383           |                  78                |           72       
      shape torch.Size([2, 256, 256])     |            2000         |             138            |             250           |            384           |                 120                |          139       
      shape torch.Size([4, 256, 256])     |            2988         |             277            |             279           |            404           |                 191                |          278       
      shape torch.Size([8, 256, 256])     |            6100         |             546            |             283           |            420           |                 332                |          286       
      shape torch.Size([16, 256, 256])    |           12100         |            1149            |             330           |            441           |                 364                |          330       
      shape torch.Size([32, 256, 256])    |           24040         |            2303            |             359           |            453           |                 401                |          360       
      shape torch.Size([64, 256, 256])    |           48000         |            4626            |             408           |            472           |                 466                |          408       
      shape torch.Size([128, 256, 256])   |           94700         |            9247            |             543           |            543           |                 658                |          543       
      shape torch.Size([512, 256, 256])   |          372000         |           37030            |            1310           |           1185           |                1896                |         1310       
      shape torch.Size([1024, 256, 256])  |          747200         |           74100            |            2116           |           1910           |                3051                |         2122 
```
</details>

<details>
<summary>
Benchmark Results (adjoint=True)
</summary>

```
[----------------------------------------------------------------------------------------- linalg.lu_solve CUDA Adjoint -----------------------------------------------------------------------------------------]   
                                          |  lu_solve looped_magma  |  lu_solve looped cusolver  |  lu_solve batched cublas  |  lu_solve batched magma  |  lu_solve unpack+solve_triangular  |  lu_solve heuristic   
1 threads: -------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------   
      shape torch.Size([1, 1, 1])         |             749         |              34            |              28           |             33           |                  60                |           27          
      shape torch.Size([2, 1, 1])         |            1500         |              50            |              28           |             50           |                  67                |           27          
      shape torch.Size([4, 1, 1])         |            3005         |              82            |              28           |             81           |                  67                |           27          
      shape torch.Size([8, 1, 1])         |            5999         |             145            |              28           |            140           |                  71                |           27          
      shape torch.Size([16, 1, 1])        |           12000         |             273            |              28           |             77           |                  67                |           27          
      shape torch.Size([32, 1, 1])        |           24000         |             522            |              28           |             78           |                  67                |           27          
      shape torch.Size([64, 1, 1])        |           48000         |            1000            |              28           |             77           |                  67                |           27          
      shape torch.Size([128, 1, 1])       |           96000         |            2029            |              28           |             78           |                  67                |           27          
      shape torch.Size([512, 1, 1])       |          383300         |            8100            |              28           |             78           |                  67                |           28          
      shape torch.Size([1024, 1, 1])      |          767500         |           16100            |              28           |             77           |                  68                |           27          
      shape torch.Size([1, 2, 2])         |             753         |              33            |              28           |             33           |                  60                |           28          
      shape torch.Size([2, 2, 2])         |            1500         |              50            |              28           |             50           |                  67                |           28          
      shape torch.Size([4, 2, 2])         |            3002         |              82            |              28           |             80           |                  67                |           27          
      shape torch.Size([8, 2, 2])         |            6000         |             145            |              28           |            144           |                  67                |           27          
      shape torch.Size([16, 2, 2])        |           12000         |             271            |              28           |             78           |                  66                |           27          
      shape torch.Size([32, 2, 2])        |           24120         |             524            |              28           |             78           |                  69                |           28          
      shape torch.Size([64, 2, 2])        |           48300         |            1030            |              28           |             78           |                  66                |           27          
      shape torch.Size([128, 2, 2])       |           96100         |            2041            |              28           |             78           |                  67                |           28          
      shape torch.Size([512, 2, 2])       |          383000         |            8100            |              28           |             79           |                  67                |           28          
      shape torch.Size([1024, 2, 2])      |          766100         |           16000            |              28           |             78           |                  67                |           28          
      shape torch.Size([1, 8, 8])         |             750         |              34            |              28           |             34           |                  60                |           28          
      shape torch.Size([2, 8, 8])         |            1500         |              50            |              28           |             50           |                  67                |           28          
      shape torch.Size([4, 8, 8])         |            2998         |              82            |              28           |             82           |                  67                |           28          
      shape torch.Size([8, 8, 8])         |            5990         |             146            |              28           |            150           |                  66                |           28          
      shape torch.Size([16, 8, 8])        |           11980         |             272            |              28           |             79           |                  67                |           28          
      shape torch.Size([32, 8, 8])        |           23970         |             530            |              28           |             79           |                  67                |           28          
      shape torch.Size([64, 8, 8])        |           47900         |            1040            |              28           |             79           |                  67                |           28          
      shape torch.Size([128, 8, 8])       |           96000         |            2048            |              28           |             78           |                  67                |           28          
      shape torch.Size([512, 8, 8])       |          383700         |            8100            |              28           |             80           |                  67                |           28          
      shape torch.Size([1024, 8, 8])      |          766200         |           16300            |              28           |             80           |                  68                |           28          
      shape torch.Size([1, 16, 16])       |             760         |              33            |              28           |             34           |                  60                |           28          
      shape torch.Size([2, 16, 16])       |            1500         |              50            |              28           |             50           |                  67                |           28          
      shape torch.Size([4, 16, 16])       |            3001         |              81            |              28           |             82           |                  67                |           28          
      shape torch.Size([8, 16, 16])       |            6000         |             145            |              28           |            140           |                  67                |           28          
      shape torch.Size([16, 16, 16])      |           12000         |             276            |              28           |             79           |                  67                |           28          
      shape torch.Size([32, 16, 16])      |           23870         |             549            |              28           |             79           |                  67                |           28          
      shape torch.Size([64, 16, 16])      |           47900         |            1098            |              29           |             80           |                  68                |           28          
      shape torch.Size([128, 16, 16])     |           95800         |            2184            |              28           |             79           |                  68                |           28          
      shape torch.Size([512, 16, 16])     |          386900         |            8769            |              28           |             80           |                  67                |           28          
      shape torch.Size([1024, 16, 16])    |          769800         |           17460            |              37           |             80           |                  67                |           37          
      shape torch.Size([1, 32, 32])       |             760         |              33            |              28           |             34           |                  60                |           28          
      shape torch.Size([2, 32, 32])       |            1500         |              50            |              28           |             50           |                  67                |           29          
      shape torch.Size([4, 32, 32])       |            3021         |              86            |              31           |             84           |                  67                |           32          
      shape torch.Size([8, 32, 32])       |            6040         |             167            |              32           |            167           |                  67                |           32          
      shape torch.Size([16, 32, 32])      |           12100         |             330            |              33           |             78           |                  67                |           33          
      shape torch.Size([32, 32, 32])      |           24150         |             662            |              35           |             78           |                  66                |           35          
      shape torch.Size([64, 32, 32])      |           48200         |            1323            |              36           |             79           |                  67                |           36          
      shape torch.Size([128, 32, 32])     |           97000         |            2637            |              44           |             79           |                  67                |           43          
      shape torch.Size([512, 32, 32])     |          382500         |           10580            |              83           |            180           |                 121                |           83          
      shape torch.Size([1024, 32, 32])    |          766600         |           22670            |             123           |            260           |                 165                |          120        
      shape torch.Size([1, 64, 64])       |             760         |              33            |              58           |             33           |                  49                |           34       
      shape torch.Size([2, 64, 64])       |            1520         |              60            |              60           |             60           |                  67                |           59       
      shape torch.Size([4, 64, 64])       |            3016         |             115            |              67           |            119           |                  90                |           66       
      shape torch.Size([8, 64, 64])       |            6120         |             230            |              67           |            233           |                 153                |           68       
      shape torch.Size([16, 64, 64])      |           12100         |             457            |              69           |             86           |                  77                |           69       
      shape torch.Size([32, 64, 64])      |           24000         |             912            |              74           |             95           |                  80                |           74       
      shape torch.Size([64, 64, 64])      |           48000         |            1833            |              76           |            106           |                  82                |           76       
      shape torch.Size([128, 64, 64])     |           95000         |            3636            |              97           |            163           |                 108                |           97       
      shape torch.Size([512, 64, 64])     |          380600         |           15600            |             210           |            464           |                 278                |          210       
      shape torch.Size([1024, 64, 64])    |          761200         |           31140            |             308           |            741           |                 377                |          308       
      shape torch.Size([1, 128, 128])     |             756         |              46            |             120           |             47           |                  49                |           46       
      shape torch.Size([2, 128, 128])     |            1500         |              91            |             123           |             89           |                  67                |           92       
      shape torch.Size([4, 128, 128])     |            2994         |             178            |             139           |            180           |                 117                |          139       
      shape torch.Size([8, 128, 128])     |            5960         |             350            |             140           |            354           |                 214                |          142       
      shape torch.Size([16, 128, 128])    |           12000         |             701            |             144           |            177           |                 148                |          143       
      shape torch.Size([32, 128, 128])    |           23870         |            1401            |             155           |            225           |                 159                |          155       
      shape torch.Size([64, 128, 128])    |           47600         |            2948            |             179           |            288           |                 184                |          180       
      shape torch.Size([128, 128, 128])   |           96000         |            5910            |             231           |            442           |                 242                |          231       
      shape torch.Size([512, 128, 128])   |          381200         |           23640            |             519           |           1400           |                 643                |          519       
      shape torch.Size([1024, 128, 128])  |          755800         |           47340            |             794           |           2436           |                 930                |          794       
      shape torch.Size([1, 256, 256])     |             760         |              74            |             246           |             77           |                  62                |           78       
      shape torch.Size([2, 256, 256])     |            1510         |             150            |             256           |            150           |                 106                |          150       
      shape torch.Size([4, 256, 256])     |            3030         |             296            |             284           |            296           |                 183                |          284       
      shape torch.Size([8, 256, 256])     |            6100         |             588            |             286           |            592           |                 337                |          288       
      shape torch.Size([16, 256, 256])    |           12200         |            1238            |             330           |            445           |                 330                |          330       
      shape torch.Size([32, 256, 256])    |           24430         |            2476            |             368           |            568           |                 365                |          367       
      shape torch.Size([64, 256, 256])    |           49000         |            4950            |             415           |            800           |                 408                |          414       
      shape torch.Size([128, 256, 256])   |           96000         |            9900            |             552           |           1330           |                 561                |          553       
      shape torch.Size([512, 256, 256])   |          369400         |           39580            |            1410           |           4614           |                1649                |         1410       
      shape torch.Size([1024, 256, 256])  |          716200         |           79200            |            2270           |           8472           |                2528                |         2277      

Times are in microseconds (us).
```

</details>

To generate the results below, I put the backend I wanted to test at the beginning of the function `lu_solve_kernel`, followed by a `return;`. Then I run the following script, changing the variable `name`.
<details>
<summary>
Benchmarking script
</summary>

```python
import torch
import pickle
import itertools
from functools import partial
from torch.utils.benchmark import Timer, Compare

name = "heuristic"
label = "lu_solve {}".format(name)
shapes = [1, 2, 8, 16, 32, 64, 128, 256]
batches = [(1,), (2,), (4,), (8,), (16,), (32,), (64,), (128,), (512,), (1024,)]
results = []
make_arg = partial(torch.randn, dtype=torch.float32, device="cuda")

for n, batch in itertools.product(shapes, batches):
    LU, pivots = torch.linalg.lu_factor(make_arg(batch + (n, n)))
    B = make_arg(batch + (n, 1))
    print(LU.shape)
    stmt = "torch.linalg.lu_solve(LU, pivots, B, adjoint=adjoint)"
    for adjoint in (True, False):
        timer = Timer(stmt,
                      globals=globals(),
                      label="linalg.lu_solve CUDA{}".format(" Adjoint" if adjoint else ""),
                      description=label,
                      sub_label=f"shape {LU.shape}",
                      num_threads=1)
        results.append(timer.blocked_autorange())


compare = Compare(results)
compare.trim_significant_figures()
compare.print()

with open("{}_lu_solve.pickle".format(name), 'wb') as f:
    pickle.dump(results, f)
```
</details>

Finally, I joined all the results with the following script:
<details>
<summary>
Script to join the results
</summary>

```python
import pickle
from torch.utils.benchmark import Timer, Compare

files = [
         "looped_magma",
         "looped cusolver",
         "batched cublas",
         "batched magma",
         "unpack+solve_triangular",
         "heuristic",
        ]

timers = []
for name in files:
    with open("{}_lu_solve.pickle".format(name), 'rb') as f:
        timers += pickle.load(f)

compare = Compare(timers)
compare.trim_significant_figures()
compare.print()
```
</details>

### Fix for Magma's batched lu_solve when `adjoint=True`
I also developed the following fix around MAGMA's bug, but I ended up not using it, and preferring the triangular solves over it, as they were faster. I'm leaving it here in case it's useful in the future.

<details>
<summary>
Fix for MAGMA's issue with `adjoint=True`
</summary>

```cpp
auto lu_solve_batched_magma_fn = [m](const Tensor& LU, const Tensor& pivots, const Tensor& B, TransposeType trans) {
    if (trans == TransposeType::NoTranspose) {
      lu_solve_batched_magma(LU, pivots, B, trans);
      return;
    }
    // There's a bug in magma for the other cases, so we need to properly perform mT or mH on LU
    // The LU of the transpose is not the transpose of the LU
    // We need to do LU = LDU' = L'U' where L' = LD, U' = D^{-1}U and D = diag(U)
    auto diag = LU.diagonal(0, -2, -1);
    auto LU_f = LU.tril(-1).mul_(diag.unsqueeze(-2)) +
                LU.triu(1).div_(diag.unsqueeze(-1));
    LU_f.diagonal(0, -2, -1).copy_(diag);

    if (trans == TransposeType::ConjTranspose) {
      LU_f = LU_f.conj_physical();
    }
    LU_f.transpose(-2, -1);
    // At this point LU_f is F-contiguous, because triu / tril / conj_phisical return contiguous tensors
    TORCH_INTERNAL_ASSERT_DEBUG_ONLY(LU_f.mT().is_contiguous());

    // Trivial permutation
    auto pivots_aux = at::arange(1, m + 1, pivots.options()).expand_as(pivots).contiguous();

    lu_solve_batched_magma(LU_f, pivots_aux, B, TransposeType::NoTranspose);

    // We then need to multiply B by P on the right as (PLU)^T = B iff U^TL^T = BP
    // Fill `perm` with the identity permutation (perhaps batched)
    // This is faster than torch.lu_unpack + matmul, as this logic is borrowed from lu_unpack
    const auto perm = at::arange(m, pivots.options().dtype(kLong)).expand(pivots.sizes()).contiguous();
    auto iter = TensorIteratorConfig()
      .set_check_mem_overlap(false)
      .check_all_same_dtype(false)
      .resize_outputs(false)
      .declare_static_shape(pivots.sizes(), /*squash_dim=*/pivots.dim() - 1)
      .add_output(perm)
      .add_input(pivots)
      .build();
    unpack_pivots_stub(pivots.device().type(), iter, m);
    B.scatter_(-2, perm.unsqueeze(-1).expand_as(B), B.clone());
};
```

</details>



Fixes #61657

[ghstack-poisoned]
This PR adds `linalg.lu_solve`. While doing so, I found a bug in MAGMA
when calling the batched MAGMA backend with trans=True. We work around
that by solving the system solving two triangular systems.

We also update the heuristics for this function, as they were fairly
outdated. We found that cuSolver is king, so luckily we do not need to
rely on the buggy backend from magma for this function.

We added tests testing this function left and right. We also added tests
for the different backends. We also activated the tests for AMD, as
those should work as well.

### Benchmarking

<details>
<summary>
Benchmark Results (adjoint=False)
</summary>

```
--------------------------------------------------------------------------------------------- linalg.lu_solve CUDA ---------------------------------------------------------------------------------------------]   
                                          |  lu_solve looped_magma  |  lu_solve looped cusolver  |  lu_solve batched cublas  |  lu_solve batched magma  |  lu_solve unpack+solve_triangular  |  lu_solve heuristic   
1 threads: -------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------   
      shape torch.Size([1, 1, 1])         |             750         |              34            |              28           |            252           |                  78                |           27          
      shape torch.Size([2, 1, 1])         |            1500         |              50            |              28           |            239           |                  85                |           27          
      shape torch.Size([4, 1, 1])         |            2995         |              83            |              28           |            239           |                  85                |           27          
      shape torch.Size([8, 1, 1])         |            6000         |             146            |              28           |            239           |                  85                |           27          
      shape torch.Size([16, 1, 1])        |           11900         |             272            |              28           |            241           |                  85                |           27          
      shape torch.Size([32, 1, 1])        |           23880         |             524            |              28           |            244           |                  85                |           27          
      shape torch.Size([64, 1, 1])        |           48000         |            1000            |              28           |            245           |                  85                |           27          
      shape torch.Size([128, 1, 1])       |           96000         |            2054            |              28           |            242           |                  86                |           27          
      shape torch.Size([512, 1, 1])       |          381900         |            8100            |              28           |            250           |                  85                |           27          
      shape torch.Size([1024, 1, 1])      |          763800         |           16200            |              28           |            257           |                  86                |           27          
      shape torch.Size([1, 2, 2])         |             750         |              33            |              28           |            240           |                  82                |           27          
      shape torch.Size([2, 2, 2])         |            1500         |              51            |              28           |            240           |                  88                |           27          
      shape torch.Size([4, 2, 2])         |            2991         |              82            |              28           |            241           |                  88                |           28          
      shape torch.Size([8, 2, 2])         |            6000         |             150            |              28           |            241           |                  88                |           27          
      shape torch.Size([16, 2, 2])        |           12000         |             275            |              28           |            242           |                  88                |           27          
      shape torch.Size([32, 2, 2])        |           23980         |             530            |              28           |            246           |                  90                |           28          
      shape torch.Size([64, 2, 2])        |           48000         |            1000            |              28           |            244           |                  89                |           27          
      shape torch.Size([128, 2, 2])       |           96000         |            2063            |              28           |            245           |                  89                |           28          
      shape torch.Size([512, 2, 2])       |          382000         |            8300            |              28           |            257           |                  89                |           28          
      shape torch.Size([1024, 2, 2])      |          764000         |           20000            |              28           |            271           |                  88                |           28          
      shape torch.Size([1, 8, 8])         |             749         |              34            |              28           |            243           |                  82                |           28          
      shape torch.Size([2, 8, 8])         |            1500         |              50            |              28           |            244           |                  89                |           28          
      shape torch.Size([4, 8, 8])         |            2988         |              83            |              28           |            244           |                  89                |           28          
      shape torch.Size([8, 8, 8])         |            5980         |             150            |              28           |            245           |                  88                |           28          
      shape torch.Size([16, 8, 8])        |           12000         |             278            |              28           |            246           |                  89                |           28          
      shape torch.Size([32, 8, 8])        |           23910         |             536            |              28           |            249           |                  89                |           28          
      shape torch.Size([64, 8, 8])        |           47800         |            1100            |              28           |            247           |                  89                |           28          
      shape torch.Size([128, 8, 8])       |           96000         |            2075            |              28           |            248           |                  89                |           28          
      shape torch.Size([512, 8, 8])       |          382100         |            8300            |              28           |            270           |                  89                |           28          
      shape torch.Size([1024, 8, 8])      |          764100         |           16400            |              28           |            291           |                  89                |           28          
      shape torch.Size([1, 16, 16])       |             750         |              33            |              28           |            248           |                  82                |           28       
      shape torch.Size([2, 16, 16])       |            1500         |              50            |              28           |            250           |                  89                |           28       
      shape torch.Size([4, 16, 16])       |            2996         |              83            |              28           |            250           |                  89                |           28       
      shape torch.Size([8, 16, 16])       |            5980         |             147            |              28           |            251           |                  90                |           28       
      shape torch.Size([16, 16, 16])      |           11900         |             274            |              28           |            251           |                  89                |           28       
      shape torch.Size([32, 16, 16])      |           24040         |             527            |              28           |            252           |                  89                |           28       
      shape torch.Size([64, 16, 16])      |           47800         |            1037            |              28           |            251           |                  89                |           28       
      shape torch.Size([128, 16, 16])     |           95600         |            2044            |              28           |            252           |                  89                |           28       
      shape torch.Size([512, 16, 16])     |          388200         |            8100            |              28           |            280           |                  88                |           28       
      shape torch.Size([1024, 16, 16])    |          769700         |           16000            |              28           |            322           |                  89                |           28       
      shape torch.Size([1, 32, 32])       |             760         |              33            |              28           |            255           |                  82                |           28       
      shape torch.Size([2, 32, 32])       |            1510         |              50            |              28           |            256           |                  89                |           28       
      shape torch.Size([4, 32, 32])       |            3022         |              82            |              31           |            256           |                  89                |           30       
      shape torch.Size([8, 32, 32])       |            6000         |             140            |              31           |            257           |                  89                |           31       
      shape torch.Size([16, 32, 32])      |           12000         |             281            |              31           |            258           |                  89                |           31       
      shape torch.Size([32, 32, 32])      |           24150         |             563            |              35           |            258           |                  89                |           35       
      shape torch.Size([64, 32, 32])      |           48300         |            1119            |              36           |            258           |                  90                |           36       
      shape torch.Size([128, 32, 32])     |           96500         |            2235            |              43           |            261           |                  88                |           43       
      shape torch.Size([512, 32, 32])     |          383100         |            8930            |              82           |            317           |                 137                |           82       
      shape torch.Size([1024, 32, 32])    |          766300         |           19200            |             122           |            400           |                 187                |          122       
      shape torch.Size([1, 64, 64])       |             760         |              33            |              55           |            272           |                  71                |           34       
      shape torch.Size([2, 64, 64])       |            1500         |              52            |              58           |            273           |                  90                |           52       
      shape torch.Size([4, 64, 64])       |            3127         |             102            |              65           |            273           |                 110                |           65       
      shape torch.Size([8, 64, 64])       |            6070         |             201            |              65           |            275           |                 160                |           65       
      shape torch.Size([16, 64, 64])      |           12000         |             399            |              66           |            274           |                  88                |           67       
      shape torch.Size([32, 64, 64])      |           23900         |             796            |              73           |            275           |                  91                |           73       
      shape torch.Size([64, 64, 64])      |           48000         |            1594            |              75           |            283           |                  94                |           76       
      shape torch.Size([128, 64, 64])     |           95000         |            3177            |              96           |            292           |                 122                |           96       
      shape torch.Size([512, 64, 64])     |          379300         |           13520            |             208           |            426           |                 313                |          208       
      shape torch.Size([1024, 64, 64])    |          758700         |           27100            |             306           |            570           |                 437                |          306       
      shape torch.Size([1, 128, 128])     |             750         |              42            |             115           |            306           |                  71                |           42       
      shape torch.Size([2, 128, 128])     |            1500         |              82            |             122           |            307           |                  90                |           83       
      shape torch.Size([4, 128, 128])     |            2966         |             162            |             136           |            307           |                 117                |          136       
      shape torch.Size([8, 128, 128])     |            5930         |             317            |             137           |            308           |                 202                |          138       
      shape torch.Size([16, 128, 128])    |           12000         |             635            |             143           |            316           |                 162                |          143       
      shape torch.Size([32, 128, 128])    |           23700         |            1266            |             152           |            322           |                 168                |          152       
      shape torch.Size([64, 128, 128])    |           48000         |            2668            |             177           |            337           |                 196                |          177       
      shape torch.Size([128, 128, 128])   |           96000         |            5366            |             228           |            365           |                 260                |          228       
      shape torch.Size([512, 128, 128])   |          379400         |           21490            |             502           |            620           |                 669                |          502       
      shape torch.Size([1024, 128, 128])  |          755700         |           43040            |             764           |            903           |                 978                |          770       
      shape torch.Size([1, 256, 256])     |             750         |              70            |             235           |            383           |                  78                |           72       
      shape torch.Size([2, 256, 256])     |            2000         |             138            |             250           |            384           |                 120                |          139       
      shape torch.Size([4, 256, 256])     |            2988         |             277            |             279           |            404           |                 191                |          278       
      shape torch.Size([8, 256, 256])     |            6100         |             546            |             283           |            420           |                 332                |          286       
      shape torch.Size([16, 256, 256])    |           12100         |            1149            |             330           |            441           |                 364                |          330       
      shape torch.Size([32, 256, 256])    |           24040         |            2303            |             359           |            453           |                 401                |          360       
      shape torch.Size([64, 256, 256])    |           48000         |            4626            |             408           |            472           |                 466                |          408       
      shape torch.Size([128, 256, 256])   |           94700         |            9247            |             543           |            543           |                 658                |          543       
      shape torch.Size([512, 256, 256])   |          372000         |           37030            |            1310           |           1185           |                1896                |         1310       
      shape torch.Size([1024, 256, 256])  |          747200         |           74100            |            2116           |           1910           |                3051                |         2122 
```
</details>

<details>
<summary>
Benchmark Results (adjoint=True)
</summary>

```
[----------------------------------------------------------------------------------------- linalg.lu_solve CUDA Adjoint -----------------------------------------------------------------------------------------]   
                                          |  lu_solve looped_magma  |  lu_solve looped cusolver  |  lu_solve batched cublas  |  lu_solve batched magma  |  lu_solve unpack+solve_triangular  |  lu_solve heuristic   
1 threads: -------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------   
      shape torch.Size([1, 1, 1])         |             749         |              34            |              28           |             33           |                  60                |           27          
      shape torch.Size([2, 1, 1])         |            1500         |              50            |              28           |             50           |                  67                |           27          
      shape torch.Size([4, 1, 1])         |            3005         |              82            |              28           |             81           |                  67                |           27          
      shape torch.Size([8, 1, 1])         |            5999         |             145            |              28           |            140           |                  71                |           27          
      shape torch.Size([16, 1, 1])        |           12000         |             273            |              28           |             77           |                  67                |           27          
      shape torch.Size([32, 1, 1])        |           24000         |             522            |              28           |             78           |                  67                |           27          
      shape torch.Size([64, 1, 1])        |           48000         |            1000            |              28           |             77           |                  67                |           27          
      shape torch.Size([128, 1, 1])       |           96000         |            2029            |              28           |             78           |                  67                |           27          
      shape torch.Size([512, 1, 1])       |          383300         |            8100            |              28           |             78           |                  67                |           28          
      shape torch.Size([1024, 1, 1])      |          767500         |           16100            |              28           |             77           |                  68                |           27          
      shape torch.Size([1, 2, 2])         |             753         |              33            |              28           |             33           |                  60                |           28          
      shape torch.Size([2, 2, 2])         |            1500         |              50            |              28           |             50           |                  67                |           28          
      shape torch.Size([4, 2, 2])         |            3002         |              82            |              28           |             80           |                  67                |           27          
      shape torch.Size([8, 2, 2])         |            6000         |             145            |              28           |            144           |                  67                |           27          
      shape torch.Size([16, 2, 2])        |           12000         |             271            |              28           |             78           |                  66                |           27          
      shape torch.Size([32, 2, 2])        |           24120         |             524            |              28           |             78           |                  69                |           28          
      shape torch.Size([64, 2, 2])        |           48300         |            1030            |              28           |             78           |                  66                |           27          
      shape torch.Size([128, 2, 2])       |           96100         |            2041            |              28           |             78           |                  67                |           28          
      shape torch.Size([512, 2, 2])       |          383000         |            8100            |              28           |             79           |                  67                |           28          
      shape torch.Size([1024, 2, 2])      |          766100         |           16000            |              28           |             78           |                  67                |           28          
      shape torch.Size([1, 8, 8])         |             750         |              34            |              28           |             34           |                  60                |           28          
      shape torch.Size([2, 8, 8])         |            1500         |              50            |              28           |             50           |                  67                |           28          
      shape torch.Size([4, 8, 8])         |            2998         |              82            |              28           |             82           |                  67                |           28          
      shape torch.Size([8, 8, 8])         |            5990         |             146            |              28           |            150           |                  66                |           28          
      shape torch.Size([16, 8, 8])        |           11980         |             272            |              28           |             79           |                  67                |           28          
      shape torch.Size([32, 8, 8])        |           23970         |             530            |              28           |             79           |                  67                |           28          
      shape torch.Size([64, 8, 8])        |           47900         |            1040            |              28           |             79           |                  67                |           28          
      shape torch.Size([128, 8, 8])       |           96000         |            2048            |              28           |             78           |                  67                |           28          
      shape torch.Size([512, 8, 8])       |          383700         |            8100            |              28           |             80           |                  67                |           28          
      shape torch.Size([1024, 8, 8])      |          766200         |           16300            |              28           |             80           |                  68                |           28          
      shape torch.Size([1, 16, 16])       |             760         |              33            |              28           |             34           |                  60                |           28          
      shape torch.Size([2, 16, 16])       |            1500         |              50            |              28           |             50           |                  67                |           28          
      shape torch.Size([4, 16, 16])       |            3001         |              81            |              28           |             82           |                  67                |           28          
      shape torch.Size([8, 16, 16])       |            6000         |             145            |              28           |            140           |                  67                |           28          
      shape torch.Size([16, 16, 16])      |           12000         |             276            |              28           |             79           |                  67                |           28          
      shape torch.Size([32, 16, 16])      |           23870         |             549            |              28           |             79           |                  67                |           28          
      shape torch.Size([64, 16, 16])      |           47900         |            1098            |              29           |             80           |                  68                |           28          
      shape torch.Size([128, 16, 16])     |           95800         |            2184            |              28           |             79           |                  68                |           28          
      shape torch.Size([512, 16, 16])     |          386900         |            8769            |              28           |             80           |                  67                |           28          
      shape torch.Size([1024, 16, 16])    |          769800         |           17460            |              37           |             80           |                  67                |           37          
      shape torch.Size([1, 32, 32])       |             760         |              33            |              28           |             34           |                  60                |           28          
      shape torch.Size([2, 32, 32])       |            1500         |              50            |              28           |             50           |                  67                |           29          
      shape torch.Size([4, 32, 32])       |            3021         |              86            |              31           |             84           |                  67                |           32          
      shape torch.Size([8, 32, 32])       |            6040         |             167            |              32           |            167           |                  67                |           32          
      shape torch.Size([16, 32, 32])      |           12100         |             330            |              33           |             78           |                  67                |           33          
      shape torch.Size([32, 32, 32])      |           24150         |             662            |              35           |             78           |                  66                |           35          
      shape torch.Size([64, 32, 32])      |           48200         |            1323            |              36           |             79           |                  67                |           36          
      shape torch.Size([128, 32, 32])     |           97000         |            2637            |              44           |             79           |                  67                |           43          
      shape torch.Size([512, 32, 32])     |          382500         |           10580            |              83           |            180           |                 121                |           83          
      shape torch.Size([1024, 32, 32])    |          766600         |           22670            |             123           |            260           |                 165                |          120        
      shape torch.Size([1, 64, 64])       |             760         |              33            |              58           |             33           |                  49                |           34       
      shape torch.Size([2, 64, 64])       |            1520         |              60            |              60           |             60           |                  67                |           59       
      shape torch.Size([4, 64, 64])       |            3016         |             115            |              67           |            119           |                  90                |           66       
      shape torch.Size([8, 64, 64])       |            6120         |             230            |              67           |            233           |                 153                |           68       
      shape torch.Size([16, 64, 64])      |           12100         |             457            |              69           |             86           |                  77                |           69       
      shape torch.Size([32, 64, 64])      |           24000         |             912            |              74           |             95           |                  80                |           74       
      shape torch.Size([64, 64, 64])      |           48000         |            1833            |              76           |            106           |                  82                |           76       
      shape torch.Size([128, 64, 64])     |           95000         |            3636            |              97           |            163           |                 108                |           97       
      shape torch.Size([512, 64, 64])     |          380600         |           15600            |             210           |            464           |                 278                |          210       
      shape torch.Size([1024, 64, 64])    |          761200         |           31140            |             308           |            741           |                 377                |          308       
      shape torch.Size([1, 128, 128])     |             756         |              46            |             120           |             47           |                  49                |           46       
      shape torch.Size([2, 128, 128])     |            1500         |              91            |             123           |             89           |                  67                |           92       
      shape torch.Size([4, 128, 128])     |            2994         |             178            |             139           |            180           |                 117                |          139       
      shape torch.Size([8, 128, 128])     |            5960         |             350            |             140           |            354           |                 214                |          142       
      shape torch.Size([16, 128, 128])    |           12000         |             701            |             144           |            177           |                 148                |          143       
      shape torch.Size([32, 128, 128])    |           23870         |            1401            |             155           |            225           |                 159                |          155       
      shape torch.Size([64, 128, 128])    |           47600         |            2948            |             179           |            288           |                 184                |          180       
      shape torch.Size([128, 128, 128])   |           96000         |            5910            |             231           |            442           |                 242                |          231       
      shape torch.Size([512, 128, 128])   |          381200         |           23640            |             519           |           1400           |                 643                |          519       
      shape torch.Size([1024, 128, 128])  |          755800         |           47340            |             794           |           2436           |                 930                |          794       
      shape torch.Size([1, 256, 256])     |             760         |              74            |             246           |             77           |                  62                |           78       
      shape torch.Size([2, 256, 256])     |            1510         |             150            |             256           |            150           |                 106                |          150       
      shape torch.Size([4, 256, 256])     |            3030         |             296            |             284           |            296           |                 183                |          284       
      shape torch.Size([8, 256, 256])     |            6100         |             588            |             286           |            592           |                 337                |          288       
      shape torch.Size([16, 256, 256])    |           12200         |            1238            |             330           |            445           |                 330                |          330       
      shape torch.Size([32, 256, 256])    |           24430         |            2476            |             368           |            568           |                 365                |          367       
      shape torch.Size([64, 256, 256])    |           49000         |            4950            |             415           |            800           |                 408                |          414       
      shape torch.Size([128, 256, 256])   |           96000         |            9900            |             552           |           1330           |                 561                |          553       
      shape torch.Size([512, 256, 256])   |          369400         |           39580            |            1410           |           4614           |                1649                |         1410       
      shape torch.Size([1024, 256, 256])  |          716200         |           79200            |            2270           |           8472           |                2528                |         2277      

Times are in microseconds (us).
```

</details>

To generate the results below, I put the backend I wanted to test at the beginning of the function `lu_solve_kernel`, followed by a `return;`. Then I run the following script, changing the variable `name`.
<details>
<summary>
Benchmarking script
</summary>

```python
import torch
import pickle
import itertools
from functools import partial
from torch.utils.benchmark import Timer, Compare

name = "heuristic"
label = "lu_solve {}".format(name)
shapes = [1, 2, 8, 16, 32, 64, 128, 256]
batches = [(1,), (2,), (4,), (8,), (16,), (32,), (64,), (128,), (512,), (1024,)]
results = []
make_arg = partial(torch.randn, dtype=torch.float32, device="cuda")

for n, batch in itertools.product(shapes, batches):
    LU, pivots = torch.linalg.lu_factor(make_arg(batch + (n, n)))
    B = make_arg(batch + (n, 1))
    print(LU.shape)
    stmt = "torch.linalg.lu_solve(LU, pivots, B, adjoint=adjoint)"
    for adjoint in (True, False):
        timer = Timer(stmt,
                      globals=globals(),
                      label="linalg.lu_solve CUDA{}".format(" Adjoint" if adjoint else ""),
                      description=label,
                      sub_label=f"shape {LU.shape}",
                      num_threads=1)
        results.append(timer.blocked_autorange())


compare = Compare(results)
compare.trim_significant_figures()
compare.print()

with open("{}_lu_solve.pickle".format(name), 'wb') as f:
    pickle.dump(results, f)
```
</details>

Finally, I joined all the results with the following script:
<details>
<summary>
Script to join the results
</summary>

```python
import pickle
from torch.utils.benchmark import Timer, Compare

files = [
         "looped_magma",
         "looped cusolver",
         "batched cublas",
         "batched magma",
         "unpack+solve_triangular",
         "heuristic",
        ]

timers = []
for name in files:
    with open("{}_lu_solve.pickle".format(name), 'rb') as f:
        timers += pickle.load(f)

compare = Compare(timers)
compare.trim_significant_figures()
compare.print()
```
</details>

### Fix for Magma's batched lu_solve when `adjoint=True`
I also developed the following fix around MAGMA's bug, but I ended up not using it, and preferring the triangular solves over it, as they were faster. I'm leaving it here in case it's useful in the future.

<details>
<summary>
Fix for MAGMA's issue with `adjoint=True`
</summary>

```cpp
auto lu_solve_batched_magma_fn = [m](const Tensor& LU, const Tensor& pivots, const Tensor& B, TransposeType trans) {
    if (trans == TransposeType::NoTranspose) {
      lu_solve_batched_magma(LU, pivots, B, trans);
      return;
    }
    // There's a bug in magma for the other cases, so we need to properly perform mT or mH on LU
    // The LU of the transpose is not the transpose of the LU
    // We need to do LU = LDU' = L'U' where L' = LD, U' = D^{-1}U and D = diag(U)
    auto diag = LU.diagonal(0, -2, -1);
    auto LU_f = LU.tril(-1).mul_(diag.unsqueeze(-2)) +
                LU.triu(1).div_(diag.unsqueeze(-1));
    LU_f.diagonal(0, -2, -1).copy_(diag);

    if (trans == TransposeType::ConjTranspose) {
      LU_f = LU_f.conj_physical();
    }
    LU_f.transpose(-2, -1);
    // At this point LU_f is F-contiguous, because triu / tril / conj_phisical return contiguous tensors
    TORCH_INTERNAL_ASSERT_DEBUG_ONLY(LU_f.mT().is_contiguous());

    // Trivial permutation
    auto pivots_aux = at::arange(1, m + 1, pivots.options()).expand_as(pivots).contiguous();

    lu_solve_batched_magma(LU_f, pivots_aux, B, TransposeType::NoTranspose);

    // We then need to multiply B by P on the right as (PLU)^T = B iff U^TL^T = BP
    // Fill `perm` with the identity permutation (perhaps batched)
    // This is faster than torch.lu_unpack + matmul, as this logic is borrowed from lu_unpack
    const auto perm = at::arange(m, pivots.options().dtype(kLong)).expand(pivots.sizes()).contiguous();
    auto iter = TensorIteratorConfig()
      .set_check_mem_overlap(false)
      .check_all_same_dtype(false)
      .resize_outputs(false)
      .declare_static_shape(pivots.sizes(), /*squash_dim=*/pivots.dim() - 1)
      .add_output(perm)
      .add_input(pivots)
      .build();
    unpack_pivots_stub(pivots.device().type(), iter, m);
    B.scatter_(-2, perm.unsqueeze(-1).expand_as(B), B.clone());
};
```

</details>



Fixes #61657

[ghstack-poisoned]
@lezcano
Copy link
Copy Markdown
Collaborator Author

lezcano commented May 5, 2022

@nikitaved added the case you requested in #74046 (comment) to this PR as it belongs here.

Comment thread aten/src/ATen/native/cuda/linalg/BatchLinearAlgebra.cpp
Comment thread aten/src/ATen/native/BatchLinearAlgebraKernel.cpp
Comment thread aten/src/ATen/native/BatchLinearAlgebra.cpp
This PR adds `linalg.lu_solve`. While doing so, I found a bug in MAGMA
when calling the batched MAGMA backend with trans=True. We work around
that by solving the system solving two triangular systems.

We also update the heuristics for this function, as they were fairly
outdated. We found that cuSolver is king, so luckily we do not need to
rely on the buggy backend from magma for this function.

We added tests testing this function left and right. We also added tests
for the different backends. We also activated the tests for AMD, as
those should work as well.

### Benchmarking

<details>
<summary>
Benchmark Results (adjoint=False)
</summary>

```
--------------------------------------------------------------------------------------------- linalg.lu_solve CUDA ---------------------------------------------------------------------------------------------]   
                                          |  lu_solve looped_magma  |  lu_solve looped cusolver  |  lu_solve batched cublas  |  lu_solve batched magma  |  lu_solve unpack+solve_triangular  |  lu_solve heuristic   
1 threads: -------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------   
      shape torch.Size([1, 1, 1])         |             750         |              34            |              28           |            252           |                  78                |           27          
      shape torch.Size([2, 1, 1])         |            1500         |              50            |              28           |            239           |                  85                |           27          
      shape torch.Size([4, 1, 1])         |            2995         |              83            |              28           |            239           |                  85                |           27          
      shape torch.Size([8, 1, 1])         |            6000         |             146            |              28           |            239           |                  85                |           27          
      shape torch.Size([16, 1, 1])        |           11900         |             272            |              28           |            241           |                  85                |           27          
      shape torch.Size([32, 1, 1])        |           23880         |             524            |              28           |            244           |                  85                |           27          
      shape torch.Size([64, 1, 1])        |           48000         |            1000            |              28           |            245           |                  85                |           27          
      shape torch.Size([128, 1, 1])       |           96000         |            2054            |              28           |            242           |                  86                |           27          
      shape torch.Size([512, 1, 1])       |          381900         |            8100            |              28           |            250           |                  85                |           27          
      shape torch.Size([1024, 1, 1])      |          763800         |           16200            |              28           |            257           |                  86                |           27          
      shape torch.Size([1, 2, 2])         |             750         |              33            |              28           |            240           |                  82                |           27          
      shape torch.Size([2, 2, 2])         |            1500         |              51            |              28           |            240           |                  88                |           27          
      shape torch.Size([4, 2, 2])         |            2991         |              82            |              28           |            241           |                  88                |           28          
      shape torch.Size([8, 2, 2])         |            6000         |             150            |              28           |            241           |                  88                |           27          
      shape torch.Size([16, 2, 2])        |           12000         |             275            |              28           |            242           |                  88                |           27          
      shape torch.Size([32, 2, 2])        |           23980         |             530            |              28           |            246           |                  90                |           28          
      shape torch.Size([64, 2, 2])        |           48000         |            1000            |              28           |            244           |                  89                |           27          
      shape torch.Size([128, 2, 2])       |           96000         |            2063            |              28           |            245           |                  89                |           28          
      shape torch.Size([512, 2, 2])       |          382000         |            8300            |              28           |            257           |                  89                |           28          
      shape torch.Size([1024, 2, 2])      |          764000         |           20000            |              28           |            271           |                  88                |           28          
      shape torch.Size([1, 8, 8])         |             749         |              34            |              28           |            243           |                  82                |           28          
      shape torch.Size([2, 8, 8])         |            1500         |              50            |              28           |            244           |                  89                |           28          
      shape torch.Size([4, 8, 8])         |            2988         |              83            |              28           |            244           |                  89                |           28          
      shape torch.Size([8, 8, 8])         |            5980         |             150            |              28           |            245           |                  88                |           28          
      shape torch.Size([16, 8, 8])        |           12000         |             278            |              28           |            246           |                  89                |           28          
      shape torch.Size([32, 8, 8])        |           23910         |             536            |              28           |            249           |                  89                |           28          
      shape torch.Size([64, 8, 8])        |           47800         |            1100            |              28           |            247           |                  89                |           28          
      shape torch.Size([128, 8, 8])       |           96000         |            2075            |              28           |            248           |                  89                |           28          
      shape torch.Size([512, 8, 8])       |          382100         |            8300            |              28           |            270           |                  89                |           28          
      shape torch.Size([1024, 8, 8])      |          764100         |           16400            |              28           |            291           |                  89                |           28          
      shape torch.Size([1, 16, 16])       |             750         |              33            |              28           |            248           |                  82                |           28       
      shape torch.Size([2, 16, 16])       |            1500         |              50            |              28           |            250           |                  89                |           28       
      shape torch.Size([4, 16, 16])       |            2996         |              83            |              28           |            250           |                  89                |           28       
      shape torch.Size([8, 16, 16])       |            5980         |             147            |              28           |            251           |                  90                |           28       
      shape torch.Size([16, 16, 16])      |           11900         |             274            |              28           |            251           |                  89                |           28       
      shape torch.Size([32, 16, 16])      |           24040         |             527            |              28           |            252           |                  89                |           28       
      shape torch.Size([64, 16, 16])      |           47800         |            1037            |              28           |            251           |                  89                |           28       
      shape torch.Size([128, 16, 16])     |           95600         |            2044            |              28           |            252           |                  89                |           28       
      shape torch.Size([512, 16, 16])     |          388200         |            8100            |              28           |            280           |                  88                |           28       
      shape torch.Size([1024, 16, 16])    |          769700         |           16000            |              28           |            322           |                  89                |           28       
      shape torch.Size([1, 32, 32])       |             760         |              33            |              28           |            255           |                  82                |           28       
      shape torch.Size([2, 32, 32])       |            1510         |              50            |              28           |            256           |                  89                |           28       
      shape torch.Size([4, 32, 32])       |            3022         |              82            |              31           |            256           |                  89                |           30       
      shape torch.Size([8, 32, 32])       |            6000         |             140            |              31           |            257           |                  89                |           31       
      shape torch.Size([16, 32, 32])      |           12000         |             281            |              31           |            258           |                  89                |           31       
      shape torch.Size([32, 32, 32])      |           24150         |             563            |              35           |            258           |                  89                |           35       
      shape torch.Size([64, 32, 32])      |           48300         |            1119            |              36           |            258           |                  90                |           36       
      shape torch.Size([128, 32, 32])     |           96500         |            2235            |              43           |            261           |                  88                |           43       
      shape torch.Size([512, 32, 32])     |          383100         |            8930            |              82           |            317           |                 137                |           82       
      shape torch.Size([1024, 32, 32])    |          766300         |           19200            |             122           |            400           |                 187                |          122       
      shape torch.Size([1, 64, 64])       |             760         |              33            |              55           |            272           |                  71                |           34       
      shape torch.Size([2, 64, 64])       |            1500         |              52            |              58           |            273           |                  90                |           52       
      shape torch.Size([4, 64, 64])       |            3127         |             102            |              65           |            273           |                 110                |           65       
      shape torch.Size([8, 64, 64])       |            6070         |             201            |              65           |            275           |                 160                |           65       
      shape torch.Size([16, 64, 64])      |           12000         |             399            |              66           |            274           |                  88                |           67       
      shape torch.Size([32, 64, 64])      |           23900         |             796            |              73           |            275           |                  91                |           73       
      shape torch.Size([64, 64, 64])      |           48000         |            1594            |              75           |            283           |                  94                |           76       
      shape torch.Size([128, 64, 64])     |           95000         |            3177            |              96           |            292           |                 122                |           96       
      shape torch.Size([512, 64, 64])     |          379300         |           13520            |             208           |            426           |                 313                |          208       
      shape torch.Size([1024, 64, 64])    |          758700         |           27100            |             306           |            570           |                 437                |          306       
      shape torch.Size([1, 128, 128])     |             750         |              42            |             115           |            306           |                  71                |           42       
      shape torch.Size([2, 128, 128])     |            1500         |              82            |             122           |            307           |                  90                |           83       
      shape torch.Size([4, 128, 128])     |            2966         |             162            |             136           |            307           |                 117                |          136       
      shape torch.Size([8, 128, 128])     |            5930         |             317            |             137           |            308           |                 202                |          138       
      shape torch.Size([16, 128, 128])    |           12000         |             635            |             143           |            316           |                 162                |          143       
      shape torch.Size([32, 128, 128])    |           23700         |            1266            |             152           |            322           |                 168                |          152       
      shape torch.Size([64, 128, 128])    |           48000         |            2668            |             177           |            337           |                 196                |          177       
      shape torch.Size([128, 128, 128])   |           96000         |            5366            |             228           |            365           |                 260                |          228       
      shape torch.Size([512, 128, 128])   |          379400         |           21490            |             502           |            620           |                 669                |          502       
      shape torch.Size([1024, 128, 128])  |          755700         |           43040            |             764           |            903           |                 978                |          770       
      shape torch.Size([1, 256, 256])     |             750         |              70            |             235           |            383           |                  78                |           72       
      shape torch.Size([2, 256, 256])     |            2000         |             138            |             250           |            384           |                 120                |          139       
      shape torch.Size([4, 256, 256])     |            2988         |             277            |             279           |            404           |                 191                |          278       
      shape torch.Size([8, 256, 256])     |            6100         |             546            |             283           |            420           |                 332                |          286       
      shape torch.Size([16, 256, 256])    |           12100         |            1149            |             330           |            441           |                 364                |          330       
      shape torch.Size([32, 256, 256])    |           24040         |            2303            |             359           |            453           |                 401                |          360       
      shape torch.Size([64, 256, 256])    |           48000         |            4626            |             408           |            472           |                 466                |          408       
      shape torch.Size([128, 256, 256])   |           94700         |            9247            |             543           |            543           |                 658                |          543       
      shape torch.Size([512, 256, 256])   |          372000         |           37030            |            1310           |           1185           |                1896                |         1310       
      shape torch.Size([1024, 256, 256])  |          747200         |           74100            |            2116           |           1910           |                3051                |         2122 
```
</details>

<details>
<summary>
Benchmark Results (adjoint=True)
</summary>

```
[----------------------------------------------------------------------------------------- linalg.lu_solve CUDA Adjoint -----------------------------------------------------------------------------------------]   
                                          |  lu_solve looped_magma  |  lu_solve looped cusolver  |  lu_solve batched cublas  |  lu_solve batched magma  |  lu_solve unpack+solve_triangular  |  lu_solve heuristic   
1 threads: -------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------   
      shape torch.Size([1, 1, 1])         |             749         |              34            |              28           |             33           |                  60                |           27          
      shape torch.Size([2, 1, 1])         |            1500         |              50            |              28           |             50           |                  67                |           27          
      shape torch.Size([4, 1, 1])         |            3005         |              82            |              28           |             81           |                  67                |           27          
      shape torch.Size([8, 1, 1])         |            5999         |             145            |              28           |            140           |                  71                |           27          
      shape torch.Size([16, 1, 1])        |           12000         |             273            |              28           |             77           |                  67                |           27          
      shape torch.Size([32, 1, 1])        |           24000         |             522            |              28           |             78           |                  67                |           27          
      shape torch.Size([64, 1, 1])        |           48000         |            1000            |              28           |             77           |                  67                |           27          
      shape torch.Size([128, 1, 1])       |           96000         |            2029            |              28           |             78           |                  67                |           27          
      shape torch.Size([512, 1, 1])       |          383300         |            8100            |              28           |             78           |                  67                |           28          
      shape torch.Size([1024, 1, 1])      |          767500         |           16100            |              28           |             77           |                  68                |           27          
      shape torch.Size([1, 2, 2])         |             753         |              33            |              28           |             33           |                  60                |           28          
      shape torch.Size([2, 2, 2])         |            1500         |              50            |              28           |             50           |                  67                |           28          
      shape torch.Size([4, 2, 2])         |            3002         |              82            |              28           |             80           |                  67                |           27          
      shape torch.Size([8, 2, 2])         |            6000         |             145            |              28           |            144           |                  67                |           27          
      shape torch.Size([16, 2, 2])        |           12000         |             271            |              28           |             78           |                  66                |           27          
      shape torch.Size([32, 2, 2])        |           24120         |             524            |              28           |             78           |                  69                |           28          
      shape torch.Size([64, 2, 2])        |           48300         |            1030            |              28           |             78           |                  66                |           27          
      shape torch.Size([128, 2, 2])       |           96100         |            2041            |              28           |             78           |                  67                |           28          
      shape torch.Size([512, 2, 2])       |          383000         |            8100            |              28           |             79           |                  67                |           28          
      shape torch.Size([1024, 2, 2])      |          766100         |           16000            |              28           |             78           |                  67                |           28          
      shape torch.Size([1, 8, 8])         |             750         |              34            |              28           |             34           |                  60                |           28          
      shape torch.Size([2, 8, 8])         |            1500         |              50            |              28           |             50           |                  67                |           28          
      shape torch.Size([4, 8, 8])         |            2998         |              82            |              28           |             82           |                  67                |           28          
      shape torch.Size([8, 8, 8])         |            5990         |             146            |              28           |            150           |                  66                |           28          
      shape torch.Size([16, 8, 8])        |           11980         |             272            |              28           |             79           |                  67                |           28          
      shape torch.Size([32, 8, 8])        |           23970         |             530            |              28           |             79           |                  67                |           28          
      shape torch.Size([64, 8, 8])        |           47900         |            1040            |              28           |             79           |                  67                |           28          
      shape torch.Size([128, 8, 8])       |           96000         |            2048            |              28           |             78           |                  67                |           28          
      shape torch.Size([512, 8, 8])       |          383700         |            8100            |              28           |             80           |                  67                |           28          
      shape torch.Size([1024, 8, 8])      |          766200         |           16300            |              28           |             80           |                  68                |           28          
      shape torch.Size([1, 16, 16])       |             760         |              33            |              28           |             34           |                  60                |           28          
      shape torch.Size([2, 16, 16])       |            1500         |              50            |              28           |             50           |                  67                |           28          
      shape torch.Size([4, 16, 16])       |            3001         |              81            |              28           |             82           |                  67                |           28          
      shape torch.Size([8, 16, 16])       |            6000         |             145            |              28           |            140           |                  67                |           28          
      shape torch.Size([16, 16, 16])      |           12000         |             276            |              28           |             79           |                  67                |           28          
      shape torch.Size([32, 16, 16])      |           23870         |             549            |              28           |             79           |                  67                |           28          
      shape torch.Size([64, 16, 16])      |           47900         |            1098            |              29           |             80           |                  68                |           28          
      shape torch.Size([128, 16, 16])     |           95800         |            2184            |              28           |             79           |                  68                |           28          
      shape torch.Size([512, 16, 16])     |          386900         |            8769            |              28           |             80           |                  67                |           28          
      shape torch.Size([1024, 16, 16])    |          769800         |           17460            |              37           |             80           |                  67                |           37          
      shape torch.Size([1, 32, 32])       |             760         |              33            |              28           |             34           |                  60                |           28          
      shape torch.Size([2, 32, 32])       |            1500         |              50            |              28           |             50           |                  67                |           29          
      shape torch.Size([4, 32, 32])       |            3021         |              86            |              31           |             84           |                  67                |           32          
      shape torch.Size([8, 32, 32])       |            6040         |             167            |              32           |            167           |                  67                |           32          
      shape torch.Size([16, 32, 32])      |           12100         |             330            |              33           |             78           |                  67                |           33          
      shape torch.Size([32, 32, 32])      |           24150         |             662            |              35           |             78           |                  66                |           35          
      shape torch.Size([64, 32, 32])      |           48200         |            1323            |              36           |             79           |                  67                |           36          
      shape torch.Size([128, 32, 32])     |           97000         |            2637            |              44           |             79           |                  67                |           43          
      shape torch.Size([512, 32, 32])     |          382500         |           10580            |              83           |            180           |                 121                |           83          
      shape torch.Size([1024, 32, 32])    |          766600         |           22670            |             123           |            260           |                 165                |          120        
      shape torch.Size([1, 64, 64])       |             760         |              33            |              58           |             33           |                  49                |           34       
      shape torch.Size([2, 64, 64])       |            1520         |              60            |              60           |             60           |                  67                |           59       
      shape torch.Size([4, 64, 64])       |            3016         |             115            |              67           |            119           |                  90                |           66       
      shape torch.Size([8, 64, 64])       |            6120         |             230            |              67           |            233           |                 153                |           68       
      shape torch.Size([16, 64, 64])      |           12100         |             457            |              69           |             86           |                  77                |           69       
      shape torch.Size([32, 64, 64])      |           24000         |             912            |              74           |             95           |                  80                |           74       
      shape torch.Size([64, 64, 64])      |           48000         |            1833            |              76           |            106           |                  82                |           76       
      shape torch.Size([128, 64, 64])     |           95000         |            3636            |              97           |            163           |                 108                |           97       
      shape torch.Size([512, 64, 64])     |          380600         |           15600            |             210           |            464           |                 278                |          210       
      shape torch.Size([1024, 64, 64])    |          761200         |           31140            |             308           |            741           |                 377                |          308       
      shape torch.Size([1, 128, 128])     |             756         |              46            |             120           |             47           |                  49                |           46       
      shape torch.Size([2, 128, 128])     |            1500         |              91            |             123           |             89           |                  67                |           92       
      shape torch.Size([4, 128, 128])     |            2994         |             178            |             139           |            180           |                 117                |          139       
      shape torch.Size([8, 128, 128])     |            5960         |             350            |             140           |            354           |                 214                |          142       
      shape torch.Size([16, 128, 128])    |           12000         |             701            |             144           |            177           |                 148                |          143       
      shape torch.Size([32, 128, 128])    |           23870         |            1401            |             155           |            225           |                 159                |          155       
      shape torch.Size([64, 128, 128])    |           47600         |            2948            |             179           |            288           |                 184                |          180       
      shape torch.Size([128, 128, 128])   |           96000         |            5910            |             231           |            442           |                 242                |          231       
      shape torch.Size([512, 128, 128])   |          381200         |           23640            |             519           |           1400           |                 643                |          519       
      shape torch.Size([1024, 128, 128])  |          755800         |           47340            |             794           |           2436           |                 930                |          794       
      shape torch.Size([1, 256, 256])     |             760         |              74            |             246           |             77           |                  62                |           78       
      shape torch.Size([2, 256, 256])     |            1510         |             150            |             256           |            150           |                 106                |          150       
      shape torch.Size([4, 256, 256])     |            3030         |             296            |             284           |            296           |                 183                |          284       
      shape torch.Size([8, 256, 256])     |            6100         |             588            |             286           |            592           |                 337                |          288       
      shape torch.Size([16, 256, 256])    |           12200         |            1238            |             330           |            445           |                 330                |          330       
      shape torch.Size([32, 256, 256])    |           24430         |            2476            |             368           |            568           |                 365                |          367       
      shape torch.Size([64, 256, 256])    |           49000         |            4950            |             415           |            800           |                 408                |          414       
      shape torch.Size([128, 256, 256])   |           96000         |            9900            |             552           |           1330           |                 561                |          553       
      shape torch.Size([512, 256, 256])   |          369400         |           39580            |            1410           |           4614           |                1649                |         1410       
      shape torch.Size([1024, 256, 256])  |          716200         |           79200            |            2270           |           8472           |                2528                |         2277      

Times are in microseconds (us).
```

</details>

To generate the results below, I put the backend I wanted to test at the beginning of the function `lu_solve_kernel`, followed by a `return;`. Then I run the following script, changing the variable `name`.
<details>
<summary>
Benchmarking script
</summary>

```python
import torch
import pickle
import itertools
from functools import partial
from torch.utils.benchmark import Timer, Compare

name = "heuristic"
label = "lu_solve {}".format(name)
shapes = [1, 2, 8, 16, 32, 64, 128, 256]
batches = [(1,), (2,), (4,), (8,), (16,), (32,), (64,), (128,), (512,), (1024,)]
results = []
make_arg = partial(torch.randn, dtype=torch.float32, device="cuda")

for n, batch in itertools.product(shapes, batches):
    LU, pivots = torch.linalg.lu_factor(make_arg(batch + (n, n)))
    B = make_arg(batch + (n, 1))
    print(LU.shape)
    stmt = "torch.linalg.lu_solve(LU, pivots, B, adjoint=adjoint)"
    for adjoint in (True, False):
        timer = Timer(stmt,
                      globals=globals(),
                      label="linalg.lu_solve CUDA{}".format(" Adjoint" if adjoint else ""),
                      description=label,
                      sub_label=f"shape {LU.shape}",
                      num_threads=1)
        results.append(timer.blocked_autorange())


compare = Compare(results)
compare.trim_significant_figures()
compare.print()

with open("{}_lu_solve.pickle".format(name), 'wb') as f:
    pickle.dump(results, f)
```
</details>

Finally, I joined all the results with the following script:
<details>
<summary>
Script to join the results
</summary>

```python
import pickle
from torch.utils.benchmark import Timer, Compare

files = [
         "looped_magma",
         "looped cusolver",
         "batched cublas",
         "batched magma",
         "unpack+solve_triangular",
         "heuristic",
        ]

timers = []
for name in files:
    with open("{}_lu_solve.pickle".format(name), 'rb') as f:
        timers += pickle.load(f)

compare = Compare(timers)
compare.trim_significant_figures()
compare.print()
```
</details>

### Fix for Magma's batched lu_solve when `adjoint=True`
I also developed the following fix around MAGMA's bug, but I ended up not using it, and preferring the triangular solves over it, as they were faster. I'm leaving it here in case it's useful in the future.

<details>
<summary>
Fix for MAGMA's issue with `adjoint=True`
</summary>

```cpp
auto lu_solve_batched_magma_fn = [m](const Tensor& LU, const Tensor& pivots, const Tensor& B, TransposeType trans) {
    if (trans == TransposeType::NoTranspose) {
      lu_solve_batched_magma(LU, pivots, B, trans);
      return;
    }
    // There's a bug in magma for the other cases, so we need to properly perform mT or mH on LU
    // The LU of the transpose is not the transpose of the LU
    // We need to do LU = LDU' = L'U' where L' = LD, U' = D^{-1}U and D = diag(U)
    auto diag = LU.diagonal(0, -2, -1);
    auto LU_f = LU.tril(-1).mul_(diag.unsqueeze(-2)) +
                LU.triu(1).div_(diag.unsqueeze(-1));
    LU_f.diagonal(0, -2, -1).copy_(diag);

    if (trans == TransposeType::ConjTranspose) {
      LU_f = LU_f.conj_physical();
    }
    LU_f.transpose(-2, -1);
    // At this point LU_f is F-contiguous, because triu / tril / conj_phisical return contiguous tensors
    TORCH_INTERNAL_ASSERT_DEBUG_ONLY(LU_f.mT().is_contiguous());

    // Trivial permutation
    auto pivots_aux = at::arange(1, m + 1, pivots.options()).expand_as(pivots).contiguous();

    lu_solve_batched_magma(LU_f, pivots_aux, B, TransposeType::NoTranspose);

    // We then need to multiply B by P on the right as (PLU)^T = B iff U^TL^T = BP
    // Fill `perm` with the identity permutation (perhaps batched)
    // This is faster than torch.lu_unpack + matmul, as this logic is borrowed from lu_unpack
    const auto perm = at::arange(m, pivots.options().dtype(kLong)).expand(pivots.sizes()).contiguous();
    auto iter = TensorIteratorConfig()
      .set_check_mem_overlap(false)
      .check_all_same_dtype(false)
      .resize_outputs(false)
      .declare_static_shape(pivots.sizes(), /*squash_dim=*/pivots.dim() - 1)
      .add_output(perm)
      .add_input(pivots)
      .build();
    unpack_pivots_stub(pivots.device().type(), iter, m);
    B.scatter_(-2, perm.unsqueeze(-1).expand_as(B), B.clone());
};
```

</details>



Fixes #61657

[ghstack-poisoned]
Copy link
Copy Markdown
Collaborator

@mruberry mruberry left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Wahoo! Stamped

@lezcano
Copy link
Copy Markdown
Collaborator Author

lezcano commented May 5, 2022

@pytorchmergebot merge this please

@malfet
Copy link
Copy Markdown
Contributor

malfet commented May 9, 2022

Going to revert this change and bunch of changes depending on it, as it breaks internal build system with the following linker error:

ld.lld: error: undefined symbol: at::native::DispatchStubImpl::get_call_ptr(c10::DeviceType, void*)
>>> referenced by DispatchStub.h:135 (caffe2/aten/src/ATen/native/DispatchStub.h:135)
>>>               ../ATen-cu#compile-BatchLinearAlgebra.cpp.od30e6a7f,platform010-clang/src/ATen/native/cuda/linalg/BatchLinearAlgebra.cpp.o:(at::native::lu_solve_kernel(at::Tensor const&, at::Tensor const&, at::Tensor const&, at::native::TransposeType)) in archive /data/users/nshulga/fbsource/fbcode/buck-out/opt/gen/caffe2/aten/ATen-cu#platform010-clang,static/libATen-cu.a
clang-12: error: linker command failed with exit code 1 (use -v to see invocation)

@lezcano
Copy link
Copy Markdown
Collaborator Author

lezcano commented May 9, 2022

What can be done about this @malfet ?

@malfet
Copy link
Copy Markdown
Contributor

malfet commented May 9, 2022

@lezcano I'm going to revert as I could not figure out the forward fix correctly, but I assume the problem stems from the fact that unpack_pivot_stubs could not be called from native/cuda/linalg/BatchLinearAlgebra.cpp
One solution is to move LinearAlgebra.cu into linalg/ folder and call unpack_pivots_cuda directly
But I also not sure why the same error did not manifest in OSS build (as it should have failed the same way)

@malfet
Copy link
Copy Markdown
Contributor

malfet commented May 9, 2022

@pytorchbot revert this, as it breaks internal builds with undefined symbol: at::native::DispatchStubImpl::get_call_ptr(c10::DeviceType, void*)

It should have affected OSS build as well, investigating why it is not the case

@malfet malfet mentioned this pull request May 9, 2022
@lezcano
Copy link
Copy Markdown
Collaborator Author

lezcano commented May 13, 2022

@malfet any news on this? It'd be nice to merge this stack (which is mostly approved already) before the branch cut, as it gives a x2.5-x10 speed-up to linalg.solve

@malfet
Copy link
Copy Markdown
Contributor

malfet commented May 16, 2022

@lezcano let me try to import this change and see if it still have the same issue

@malfet
Copy link
Copy Markdown
Contributor

malfet commented May 17, 2022

@lezcano this PR can not be imported as gh/Lezcan/47/orig branch no longer exists (and as such it wouldn't be mergeable either) Do you mind re-exporting this stack

@lezcano lezcano mentioned this pull request May 17, 2022
@lezcano
Copy link
Copy Markdown
Collaborator Author

lezcano commented May 17, 2022

Relanding in #77634

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

cla signed module: linear algebra Issues related to specialized linear algebra operations in PyTorch; includes matrix multiply matmul open source release notes: linalg_frontend release notes category Reverted topic: new features topic category with-ssh

Projects

None yet

Development

Successfully merging this pull request may close these issues.

9 participants