Skip to content

Update and improve the heuristics for linalg.lu_factor#73878

Closed
lezcano wants to merge 43 commits intogh/Lezcano/52/basefrom
gh/Lezcano/52/head
Closed

Update and improve the heuristics for linalg.lu_factor#73878
lezcano wants to merge 43 commits intogh/Lezcano/52/basefrom
gh/Lezcano/52/head

Conversation

@lezcano
Copy link
Copy Markdown
Collaborator

@lezcano lezcano commented Mar 7, 2022

Stack from ghstack:

This PR adds getrf_cublas to the functions considered in the heuristics
for lu_factor. It also updates the heuristics of the function.

Benchmark

I'm omitting form the benchmarks the looped versions of the functions as they are much slower than the non-looped ones. The only exception to this is cusolver's looped variant, which is faster when applied to a batch of size one.

Benchmark Results
[------------------------------------------------- linalg.lu_factor CUDA -------------------------------------------------]                                                                                          
                                          |  lu_factor_heuristic  |  lu_factor_magma_batched  |  lu_factor_cusolver_batched                                                                                          
1 threads: ----------------------------------------------------------------------------------------------------------------                                                                                          
      shape torch.Size([1, 1, 1])         |            26         |              47           |                26                                                                                                    
      shape torch.Size([2, 1, 1])         |            17         |              38           |                17                                                                                                    
      shape torch.Size([4, 1, 1])         |            17         |              38           |                17                                                                                                    
      shape torch.Size([8, 1, 1])         |            20         |              38           |                18                                                                                                    
      shape torch.Size([16, 1, 1])        |            20         |              38           |                17                                                                                                    
      shape torch.Size([32, 1, 1])        |            18         |              38           |                17                                                                                                    
      shape torch.Size([64, 1, 1])        |            18         |              39           |                17                                                                                                    
      shape torch.Size([128, 1, 1])       |            17         |              38           |                17                                                                                                    
      shape torch.Size([512, 1, 1])       |            18         |              39           |                18                                                                                                    
      shape torch.Size([1024, 1, 1])      |            18         |              40           |                18                                                                                                    
      shape torch.Size([1, 2, 2])         |            18         |              38           |                17                                                                                                    
      shape torch.Size([2, 2, 2])         |            17         |              37           |                17                                                                                                    
      shape torch.Size([4, 2, 2])         |            17         |              38           |                17                                                                                                    
      shape torch.Size([8, 2, 2])         |            17         |              38           |                17                                                                                                    
      shape torch.Size([16, 2, 2])        |            17         |              38           |                17                                                                                                    
      shape torch.Size([32, 2, 2])        |            17         |              38           |                17                                                                                                    
      shape torch.Size([64, 2, 2])        |            17         |              38           |                17                                                                                                    
      shape torch.Size([128, 2, 2])       |            17         |              38           |                17                                                                                                    
      shape torch.Size([512, 2, 2])       |            17         |              39           |                17                                                                                                    
      shape torch.Size([1024, 2, 2])      |            17         |              40           |                17                                                                                                    
      shape torch.Size([1, 8, 8])         |            17         |              40           |                17                                                                                                    
      shape torch.Size([2, 8, 8])         |            17         |              40           |                17                                                                                                    
      shape torch.Size([4, 8, 8])         |            17         |              40           |                17                                                                                                    
      shape torch.Size([8, 8, 8])         |            17         |              40           |                17                                                                                                    
      shape torch.Size([16, 8, 8])        |            17         |              41           |                17                                                                                                    
      shape torch.Size([32, 8, 8])        |            17         |              40           |                17                                                                                                    
      shape torch.Size([64, 8, 8])        |            17         |              40           |                17                                                                                                    
      shape torch.Size([128, 8, 8])       |            17         |              40           |                17                                                                                                    
      shape torch.Size([512, 8, 8])       |            17         |              42           |                17                                                                                                    
      shape torch.Size([1024, 8, 8])      |            17         |              44           |                17                                                                                                    
      shape torch.Size([1, 16, 16])       |            24         |              44           |                18                                                                                                    
      shape torch.Size([2, 16, 16])       |            18         |              44           |                18                                                                                                    
      shape torch.Size([4, 16, 16])       |            18         |              45           |                18          
      shape torch.Size([8, 16, 16])       |            19         |              44           |                19          
      shape torch.Size([16, 16, 16])      |            20         |              44           |                20          
      shape torch.Size([32, 16, 16])      |            20         |              45           |                20          
      shape torch.Size([64, 16, 16])      |            20         |              44           |                20          
      shape torch.Size([128, 16, 16])     |            20         |              45           |                20          
      shape torch.Size([512, 16, 16])     |            28         |              50           |                28          
      shape torch.Size([1024, 16, 16])    |            41         |              59           |                41          
      shape torch.Size([1, 32, 32])       |            58         |              50           |                56          
      shape torch.Size([2, 32, 32])       |            56         |              50           |                56          
      shape torch.Size([4, 32, 32])       |            56         |              50           |                57          
      shape torch.Size([8, 32, 32])       |            60         |              50           |                60          
      shape torch.Size([16, 32, 32])      |            60         |              51           |                60          
      shape torch.Size([32, 32, 32])      |           247         |              51           |                61          
      shape torch.Size([64, 32, 32])      |           233         |              51           |                63          
      shape torch.Size([128, 32, 32])     |           236         |              53           |                66          
      shape torch.Size([512, 32, 32])     |           268         |              97           |               193          
      shape torch.Size([1024, 32, 32])    |           317         |             167           |               333          
      shape torch.Size([1, 64, 64])       |           131         |             216           |                99          
      shape torch.Size([2, 64, 64])       |            99         |             220           |                99          
      shape torch.Size([4, 64, 64])       |            99         |             225           |               101          
      shape torch.Size([8, 64, 64])       |           101         |             225           |               102          
      shape torch.Size([16, 64, 64])      |           107         |             230           |               108          
      shape torch.Size([32, 64, 64])      |           440         |             235           |               126          
      shape torch.Size([64, 64, 64])      |           447         |             240           |               155          
      shape torch.Size([128, 64, 64])     |           470         |             289           |               240          
      shape torch.Size([512, 64, 64])     |           793         |             678           |              1180          
      shape torch.Size([1024, 64, 64])    |          1000         |            1300           |              2112          
      shape torch.Size([1, 128, 128])     |           296         |             482           |               309          
      shape torch.Size([2, 128, 128])     |           308         |             499           |               307          
      shape torch.Size([4, 128, 128])     |           311         |             510           |               310          
      shape torch.Size([8, 128, 128])     |           314         |             522           |               314          
      shape torch.Size([16, 128, 128])    |           334         |             541           |               334          
      shape torch.Size([32, 128, 128])    |           770         |             591           |               467          
      shape torch.Size([64, 128, 128])    |           860         |             694           |               733          
      shape torch.Size([128, 128, 128])   |          1040         |             925           |              1980          
      shape torch.Size([512, 128, 128])   |          2883         |            2809           |             11000          
      shape torch.Size([1024, 128, 128])  |          5421         |            5430           |             22360          
      shape torch.Size([1, 256, 256])     |          1310         |            1109           |              1556          
      shape torch.Size([2, 256, 256])     |          1360         |            1150           |              1560          
      shape torch.Size([4, 256, 256])     |          1390         |            1188           |              1569          
      shape torch.Size([8, 256, 256])     |          1440         |            1250           |              1604          
      shape torch.Size([16, 256, 256])    |          1550         |            1390           |              1850          
      shape torch.Size([32, 256, 256])    |          1750         |            1620           |              3332          
      shape torch.Size([64, 256, 256])    |          2327         |            2246           |              6700          
      shape torch.Size([128, 256, 256])   |          3697         |            3638           |             19100          
      shape torch.Size([512, 256, 256])   |         12530         |           12500           |             87300          
      shape torch.Size([1024, 256, 256])  |         24380         |           24420           |            176000          
Benchmark Results all algorithms up to `n=2048`
[----------------------------------------------------------------- linalg.lu_factor CUDA ------------------------------------------------------------------]
                                          |  lu_factor_magma_batched  |  lu_factor_cusolver_batched  |  lu_factor_cusolver_looped  |  lu_factor_magma_looped
1 threads: -------------------------------------------------------------------------------------------------------------------------------------------------
      shape torch.Size([1, 1, 1])         |               51          |                30            |                27           |            1390
      shape torch.Size([2, 1, 1])         |               42          |                20            |                26           |            2798
      shape torch.Size([4, 1, 1])         |               42          |                20            |                42           |            5589
      shape torch.Size([8, 1, 1])         |               42          |                20            |                72           |           11000
      shape torch.Size([16, 1, 1])        |               42          |                20            |               132           |           22400
      shape torch.Size([32, 1, 1])        |               42          |                20            |               253           |           44620
      shape torch.Size([64, 1, 1])        |               42          |                20            |               496           |           89200
      shape torch.Size([128, 1, 1])       |               42          |                20            |               980           |          180000
      shape torch.Size([512, 1, 1])       |               43          |                20            |              3868           |          714100
      shape torch.Size([1024, 1, 1])      |               44          |                20            |              7800           |         1430000
      shape torch.Size([1, 2, 2])         |               43          |                21            |                19           |            1400
      shape torch.Size([2, 2, 2])         |               42          |                21            |                27           |            2898
      shape torch.Size([4, 2, 2])         |               43          |                21            |                42           |            5800
      shape torch.Size([8, 2, 2])         |               43          |                21            |                73           |           11600
      shape torch.Size([16, 2, 2])        |               43          |                21            |               133           |           23170
      shape torch.Size([32, 2, 2])        |               43          |                21            |               254           |           46290
      shape torch.Size([64, 2, 2])        |               43          |                21            |               500           |           94000
      shape torch.Size([128, 2, 2])       |               43          |                21            |               980           |          190000
      shape torch.Size([512, 2, 2])       |               44          |                21            |              3860           |          741900
      shape torch.Size([1024, 2, 2])      |               44          |                21            |              7640           |         1484000
      shape torch.Size([1, 8, 8])         |               45          |                21            |                19           |            1450
      shape torch.Size([2, 8, 8])         |               45          |                21            |                27           |            2917
      shape torch.Size([4, 8, 8])         |               45          |                21            |                53           |            5800
      shape torch.Size([8, 8, 8])         |               45          |                21            |               105           |           11580
      shape torch.Size([16, 8, 8])        |               45          |                21            |               207           |           23160
      shape torch.Size([32, 8, 8])        |               46          |                21            |               413           |           46400
      shape torch.Size([64, 8, 8])        |               46          |                21            |               824           |           93000
      shape torch.Size([128, 8, 8])       |               46          |                21            |              1645           |          185000
      shape torch.Size([512, 8, 8])       |               47          |                21            |              6574           |          742000
      shape torch.Size([1024, 8, 8])      |               49          |                21            |             13150           |         1481000
      shape torch.Size([1, 16, 16])       |               49          |                21            |                24           |            1460
      shape torch.Size([2, 16, 16])       |               49          |                21            |                46           |            2902
      shape torch.Size([4, 16, 16])       |               49          |                21            |                90           |            5800
      shape torch.Size([8, 16, 16])       |               49          |                21            |               177           |           11600
      shape torch.Size([16, 16, 16])      |               49          |                21            |               352           |           23150
      shape torch.Size([32, 16, 16])      |               49          |                21            |               703           |           46300
      shape torch.Size([64, 16, 16])      |               49          |                21            |              1404           |           92700
      shape torch.Size([128, 16, 16])     |               50          |                21            |              2807           |          185000
      shape torch.Size([512, 16, 16])     |               55          |                29            |             11220           |          741700
      shape torch.Size([1024, 16, 16])    |               64          |                42            |             22440           |         1480000
      shape torch.Size([1, 32, 32])       |               55          |                56            |                58           |            1460
      shape torch.Size([2, 32, 32])       |               55          |                57            |               114           |            2920
      shape torch.Size([4, 32, 32])       |               55          |                57            |               225           |            5830
      shape torch.Size([8, 32, 32])       |               55          |                61            |               449           |           11700
      shape torch.Size([16, 32, 32])      |               56          |                61            |               896           |           23300
      shape torch.Size([32, 32, 32])      |               56          |                62            |              1791           |           46600
      shape torch.Size([64, 32, 32])      |               56          |                63            |              3581           |           93100
      shape torch.Size([128, 32, 32])     |               58          |                66            |              7156           |          186000
      shape torch.Size([512, 32, 32])     |              100          |               194            |             28700           |          742400
      shape torch.Size([1024, 32, 32])    |              169          |               335            |             57620           |         1485000
      shape torch.Size([1, 64, 64])       |              224          |               101            |               132           |            1500
      shape torch.Size([2, 64, 64])       |              227          |               100            |               262           |            2951
      shape torch.Size([4, 64, 64])       |              229          |               101            |               523           |            5890
      shape torch.Size([8, 64, 64])       |              231          |               102            |              1040           |           12000
      shape torch.Size([16, 64, 64])      |              237          |               109            |              2088           |           23530
      shape torch.Size([32, 64, 64])      |              242          |               127            |              4171           |           46900
      shape torch.Size([64, 64, 64])      |              247          |               156            |              8330           |           95000
      shape torch.Size([128, 64, 64])     |              293          |               244            |             16710           |          189000
      shape torch.Size([512, 64, 64])     |              685          |              1180            |             67000           |          750900
      shape torch.Size([1024, 64, 64])    |             1300          |              2076            |            134000           |         1505000
      shape torch.Size([1, 128, 128])     |              490          |               309            |               298           |            1560
      shape torch.Size([2, 128, 128])     |              503          |               309            |               594           |            3120
      shape torch.Size([4, 128, 128])     |              515          |               312            |              1185           |            6230
      shape torch.Size([8, 128, 128])     |              523          |               317            |              2370           |           12500
      shape torch.Size([16, 128, 128])    |              547          |               336            |              4734           |           24890
      shape torch.Size([32, 128, 128])    |              596          |               472            |              9491           |           49800
      shape torch.Size([64, 128, 128])    |              700          |               741            |             19000           |          100000
      shape torch.Size([128, 128, 128])   |              930          |              1770            |             37990           |          199000
      shape torch.Size([512, 128, 128])   |             2810          |             11000            |            152000           |          797100
      shape torch.Size([1024, 128, 128])  |             5430          |             22430            |            303900           |         1595000
      shape torch.Size([1, 256, 256])     |             1120          |              1580            |               666           |            1890
      shape torch.Size([2, 256, 256])     |             1160          |              1574            |              1330           |            3784
      shape torch.Size([4, 256, 256])     |             1190          |              1580            |              2658           |            7570
      shape torch.Size([8, 256, 256])     |             1250          |              1613            |              5325           |           15100
      shape torch.Size([16, 256, 256])    |             1394          |              1880            |             10700           |           30260
      shape torch.Size([32, 256, 256])    |             1633          |              3360            |             21300           |           61000
      shape torch.Size([64, 256, 256])    |             2258          |              6730            |             42600           |          120000
      shape torch.Size([128, 256, 256])   |             3639          |             19200            |             85170           |          242200
      shape torch.Size([512, 256, 256])   |            12600          |             87200            |            340600           |          969000
      shape torch.Size([1024, 256, 256])  |            24530          |            176000            |            681300           |         1943000
      shape torch.Size([1, 512, 512])     |             2557          |              9117            |              1724           |            2577
      shape torch.Size([2, 512, 512])     |             2691          |              9209            |              3464           |            5200
      shape torch.Size([4, 512, 512])     |             2853          |              9860            |              6940           |           10000
      shape torch.Size([8, 512, 512])     |             3153          |             11000            |             13900           |           20570
      shape torch.Size([16, 512, 512])    |             3765          |             13000            |             27720           |           41360
      shape torch.Size([32, 512, 512])    |             5500          |             21400            |             55420           |           82000
      shape torch.Size([64, 512, 512])    |             8790          |             44000            |            111000           |          165000
      shape torch.Size([128, 512, 512])   |            15300          |             98000            |            221700           |          329800
      shape torch.Size([512, 512, 512])   |            55400          |            424100            |            886600           |         1325000
      shape torch.Size([1024, 512, 512])  |           110000          |            856200            |           1773000           |         2691000
      shape torch.Size([1, 1024, 1024])   |            10350          |             69290            |              5020           |            5327
      shape torch.Size([2, 1024, 1024])   |            11200          |             74860            |             10040           |           11000
      shape torch.Size([4, 1024, 1024])   |            12200          |             78030            |             20080           |           21290
      shape torch.Size([8, 1024, 1024])   |            14000          |             81200            |             40160           |           42850
      shape torch.Size([16, 1024, 1024])  |            17700          |             96000            |             80300           |           85500
      shape torch.Size([32, 1024, 1024])  |            27740          |            150000            |            160700           |          171000
      shape torch.Size([64, 1024, 1024])  |            45940          |            233400            |            321200           |          344100
      shape torch.Size([1, 2048, 2048])   |            29860          |            579800            |             12920           |           13500
      shape torch.Size([2, 2048, 2048])   |            34000          |            585000            |             25840           |           26840
      shape torch.Size([4, 2048, 2048])   |            39770          |            593900            |             51670           |           54000
      shape torch.Size([8, 2048, 2048])   |            51720          |            632100            |            103000           |          109000
      shape torch.Size([16, 2048, 2048])  |            76900          |            845500            |            206600           |          218400
      shape torch.Size([32, 2048, 2048])  |           130000          |           1058000            |            413900           |          437300

Times are in microseconds (us).


To generate the results below, I put the backend I wanted to test at the beginning of the function `lu_solve_kernel`, followed by a `return;`. Then I run the following script, changing the variable `name`.
Benchmarking script
import torch
import pickle
import itertools
from functools import partial
from torch.utils.benchmark import Timer, Compare

benchmark_name = "linalg.lu_factor CUDA"
name = "magma_looped"
label = "lu_factor_{}".format(name)
shapes = [1, 2, 8, 16, 32, 64, 128, 256, 512, 1024, 2048]
batches = [(1,), (2,), (4,), (8,), (16,), (32,), (64,), (128,), (512,), (1024,)]
results = []
make_arg = partial(torch.randn, dtype=torch.float32, device="cuda")


for n, batch in itertools.product(shapes, batches):
    if n == 1024 and batch[0] >= 128:
        continue
    if n == 2048 and batch[0] >= 64:
        continue
    A = make_arg(batch + (n, n))
    print(A.shape)
    stmt = "torch.linalg.lu_factor_ex(A)"
    timer = Timer(stmt,
                  globals=globals(),
                  label=benchmark_name,
                  description=label,
                  sub_label=f"shape {A.shape}",
                  num_threads=1)
    results.append(timer.blocked_autorange())

compare = Compare(results)
compare.trim_significant_figures()
compare.print()

with open(f"{label}.pickle", 'wb') as f:
    pickle.dump(results, f)

See #72935 (comment) for the script to join the results.

This PR adds getrf_cublas to the functions considered in the heuristics
for lu_solve.

[ghstack-poisoned]
@lezcano lezcano mentioned this pull request Mar 7, 2022
@lezcano lezcano mentioned this pull request Mar 7, 2022
@pytorch-bot
Copy link
Copy Markdown

pytorch-bot Bot commented Mar 7, 2022

CI Flow Status

⚛️ CI Flow

Ruleset - Version: v1
Ruleset - File: https://github.com/pytorch/pytorch/blob/fbd052b492b450a7dffad805ecca7d27c9fdcd69/.github/generated-ciflow-ruleset.json
PR ciflow labels: ciflow/default
Add ciflow labels to this PR to trigger more builds:

Workflows Labels (bold enabled) Status
Triggered Workflows
linux-binary-conda ciflow/binaries, ciflow/binaries_conda, ciflow/default ✅ triggered
linux-binary-libtorch-cxx11-abi ciflow/binaries, ciflow/binaries_libtorch, ciflow/default ✅ triggered
linux-binary-libtorch-pre-cxx11 ciflow/binaries, ciflow/binaries_libtorch, ciflow/default ✅ triggered
linux-binary-manywheel ciflow/binaries, ciflow/binaries_wheel, ciflow/default ✅ triggered
linux-bionic-py3.7-clang9 ciflow/all, ciflow/cpu, ciflow/default, ciflow/linux, ciflow/noarch, ciflow/trunk ✅ triggered
linux-bionic-rocm4.5-py3.7 ciflow/all, ciflow/default, ciflow/linux, ciflow/rocm, ciflow/trunk ✅ triggered
linux-docs ciflow/all, ciflow/cpu, ciflow/default, ciflow/docs, ciflow/linux, ciflow/trunk ✅ triggered
linux-vulkan-bionic-py3.7-clang9 ciflow/all, ciflow/cpu, ciflow/default, ciflow/linux, ciflow/trunk, ciflow/vulkan ✅ triggered
linux-xenial-cuda11.3-py3.7-gcc7 ciflow/all, ciflow/cuda, ciflow/default, ciflow/linux, ciflow/trunk ✅ triggered
linux-xenial-cuda11.3-py3.7-gcc7-bazel-test ciflow/all, ciflow/bazel, ciflow/cpu, ciflow/default, ciflow/linux, ciflow/trunk ✅ triggered
linux-xenial-py3-clang5-mobile-build ciflow/all, ciflow/default, ciflow/linux, ciflow/mobile, ciflow/trunk ✅ triggered
linux-xenial-py3-clang5-mobile-custom-build-static ciflow/all, ciflow/default, ciflow/linux, ciflow/mobile, ciflow/trunk ✅ triggered
linux-xenial-py3.7-clang7-asan ciflow/all, ciflow/cpu, ciflow/default, ciflow/linux, ciflow/sanitizers, ciflow/trunk ✅ triggered
linux-xenial-py3.7-clang7-onnx ciflow/all, ciflow/cpu, ciflow/default, ciflow/linux, ciflow/onnx, ciflow/trunk ✅ triggered
linux-xenial-py3.7-gcc5.4 ciflow/all, ciflow/cpu, ciflow/default, ciflow/linux, ciflow/trunk ✅ triggered
linux-xenial-py3.7-gcc5.4-mobile-lightweight-dispatch-build ciflow/all, ciflow/cpu, ciflow/default, ciflow/libtorch, ciflow/linux, ciflow/mobile, ciflow/trunk ✅ triggered
linux-xenial-py3.7-gcc7 ciflow/all, ciflow/cpu, ciflow/default, ciflow/linux, ciflow/trunk ✅ triggered
linux-xenial-py3.7-gcc7-no-ops ciflow/all, ciflow/cpu, ciflow/default, ciflow/linux, ciflow/trunk ✅ triggered
macos-arm64-binary-conda ciflow/binaries, ciflow/binaries_conda, ciflow/default ✅ triggered
macos-arm64-binary-wheel ciflow/binaries, ciflow/binaries_wheel, ciflow/default ✅ triggered
macos-binary-conda ciflow/binaries, ciflow/binaries_conda, ciflow/default ✅ triggered
macos-binary-libtorch-cxx11-abi ciflow/binaries, ciflow/binaries_libtorch, ciflow/default ✅ triggered
macos-binary-libtorch-pre-cxx11 ciflow/binaries, ciflow/binaries_libtorch, ciflow/default ✅ triggered
macos-binary-wheel ciflow/binaries, ciflow/binaries_wheel, ciflow/default ✅ triggered
pytorch-linux-xenial-py3-clang5-android-ndk-r19c-gradle-custom-build-single ciflow/all, ciflow/android, ciflow/cpu, ciflow/default, ciflow/linux, ciflow/trunk ✅ triggered
pytorch-linux-xenial-py3-clang5-android-ndk-r19c-gradle-custom-build-single-full-jit ciflow/all, ciflow/android, ciflow/cpu, ciflow/default, ciflow/linux, ciflow/trunk ✅ triggered
win-vs2019-cpu-py3 ciflow/all, ciflow/cpu, ciflow/default, ciflow/trunk, ciflow/win ✅ triggered
win-vs2019-cuda11.3-py3 ciflow/all, ciflow/cuda, ciflow/default, ciflow/trunk, ciflow/win ✅ triggered
windows-binary-libtorch-debug ciflow/binaries, ciflow/binaries_libtorch, ciflow/default ✅ triggered
windows-binary-libtorch-release ciflow/binaries, ciflow/binaries_libtorch, ciflow/default ✅ triggered
windows-binary-wheel ciflow/binaries, ciflow/binaries_wheel, ciflow/default ✅ triggered
Skipped Workflows
caffe2-linux-xenial-py3.7-gcc5.4 ciflow/all, ciflow/cpu, ciflow/linux, ciflow/trunk 🚫 skipped
docker-builds ciflow/all, ciflow/trunk 🚫 skipped
ios-12-5-1-arm64 ciflow/all, ciflow/ios, ciflow/macos, ciflow/scheduled 🚫 skipped
ios-12-5-1-arm64-coreml ciflow/all, ciflow/ios, ciflow/macos, ciflow/scheduled 🚫 skipped
ios-12-5-1-arm64-custom-ops ciflow/all, ciflow/ios, ciflow/macos, ciflow/scheduled 🚫 skipped
ios-12-5-1-arm64-metal ciflow/all, ciflow/ios, ciflow/macos, ciflow/scheduled 🚫 skipped
ios-12-5-1-x86-64 ciflow/all, ciflow/ios, ciflow/macos, ciflow/trunk 🚫 skipped
ios-12-5-1-x86-64-coreml ciflow/all, ciflow/ios, ciflow/macos, ciflow/trunk 🚫 skipped
libtorch-linux-xenial-cuda10.2-py3.7-gcc7 ciflow/all, ciflow/cuda, ciflow/libtorch, ciflow/linux, ciflow/trunk 🚫 skipped
libtorch-linux-xenial-cuda11.3-py3.7-gcc7 ciflow/all, ciflow/cuda, ciflow/libtorch, ciflow/linux, ciflow/trunk 🚫 skipped
linux-bionic-cuda10.2-py3.9-gcc7 ciflow/all, ciflow/cuda, ciflow/linux, ciflow/slow, ciflow/trunk 🚫 skipped
linux-docs-push ciflow/all, ciflow/cpu, ciflow/linux, ciflow/scheduled 🚫 skipped
linux-xenial-cuda11.3-py3.7-gcc7-no-ops ciflow/all, ciflow/cuda, ciflow/linux, ciflow/trunk 🚫 skipped
macos-10-15-py3-arm64 ciflow/all, ciflow/macos, ciflow/trunk 🚫 skipped
macos-10-15-py3-lite-interpreter-x86-64 ciflow/all, ciflow/macos, ciflow/trunk 🚫 skipped
macos-11-py3-x86-64 ciflow/all, ciflow/macos, ciflow/trunk 🚫 skipped
parallelnative-linux-xenial-py3.7-gcc5.4 ciflow/all, ciflow/cpu, ciflow/linux, ciflow/trunk 🚫 skipped
periodic-libtorch-linux-bionic-cuda11.5-py3.7-gcc7 ciflow/all, ciflow/cuda, ciflow/libtorch, ciflow/linux, ciflow/scheduled 🚫 skipped
periodic-linux-bionic-cuda11.5-py3.7-gcc7 ciflow/all, ciflow/cuda, ciflow/linux, ciflow/scheduled 🚫 skipped
periodic-linux-xenial-cuda10.2-py3-gcc7-slow-gradcheck ciflow/all, ciflow/cuda, ciflow/linux, ciflow/scheduled, ciflow/slow, ciflow/slow-gradcheck 🚫 skipped
periodic-linux-xenial-cuda11.3-py3.7-gcc7-debug ciflow/all, ciflow/cuda, ciflow/linux, ciflow/scheduled 🚫 skipped
periodic-win-vs2019-cuda11.5-py3 ciflow/all, ciflow/cuda, ciflow/scheduled, ciflow/win 🚫 skipped
pytorch-linux-xenial-py3-clang5-android-ndk-r19c-build ciflow/all, ciflow/android, ciflow/cpu, ciflow/linux, ciflow/trunk 🚫 skipped
pytorch-xla-linux-bionic-py3.7-clang8 ciflow/all, ciflow/cpu, ciflow/linux, ciflow/trunk, ciflow/xla 🚫 skipped

@facebook-github-bot
Copy link
Copy Markdown
Contributor

facebook-github-bot commented Mar 7, 2022

🔗 Helpful links

✅ No Failures (0 Pending)

As of commit 1bd5dc4 (more details on the Dr. CI page):

Expand to see more

💚 💚 Looks good so far! There are no failures yet. 💚 💚


This comment was automatically generated by Dr. CI (expand for details).

Please report bugs/suggestions to the (internal) Dr. CI Users group.

Click here to manually regenerate this comment.

lezcano added a commit that referenced this pull request Mar 7, 2022
This PR adds getrf_cublas to the functions considered in the heuristics
for lu_solve.

ghstack-source-id: 6678fda
Pull Request resolved: #73878
@lezcano lezcano added the module: linear algebra Issues related to specialized linear algebra operations in PyTorch; includes matrix multiply matmul label Mar 7, 2022
This PR adds getrf_cublas to the functions considered in the heuristics
for lu_solve. It also updates the heuristics of the function.

## Benchmark

I'm omitting form the benchmarks the looped versions of the functions as they are much slower than the non-looped ones. The only exception to this is cusolver's looped variant, which is faster when applied to a batch of size one.

<details>
<summary>
Benchmark Results
</summary>

```
[------------------------------------------------- linalg.lu_factor CUDA -------------------------------------------------]                                                                                          
                                          |  lu_factor_heuristic  |  lu_factor_magma_batched  |  lu_factor_cusolver_batched                                                                                          
1 threads: ----------------------------------------------------------------------------------------------------------------                                                                                          
      shape torch.Size([1, 1, 1])         |            26         |              47           |                26                                                                                                    
      shape torch.Size([2, 1, 1])         |            17         |              38           |                17                                                                                                    
      shape torch.Size([4, 1, 1])         |            17         |              38           |                17                                                                                                    
      shape torch.Size([8, 1, 1])         |            20         |              38           |                18                                                                                                    
      shape torch.Size([16, 1, 1])        |            20         |              38           |                17                                                                                                    
      shape torch.Size([32, 1, 1])        |            18         |              38           |                17                                                                                                    
      shape torch.Size([64, 1, 1])        |            18         |              39           |                17                                                                                                    
      shape torch.Size([128, 1, 1])       |            17         |              38           |                17                                                                                                    
      shape torch.Size([512, 1, 1])       |            18         |              39           |                18                                                                                                    
      shape torch.Size([1024, 1, 1])      |            18         |              40           |                18                                                                                                    
      shape torch.Size([1, 2, 2])         |            18         |              38           |                17                                                                                                    
      shape torch.Size([2, 2, 2])         |            17         |              37           |                17                                                                                                    
      shape torch.Size([4, 2, 2])         |            17         |              38           |                17                                                                                                    
      shape torch.Size([8, 2, 2])         |            17         |              38           |                17                                                                                                    
      shape torch.Size([16, 2, 2])        |            17         |              38           |                17                                                                                                    
      shape torch.Size([32, 2, 2])        |            17         |              38           |                17                                                                                                    
      shape torch.Size([64, 2, 2])        |            17         |              38           |                17                                                                                                    
      shape torch.Size([128, 2, 2])       |            17         |              38           |                17                                                                                                    
      shape torch.Size([512, 2, 2])       |            17         |              39           |                17                                                                                                    
      shape torch.Size([1024, 2, 2])      |            17         |              40           |                17                                                                                                    
      shape torch.Size([1, 8, 8])         |            17         |              40           |                17                                                                                                    
      shape torch.Size([2, 8, 8])         |            17         |              40           |                17                                                                                                    
      shape torch.Size([4, 8, 8])         |            17         |              40           |                17                                                                                                    
      shape torch.Size([8, 8, 8])         |            17         |              40           |                17                                                                                                    
      shape torch.Size([16, 8, 8])        |            17         |              41           |                17                                                                                                    
      shape torch.Size([32, 8, 8])        |            17         |              40           |                17                                                                                                    
      shape torch.Size([64, 8, 8])        |            17         |              40           |                17                                                                                                    
      shape torch.Size([128, 8, 8])       |            17         |              40           |                17                                                                                                    
      shape torch.Size([512, 8, 8])       |            17         |              42           |                17                                                                                                    
      shape torch.Size([1024, 8, 8])      |            17         |              44           |                17                                                                                                    
      shape torch.Size([1, 16, 16])       |            24         |              44           |                18                                                                                                    
      shape torch.Size([2, 16, 16])       |            18         |              44           |                18                                                                                                    
      shape torch.Size([4, 16, 16])       |            18         |              45           |                18          
      shape torch.Size([8, 16, 16])       |            19         |              44           |                19          
      shape torch.Size([16, 16, 16])      |            20         |              44           |                20          
      shape torch.Size([32, 16, 16])      |            20         |              45           |                20          
      shape torch.Size([64, 16, 16])      |            20         |              44           |                20          
      shape torch.Size([128, 16, 16])     |            20         |              45           |                20          
      shape torch.Size([512, 16, 16])     |            28         |              50           |                28          
      shape torch.Size([1024, 16, 16])    |            41         |              59           |                41          
      shape torch.Size([1, 32, 32])       |            58         |              50           |                56          
      shape torch.Size([2, 32, 32])       |            56         |              50           |                56          
      shape torch.Size([4, 32, 32])       |            56         |              50           |                57          
      shape torch.Size([8, 32, 32])       |            60         |              50           |                60          
      shape torch.Size([16, 32, 32])      |            60         |              51           |                60          
      shape torch.Size([32, 32, 32])      |           247         |              51           |                61          
      shape torch.Size([64, 32, 32])      |           233         |              51           |                63          
      shape torch.Size([128, 32, 32])     |           236         |              53           |                66          
      shape torch.Size([512, 32, 32])     |           268         |              97           |               193          
      shape torch.Size([1024, 32, 32])    |           317         |             167           |               333          
      shape torch.Size([1, 64, 64])       |           131         |             216           |                99          
      shape torch.Size([2, 64, 64])       |            99         |             220           |                99          
      shape torch.Size([4, 64, 64])       |            99         |             225           |               101          
      shape torch.Size([8, 64, 64])       |           101         |             225           |               102          
      shape torch.Size([16, 64, 64])      |           107         |             230           |               108          
      shape torch.Size([32, 64, 64])      |           440         |             235           |               126          
      shape torch.Size([64, 64, 64])      |           447         |             240           |               155          
      shape torch.Size([128, 64, 64])     |           470         |             289           |               240          
      shape torch.Size([512, 64, 64])     |           793         |             678           |              1180          
      shape torch.Size([1024, 64, 64])    |          1000         |            1300           |              2112          
      shape torch.Size([1, 128, 128])     |           296         |             482           |               309          
      shape torch.Size([2, 128, 128])     |           308         |             499           |               307          
      shape torch.Size([4, 128, 128])     |           311         |             510           |               310          
      shape torch.Size([8, 128, 128])     |           314         |             522           |               314          
      shape torch.Size([16, 128, 128])    |           334         |             541           |               334          
      shape torch.Size([32, 128, 128])    |           770         |             591           |               467          
      shape torch.Size([64, 128, 128])    |           860         |             694           |               733          
      shape torch.Size([128, 128, 128])   |          1040         |             925           |              1980          
      shape torch.Size([512, 128, 128])   |          2883         |            2809           |             11000          
      shape torch.Size([1024, 128, 128])  |          5421         |            5430           |             22360          
      shape torch.Size([1, 256, 256])     |          1310         |            1109           |              1556          
      shape torch.Size([2, 256, 256])     |          1360         |            1150           |              1560          
      shape torch.Size([4, 256, 256])     |          1390         |            1188           |              1569          
      shape torch.Size([8, 256, 256])     |          1440         |            1250           |              1604          
      shape torch.Size([16, 256, 256])    |          1550         |            1390           |              1850          
      shape torch.Size([32, 256, 256])    |          1750         |            1620           |              3332          
      shape torch.Size([64, 256, 256])    |          2327         |            2246           |              6700          
      shape torch.Size([128, 256, 256])   |          3697         |            3638           |             19100          
      shape torch.Size([512, 256, 256])   |         12530         |           12500           |             87300          
      shape torch.Size([1024, 256, 256])  |         24380         |           24420           |            176000          
```

</details>

To generate the results below, I put the backend I wanted to test at the beginning of the function `lu_solve_kernel`, followed by a `return;`. Then I run the following script, changing the variable `name`. For the `lu_solve unpack+solve_triangular`, I also changed the `stmt` variable (uncomenting the commented one)
<details>
<summary>
Benchmarking script
</summary>

```python
import torch
import pickle
import itertools
from functools import partial
from torch.utils.benchmark import Timer, Compare

benchmark_name = "linalg.lu_factor CUDA"
name = "magma_batched"
label = "lu_factor_{}".format(name)
shapes = [1, 2, 8, 16, 32, 64, 128, 256]
batches = [(1,), (2,), (4,), (8,), (16,), (32,), (64,), (128,), (512,), (1024,)]
results = []
make_arg = partial(torch.randn, dtype=torch.float32, device="cuda")


for n, batch in itertools.product(shapes, batches):
    A = make_arg(batch + (n, n))
    print(A.shape)
    stmt = "torch.linalg.lu_factor_ex(A)"
    timer = Timer(stmt,
                  globals=globals(),
                  label=benchmark_name,
                  description=label,
                  sub_label=f"shape {A.shape}",
                  num_threads=1)
    results.append(timer.blocked_autorange())

    # Test
    LU, pivots = torch.linalg.lu_factor(A)
    P, L, U = torch.lu_unpack(LU, pivots)
    assert torch.allclose(P @ L @ U, A, rtol=1e-2, atol=1e-3)


compare = Compare(results)
compare.trim_significant_figures()
compare.print()

with open(f"{label}.pickle", 'wb') as f:
    pickle.dump(results, f)
```

</details>

See #72935 (comment) for the script to join the results.

[ghstack-poisoned]
This PR adds getrf_cublas to the functions considered in the heuristics
for lu_solve. It also updates the heuristics of the function.

## Benchmark

I'm omitting form the benchmarks the looped versions of the functions as they are much slower than the non-looped ones. The only exception to this is cusolver's looped variant, which is faster when applied to a batch of size one.

<details>
<summary>
Benchmark Results
</summary>

```
[------------------------------------------------- linalg.lu_factor CUDA -------------------------------------------------]                                                                                          
                                          |  lu_factor_heuristic  |  lu_factor_magma_batched  |  lu_factor_cusolver_batched                                                                                          
1 threads: ----------------------------------------------------------------------------------------------------------------                                                                                          
      shape torch.Size([1, 1, 1])         |            26         |              47           |                26                                                                                                    
      shape torch.Size([2, 1, 1])         |            17         |              38           |                17                                                                                                    
      shape torch.Size([4, 1, 1])         |            17         |              38           |                17                                                                                                    
      shape torch.Size([8, 1, 1])         |            20         |              38           |                18                                                                                                    
      shape torch.Size([16, 1, 1])        |            20         |              38           |                17                                                                                                    
      shape torch.Size([32, 1, 1])        |            18         |              38           |                17                                                                                                    
      shape torch.Size([64, 1, 1])        |            18         |              39           |                17                                                                                                    
      shape torch.Size([128, 1, 1])       |            17         |              38           |                17                                                                                                    
      shape torch.Size([512, 1, 1])       |            18         |              39           |                18                                                                                                    
      shape torch.Size([1024, 1, 1])      |            18         |              40           |                18                                                                                                    
      shape torch.Size([1, 2, 2])         |            18         |              38           |                17                                                                                                    
      shape torch.Size([2, 2, 2])         |            17         |              37           |                17                                                                                                    
      shape torch.Size([4, 2, 2])         |            17         |              38           |                17                                                                                                    
      shape torch.Size([8, 2, 2])         |            17         |              38           |                17                                                                                                    
      shape torch.Size([16, 2, 2])        |            17         |              38           |                17                                                                                                    
      shape torch.Size([32, 2, 2])        |            17         |              38           |                17                                                                                                    
      shape torch.Size([64, 2, 2])        |            17         |              38           |                17                                                                                                    
      shape torch.Size([128, 2, 2])       |            17         |              38           |                17                                                                                                    
      shape torch.Size([512, 2, 2])       |            17         |              39           |                17                                                                                                    
      shape torch.Size([1024, 2, 2])      |            17         |              40           |                17                                                                                                    
      shape torch.Size([1, 8, 8])         |            17         |              40           |                17                                                                                                    
      shape torch.Size([2, 8, 8])         |            17         |              40           |                17                                                                                                    
      shape torch.Size([4, 8, 8])         |            17         |              40           |                17                                                                                                    
      shape torch.Size([8, 8, 8])         |            17         |              40           |                17                                                                                                    
      shape torch.Size([16, 8, 8])        |            17         |              41           |                17                                                                                                    
      shape torch.Size([32, 8, 8])        |            17         |              40           |                17                                                                                                    
      shape torch.Size([64, 8, 8])        |            17         |              40           |                17                                                                                                    
      shape torch.Size([128, 8, 8])       |            17         |              40           |                17                                                                                                    
      shape torch.Size([512, 8, 8])       |            17         |              42           |                17                                                                                                    
      shape torch.Size([1024, 8, 8])      |            17         |              44           |                17                                                                                                    
      shape torch.Size([1, 16, 16])       |            24         |              44           |                18                                                                                                    
      shape torch.Size([2, 16, 16])       |            18         |              44           |                18                                                                                                    
      shape torch.Size([4, 16, 16])       |            18         |              45           |                18          
      shape torch.Size([8, 16, 16])       |            19         |              44           |                19          
      shape torch.Size([16, 16, 16])      |            20         |              44           |                20          
      shape torch.Size([32, 16, 16])      |            20         |              45           |                20          
      shape torch.Size([64, 16, 16])      |            20         |              44           |                20          
      shape torch.Size([128, 16, 16])     |            20         |              45           |                20          
      shape torch.Size([512, 16, 16])     |            28         |              50           |                28          
      shape torch.Size([1024, 16, 16])    |            41         |              59           |                41          
      shape torch.Size([1, 32, 32])       |            58         |              50           |                56          
      shape torch.Size([2, 32, 32])       |            56         |              50           |                56          
      shape torch.Size([4, 32, 32])       |            56         |              50           |                57          
      shape torch.Size([8, 32, 32])       |            60         |              50           |                60          
      shape torch.Size([16, 32, 32])      |            60         |              51           |                60          
      shape torch.Size([32, 32, 32])      |           247         |              51           |                61          
      shape torch.Size([64, 32, 32])      |           233         |              51           |                63          
      shape torch.Size([128, 32, 32])     |           236         |              53           |                66          
      shape torch.Size([512, 32, 32])     |           268         |              97           |               193          
      shape torch.Size([1024, 32, 32])    |           317         |             167           |               333          
      shape torch.Size([1, 64, 64])       |           131         |             216           |                99          
      shape torch.Size([2, 64, 64])       |            99         |             220           |                99          
      shape torch.Size([4, 64, 64])       |            99         |             225           |               101          
      shape torch.Size([8, 64, 64])       |           101         |             225           |               102          
      shape torch.Size([16, 64, 64])      |           107         |             230           |               108          
      shape torch.Size([32, 64, 64])      |           440         |             235           |               126          
      shape torch.Size([64, 64, 64])      |           447         |             240           |               155          
      shape torch.Size([128, 64, 64])     |           470         |             289           |               240          
      shape torch.Size([512, 64, 64])     |           793         |             678           |              1180          
      shape torch.Size([1024, 64, 64])    |          1000         |            1300           |              2112          
      shape torch.Size([1, 128, 128])     |           296         |             482           |               309          
      shape torch.Size([2, 128, 128])     |           308         |             499           |               307          
      shape torch.Size([4, 128, 128])     |           311         |             510           |               310          
      shape torch.Size([8, 128, 128])     |           314         |             522           |               314          
      shape torch.Size([16, 128, 128])    |           334         |             541           |               334          
      shape torch.Size([32, 128, 128])    |           770         |             591           |               467          
      shape torch.Size([64, 128, 128])    |           860         |             694           |               733          
      shape torch.Size([128, 128, 128])   |          1040         |             925           |              1980          
      shape torch.Size([512, 128, 128])   |          2883         |            2809           |             11000          
      shape torch.Size([1024, 128, 128])  |          5421         |            5430           |             22360          
      shape torch.Size([1, 256, 256])     |          1310         |            1109           |              1556          
      shape torch.Size([2, 256, 256])     |          1360         |            1150           |              1560          
      shape torch.Size([4, 256, 256])     |          1390         |            1188           |              1569          
      shape torch.Size([8, 256, 256])     |          1440         |            1250           |              1604          
      shape torch.Size([16, 256, 256])    |          1550         |            1390           |              1850          
      shape torch.Size([32, 256, 256])    |          1750         |            1620           |              3332          
      shape torch.Size([64, 256, 256])    |          2327         |            2246           |              6700          
      shape torch.Size([128, 256, 256])   |          3697         |            3638           |             19100          
      shape torch.Size([512, 256, 256])   |         12530         |           12500           |             87300          
      shape torch.Size([1024, 256, 256])  |         24380         |           24420           |            176000          
```

</details>

To generate the results below, I put the backend I wanted to test at the beginning of the function `lu_solve_kernel`, followed by a `return;`. Then I run the following script, changing the variable `name`. For the `lu_solve unpack+solve_triangular`, I also changed the `stmt` variable (uncomenting the commented one)
<details>
<summary>
Benchmarking script
</summary>

```python
import torch
import pickle
import itertools
from functools import partial
from torch.utils.benchmark import Timer, Compare

benchmark_name = "linalg.lu_factor CUDA"
name = "magma_batched"
label = "lu_factor_{}".format(name)
shapes = [1, 2, 8, 16, 32, 64, 128, 256]
batches = [(1,), (2,), (4,), (8,), (16,), (32,), (64,), (128,), (512,), (1024,)]
results = []
make_arg = partial(torch.randn, dtype=torch.float32, device="cuda")


for n, batch in itertools.product(shapes, batches):
    A = make_arg(batch + (n, n))
    print(A.shape)
    stmt = "torch.linalg.lu_factor_ex(A)"
    timer = Timer(stmt,
                  globals=globals(),
                  label=benchmark_name,
                  description=label,
                  sub_label=f"shape {A.shape}",
                  num_threads=1)
    results.append(timer.blocked_autorange())

    # Test
    LU, pivots = torch.linalg.lu_factor(A)
    P, L, U = torch.lu_unpack(LU, pivots)
    assert torch.allclose(P @ L @ U, A, rtol=1e-2, atol=1e-3)


compare = Compare(results)
compare.trim_significant_figures()
compare.print()

with open(f"{label}.pickle", 'wb') as f:
    pickle.dump(results, f)
```

</details>

See #72935 (comment) for the script to join the results.

[ghstack-poisoned]
lezcano added a commit that referenced this pull request Mar 8, 2022
This PR adds getrf_cublas to the functions considered in the heuristics
for lu_solve.

ghstack-source-id: 0fdfb74
Pull Request resolved: #73878
@IvanYashchuk
Copy link
Copy Markdown
Collaborator

You're running the benchmark only up to 256x256 matrices. It's important to test larger matrices as well, this is the regime where the "looped" variant should be faster.

@lezcano
Copy link
Copy Markdown
Collaborator Author

lezcano commented Mar 10, 2022

How large do you want the matrices to be? Do you reckon adding 512 and 1024 would do it?

@IvanYashchuk
Copy link
Copy Markdown
Collaborator

IvanYashchuk commented Mar 10, 2022

A comment here suggests that 512 is the breaking point where the cusolver looped variant is better:

// Heuristic: For small batch size or large matrix size, we use for-loop to iterate over the batches instead of
// calling the batched cublas routine.
if (batch_size <= 8 || /* batch_size > 8 && */ n >= 512) {

A few other examples of using 512:
// cuBLAS batched is faster than MAGMA batched up until 512x512, after that MAGMA is faster

if ((batch_size == 1 && m > 512) || (batch_size <= 8 && over_magma_dim_limit)) {

This PR adds getrf_cublas to the functions considered in the heuristics
for lu_solve. It also updates the heuristics of the function.

## Benchmark

I'm omitting form the benchmarks the looped versions of the functions as they are much slower than the non-looped ones. The only exception to this is cusolver's looped variant, which is faster when applied to a batch of size one.

<details>
<summary>
Benchmark Results
</summary>

```
[------------------------------------------------- linalg.lu_factor CUDA -------------------------------------------------]                                                                                          
                                          |  lu_factor_heuristic  |  lu_factor_magma_batched  |  lu_factor_cusolver_batched                                                                                          
1 threads: ----------------------------------------------------------------------------------------------------------------                                                                                          
      shape torch.Size([1, 1, 1])         |            26         |              47           |                26                                                                                                    
      shape torch.Size([2, 1, 1])         |            17         |              38           |                17                                                                                                    
      shape torch.Size([4, 1, 1])         |            17         |              38           |                17                                                                                                    
      shape torch.Size([8, 1, 1])         |            20         |              38           |                18                                                                                                    
      shape torch.Size([16, 1, 1])        |            20         |              38           |                17                                                                                                    
      shape torch.Size([32, 1, 1])        |            18         |              38           |                17                                                                                                    
      shape torch.Size([64, 1, 1])        |            18         |              39           |                17                                                                                                    
      shape torch.Size([128, 1, 1])       |            17         |              38           |                17                                                                                                    
      shape torch.Size([512, 1, 1])       |            18         |              39           |                18                                                                                                    
      shape torch.Size([1024, 1, 1])      |            18         |              40           |                18                                                                                                    
      shape torch.Size([1, 2, 2])         |            18         |              38           |                17                                                                                                    
      shape torch.Size([2, 2, 2])         |            17         |              37           |                17                                                                                                    
      shape torch.Size([4, 2, 2])         |            17         |              38           |                17                                                                                                    
      shape torch.Size([8, 2, 2])         |            17         |              38           |                17                                                                                                    
      shape torch.Size([16, 2, 2])        |            17         |              38           |                17                                                                                                    
      shape torch.Size([32, 2, 2])        |            17         |              38           |                17                                                                                                    
      shape torch.Size([64, 2, 2])        |            17         |              38           |                17                                                                                                    
      shape torch.Size([128, 2, 2])       |            17         |              38           |                17                                                                                                    
      shape torch.Size([512, 2, 2])       |            17         |              39           |                17                                                                                                    
      shape torch.Size([1024, 2, 2])      |            17         |              40           |                17                                                                                                    
      shape torch.Size([1, 8, 8])         |            17         |              40           |                17                                                                                                    
      shape torch.Size([2, 8, 8])         |            17         |              40           |                17                                                                                                    
      shape torch.Size([4, 8, 8])         |            17         |              40           |                17                                                                                                    
      shape torch.Size([8, 8, 8])         |            17         |              40           |                17                                                                                                    
      shape torch.Size([16, 8, 8])        |            17         |              41           |                17                                                                                                    
      shape torch.Size([32, 8, 8])        |            17         |              40           |                17                                                                                                    
      shape torch.Size([64, 8, 8])        |            17         |              40           |                17                                                                                                    
      shape torch.Size([128, 8, 8])       |            17         |              40           |                17                                                                                                    
      shape torch.Size([512, 8, 8])       |            17         |              42           |                17                                                                                                    
      shape torch.Size([1024, 8, 8])      |            17         |              44           |                17                                                                                                    
      shape torch.Size([1, 16, 16])       |            24         |              44           |                18                                                                                                    
      shape torch.Size([2, 16, 16])       |            18         |              44           |                18                                                                                                    
      shape torch.Size([4, 16, 16])       |            18         |              45           |                18          
      shape torch.Size([8, 16, 16])       |            19         |              44           |                19          
      shape torch.Size([16, 16, 16])      |            20         |              44           |                20          
      shape torch.Size([32, 16, 16])      |            20         |              45           |                20          
      shape torch.Size([64, 16, 16])      |            20         |              44           |                20          
      shape torch.Size([128, 16, 16])     |            20         |              45           |                20          
      shape torch.Size([512, 16, 16])     |            28         |              50           |                28          
      shape torch.Size([1024, 16, 16])    |            41         |              59           |                41          
      shape torch.Size([1, 32, 32])       |            58         |              50           |                56          
      shape torch.Size([2, 32, 32])       |            56         |              50           |                56          
      shape torch.Size([4, 32, 32])       |            56         |              50           |                57          
      shape torch.Size([8, 32, 32])       |            60         |              50           |                60          
      shape torch.Size([16, 32, 32])      |            60         |              51           |                60          
      shape torch.Size([32, 32, 32])      |           247         |              51           |                61          
      shape torch.Size([64, 32, 32])      |           233         |              51           |                63          
      shape torch.Size([128, 32, 32])     |           236         |              53           |                66          
      shape torch.Size([512, 32, 32])     |           268         |              97           |               193          
      shape torch.Size([1024, 32, 32])    |           317         |             167           |               333          
      shape torch.Size([1, 64, 64])       |           131         |             216           |                99          
      shape torch.Size([2, 64, 64])       |            99         |             220           |                99          
      shape torch.Size([4, 64, 64])       |            99         |             225           |               101          
      shape torch.Size([8, 64, 64])       |           101         |             225           |               102          
      shape torch.Size([16, 64, 64])      |           107         |             230           |               108          
      shape torch.Size([32, 64, 64])      |           440         |             235           |               126          
      shape torch.Size([64, 64, 64])      |           447         |             240           |               155          
      shape torch.Size([128, 64, 64])     |           470         |             289           |               240          
      shape torch.Size([512, 64, 64])     |           793         |             678           |              1180          
      shape torch.Size([1024, 64, 64])    |          1000         |            1300           |              2112          
      shape torch.Size([1, 128, 128])     |           296         |             482           |               309          
      shape torch.Size([2, 128, 128])     |           308         |             499           |               307          
      shape torch.Size([4, 128, 128])     |           311         |             510           |               310          
      shape torch.Size([8, 128, 128])     |           314         |             522           |               314          
      shape torch.Size([16, 128, 128])    |           334         |             541           |               334          
      shape torch.Size([32, 128, 128])    |           770         |             591           |               467          
      shape torch.Size([64, 128, 128])    |           860         |             694           |               733          
      shape torch.Size([128, 128, 128])   |          1040         |             925           |              1980          
      shape torch.Size([512, 128, 128])   |          2883         |            2809           |             11000          
      shape torch.Size([1024, 128, 128])  |          5421         |            5430           |             22360          
      shape torch.Size([1, 256, 256])     |          1310         |            1109           |              1556          
      shape torch.Size([2, 256, 256])     |          1360         |            1150           |              1560          
      shape torch.Size([4, 256, 256])     |          1390         |            1188           |              1569          
      shape torch.Size([8, 256, 256])     |          1440         |            1250           |              1604          
      shape torch.Size([16, 256, 256])    |          1550         |            1390           |              1850          
      shape torch.Size([32, 256, 256])    |          1750         |            1620           |              3332          
      shape torch.Size([64, 256, 256])    |          2327         |            2246           |              6700          
      shape torch.Size([128, 256, 256])   |          3697         |            3638           |             19100          
      shape torch.Size([512, 256, 256])   |         12530         |           12500           |             87300          
      shape torch.Size([1024, 256, 256])  |         24380         |           24420           |            176000          
```

</details>

To generate the results below, I put the backend I wanted to test at the beginning of the function `lu_solve_kernel`, followed by a `return;`. Then I run the following script, changing the variable `name`. For the `lu_solve unpack+solve_triangular`, I also changed the `stmt` variable (uncomenting the commented one)
<details>
<summary>
Benchmarking script
</summary>

```python
import torch
import pickle
import itertools
from functools import partial
from torch.utils.benchmark import Timer, Compare

benchmark_name = "linalg.lu_factor CUDA"
name = "magma_batched"
label = "lu_factor_{}".format(name)
shapes = [1, 2, 8, 16, 32, 64, 128, 256]
batches = [(1,), (2,), (4,), (8,), (16,), (32,), (64,), (128,), (512,), (1024,)]
results = []
make_arg = partial(torch.randn, dtype=torch.float32, device="cuda")


for n, batch in itertools.product(shapes, batches):
    A = make_arg(batch + (n, n))
    print(A.shape)
    stmt = "torch.linalg.lu_factor_ex(A)"
    timer = Timer(stmt,
                  globals=globals(),
                  label=benchmark_name,
                  description=label,
                  sub_label=f"shape {A.shape}",
                  num_threads=1)
    results.append(timer.blocked_autorange())

    # Test
    LU, pivots = torch.linalg.lu_factor(A)
    P, L, U = torch.lu_unpack(LU, pivots)
    assert torch.allclose(P @ L @ U, A, rtol=1e-2, atol=1e-3)


compare = Compare(results)
compare.trim_significant_figures()
compare.print()

with open(f"{label}.pickle", 'wb') as f:
    pickle.dump(results, f)
```

</details>

See #72935 (comment) for the script to join the results.

[ghstack-poisoned]
lezcano added 3 commits March 10, 2022 19:19
This PR adds getrf_cublas to the functions considered in the heuristics
for lu_solve. It also updates the heuristics of the function.

## Benchmark

I'm omitting form the benchmarks the looped versions of the functions as they are much slower than the non-looped ones. The only exception to this is cusolver's looped variant, which is faster when applied to a batch of size one.

<details>
<summary>
Benchmark Results
</summary>

```
[------------------------------------------------- linalg.lu_factor CUDA -------------------------------------------------]                                                                                          
                                          |  lu_factor_heuristic  |  lu_factor_magma_batched  |  lu_factor_cusolver_batched                                                                                          
1 threads: ----------------------------------------------------------------------------------------------------------------                                                                                          
      shape torch.Size([1, 1, 1])         |            26         |              47           |                26                                                                                                    
      shape torch.Size([2, 1, 1])         |            17         |              38           |                17                                                                                                    
      shape torch.Size([4, 1, 1])         |            17         |              38           |                17                                                                                                    
      shape torch.Size([8, 1, 1])         |            20         |              38           |                18                                                                                                    
      shape torch.Size([16, 1, 1])        |            20         |              38           |                17                                                                                                    
      shape torch.Size([32, 1, 1])        |            18         |              38           |                17                                                                                                    
      shape torch.Size([64, 1, 1])        |            18         |              39           |                17                                                                                                    
      shape torch.Size([128, 1, 1])       |            17         |              38           |                17                                                                                                    
      shape torch.Size([512, 1, 1])       |            18         |              39           |                18                                                                                                    
      shape torch.Size([1024, 1, 1])      |            18         |              40           |                18                                                                                                    
      shape torch.Size([1, 2, 2])         |            18         |              38           |                17                                                                                                    
      shape torch.Size([2, 2, 2])         |            17         |              37           |                17                                                                                                    
      shape torch.Size([4, 2, 2])         |            17         |              38           |                17                                                                                                    
      shape torch.Size([8, 2, 2])         |            17         |              38           |                17                                                                                                    
      shape torch.Size([16, 2, 2])        |            17         |              38           |                17                                                                                                    
      shape torch.Size([32, 2, 2])        |            17         |              38           |                17                                                                                                    
      shape torch.Size([64, 2, 2])        |            17         |              38           |                17                                                                                                    
      shape torch.Size([128, 2, 2])       |            17         |              38           |                17                                                                                                    
      shape torch.Size([512, 2, 2])       |            17         |              39           |                17                                                                                                    
      shape torch.Size([1024, 2, 2])      |            17         |              40           |                17                                                                                                    
      shape torch.Size([1, 8, 8])         |            17         |              40           |                17                                                                                                    
      shape torch.Size([2, 8, 8])         |            17         |              40           |                17                                                                                                    
      shape torch.Size([4, 8, 8])         |            17         |              40           |                17                                                                                                    
      shape torch.Size([8, 8, 8])         |            17         |              40           |                17                                                                                                    
      shape torch.Size([16, 8, 8])        |            17         |              41           |                17                                                                                                    
      shape torch.Size([32, 8, 8])        |            17         |              40           |                17                                                                                                    
      shape torch.Size([64, 8, 8])        |            17         |              40           |                17                                                                                                    
      shape torch.Size([128, 8, 8])       |            17         |              40           |                17                                                                                                    
      shape torch.Size([512, 8, 8])       |            17         |              42           |                17                                                                                                    
      shape torch.Size([1024, 8, 8])      |            17         |              44           |                17                                                                                                    
      shape torch.Size([1, 16, 16])       |            24         |              44           |                18                                                                                                    
      shape torch.Size([2, 16, 16])       |            18         |              44           |                18                                                                                                    
      shape torch.Size([4, 16, 16])       |            18         |              45           |                18          
      shape torch.Size([8, 16, 16])       |            19         |              44           |                19          
      shape torch.Size([16, 16, 16])      |            20         |              44           |                20          
      shape torch.Size([32, 16, 16])      |            20         |              45           |                20          
      shape torch.Size([64, 16, 16])      |            20         |              44           |                20          
      shape torch.Size([128, 16, 16])     |            20         |              45           |                20          
      shape torch.Size([512, 16, 16])     |            28         |              50           |                28          
      shape torch.Size([1024, 16, 16])    |            41         |              59           |                41          
      shape torch.Size([1, 32, 32])       |            58         |              50           |                56          
      shape torch.Size([2, 32, 32])       |            56         |              50           |                56          
      shape torch.Size([4, 32, 32])       |            56         |              50           |                57          
      shape torch.Size([8, 32, 32])       |            60         |              50           |                60          
      shape torch.Size([16, 32, 32])      |            60         |              51           |                60          
      shape torch.Size([32, 32, 32])      |           247         |              51           |                61          
      shape torch.Size([64, 32, 32])      |           233         |              51           |                63          
      shape torch.Size([128, 32, 32])     |           236         |              53           |                66          
      shape torch.Size([512, 32, 32])     |           268         |              97           |               193          
      shape torch.Size([1024, 32, 32])    |           317         |             167           |               333          
      shape torch.Size([1, 64, 64])       |           131         |             216           |                99          
      shape torch.Size([2, 64, 64])       |            99         |             220           |                99          
      shape torch.Size([4, 64, 64])       |            99         |             225           |               101          
      shape torch.Size([8, 64, 64])       |           101         |             225           |               102          
      shape torch.Size([16, 64, 64])      |           107         |             230           |               108          
      shape torch.Size([32, 64, 64])      |           440         |             235           |               126          
      shape torch.Size([64, 64, 64])      |           447         |             240           |               155          
      shape torch.Size([128, 64, 64])     |           470         |             289           |               240          
      shape torch.Size([512, 64, 64])     |           793         |             678           |              1180          
      shape torch.Size([1024, 64, 64])    |          1000         |            1300           |              2112          
      shape torch.Size([1, 128, 128])     |           296         |             482           |               309          
      shape torch.Size([2, 128, 128])     |           308         |             499           |               307          
      shape torch.Size([4, 128, 128])     |           311         |             510           |               310          
      shape torch.Size([8, 128, 128])     |           314         |             522           |               314          
      shape torch.Size([16, 128, 128])    |           334         |             541           |               334          
      shape torch.Size([32, 128, 128])    |           770         |             591           |               467          
      shape torch.Size([64, 128, 128])    |           860         |             694           |               733          
      shape torch.Size([128, 128, 128])   |          1040         |             925           |              1980          
      shape torch.Size([512, 128, 128])   |          2883         |            2809           |             11000          
      shape torch.Size([1024, 128, 128])  |          5421         |            5430           |             22360          
      shape torch.Size([1, 256, 256])     |          1310         |            1109           |              1556          
      shape torch.Size([2, 256, 256])     |          1360         |            1150           |              1560          
      shape torch.Size([4, 256, 256])     |          1390         |            1188           |              1569          
      shape torch.Size([8, 256, 256])     |          1440         |            1250           |              1604          
      shape torch.Size([16, 256, 256])    |          1550         |            1390           |              1850          
      shape torch.Size([32, 256, 256])    |          1750         |            1620           |              3332          
      shape torch.Size([64, 256, 256])    |          2327         |            2246           |              6700          
      shape torch.Size([128, 256, 256])   |          3697         |            3638           |             19100          
      shape torch.Size([512, 256, 256])   |         12530         |           12500           |             87300          
      shape torch.Size([1024, 256, 256])  |         24380         |           24420           |            176000          
```

</details>

To generate the results below, I put the backend I wanted to test at the beginning of the function `lu_solve_kernel`, followed by a `return;`. Then I run the following script, changing the variable `name`. For the `lu_solve unpack+solve_triangular`, I also changed the `stmt` variable (uncomenting the commented one)
<details>
<summary>
Benchmarking script
</summary>

```python
import torch
import pickle
import itertools
from functools import partial
from torch.utils.benchmark import Timer, Compare

benchmark_name = "linalg.lu_factor CUDA"
name = "magma_batched"
label = "lu_factor_{}".format(name)
shapes = [1, 2, 8, 16, 32, 64, 128, 256]
batches = [(1,), (2,), (4,), (8,), (16,), (32,), (64,), (128,), (512,), (1024,)]
results = []
make_arg = partial(torch.randn, dtype=torch.float32, device="cuda")


for n, batch in itertools.product(shapes, batches):
    A = make_arg(batch + (n, n))
    print(A.shape)
    stmt = "torch.linalg.lu_factor_ex(A)"
    timer = Timer(stmt,
                  globals=globals(),
                  label=benchmark_name,
                  description=label,
                  sub_label=f"shape {A.shape}",
                  num_threads=1)
    results.append(timer.blocked_autorange())

    # Test
    LU, pivots = torch.linalg.lu_factor(A)
    P, L, U = torch.lu_unpack(LU, pivots)
    assert torch.allclose(P @ L @ U, A, rtol=1e-2, atol=1e-3)


compare = Compare(results)
compare.trim_significant_figures()
compare.print()

with open(f"{label}.pickle", 'wb') as f:
    pickle.dump(results, f)
```

</details>

See #72935 (comment) for the script to join the results.

[ghstack-poisoned]
This PR adds getrf_cublas to the functions considered in the heuristics
for lu_solve. It also updates the heuristics of the function.

## Benchmark

I'm omitting form the benchmarks the looped versions of the functions as they are much slower than the non-looped ones. The only exception to this is cusolver's looped variant, which is faster when applied to a batch of size one.

<details>
<summary>
Benchmark Results
</summary>

```
[------------------------------------------------- linalg.lu_factor CUDA -------------------------------------------------]                                                                                          
                                          |  lu_factor_heuristic  |  lu_factor_magma_batched  |  lu_factor_cusolver_batched                                                                                          
1 threads: ----------------------------------------------------------------------------------------------------------------                                                                                          
      shape torch.Size([1, 1, 1])         |            26         |              47           |                26                                                                                                    
      shape torch.Size([2, 1, 1])         |            17         |              38           |                17                                                                                                    
      shape torch.Size([4, 1, 1])         |            17         |              38           |                17                                                                                                    
      shape torch.Size([8, 1, 1])         |            20         |              38           |                18                                                                                                    
      shape torch.Size([16, 1, 1])        |            20         |              38           |                17                                                                                                    
      shape torch.Size([32, 1, 1])        |            18         |              38           |                17                                                                                                    
      shape torch.Size([64, 1, 1])        |            18         |              39           |                17                                                                                                    
      shape torch.Size([128, 1, 1])       |            17         |              38           |                17                                                                                                    
      shape torch.Size([512, 1, 1])       |            18         |              39           |                18                                                                                                    
      shape torch.Size([1024, 1, 1])      |            18         |              40           |                18                                                                                                    
      shape torch.Size([1, 2, 2])         |            18         |              38           |                17                                                                                                    
      shape torch.Size([2, 2, 2])         |            17         |              37           |                17                                                                                                    
      shape torch.Size([4, 2, 2])         |            17         |              38           |                17                                                                                                    
      shape torch.Size([8, 2, 2])         |            17         |              38           |                17                                                                                                    
      shape torch.Size([16, 2, 2])        |            17         |              38           |                17                                                                                                    
      shape torch.Size([32, 2, 2])        |            17         |              38           |                17                                                                                                    
      shape torch.Size([64, 2, 2])        |            17         |              38           |                17                                                                                                    
      shape torch.Size([128, 2, 2])       |            17         |              38           |                17                                                                                                    
      shape torch.Size([512, 2, 2])       |            17         |              39           |                17                                                                                                    
      shape torch.Size([1024, 2, 2])      |            17         |              40           |                17                                                                                                    
      shape torch.Size([1, 8, 8])         |            17         |              40           |                17                                                                                                    
      shape torch.Size([2, 8, 8])         |            17         |              40           |                17                                                                                                    
      shape torch.Size([4, 8, 8])         |            17         |              40           |                17                                                                                                    
      shape torch.Size([8, 8, 8])         |            17         |              40           |                17                                                                                                    
      shape torch.Size([16, 8, 8])        |            17         |              41           |                17                                                                                                    
      shape torch.Size([32, 8, 8])        |            17         |              40           |                17                                                                                                    
      shape torch.Size([64, 8, 8])        |            17         |              40           |                17                                                                                                    
      shape torch.Size([128, 8, 8])       |            17         |              40           |                17                                                                                                    
      shape torch.Size([512, 8, 8])       |            17         |              42           |                17                                                                                                    
      shape torch.Size([1024, 8, 8])      |            17         |              44           |                17                                                                                                    
      shape torch.Size([1, 16, 16])       |            24         |              44           |                18                                                                                                    
      shape torch.Size([2, 16, 16])       |            18         |              44           |                18                                                                                                    
      shape torch.Size([4, 16, 16])       |            18         |              45           |                18          
      shape torch.Size([8, 16, 16])       |            19         |              44           |                19          
      shape torch.Size([16, 16, 16])      |            20         |              44           |                20          
      shape torch.Size([32, 16, 16])      |            20         |              45           |                20          
      shape torch.Size([64, 16, 16])      |            20         |              44           |                20          
      shape torch.Size([128, 16, 16])     |            20         |              45           |                20          
      shape torch.Size([512, 16, 16])     |            28         |              50           |                28          
      shape torch.Size([1024, 16, 16])    |            41         |              59           |                41          
      shape torch.Size([1, 32, 32])       |            58         |              50           |                56          
      shape torch.Size([2, 32, 32])       |            56         |              50           |                56          
      shape torch.Size([4, 32, 32])       |            56         |              50           |                57          
      shape torch.Size([8, 32, 32])       |            60         |              50           |                60          
      shape torch.Size([16, 32, 32])      |            60         |              51           |                60          
      shape torch.Size([32, 32, 32])      |           247         |              51           |                61          
      shape torch.Size([64, 32, 32])      |           233         |              51           |                63          
      shape torch.Size([128, 32, 32])     |           236         |              53           |                66          
      shape torch.Size([512, 32, 32])     |           268         |              97           |               193          
      shape torch.Size([1024, 32, 32])    |           317         |             167           |               333          
      shape torch.Size([1, 64, 64])       |           131         |             216           |                99          
      shape torch.Size([2, 64, 64])       |            99         |             220           |                99          
      shape torch.Size([4, 64, 64])       |            99         |             225           |               101          
      shape torch.Size([8, 64, 64])       |           101         |             225           |               102          
      shape torch.Size([16, 64, 64])      |           107         |             230           |               108          
      shape torch.Size([32, 64, 64])      |           440         |             235           |               126          
      shape torch.Size([64, 64, 64])      |           447         |             240           |               155          
      shape torch.Size([128, 64, 64])     |           470         |             289           |               240          
      shape torch.Size([512, 64, 64])     |           793         |             678           |              1180          
      shape torch.Size([1024, 64, 64])    |          1000         |            1300           |              2112          
      shape torch.Size([1, 128, 128])     |           296         |             482           |               309          
      shape torch.Size([2, 128, 128])     |           308         |             499           |               307          
      shape torch.Size([4, 128, 128])     |           311         |             510           |               310          
      shape torch.Size([8, 128, 128])     |           314         |             522           |               314          
      shape torch.Size([16, 128, 128])    |           334         |             541           |               334          
      shape torch.Size([32, 128, 128])    |           770         |             591           |               467          
      shape torch.Size([64, 128, 128])    |           860         |             694           |               733          
      shape torch.Size([128, 128, 128])   |          1040         |             925           |              1980          
      shape torch.Size([512, 128, 128])   |          2883         |            2809           |             11000          
      shape torch.Size([1024, 128, 128])  |          5421         |            5430           |             22360          
      shape torch.Size([1, 256, 256])     |          1310         |            1109           |              1556          
      shape torch.Size([2, 256, 256])     |          1360         |            1150           |              1560          
      shape torch.Size([4, 256, 256])     |          1390         |            1188           |              1569          
      shape torch.Size([8, 256, 256])     |          1440         |            1250           |              1604          
      shape torch.Size([16, 256, 256])    |          1550         |            1390           |              1850          
      shape torch.Size([32, 256, 256])    |          1750         |            1620           |              3332          
      shape torch.Size([64, 256, 256])    |          2327         |            2246           |              6700          
      shape torch.Size([128, 256, 256])   |          3697         |            3638           |             19100          
      shape torch.Size([512, 256, 256])   |         12530         |           12500           |             87300          
      shape torch.Size([1024, 256, 256])  |         24380         |           24420           |            176000          
```

</details>

To generate the results below, I put the backend I wanted to test at the beginning of the function `lu_solve_kernel`, followed by a `return;`. Then I run the following script, changing the variable `name`. For the `lu_solve unpack+solve_triangular`, I also changed the `stmt` variable (uncomenting the commented one)
<details>
<summary>
Benchmarking script
</summary>

```python
import torch
import pickle
import itertools
from functools import partial
from torch.utils.benchmark import Timer, Compare

benchmark_name = "linalg.lu_factor CUDA"
name = "magma_batched"
label = "lu_factor_{}".format(name)
shapes = [1, 2, 8, 16, 32, 64, 128, 256]
batches = [(1,), (2,), (4,), (8,), (16,), (32,), (64,), (128,), (512,), (1024,)]
results = []
make_arg = partial(torch.randn, dtype=torch.float32, device="cuda")


for n, batch in itertools.product(shapes, batches):
    A = make_arg(batch + (n, n))
    print(A.shape)
    stmt = "torch.linalg.lu_factor_ex(A)"
    timer = Timer(stmt,
                  globals=globals(),
                  label=benchmark_name,
                  description=label,
                  sub_label=f"shape {A.shape}",
                  num_threads=1)
    results.append(timer.blocked_autorange())

    # Test
    LU, pivots = torch.linalg.lu_factor(A)
    P, L, U = torch.lu_unpack(LU, pivots)
    assert torch.allclose(P @ L @ U, A, rtol=1e-2, atol=1e-3)


compare = Compare(results)
compare.trim_significant_figures()
compare.print()

with open(f"{label}.pickle", 'wb') as f:
    pickle.dump(results, f)
```

</details>

See #72935 (comment) for the script to join the results.

[ghstack-poisoned]
This PR adds getrf_cublas to the functions considered in the heuristics
for lu_solve. It also updates the heuristics of the function.

## Benchmark

I'm omitting form the benchmarks the looped versions of the functions as they are much slower than the non-looped ones. The only exception to this is cusolver's looped variant, which is faster when applied to a batch of size one.

<details>
<summary>
Benchmark Results
</summary>

```
[------------------------------------------------- linalg.lu_factor CUDA -------------------------------------------------]                                                                                          
                                          |  lu_factor_heuristic  |  lu_factor_magma_batched  |  lu_factor_cusolver_batched                                                                                          
1 threads: ----------------------------------------------------------------------------------------------------------------                                                                                          
      shape torch.Size([1, 1, 1])         |            26         |              47           |                26                                                                                                    
      shape torch.Size([2, 1, 1])         |            17         |              38           |                17                                                                                                    
      shape torch.Size([4, 1, 1])         |            17         |              38           |                17                                                                                                    
      shape torch.Size([8, 1, 1])         |            20         |              38           |                18                                                                                                    
      shape torch.Size([16, 1, 1])        |            20         |              38           |                17                                                                                                    
      shape torch.Size([32, 1, 1])        |            18         |              38           |                17                                                                                                    
      shape torch.Size([64, 1, 1])        |            18         |              39           |                17                                                                                                    
      shape torch.Size([128, 1, 1])       |            17         |              38           |                17                                                                                                    
      shape torch.Size([512, 1, 1])       |            18         |              39           |                18                                                                                                    
      shape torch.Size([1024, 1, 1])      |            18         |              40           |                18                                                                                                    
      shape torch.Size([1, 2, 2])         |            18         |              38           |                17                                                                                                    
      shape torch.Size([2, 2, 2])         |            17         |              37           |                17                                                                                                    
      shape torch.Size([4, 2, 2])         |            17         |              38           |                17                                                                                                    
      shape torch.Size([8, 2, 2])         |            17         |              38           |                17                                                                                                    
      shape torch.Size([16, 2, 2])        |            17         |              38           |                17                                                                                                    
      shape torch.Size([32, 2, 2])        |            17         |              38           |                17                                                                                                    
      shape torch.Size([64, 2, 2])        |            17         |              38           |                17                                                                                                    
      shape torch.Size([128, 2, 2])       |            17         |              38           |                17                                                                                                    
      shape torch.Size([512, 2, 2])       |            17         |              39           |                17                                                                                                    
      shape torch.Size([1024, 2, 2])      |            17         |              40           |                17                                                                                                    
      shape torch.Size([1, 8, 8])         |            17         |              40           |                17                                                                                                    
      shape torch.Size([2, 8, 8])         |            17         |              40           |                17                                                                                                    
      shape torch.Size([4, 8, 8])         |            17         |              40           |                17                                                                                                    
      shape torch.Size([8, 8, 8])         |            17         |              40           |                17                                                                                                    
      shape torch.Size([16, 8, 8])        |            17         |              41           |                17                                                                                                    
      shape torch.Size([32, 8, 8])        |            17         |              40           |                17                                                                                                    
      shape torch.Size([64, 8, 8])        |            17         |              40           |                17                                                                                                    
      shape torch.Size([128, 8, 8])       |            17         |              40           |                17                                                                                                    
      shape torch.Size([512, 8, 8])       |            17         |              42           |                17                                                                                                    
      shape torch.Size([1024, 8, 8])      |            17         |              44           |                17                                                                                                    
      shape torch.Size([1, 16, 16])       |            24         |              44           |                18                                                                                                    
      shape torch.Size([2, 16, 16])       |            18         |              44           |                18                                                                                                    
      shape torch.Size([4, 16, 16])       |            18         |              45           |                18          
      shape torch.Size([8, 16, 16])       |            19         |              44           |                19          
      shape torch.Size([16, 16, 16])      |            20         |              44           |                20          
      shape torch.Size([32, 16, 16])      |            20         |              45           |                20          
      shape torch.Size([64, 16, 16])      |            20         |              44           |                20          
      shape torch.Size([128, 16, 16])     |            20         |              45           |                20          
      shape torch.Size([512, 16, 16])     |            28         |              50           |                28          
      shape torch.Size([1024, 16, 16])    |            41         |              59           |                41          
      shape torch.Size([1, 32, 32])       |            58         |              50           |                56          
      shape torch.Size([2, 32, 32])       |            56         |              50           |                56          
      shape torch.Size([4, 32, 32])       |            56         |              50           |                57          
      shape torch.Size([8, 32, 32])       |            60         |              50           |                60          
      shape torch.Size([16, 32, 32])      |            60         |              51           |                60          
      shape torch.Size([32, 32, 32])      |           247         |              51           |                61          
      shape torch.Size([64, 32, 32])      |           233         |              51           |                63          
      shape torch.Size([128, 32, 32])     |           236         |              53           |                66          
      shape torch.Size([512, 32, 32])     |           268         |              97           |               193          
      shape torch.Size([1024, 32, 32])    |           317         |             167           |               333          
      shape torch.Size([1, 64, 64])       |           131         |             216           |                99          
      shape torch.Size([2, 64, 64])       |            99         |             220           |                99          
      shape torch.Size([4, 64, 64])       |            99         |             225           |               101          
      shape torch.Size([8, 64, 64])       |           101         |             225           |               102          
      shape torch.Size([16, 64, 64])      |           107         |             230           |               108          
      shape torch.Size([32, 64, 64])      |           440         |             235           |               126          
      shape torch.Size([64, 64, 64])      |           447         |             240           |               155          
      shape torch.Size([128, 64, 64])     |           470         |             289           |               240          
      shape torch.Size([512, 64, 64])     |           793         |             678           |              1180          
      shape torch.Size([1024, 64, 64])    |          1000         |            1300           |              2112          
      shape torch.Size([1, 128, 128])     |           296         |             482           |               309          
      shape torch.Size([2, 128, 128])     |           308         |             499           |               307          
      shape torch.Size([4, 128, 128])     |           311         |             510           |               310          
      shape torch.Size([8, 128, 128])     |           314         |             522           |               314          
      shape torch.Size([16, 128, 128])    |           334         |             541           |               334          
      shape torch.Size([32, 128, 128])    |           770         |             591           |               467          
      shape torch.Size([64, 128, 128])    |           860         |             694           |               733          
      shape torch.Size([128, 128, 128])   |          1040         |             925           |              1980          
      shape torch.Size([512, 128, 128])   |          2883         |            2809           |             11000          
      shape torch.Size([1024, 128, 128])  |          5421         |            5430           |             22360          
      shape torch.Size([1, 256, 256])     |          1310         |            1109           |              1556          
      shape torch.Size([2, 256, 256])     |          1360         |            1150           |              1560          
      shape torch.Size([4, 256, 256])     |          1390         |            1188           |              1569          
      shape torch.Size([8, 256, 256])     |          1440         |            1250           |              1604          
      shape torch.Size([16, 256, 256])    |          1550         |            1390           |              1850          
      shape torch.Size([32, 256, 256])    |          1750         |            1620           |              3332          
      shape torch.Size([64, 256, 256])    |          2327         |            2246           |              6700          
      shape torch.Size([128, 256, 256])   |          3697         |            3638           |             19100          
      shape torch.Size([512, 256, 256])   |         12530         |           12500           |             87300          
      shape torch.Size([1024, 256, 256])  |         24380         |           24420           |            176000          
```

</details>

To generate the results below, I put the backend I wanted to test at the beginning of the function `lu_solve_kernel`, followed by a `return;`. Then I run the following script, changing the variable `name`. For the `lu_solve unpack+solve_triangular`, I also changed the `stmt` variable (uncomenting the commented one)
<details>
<summary>
Benchmarking script
</summary>

```python
import torch
import pickle
import itertools
from functools import partial
from torch.utils.benchmark import Timer, Compare

benchmark_name = "linalg.lu_factor CUDA"
name = "magma_batched"
label = "lu_factor_{}".format(name)
shapes = [1, 2, 8, 16, 32, 64, 128, 256]
batches = [(1,), (2,), (4,), (8,), (16,), (32,), (64,), (128,), (512,), (1024,)]
results = []
make_arg = partial(torch.randn, dtype=torch.float32, device="cuda")


for n, batch in itertools.product(shapes, batches):
    A = make_arg(batch + (n, n))
    print(A.shape)
    stmt = "torch.linalg.lu_factor_ex(A)"
    timer = Timer(stmt,
                  globals=globals(),
                  label=benchmark_name,
                  description=label,
                  sub_label=f"shape {A.shape}",
                  num_threads=1)
    results.append(timer.blocked_autorange())

    # Test
    LU, pivots = torch.linalg.lu_factor(A)
    P, L, U = torch.lu_unpack(LU, pivots)
    assert torch.allclose(P @ L @ U, A, rtol=1e-2, atol=1e-3)


compare = Compare(results)
compare.trim_significant_figures()
compare.print()

with open(f"{label}.pickle", 'wb') as f:
    pickle.dump(results, f)
```

</details>

See #72935 (comment) for the script to join the results.

[ghstack-poisoned]
@lezcano lezcano changed the title Update and improve the heuristics for linalg.lu_solve Update and improve the heuristics for linalg.lu_factor Mar 10, 2022
lezcano added a commit that referenced this pull request May 18, 2022
This PR adds getrf_cublas to the functions considered in the heuristics
for lu_solve.

ghstack-source-id: d768464
Pull Request resolved: #73878
This PR adds getrf_cublas to the functions considered in the heuristics
for `lu_factor`. It also updates the heuristics of the function.

## Benchmark

I'm omitting form the benchmarks the looped versions of the functions as they are much slower than the non-looped ones. The only exception to this is cusolver's looped variant, which is faster when applied to a batch of size one.

<details>
<summary>
Benchmark Results
</summary>

```
[------------------------------------------------- linalg.lu_factor CUDA -------------------------------------------------]                                                                                          
                                          |  lu_factor_heuristic  |  lu_factor_magma_batched  |  lu_factor_cusolver_batched                                                                                          
1 threads: ----------------------------------------------------------------------------------------------------------------                                                                                          
      shape torch.Size([1, 1, 1])         |            26         |              47           |                26                                                                                                    
      shape torch.Size([2, 1, 1])         |            17         |              38           |                17                                                                                                    
      shape torch.Size([4, 1, 1])         |            17         |              38           |                17                                                                                                    
      shape torch.Size([8, 1, 1])         |            20         |              38           |                18                                                                                                    
      shape torch.Size([16, 1, 1])        |            20         |              38           |                17                                                                                                    
      shape torch.Size([32, 1, 1])        |            18         |              38           |                17                                                                                                    
      shape torch.Size([64, 1, 1])        |            18         |              39           |                17                                                                                                    
      shape torch.Size([128, 1, 1])       |            17         |              38           |                17                                                                                                    
      shape torch.Size([512, 1, 1])       |            18         |              39           |                18                                                                                                    
      shape torch.Size([1024, 1, 1])      |            18         |              40           |                18                                                                                                    
      shape torch.Size([1, 2, 2])         |            18         |              38           |                17                                                                                                    
      shape torch.Size([2, 2, 2])         |            17         |              37           |                17                                                                                                    
      shape torch.Size([4, 2, 2])         |            17         |              38           |                17                                                                                                    
      shape torch.Size([8, 2, 2])         |            17         |              38           |                17                                                                                                    
      shape torch.Size([16, 2, 2])        |            17         |              38           |                17                                                                                                    
      shape torch.Size([32, 2, 2])        |            17         |              38           |                17                                                                                                    
      shape torch.Size([64, 2, 2])        |            17         |              38           |                17                                                                                                    
      shape torch.Size([128, 2, 2])       |            17         |              38           |                17                                                                                                    
      shape torch.Size([512, 2, 2])       |            17         |              39           |                17                                                                                                    
      shape torch.Size([1024, 2, 2])      |            17         |              40           |                17                                                                                                    
      shape torch.Size([1, 8, 8])         |            17         |              40           |                17                                                                                                    
      shape torch.Size([2, 8, 8])         |            17         |              40           |                17                                                                                                    
      shape torch.Size([4, 8, 8])         |            17         |              40           |                17                                                                                                    
      shape torch.Size([8, 8, 8])         |            17         |              40           |                17                                                                                                    
      shape torch.Size([16, 8, 8])        |            17         |              41           |                17                                                                                                    
      shape torch.Size([32, 8, 8])        |            17         |              40           |                17                                                                                                    
      shape torch.Size([64, 8, 8])        |            17         |              40           |                17                                                                                                    
      shape torch.Size([128, 8, 8])       |            17         |              40           |                17                                                                                                    
      shape torch.Size([512, 8, 8])       |            17         |              42           |                17                                                                                                    
      shape torch.Size([1024, 8, 8])      |            17         |              44           |                17                                                                                                    
      shape torch.Size([1, 16, 16])       |            24         |              44           |                18                                                                                                    
      shape torch.Size([2, 16, 16])       |            18         |              44           |                18                                                                                                    
      shape torch.Size([4, 16, 16])       |            18         |              45           |                18          
      shape torch.Size([8, 16, 16])       |            19         |              44           |                19          
      shape torch.Size([16, 16, 16])      |            20         |              44           |                20          
      shape torch.Size([32, 16, 16])      |            20         |              45           |                20          
      shape torch.Size([64, 16, 16])      |            20         |              44           |                20          
      shape torch.Size([128, 16, 16])     |            20         |              45           |                20          
      shape torch.Size([512, 16, 16])     |            28         |              50           |                28          
      shape torch.Size([1024, 16, 16])    |            41         |              59           |                41          
      shape torch.Size([1, 32, 32])       |            58         |              50           |                56          
      shape torch.Size([2, 32, 32])       |            56         |              50           |                56          
      shape torch.Size([4, 32, 32])       |            56         |              50           |                57          
      shape torch.Size([8, 32, 32])       |            60         |              50           |                60          
      shape torch.Size([16, 32, 32])      |            60         |              51           |                60          
      shape torch.Size([32, 32, 32])      |           247         |              51           |                61          
      shape torch.Size([64, 32, 32])      |           233         |              51           |                63          
      shape torch.Size([128, 32, 32])     |           236         |              53           |                66          
      shape torch.Size([512, 32, 32])     |           268         |              97           |               193          
      shape torch.Size([1024, 32, 32])    |           317         |             167           |               333          
      shape torch.Size([1, 64, 64])       |           131         |             216           |                99          
      shape torch.Size([2, 64, 64])       |            99         |             220           |                99          
      shape torch.Size([4, 64, 64])       |            99         |             225           |               101          
      shape torch.Size([8, 64, 64])       |           101         |             225           |               102          
      shape torch.Size([16, 64, 64])      |           107         |             230           |               108          
      shape torch.Size([32, 64, 64])      |           440         |             235           |               126          
      shape torch.Size([64, 64, 64])      |           447         |             240           |               155          
      shape torch.Size([128, 64, 64])     |           470         |             289           |               240          
      shape torch.Size([512, 64, 64])     |           793         |             678           |              1180          
      shape torch.Size([1024, 64, 64])    |          1000         |            1300           |              2112          
      shape torch.Size([1, 128, 128])     |           296         |             482           |               309          
      shape torch.Size([2, 128, 128])     |           308         |             499           |               307          
      shape torch.Size([4, 128, 128])     |           311         |             510           |               310          
      shape torch.Size([8, 128, 128])     |           314         |             522           |               314          
      shape torch.Size([16, 128, 128])    |           334         |             541           |               334          
      shape torch.Size([32, 128, 128])    |           770         |             591           |               467          
      shape torch.Size([64, 128, 128])    |           860         |             694           |               733          
      shape torch.Size([128, 128, 128])   |          1040         |             925           |              1980          
      shape torch.Size([512, 128, 128])   |          2883         |            2809           |             11000          
      shape torch.Size([1024, 128, 128])  |          5421         |            5430           |             22360          
      shape torch.Size([1, 256, 256])     |          1310         |            1109           |              1556          
      shape torch.Size([2, 256, 256])     |          1360         |            1150           |              1560          
      shape torch.Size([4, 256, 256])     |          1390         |            1188           |              1569          
      shape torch.Size([8, 256, 256])     |          1440         |            1250           |              1604          
      shape torch.Size([16, 256, 256])    |          1550         |            1390           |              1850          
      shape torch.Size([32, 256, 256])    |          1750         |            1620           |              3332          
      shape torch.Size([64, 256, 256])    |          2327         |            2246           |              6700          
      shape torch.Size([128, 256, 256])   |          3697         |            3638           |             19100          
      shape torch.Size([512, 256, 256])   |         12530         |           12500           |             87300          
      shape torch.Size([1024, 256, 256])  |         24380         |           24420           |            176000          
```

</details>

<details>
<summary>
Benchmark Results all algorithms up to `n=2048`
</summary>

```
[----------------------------------------------------------------- linalg.lu_factor CUDA ------------------------------------------------------------------]
                                          |  lu_factor_magma_batched  |  lu_factor_cusolver_batched  |  lu_factor_cusolver_looped  |  lu_factor_magma_looped
1 threads: -------------------------------------------------------------------------------------------------------------------------------------------------
      shape torch.Size([1, 1, 1])         |               51          |                30            |                27           |            1390
      shape torch.Size([2, 1, 1])         |               42          |                20            |                26           |            2798
      shape torch.Size([4, 1, 1])         |               42          |                20            |                42           |            5589
      shape torch.Size([8, 1, 1])         |               42          |                20            |                72           |           11000
      shape torch.Size([16, 1, 1])        |               42          |                20            |               132           |           22400
      shape torch.Size([32, 1, 1])        |               42          |                20            |               253           |           44620
      shape torch.Size([64, 1, 1])        |               42          |                20            |               496           |           89200
      shape torch.Size([128, 1, 1])       |               42          |                20            |               980           |          180000
      shape torch.Size([512, 1, 1])       |               43          |                20            |              3868           |          714100
      shape torch.Size([1024, 1, 1])      |               44          |                20            |              7800           |         1430000
      shape torch.Size([1, 2, 2])         |               43          |                21            |                19           |            1400
      shape torch.Size([2, 2, 2])         |               42          |                21            |                27           |            2898
      shape torch.Size([4, 2, 2])         |               43          |                21            |                42           |            5800
      shape torch.Size([8, 2, 2])         |               43          |                21            |                73           |           11600
      shape torch.Size([16, 2, 2])        |               43          |                21            |               133           |           23170
      shape torch.Size([32, 2, 2])        |               43          |                21            |               254           |           46290
      shape torch.Size([64, 2, 2])        |               43          |                21            |               500           |           94000
      shape torch.Size([128, 2, 2])       |               43          |                21            |               980           |          190000
      shape torch.Size([512, 2, 2])       |               44          |                21            |              3860           |          741900
      shape torch.Size([1024, 2, 2])      |               44          |                21            |              7640           |         1484000
      shape torch.Size([1, 8, 8])         |               45          |                21            |                19           |            1450
      shape torch.Size([2, 8, 8])         |               45          |                21            |                27           |            2917
      shape torch.Size([4, 8, 8])         |               45          |                21            |                53           |            5800
      shape torch.Size([8, 8, 8])         |               45          |                21            |               105           |           11580
      shape torch.Size([16, 8, 8])        |               45          |                21            |               207           |           23160
      shape torch.Size([32, 8, 8])        |               46          |                21            |               413           |           46400
      shape torch.Size([64, 8, 8])        |               46          |                21            |               824           |           93000
      shape torch.Size([128, 8, 8])       |               46          |                21            |              1645           |          185000
      shape torch.Size([512, 8, 8])       |               47          |                21            |              6574           |          742000
      shape torch.Size([1024, 8, 8])      |               49          |                21            |             13150           |         1481000
      shape torch.Size([1, 16, 16])       |               49          |                21            |                24           |            1460
      shape torch.Size([2, 16, 16])       |               49          |                21            |                46           |            2902
      shape torch.Size([4, 16, 16])       |               49          |                21            |                90           |            5800
      shape torch.Size([8, 16, 16])       |               49          |                21            |               177           |           11600
      shape torch.Size([16, 16, 16])      |               49          |                21            |               352           |           23150
      shape torch.Size([32, 16, 16])      |               49          |                21            |               703           |           46300
      shape torch.Size([64, 16, 16])      |               49          |                21            |              1404           |           92700
      shape torch.Size([128, 16, 16])     |               50          |                21            |              2807           |          185000
      shape torch.Size([512, 16, 16])     |               55          |                29            |             11220           |          741700
      shape torch.Size([1024, 16, 16])    |               64          |                42            |             22440           |         1480000
      shape torch.Size([1, 32, 32])       |               55          |                56            |                58           |            1460
      shape torch.Size([2, 32, 32])       |               55          |                57            |               114           |            2920
      shape torch.Size([4, 32, 32])       |               55          |                57            |               225           |            5830
      shape torch.Size([8, 32, 32])       |               55          |                61            |               449           |           11700
      shape torch.Size([16, 32, 32])      |               56          |                61            |               896           |           23300
      shape torch.Size([32, 32, 32])      |               56          |                62            |              1791           |           46600
      shape torch.Size([64, 32, 32])      |               56          |                63            |              3581           |           93100
      shape torch.Size([128, 32, 32])     |               58          |                66            |              7156           |          186000
      shape torch.Size([512, 32, 32])     |              100          |               194            |             28700           |          742400
      shape torch.Size([1024, 32, 32])    |              169          |               335            |             57620           |         1485000
      shape torch.Size([1, 64, 64])       |              224          |               101            |               132           |            1500
      shape torch.Size([2, 64, 64])       |              227          |               100            |               262           |            2951
      shape torch.Size([4, 64, 64])       |              229          |               101            |               523           |            5890
      shape torch.Size([8, 64, 64])       |              231          |               102            |              1040           |           12000
      shape torch.Size([16, 64, 64])      |              237          |               109            |              2088           |           23530
      shape torch.Size([32, 64, 64])      |              242          |               127            |              4171           |           46900
      shape torch.Size([64, 64, 64])      |              247          |               156            |              8330           |           95000
      shape torch.Size([128, 64, 64])     |              293          |               244            |             16710           |          189000
      shape torch.Size([512, 64, 64])     |              685          |              1180            |             67000           |          750900
      shape torch.Size([1024, 64, 64])    |             1300          |              2076            |            134000           |         1505000
      shape torch.Size([1, 128, 128])     |              490          |               309            |               298           |            1560
      shape torch.Size([2, 128, 128])     |              503          |               309            |               594           |            3120
      shape torch.Size([4, 128, 128])     |              515          |               312            |              1185           |            6230
      shape torch.Size([8, 128, 128])     |              523          |               317            |              2370           |           12500
      shape torch.Size([16, 128, 128])    |              547          |               336            |              4734           |           24890
      shape torch.Size([32, 128, 128])    |              596          |               472            |              9491           |           49800
      shape torch.Size([64, 128, 128])    |              700          |               741            |             19000           |          100000
      shape torch.Size([128, 128, 128])   |              930          |              1770            |             37990           |          199000
      shape torch.Size([512, 128, 128])   |             2810          |             11000            |            152000           |          797100
      shape torch.Size([1024, 128, 128])  |             5430          |             22430            |            303900           |         1595000
      shape torch.Size([1, 256, 256])     |             1120          |              1580            |               666           |            1890
      shape torch.Size([2, 256, 256])     |             1160          |              1574            |              1330           |            3784
      shape torch.Size([4, 256, 256])     |             1190          |              1580            |              2658           |            7570
      shape torch.Size([8, 256, 256])     |             1250          |              1613            |              5325           |           15100
      shape torch.Size([16, 256, 256])    |             1394          |              1880            |             10700           |           30260
      shape torch.Size([32, 256, 256])    |             1633          |              3360            |             21300           |           61000
      shape torch.Size([64, 256, 256])    |             2258          |              6730            |             42600           |          120000
      shape torch.Size([128, 256, 256])   |             3639          |             19200            |             85170           |          242200
      shape torch.Size([512, 256, 256])   |            12600          |             87200            |            340600           |          969000
      shape torch.Size([1024, 256, 256])  |            24530          |            176000            |            681300           |         1943000
      shape torch.Size([1, 512, 512])     |             2557          |              9117            |              1724           |            2577
      shape torch.Size([2, 512, 512])     |             2691          |              9209            |              3464           |            5200
      shape torch.Size([4, 512, 512])     |             2853          |              9860            |              6940           |           10000
      shape torch.Size([8, 512, 512])     |             3153          |             11000            |             13900           |           20570
      shape torch.Size([16, 512, 512])    |             3765          |             13000            |             27720           |           41360
      shape torch.Size([32, 512, 512])    |             5500          |             21400            |             55420           |           82000
      shape torch.Size([64, 512, 512])    |             8790          |             44000            |            111000           |          165000
      shape torch.Size([128, 512, 512])   |            15300          |             98000            |            221700           |          329800
      shape torch.Size([512, 512, 512])   |            55400          |            424100            |            886600           |         1325000
      shape torch.Size([1024, 512, 512])  |           110000          |            856200            |           1773000           |         2691000
      shape torch.Size([1, 1024, 1024])   |            10350          |             69290            |              5020           |            5327
      shape torch.Size([2, 1024, 1024])   |            11200          |             74860            |             10040           |           11000
      shape torch.Size([4, 1024, 1024])   |            12200          |             78030            |             20080           |           21290
      shape torch.Size([8, 1024, 1024])   |            14000          |             81200            |             40160           |           42850
      shape torch.Size([16, 1024, 1024])  |            17700          |             96000            |             80300           |           85500
      shape torch.Size([32, 1024, 1024])  |            27740          |            150000            |            160700           |          171000
      shape torch.Size([64, 1024, 1024])  |            45940          |            233400            |            321200           |          344100
      shape torch.Size([1, 2048, 2048])   |            29860          |            579800            |             12920           |           13500
      shape torch.Size([2, 2048, 2048])   |            34000          |            585000            |             25840           |           26840
      shape torch.Size([4, 2048, 2048])   |            39770          |            593900            |             51670           |           54000
      shape torch.Size([8, 2048, 2048])   |            51720          |            632100            |            103000           |          109000
      shape torch.Size([16, 2048, 2048])  |            76900          |            845500            |            206600           |          218400
      shape torch.Size([32, 2048, 2048])  |           130000          |           1058000            |            413900           |          437300

Times are in microseconds (us).


```

</details>
To generate the results below, I put the backend I wanted to test at the beginning of the function `lu_solve_kernel`, followed by a `return;`. Then I run the following script, changing the variable `name`. 
<details>
<summary>
Benchmarking script
</summary>

```python
import torch
import pickle
import itertools
from functools import partial
from torch.utils.benchmark import Timer, Compare

benchmark_name = "linalg.lu_factor CUDA"
name = "magma_looped"
label = "lu_factor_{}".format(name)
shapes = [1, 2, 8, 16, 32, 64, 128, 256, 512, 1024, 2048]
batches = [(1,), (2,), (4,), (8,), (16,), (32,), (64,), (128,), (512,), (1024,)]
results = []
make_arg = partial(torch.randn, dtype=torch.float32, device="cuda")


for n, batch in itertools.product(shapes, batches):
    if n == 1024 and batch[0] >= 128:
        continue
    if n == 2048 and batch[0] >= 64:
        continue
    A = make_arg(batch + (n, n))
    print(A.shape)
    stmt = "torch.linalg.lu_factor_ex(A)"
    timer = Timer(stmt,
                  globals=globals(),
                  label=benchmark_name,
                  description=label,
                  sub_label=f"shape {A.shape}",
                  num_threads=1)
    results.append(timer.blocked_autorange())

compare = Compare(results)
compare.trim_significant_figures()
compare.print()

with open(f"{label}.pickle", 'wb') as f:
    pickle.dump(results, f)
```

</details>

See #72935 (comment) for the script to join the results.

[ghstack-poisoned]
lezcano added a commit that referenced this pull request May 18, 2022
This PR adds getrf_cublas to the functions considered in the heuristics
for lu_solve.

ghstack-source-id: e44989e
Pull Request resolved: #73878
This PR adds getrf_cublas to the functions considered in the heuristics
for `lu_factor`. It also updates the heuristics of the function.

## Benchmark

I'm omitting form the benchmarks the looped versions of the functions as they are much slower than the non-looped ones. The only exception to this is cusolver's looped variant, which is faster when applied to a batch of size one.

<details>
<summary>
Benchmark Results
</summary>

```
[------------------------------------------------- linalg.lu_factor CUDA -------------------------------------------------]                                                                                          
                                          |  lu_factor_heuristic  |  lu_factor_magma_batched  |  lu_factor_cusolver_batched                                                                                          
1 threads: ----------------------------------------------------------------------------------------------------------------                                                                                          
      shape torch.Size([1, 1, 1])         |            26         |              47           |                26                                                                                                    
      shape torch.Size([2, 1, 1])         |            17         |              38           |                17                                                                                                    
      shape torch.Size([4, 1, 1])         |            17         |              38           |                17                                                                                                    
      shape torch.Size([8, 1, 1])         |            20         |              38           |                18                                                                                                    
      shape torch.Size([16, 1, 1])        |            20         |              38           |                17                                                                                                    
      shape torch.Size([32, 1, 1])        |            18         |              38           |                17                                                                                                    
      shape torch.Size([64, 1, 1])        |            18         |              39           |                17                                                                                                    
      shape torch.Size([128, 1, 1])       |            17         |              38           |                17                                                                                                    
      shape torch.Size([512, 1, 1])       |            18         |              39           |                18                                                                                                    
      shape torch.Size([1024, 1, 1])      |            18         |              40           |                18                                                                                                    
      shape torch.Size([1, 2, 2])         |            18         |              38           |                17                                                                                                    
      shape torch.Size([2, 2, 2])         |            17         |              37           |                17                                                                                                    
      shape torch.Size([4, 2, 2])         |            17         |              38           |                17                                                                                                    
      shape torch.Size([8, 2, 2])         |            17         |              38           |                17                                                                                                    
      shape torch.Size([16, 2, 2])        |            17         |              38           |                17                                                                                                    
      shape torch.Size([32, 2, 2])        |            17         |              38           |                17                                                                                                    
      shape torch.Size([64, 2, 2])        |            17         |              38           |                17                                                                                                    
      shape torch.Size([128, 2, 2])       |            17         |              38           |                17                                                                                                    
      shape torch.Size([512, 2, 2])       |            17         |              39           |                17                                                                                                    
      shape torch.Size([1024, 2, 2])      |            17         |              40           |                17                                                                                                    
      shape torch.Size([1, 8, 8])         |            17         |              40           |                17                                                                                                    
      shape torch.Size([2, 8, 8])         |            17         |              40           |                17                                                                                                    
      shape torch.Size([4, 8, 8])         |            17         |              40           |                17                                                                                                    
      shape torch.Size([8, 8, 8])         |            17         |              40           |                17                                                                                                    
      shape torch.Size([16, 8, 8])        |            17         |              41           |                17                                                                                                    
      shape torch.Size([32, 8, 8])        |            17         |              40           |                17                                                                                                    
      shape torch.Size([64, 8, 8])        |            17         |              40           |                17                                                                                                    
      shape torch.Size([128, 8, 8])       |            17         |              40           |                17                                                                                                    
      shape torch.Size([512, 8, 8])       |            17         |              42           |                17                                                                                                    
      shape torch.Size([1024, 8, 8])      |            17         |              44           |                17                                                                                                    
      shape torch.Size([1, 16, 16])       |            24         |              44           |                18                                                                                                    
      shape torch.Size([2, 16, 16])       |            18         |              44           |                18                                                                                                    
      shape torch.Size([4, 16, 16])       |            18         |              45           |                18          
      shape torch.Size([8, 16, 16])       |            19         |              44           |                19          
      shape torch.Size([16, 16, 16])      |            20         |              44           |                20          
      shape torch.Size([32, 16, 16])      |            20         |              45           |                20          
      shape torch.Size([64, 16, 16])      |            20         |              44           |                20          
      shape torch.Size([128, 16, 16])     |            20         |              45           |                20          
      shape torch.Size([512, 16, 16])     |            28         |              50           |                28          
      shape torch.Size([1024, 16, 16])    |            41         |              59           |                41          
      shape torch.Size([1, 32, 32])       |            58         |              50           |                56          
      shape torch.Size([2, 32, 32])       |            56         |              50           |                56          
      shape torch.Size([4, 32, 32])       |            56         |              50           |                57          
      shape torch.Size([8, 32, 32])       |            60         |              50           |                60          
      shape torch.Size([16, 32, 32])      |            60         |              51           |                60          
      shape torch.Size([32, 32, 32])      |           247         |              51           |                61          
      shape torch.Size([64, 32, 32])      |           233         |              51           |                63          
      shape torch.Size([128, 32, 32])     |           236         |              53           |                66          
      shape torch.Size([512, 32, 32])     |           268         |              97           |               193          
      shape torch.Size([1024, 32, 32])    |           317         |             167           |               333          
      shape torch.Size([1, 64, 64])       |           131         |             216           |                99          
      shape torch.Size([2, 64, 64])       |            99         |             220           |                99          
      shape torch.Size([4, 64, 64])       |            99         |             225           |               101          
      shape torch.Size([8, 64, 64])       |           101         |             225           |               102          
      shape torch.Size([16, 64, 64])      |           107         |             230           |               108          
      shape torch.Size([32, 64, 64])      |           440         |             235           |               126          
      shape torch.Size([64, 64, 64])      |           447         |             240           |               155          
      shape torch.Size([128, 64, 64])     |           470         |             289           |               240          
      shape torch.Size([512, 64, 64])     |           793         |             678           |              1180          
      shape torch.Size([1024, 64, 64])    |          1000         |            1300           |              2112          
      shape torch.Size([1, 128, 128])     |           296         |             482           |               309          
      shape torch.Size([2, 128, 128])     |           308         |             499           |               307          
      shape torch.Size([4, 128, 128])     |           311         |             510           |               310          
      shape torch.Size([8, 128, 128])     |           314         |             522           |               314          
      shape torch.Size([16, 128, 128])    |           334         |             541           |               334          
      shape torch.Size([32, 128, 128])    |           770         |             591           |               467          
      shape torch.Size([64, 128, 128])    |           860         |             694           |               733          
      shape torch.Size([128, 128, 128])   |          1040         |             925           |              1980          
      shape torch.Size([512, 128, 128])   |          2883         |            2809           |             11000          
      shape torch.Size([1024, 128, 128])  |          5421         |            5430           |             22360          
      shape torch.Size([1, 256, 256])     |          1310         |            1109           |              1556          
      shape torch.Size([2, 256, 256])     |          1360         |            1150           |              1560          
      shape torch.Size([4, 256, 256])     |          1390         |            1188           |              1569          
      shape torch.Size([8, 256, 256])     |          1440         |            1250           |              1604          
      shape torch.Size([16, 256, 256])    |          1550         |            1390           |              1850          
      shape torch.Size([32, 256, 256])    |          1750         |            1620           |              3332          
      shape torch.Size([64, 256, 256])    |          2327         |            2246           |              6700          
      shape torch.Size([128, 256, 256])   |          3697         |            3638           |             19100          
      shape torch.Size([512, 256, 256])   |         12530         |           12500           |             87300          
      shape torch.Size([1024, 256, 256])  |         24380         |           24420           |            176000          
```

</details>

<details>
<summary>
Benchmark Results all algorithms up to `n=2048`
</summary>

```
[----------------------------------------------------------------- linalg.lu_factor CUDA ------------------------------------------------------------------]
                                          |  lu_factor_magma_batched  |  lu_factor_cusolver_batched  |  lu_factor_cusolver_looped  |  lu_factor_magma_looped
1 threads: -------------------------------------------------------------------------------------------------------------------------------------------------
      shape torch.Size([1, 1, 1])         |               51          |                30            |                27           |            1390
      shape torch.Size([2, 1, 1])         |               42          |                20            |                26           |            2798
      shape torch.Size([4, 1, 1])         |               42          |                20            |                42           |            5589
      shape torch.Size([8, 1, 1])         |               42          |                20            |                72           |           11000
      shape torch.Size([16, 1, 1])        |               42          |                20            |               132           |           22400
      shape torch.Size([32, 1, 1])        |               42          |                20            |               253           |           44620
      shape torch.Size([64, 1, 1])        |               42          |                20            |               496           |           89200
      shape torch.Size([128, 1, 1])       |               42          |                20            |               980           |          180000
      shape torch.Size([512, 1, 1])       |               43          |                20            |              3868           |          714100
      shape torch.Size([1024, 1, 1])      |               44          |                20            |              7800           |         1430000
      shape torch.Size([1, 2, 2])         |               43          |                21            |                19           |            1400
      shape torch.Size([2, 2, 2])         |               42          |                21            |                27           |            2898
      shape torch.Size([4, 2, 2])         |               43          |                21            |                42           |            5800
      shape torch.Size([8, 2, 2])         |               43          |                21            |                73           |           11600
      shape torch.Size([16, 2, 2])        |               43          |                21            |               133           |           23170
      shape torch.Size([32, 2, 2])        |               43          |                21            |               254           |           46290
      shape torch.Size([64, 2, 2])        |               43          |                21            |               500           |           94000
      shape torch.Size([128, 2, 2])       |               43          |                21            |               980           |          190000
      shape torch.Size([512, 2, 2])       |               44          |                21            |              3860           |          741900
      shape torch.Size([1024, 2, 2])      |               44          |                21            |              7640           |         1484000
      shape torch.Size([1, 8, 8])         |               45          |                21            |                19           |            1450
      shape torch.Size([2, 8, 8])         |               45          |                21            |                27           |            2917
      shape torch.Size([4, 8, 8])         |               45          |                21            |                53           |            5800
      shape torch.Size([8, 8, 8])         |               45          |                21            |               105           |           11580
      shape torch.Size([16, 8, 8])        |               45          |                21            |               207           |           23160
      shape torch.Size([32, 8, 8])        |               46          |                21            |               413           |           46400
      shape torch.Size([64, 8, 8])        |               46          |                21            |               824           |           93000
      shape torch.Size([128, 8, 8])       |               46          |                21            |              1645           |          185000
      shape torch.Size([512, 8, 8])       |               47          |                21            |              6574           |          742000
      shape torch.Size([1024, 8, 8])      |               49          |                21            |             13150           |         1481000
      shape torch.Size([1, 16, 16])       |               49          |                21            |                24           |            1460
      shape torch.Size([2, 16, 16])       |               49          |                21            |                46           |            2902
      shape torch.Size([4, 16, 16])       |               49          |                21            |                90           |            5800
      shape torch.Size([8, 16, 16])       |               49          |                21            |               177           |           11600
      shape torch.Size([16, 16, 16])      |               49          |                21            |               352           |           23150
      shape torch.Size([32, 16, 16])      |               49          |                21            |               703           |           46300
      shape torch.Size([64, 16, 16])      |               49          |                21            |              1404           |           92700
      shape torch.Size([128, 16, 16])     |               50          |                21            |              2807           |          185000
      shape torch.Size([512, 16, 16])     |               55          |                29            |             11220           |          741700
      shape torch.Size([1024, 16, 16])    |               64          |                42            |             22440           |         1480000
      shape torch.Size([1, 32, 32])       |               55          |                56            |                58           |            1460
      shape torch.Size([2, 32, 32])       |               55          |                57            |               114           |            2920
      shape torch.Size([4, 32, 32])       |               55          |                57            |               225           |            5830
      shape torch.Size([8, 32, 32])       |               55          |                61            |               449           |           11700
      shape torch.Size([16, 32, 32])      |               56          |                61            |               896           |           23300
      shape torch.Size([32, 32, 32])      |               56          |                62            |              1791           |           46600
      shape torch.Size([64, 32, 32])      |               56          |                63            |              3581           |           93100
      shape torch.Size([128, 32, 32])     |               58          |                66            |              7156           |          186000
      shape torch.Size([512, 32, 32])     |              100          |               194            |             28700           |          742400
      shape torch.Size([1024, 32, 32])    |              169          |               335            |             57620           |         1485000
      shape torch.Size([1, 64, 64])       |              224          |               101            |               132           |            1500
      shape torch.Size([2, 64, 64])       |              227          |               100            |               262           |            2951
      shape torch.Size([4, 64, 64])       |              229          |               101            |               523           |            5890
      shape torch.Size([8, 64, 64])       |              231          |               102            |              1040           |           12000
      shape torch.Size([16, 64, 64])      |              237          |               109            |              2088           |           23530
      shape torch.Size([32, 64, 64])      |              242          |               127            |              4171           |           46900
      shape torch.Size([64, 64, 64])      |              247          |               156            |              8330           |           95000
      shape torch.Size([128, 64, 64])     |              293          |               244            |             16710           |          189000
      shape torch.Size([512, 64, 64])     |              685          |              1180            |             67000           |          750900
      shape torch.Size([1024, 64, 64])    |             1300          |              2076            |            134000           |         1505000
      shape torch.Size([1, 128, 128])     |              490          |               309            |               298           |            1560
      shape torch.Size([2, 128, 128])     |              503          |               309            |               594           |            3120
      shape torch.Size([4, 128, 128])     |              515          |               312            |              1185           |            6230
      shape torch.Size([8, 128, 128])     |              523          |               317            |              2370           |           12500
      shape torch.Size([16, 128, 128])    |              547          |               336            |              4734           |           24890
      shape torch.Size([32, 128, 128])    |              596          |               472            |              9491           |           49800
      shape torch.Size([64, 128, 128])    |              700          |               741            |             19000           |          100000
      shape torch.Size([128, 128, 128])   |              930          |              1770            |             37990           |          199000
      shape torch.Size([512, 128, 128])   |             2810          |             11000            |            152000           |          797100
      shape torch.Size([1024, 128, 128])  |             5430          |             22430            |            303900           |         1595000
      shape torch.Size([1, 256, 256])     |             1120          |              1580            |               666           |            1890
      shape torch.Size([2, 256, 256])     |             1160          |              1574            |              1330           |            3784
      shape torch.Size([4, 256, 256])     |             1190          |              1580            |              2658           |            7570
      shape torch.Size([8, 256, 256])     |             1250          |              1613            |              5325           |           15100
      shape torch.Size([16, 256, 256])    |             1394          |              1880            |             10700           |           30260
      shape torch.Size([32, 256, 256])    |             1633          |              3360            |             21300           |           61000
      shape torch.Size([64, 256, 256])    |             2258          |              6730            |             42600           |          120000
      shape torch.Size([128, 256, 256])   |             3639          |             19200            |             85170           |          242200
      shape torch.Size([512, 256, 256])   |            12600          |             87200            |            340600           |          969000
      shape torch.Size([1024, 256, 256])  |            24530          |            176000            |            681300           |         1943000
      shape torch.Size([1, 512, 512])     |             2557          |              9117            |              1724           |            2577
      shape torch.Size([2, 512, 512])     |             2691          |              9209            |              3464           |            5200
      shape torch.Size([4, 512, 512])     |             2853          |              9860            |              6940           |           10000
      shape torch.Size([8, 512, 512])     |             3153          |             11000            |             13900           |           20570
      shape torch.Size([16, 512, 512])    |             3765          |             13000            |             27720           |           41360
      shape torch.Size([32, 512, 512])    |             5500          |             21400            |             55420           |           82000
      shape torch.Size([64, 512, 512])    |             8790          |             44000            |            111000           |          165000
      shape torch.Size([128, 512, 512])   |            15300          |             98000            |            221700           |          329800
      shape torch.Size([512, 512, 512])   |            55400          |            424100            |            886600           |         1325000
      shape torch.Size([1024, 512, 512])  |           110000          |            856200            |           1773000           |         2691000
      shape torch.Size([1, 1024, 1024])   |            10350          |             69290            |              5020           |            5327
      shape torch.Size([2, 1024, 1024])   |            11200          |             74860            |             10040           |           11000
      shape torch.Size([4, 1024, 1024])   |            12200          |             78030            |             20080           |           21290
      shape torch.Size([8, 1024, 1024])   |            14000          |             81200            |             40160           |           42850
      shape torch.Size([16, 1024, 1024])  |            17700          |             96000            |             80300           |           85500
      shape torch.Size([32, 1024, 1024])  |            27740          |            150000            |            160700           |          171000
      shape torch.Size([64, 1024, 1024])  |            45940          |            233400            |            321200           |          344100
      shape torch.Size([1, 2048, 2048])   |            29860          |            579800            |             12920           |           13500
      shape torch.Size([2, 2048, 2048])   |            34000          |            585000            |             25840           |           26840
      shape torch.Size([4, 2048, 2048])   |            39770          |            593900            |             51670           |           54000
      shape torch.Size([8, 2048, 2048])   |            51720          |            632100            |            103000           |          109000
      shape torch.Size([16, 2048, 2048])  |            76900          |            845500            |            206600           |          218400
      shape torch.Size([32, 2048, 2048])  |           130000          |           1058000            |            413900           |          437300

Times are in microseconds (us).


```

</details>
To generate the results below, I put the backend I wanted to test at the beginning of the function `lu_solve_kernel`, followed by a `return;`. Then I run the following script, changing the variable `name`. 
<details>
<summary>
Benchmarking script
</summary>

```python
import torch
import pickle
import itertools
from functools import partial
from torch.utils.benchmark import Timer, Compare

benchmark_name = "linalg.lu_factor CUDA"
name = "magma_looped"
label = "lu_factor_{}".format(name)
shapes = [1, 2, 8, 16, 32, 64, 128, 256, 512, 1024, 2048]
batches = [(1,), (2,), (4,), (8,), (16,), (32,), (64,), (128,), (512,), (1024,)]
results = []
make_arg = partial(torch.randn, dtype=torch.float32, device="cuda")


for n, batch in itertools.product(shapes, batches):
    if n == 1024 and batch[0] >= 128:
        continue
    if n == 2048 and batch[0] >= 64:
        continue
    A = make_arg(batch + (n, n))
    print(A.shape)
    stmt = "torch.linalg.lu_factor_ex(A)"
    timer = Timer(stmt,
                  globals=globals(),
                  label=benchmark_name,
                  description=label,
                  sub_label=f"shape {A.shape}",
                  num_threads=1)
    results.append(timer.blocked_autorange())

compare = Compare(results)
compare.trim_significant_figures()
compare.print()

with open(f"{label}.pickle", 'wb') as f:
    pickle.dump(results, f)
```

</details>

See #72935 (comment) for the script to join the results.

[ghstack-poisoned]
This PR adds getrf_cublas to the functions considered in the heuristics
for `lu_factor`. It also updates the heuristics of the function.

## Benchmark

I'm omitting form the benchmarks the looped versions of the functions as they are much slower than the non-looped ones. The only exception to this is cusolver's looped variant, which is faster when applied to a batch of size one.

<details>
<summary>
Benchmark Results
</summary>

```
[------------------------------------------------- linalg.lu_factor CUDA -------------------------------------------------]                                                                                          
                                          |  lu_factor_heuristic  |  lu_factor_magma_batched  |  lu_factor_cusolver_batched                                                                                          
1 threads: ----------------------------------------------------------------------------------------------------------------                                                                                          
      shape torch.Size([1, 1, 1])         |            26         |              47           |                26                                                                                                    
      shape torch.Size([2, 1, 1])         |            17         |              38           |                17                                                                                                    
      shape torch.Size([4, 1, 1])         |            17         |              38           |                17                                                                                                    
      shape torch.Size([8, 1, 1])         |            20         |              38           |                18                                                                                                    
      shape torch.Size([16, 1, 1])        |            20         |              38           |                17                                                                                                    
      shape torch.Size([32, 1, 1])        |            18         |              38           |                17                                                                                                    
      shape torch.Size([64, 1, 1])        |            18         |              39           |                17                                                                                                    
      shape torch.Size([128, 1, 1])       |            17         |              38           |                17                                                                                                    
      shape torch.Size([512, 1, 1])       |            18         |              39           |                18                                                                                                    
      shape torch.Size([1024, 1, 1])      |            18         |              40           |                18                                                                                                    
      shape torch.Size([1, 2, 2])         |            18         |              38           |                17                                                                                                    
      shape torch.Size([2, 2, 2])         |            17         |              37           |                17                                                                                                    
      shape torch.Size([4, 2, 2])         |            17         |              38           |                17                                                                                                    
      shape torch.Size([8, 2, 2])         |            17         |              38           |                17                                                                                                    
      shape torch.Size([16, 2, 2])        |            17         |              38           |                17                                                                                                    
      shape torch.Size([32, 2, 2])        |            17         |              38           |                17                                                                                                    
      shape torch.Size([64, 2, 2])        |            17         |              38           |                17                                                                                                    
      shape torch.Size([128, 2, 2])       |            17         |              38           |                17                                                                                                    
      shape torch.Size([512, 2, 2])       |            17         |              39           |                17                                                                                                    
      shape torch.Size([1024, 2, 2])      |            17         |              40           |                17                                                                                                    
      shape torch.Size([1, 8, 8])         |            17         |              40           |                17                                                                                                    
      shape torch.Size([2, 8, 8])         |            17         |              40           |                17                                                                                                    
      shape torch.Size([4, 8, 8])         |            17         |              40           |                17                                                                                                    
      shape torch.Size([8, 8, 8])         |            17         |              40           |                17                                                                                                    
      shape torch.Size([16, 8, 8])        |            17         |              41           |                17                                                                                                    
      shape torch.Size([32, 8, 8])        |            17         |              40           |                17                                                                                                    
      shape torch.Size([64, 8, 8])        |            17         |              40           |                17                                                                                                    
      shape torch.Size([128, 8, 8])       |            17         |              40           |                17                                                                                                    
      shape torch.Size([512, 8, 8])       |            17         |              42           |                17                                                                                                    
      shape torch.Size([1024, 8, 8])      |            17         |              44           |                17                                                                                                    
      shape torch.Size([1, 16, 16])       |            24         |              44           |                18                                                                                                    
      shape torch.Size([2, 16, 16])       |            18         |              44           |                18                                                                                                    
      shape torch.Size([4, 16, 16])       |            18         |              45           |                18          
      shape torch.Size([8, 16, 16])       |            19         |              44           |                19          
      shape torch.Size([16, 16, 16])      |            20         |              44           |                20          
      shape torch.Size([32, 16, 16])      |            20         |              45           |                20          
      shape torch.Size([64, 16, 16])      |            20         |              44           |                20          
      shape torch.Size([128, 16, 16])     |            20         |              45           |                20          
      shape torch.Size([512, 16, 16])     |            28         |              50           |                28          
      shape torch.Size([1024, 16, 16])    |            41         |              59           |                41          
      shape torch.Size([1, 32, 32])       |            58         |              50           |                56          
      shape torch.Size([2, 32, 32])       |            56         |              50           |                56          
      shape torch.Size([4, 32, 32])       |            56         |              50           |                57          
      shape torch.Size([8, 32, 32])       |            60         |              50           |                60          
      shape torch.Size([16, 32, 32])      |            60         |              51           |                60          
      shape torch.Size([32, 32, 32])      |           247         |              51           |                61          
      shape torch.Size([64, 32, 32])      |           233         |              51           |                63          
      shape torch.Size([128, 32, 32])     |           236         |              53           |                66          
      shape torch.Size([512, 32, 32])     |           268         |              97           |               193          
      shape torch.Size([1024, 32, 32])    |           317         |             167           |               333          
      shape torch.Size([1, 64, 64])       |           131         |             216           |                99          
      shape torch.Size([2, 64, 64])       |            99         |             220           |                99          
      shape torch.Size([4, 64, 64])       |            99         |             225           |               101          
      shape torch.Size([8, 64, 64])       |           101         |             225           |               102          
      shape torch.Size([16, 64, 64])      |           107         |             230           |               108          
      shape torch.Size([32, 64, 64])      |           440         |             235           |               126          
      shape torch.Size([64, 64, 64])      |           447         |             240           |               155          
      shape torch.Size([128, 64, 64])     |           470         |             289           |               240          
      shape torch.Size([512, 64, 64])     |           793         |             678           |              1180          
      shape torch.Size([1024, 64, 64])    |          1000         |            1300           |              2112          
      shape torch.Size([1, 128, 128])     |           296         |             482           |               309          
      shape torch.Size([2, 128, 128])     |           308         |             499           |               307          
      shape torch.Size([4, 128, 128])     |           311         |             510           |               310          
      shape torch.Size([8, 128, 128])     |           314         |             522           |               314          
      shape torch.Size([16, 128, 128])    |           334         |             541           |               334          
      shape torch.Size([32, 128, 128])    |           770         |             591           |               467          
      shape torch.Size([64, 128, 128])    |           860         |             694           |               733          
      shape torch.Size([128, 128, 128])   |          1040         |             925           |              1980          
      shape torch.Size([512, 128, 128])   |          2883         |            2809           |             11000          
      shape torch.Size([1024, 128, 128])  |          5421         |            5430           |             22360          
      shape torch.Size([1, 256, 256])     |          1310         |            1109           |              1556          
      shape torch.Size([2, 256, 256])     |          1360         |            1150           |              1560          
      shape torch.Size([4, 256, 256])     |          1390         |            1188           |              1569          
      shape torch.Size([8, 256, 256])     |          1440         |            1250           |              1604          
      shape torch.Size([16, 256, 256])    |          1550         |            1390           |              1850          
      shape torch.Size([32, 256, 256])    |          1750         |            1620           |              3332          
      shape torch.Size([64, 256, 256])    |          2327         |            2246           |              6700          
      shape torch.Size([128, 256, 256])   |          3697         |            3638           |             19100          
      shape torch.Size([512, 256, 256])   |         12530         |           12500           |             87300          
      shape torch.Size([1024, 256, 256])  |         24380         |           24420           |            176000          
```

</details>

<details>
<summary>
Benchmark Results all algorithms up to `n=2048`
</summary>

```
[----------------------------------------------------------------- linalg.lu_factor CUDA ------------------------------------------------------------------]
                                          |  lu_factor_magma_batched  |  lu_factor_cusolver_batched  |  lu_factor_cusolver_looped  |  lu_factor_magma_looped
1 threads: -------------------------------------------------------------------------------------------------------------------------------------------------
      shape torch.Size([1, 1, 1])         |               51          |                30            |                27           |            1390
      shape torch.Size([2, 1, 1])         |               42          |                20            |                26           |            2798
      shape torch.Size([4, 1, 1])         |               42          |                20            |                42           |            5589
      shape torch.Size([8, 1, 1])         |               42          |                20            |                72           |           11000
      shape torch.Size([16, 1, 1])        |               42          |                20            |               132           |           22400
      shape torch.Size([32, 1, 1])        |               42          |                20            |               253           |           44620
      shape torch.Size([64, 1, 1])        |               42          |                20            |               496           |           89200
      shape torch.Size([128, 1, 1])       |               42          |                20            |               980           |          180000
      shape torch.Size([512, 1, 1])       |               43          |                20            |              3868           |          714100
      shape torch.Size([1024, 1, 1])      |               44          |                20            |              7800           |         1430000
      shape torch.Size([1, 2, 2])         |               43          |                21            |                19           |            1400
      shape torch.Size([2, 2, 2])         |               42          |                21            |                27           |            2898
      shape torch.Size([4, 2, 2])         |               43          |                21            |                42           |            5800
      shape torch.Size([8, 2, 2])         |               43          |                21            |                73           |           11600
      shape torch.Size([16, 2, 2])        |               43          |                21            |               133           |           23170
      shape torch.Size([32, 2, 2])        |               43          |                21            |               254           |           46290
      shape torch.Size([64, 2, 2])        |               43          |                21            |               500           |           94000
      shape torch.Size([128, 2, 2])       |               43          |                21            |               980           |          190000
      shape torch.Size([512, 2, 2])       |               44          |                21            |              3860           |          741900
      shape torch.Size([1024, 2, 2])      |               44          |                21            |              7640           |         1484000
      shape torch.Size([1, 8, 8])         |               45          |                21            |                19           |            1450
      shape torch.Size([2, 8, 8])         |               45          |                21            |                27           |            2917
      shape torch.Size([4, 8, 8])         |               45          |                21            |                53           |            5800
      shape torch.Size([8, 8, 8])         |               45          |                21            |               105           |           11580
      shape torch.Size([16, 8, 8])        |               45          |                21            |               207           |           23160
      shape torch.Size([32, 8, 8])        |               46          |                21            |               413           |           46400
      shape torch.Size([64, 8, 8])        |               46          |                21            |               824           |           93000
      shape torch.Size([128, 8, 8])       |               46          |                21            |              1645           |          185000
      shape torch.Size([512, 8, 8])       |               47          |                21            |              6574           |          742000
      shape torch.Size([1024, 8, 8])      |               49          |                21            |             13150           |         1481000
      shape torch.Size([1, 16, 16])       |               49          |                21            |                24           |            1460
      shape torch.Size([2, 16, 16])       |               49          |                21            |                46           |            2902
      shape torch.Size([4, 16, 16])       |               49          |                21            |                90           |            5800
      shape torch.Size([8, 16, 16])       |               49          |                21            |               177           |           11600
      shape torch.Size([16, 16, 16])      |               49          |                21            |               352           |           23150
      shape torch.Size([32, 16, 16])      |               49          |                21            |               703           |           46300
      shape torch.Size([64, 16, 16])      |               49          |                21            |              1404           |           92700
      shape torch.Size([128, 16, 16])     |               50          |                21            |              2807           |          185000
      shape torch.Size([512, 16, 16])     |               55          |                29            |             11220           |          741700
      shape torch.Size([1024, 16, 16])    |               64          |                42            |             22440           |         1480000
      shape torch.Size([1, 32, 32])       |               55          |                56            |                58           |            1460
      shape torch.Size([2, 32, 32])       |               55          |                57            |               114           |            2920
      shape torch.Size([4, 32, 32])       |               55          |                57            |               225           |            5830
      shape torch.Size([8, 32, 32])       |               55          |                61            |               449           |           11700
      shape torch.Size([16, 32, 32])      |               56          |                61            |               896           |           23300
      shape torch.Size([32, 32, 32])      |               56          |                62            |              1791           |           46600
      shape torch.Size([64, 32, 32])      |               56          |                63            |              3581           |           93100
      shape torch.Size([128, 32, 32])     |               58          |                66            |              7156           |          186000
      shape torch.Size([512, 32, 32])     |              100          |               194            |             28700           |          742400
      shape torch.Size([1024, 32, 32])    |              169          |               335            |             57620           |         1485000
      shape torch.Size([1, 64, 64])       |              224          |               101            |               132           |            1500
      shape torch.Size([2, 64, 64])       |              227          |               100            |               262           |            2951
      shape torch.Size([4, 64, 64])       |              229          |               101            |               523           |            5890
      shape torch.Size([8, 64, 64])       |              231          |               102            |              1040           |           12000
      shape torch.Size([16, 64, 64])      |              237          |               109            |              2088           |           23530
      shape torch.Size([32, 64, 64])      |              242          |               127            |              4171           |           46900
      shape torch.Size([64, 64, 64])      |              247          |               156            |              8330           |           95000
      shape torch.Size([128, 64, 64])     |              293          |               244            |             16710           |          189000
      shape torch.Size([512, 64, 64])     |              685          |              1180            |             67000           |          750900
      shape torch.Size([1024, 64, 64])    |             1300          |              2076            |            134000           |         1505000
      shape torch.Size([1, 128, 128])     |              490          |               309            |               298           |            1560
      shape torch.Size([2, 128, 128])     |              503          |               309            |               594           |            3120
      shape torch.Size([4, 128, 128])     |              515          |               312            |              1185           |            6230
      shape torch.Size([8, 128, 128])     |              523          |               317            |              2370           |           12500
      shape torch.Size([16, 128, 128])    |              547          |               336            |              4734           |           24890
      shape torch.Size([32, 128, 128])    |              596          |               472            |              9491           |           49800
      shape torch.Size([64, 128, 128])    |              700          |               741            |             19000           |          100000
      shape torch.Size([128, 128, 128])   |              930          |              1770            |             37990           |          199000
      shape torch.Size([512, 128, 128])   |             2810          |             11000            |            152000           |          797100
      shape torch.Size([1024, 128, 128])  |             5430          |             22430            |            303900           |         1595000
      shape torch.Size([1, 256, 256])     |             1120          |              1580            |               666           |            1890
      shape torch.Size([2, 256, 256])     |             1160          |              1574            |              1330           |            3784
      shape torch.Size([4, 256, 256])     |             1190          |              1580            |              2658           |            7570
      shape torch.Size([8, 256, 256])     |             1250          |              1613            |              5325           |           15100
      shape torch.Size([16, 256, 256])    |             1394          |              1880            |             10700           |           30260
      shape torch.Size([32, 256, 256])    |             1633          |              3360            |             21300           |           61000
      shape torch.Size([64, 256, 256])    |             2258          |              6730            |             42600           |          120000
      shape torch.Size([128, 256, 256])   |             3639          |             19200            |             85170           |          242200
      shape torch.Size([512, 256, 256])   |            12600          |             87200            |            340600           |          969000
      shape torch.Size([1024, 256, 256])  |            24530          |            176000            |            681300           |         1943000
      shape torch.Size([1, 512, 512])     |             2557          |              9117            |              1724           |            2577
      shape torch.Size([2, 512, 512])     |             2691          |              9209            |              3464           |            5200
      shape torch.Size([4, 512, 512])     |             2853          |              9860            |              6940           |           10000
      shape torch.Size([8, 512, 512])     |             3153          |             11000            |             13900           |           20570
      shape torch.Size([16, 512, 512])    |             3765          |             13000            |             27720           |           41360
      shape torch.Size([32, 512, 512])    |             5500          |             21400            |             55420           |           82000
      shape torch.Size([64, 512, 512])    |             8790          |             44000            |            111000           |          165000
      shape torch.Size([128, 512, 512])   |            15300          |             98000            |            221700           |          329800
      shape torch.Size([512, 512, 512])   |            55400          |            424100            |            886600           |         1325000
      shape torch.Size([1024, 512, 512])  |           110000          |            856200            |           1773000           |         2691000
      shape torch.Size([1, 1024, 1024])   |            10350          |             69290            |              5020           |            5327
      shape torch.Size([2, 1024, 1024])   |            11200          |             74860            |             10040           |           11000
      shape torch.Size([4, 1024, 1024])   |            12200          |             78030            |             20080           |           21290
      shape torch.Size([8, 1024, 1024])   |            14000          |             81200            |             40160           |           42850
      shape torch.Size([16, 1024, 1024])  |            17700          |             96000            |             80300           |           85500
      shape torch.Size([32, 1024, 1024])  |            27740          |            150000            |            160700           |          171000
      shape torch.Size([64, 1024, 1024])  |            45940          |            233400            |            321200           |          344100
      shape torch.Size([1, 2048, 2048])   |            29860          |            579800            |             12920           |           13500
      shape torch.Size([2, 2048, 2048])   |            34000          |            585000            |             25840           |           26840
      shape torch.Size([4, 2048, 2048])   |            39770          |            593900            |             51670           |           54000
      shape torch.Size([8, 2048, 2048])   |            51720          |            632100            |            103000           |          109000
      shape torch.Size([16, 2048, 2048])  |            76900          |            845500            |            206600           |          218400
      shape torch.Size([32, 2048, 2048])  |           130000          |           1058000            |            413900           |          437300

Times are in microseconds (us).


```

</details>
To generate the results below, I put the backend I wanted to test at the beginning of the function `lu_solve_kernel`, followed by a `return;`. Then I run the following script, changing the variable `name`. 
<details>
<summary>
Benchmarking script
</summary>

```python
import torch
import pickle
import itertools
from functools import partial
from torch.utils.benchmark import Timer, Compare

benchmark_name = "linalg.lu_factor CUDA"
name = "magma_looped"
label = "lu_factor_{}".format(name)
shapes = [1, 2, 8, 16, 32, 64, 128, 256, 512, 1024, 2048]
batches = [(1,), (2,), (4,), (8,), (16,), (32,), (64,), (128,), (512,), (1024,)]
results = []
make_arg = partial(torch.randn, dtype=torch.float32, device="cuda")


for n, batch in itertools.product(shapes, batches):
    if n == 1024 and batch[0] >= 128:
        continue
    if n == 2048 and batch[0] >= 64:
        continue
    A = make_arg(batch + (n, n))
    print(A.shape)
    stmt = "torch.linalg.lu_factor_ex(A)"
    timer = Timer(stmt,
                  globals=globals(),
                  label=benchmark_name,
                  description=label,
                  sub_label=f"shape {A.shape}",
                  num_threads=1)
    results.append(timer.blocked_autorange())

compare = Compare(results)
compare.trim_significant_figures()
compare.print()

with open(f"{label}.pickle", 'wb') as f:
    pickle.dump(results, f)
```

</details>

See #72935 (comment) for the script to join the results.

[ghstack-poisoned]
This PR adds getrf_cublas to the functions considered in the heuristics
for `lu_factor`. It also updates the heuristics of the function.

## Benchmark

I'm omitting form the benchmarks the looped versions of the functions as they are much slower than the non-looped ones. The only exception to this is cusolver's looped variant, which is faster when applied to a batch of size one.

<details>
<summary>
Benchmark Results
</summary>

```
[------------------------------------------------- linalg.lu_factor CUDA -------------------------------------------------]                                                                                          
                                          |  lu_factor_heuristic  |  lu_factor_magma_batched  |  lu_factor_cusolver_batched                                                                                          
1 threads: ----------------------------------------------------------------------------------------------------------------                                                                                          
      shape torch.Size([1, 1, 1])         |            26         |              47           |                26                                                                                                    
      shape torch.Size([2, 1, 1])         |            17         |              38           |                17                                                                                                    
      shape torch.Size([4, 1, 1])         |            17         |              38           |                17                                                                                                    
      shape torch.Size([8, 1, 1])         |            20         |              38           |                18                                                                                                    
      shape torch.Size([16, 1, 1])        |            20         |              38           |                17                                                                                                    
      shape torch.Size([32, 1, 1])        |            18         |              38           |                17                                                                                                    
      shape torch.Size([64, 1, 1])        |            18         |              39           |                17                                                                                                    
      shape torch.Size([128, 1, 1])       |            17         |              38           |                17                                                                                                    
      shape torch.Size([512, 1, 1])       |            18         |              39           |                18                                                                                                    
      shape torch.Size([1024, 1, 1])      |            18         |              40           |                18                                                                                                    
      shape torch.Size([1, 2, 2])         |            18         |              38           |                17                                                                                                    
      shape torch.Size([2, 2, 2])         |            17         |              37           |                17                                                                                                    
      shape torch.Size([4, 2, 2])         |            17         |              38           |                17                                                                                                    
      shape torch.Size([8, 2, 2])         |            17         |              38           |                17                                                                                                    
      shape torch.Size([16, 2, 2])        |            17         |              38           |                17                                                                                                    
      shape torch.Size([32, 2, 2])        |            17         |              38           |                17                                                                                                    
      shape torch.Size([64, 2, 2])        |            17         |              38           |                17                                                                                                    
      shape torch.Size([128, 2, 2])       |            17         |              38           |                17                                                                                                    
      shape torch.Size([512, 2, 2])       |            17         |              39           |                17                                                                                                    
      shape torch.Size([1024, 2, 2])      |            17         |              40           |                17                                                                                                    
      shape torch.Size([1, 8, 8])         |            17         |              40           |                17                                                                                                    
      shape torch.Size([2, 8, 8])         |            17         |              40           |                17                                                                                                    
      shape torch.Size([4, 8, 8])         |            17         |              40           |                17                                                                                                    
      shape torch.Size([8, 8, 8])         |            17         |              40           |                17                                                                                                    
      shape torch.Size([16, 8, 8])        |            17         |              41           |                17                                                                                                    
      shape torch.Size([32, 8, 8])        |            17         |              40           |                17                                                                                                    
      shape torch.Size([64, 8, 8])        |            17         |              40           |                17                                                                                                    
      shape torch.Size([128, 8, 8])       |            17         |              40           |                17                                                                                                    
      shape torch.Size([512, 8, 8])       |            17         |              42           |                17                                                                                                    
      shape torch.Size([1024, 8, 8])      |            17         |              44           |                17                                                                                                    
      shape torch.Size([1, 16, 16])       |            24         |              44           |                18                                                                                                    
      shape torch.Size([2, 16, 16])       |            18         |              44           |                18                                                                                                    
      shape torch.Size([4, 16, 16])       |            18         |              45           |                18          
      shape torch.Size([8, 16, 16])       |            19         |              44           |                19          
      shape torch.Size([16, 16, 16])      |            20         |              44           |                20          
      shape torch.Size([32, 16, 16])      |            20         |              45           |                20          
      shape torch.Size([64, 16, 16])      |            20         |              44           |                20          
      shape torch.Size([128, 16, 16])     |            20         |              45           |                20          
      shape torch.Size([512, 16, 16])     |            28         |              50           |                28          
      shape torch.Size([1024, 16, 16])    |            41         |              59           |                41          
      shape torch.Size([1, 32, 32])       |            58         |              50           |                56          
      shape torch.Size([2, 32, 32])       |            56         |              50           |                56          
      shape torch.Size([4, 32, 32])       |            56         |              50           |                57          
      shape torch.Size([8, 32, 32])       |            60         |              50           |                60          
      shape torch.Size([16, 32, 32])      |            60         |              51           |                60          
      shape torch.Size([32, 32, 32])      |           247         |              51           |                61          
      shape torch.Size([64, 32, 32])      |           233         |              51           |                63          
      shape torch.Size([128, 32, 32])     |           236         |              53           |                66          
      shape torch.Size([512, 32, 32])     |           268         |              97           |               193          
      shape torch.Size([1024, 32, 32])    |           317         |             167           |               333          
      shape torch.Size([1, 64, 64])       |           131         |             216           |                99          
      shape torch.Size([2, 64, 64])       |            99         |             220           |                99          
      shape torch.Size([4, 64, 64])       |            99         |             225           |               101          
      shape torch.Size([8, 64, 64])       |           101         |             225           |               102          
      shape torch.Size([16, 64, 64])      |           107         |             230           |               108          
      shape torch.Size([32, 64, 64])      |           440         |             235           |               126          
      shape torch.Size([64, 64, 64])      |           447         |             240           |               155          
      shape torch.Size([128, 64, 64])     |           470         |             289           |               240          
      shape torch.Size([512, 64, 64])     |           793         |             678           |              1180          
      shape torch.Size([1024, 64, 64])    |          1000         |            1300           |              2112          
      shape torch.Size([1, 128, 128])     |           296         |             482           |               309          
      shape torch.Size([2, 128, 128])     |           308         |             499           |               307          
      shape torch.Size([4, 128, 128])     |           311         |             510           |               310          
      shape torch.Size([8, 128, 128])     |           314         |             522           |               314          
      shape torch.Size([16, 128, 128])    |           334         |             541           |               334          
      shape torch.Size([32, 128, 128])    |           770         |             591           |               467          
      shape torch.Size([64, 128, 128])    |           860         |             694           |               733          
      shape torch.Size([128, 128, 128])   |          1040         |             925           |              1980          
      shape torch.Size([512, 128, 128])   |          2883         |            2809           |             11000          
      shape torch.Size([1024, 128, 128])  |          5421         |            5430           |             22360          
      shape torch.Size([1, 256, 256])     |          1310         |            1109           |              1556          
      shape torch.Size([2, 256, 256])     |          1360         |            1150           |              1560          
      shape torch.Size([4, 256, 256])     |          1390         |            1188           |              1569          
      shape torch.Size([8, 256, 256])     |          1440         |            1250           |              1604          
      shape torch.Size([16, 256, 256])    |          1550         |            1390           |              1850          
      shape torch.Size([32, 256, 256])    |          1750         |            1620           |              3332          
      shape torch.Size([64, 256, 256])    |          2327         |            2246           |              6700          
      shape torch.Size([128, 256, 256])   |          3697         |            3638           |             19100          
      shape torch.Size([512, 256, 256])   |         12530         |           12500           |             87300          
      shape torch.Size([1024, 256, 256])  |         24380         |           24420           |            176000          
```

</details>

<details>
<summary>
Benchmark Results all algorithms up to `n=2048`
</summary>

```
[----------------------------------------------------------------- linalg.lu_factor CUDA ------------------------------------------------------------------]
                                          |  lu_factor_magma_batched  |  lu_factor_cusolver_batched  |  lu_factor_cusolver_looped  |  lu_factor_magma_looped
1 threads: -------------------------------------------------------------------------------------------------------------------------------------------------
      shape torch.Size([1, 1, 1])         |               51          |                30            |                27           |            1390
      shape torch.Size([2, 1, 1])         |               42          |                20            |                26           |            2798
      shape torch.Size([4, 1, 1])         |               42          |                20            |                42           |            5589
      shape torch.Size([8, 1, 1])         |               42          |                20            |                72           |           11000
      shape torch.Size([16, 1, 1])        |               42          |                20            |               132           |           22400
      shape torch.Size([32, 1, 1])        |               42          |                20            |               253           |           44620
      shape torch.Size([64, 1, 1])        |               42          |                20            |               496           |           89200
      shape torch.Size([128, 1, 1])       |               42          |                20            |               980           |          180000
      shape torch.Size([512, 1, 1])       |               43          |                20            |              3868           |          714100
      shape torch.Size([1024, 1, 1])      |               44          |                20            |              7800           |         1430000
      shape torch.Size([1, 2, 2])         |               43          |                21            |                19           |            1400
      shape torch.Size([2, 2, 2])         |               42          |                21            |                27           |            2898
      shape torch.Size([4, 2, 2])         |               43          |                21            |                42           |            5800
      shape torch.Size([8, 2, 2])         |               43          |                21            |                73           |           11600
      shape torch.Size([16, 2, 2])        |               43          |                21            |               133           |           23170
      shape torch.Size([32, 2, 2])        |               43          |                21            |               254           |           46290
      shape torch.Size([64, 2, 2])        |               43          |                21            |               500           |           94000
      shape torch.Size([128, 2, 2])       |               43          |                21            |               980           |          190000
      shape torch.Size([512, 2, 2])       |               44          |                21            |              3860           |          741900
      shape torch.Size([1024, 2, 2])      |               44          |                21            |              7640           |         1484000
      shape torch.Size([1, 8, 8])         |               45          |                21            |                19           |            1450
      shape torch.Size([2, 8, 8])         |               45          |                21            |                27           |            2917
      shape torch.Size([4, 8, 8])         |               45          |                21            |                53           |            5800
      shape torch.Size([8, 8, 8])         |               45          |                21            |               105           |           11580
      shape torch.Size([16, 8, 8])        |               45          |                21            |               207           |           23160
      shape torch.Size([32, 8, 8])        |               46          |                21            |               413           |           46400
      shape torch.Size([64, 8, 8])        |               46          |                21            |               824           |           93000
      shape torch.Size([128, 8, 8])       |               46          |                21            |              1645           |          185000
      shape torch.Size([512, 8, 8])       |               47          |                21            |              6574           |          742000
      shape torch.Size([1024, 8, 8])      |               49          |                21            |             13150           |         1481000
      shape torch.Size([1, 16, 16])       |               49          |                21            |                24           |            1460
      shape torch.Size([2, 16, 16])       |               49          |                21            |                46           |            2902
      shape torch.Size([4, 16, 16])       |               49          |                21            |                90           |            5800
      shape torch.Size([8, 16, 16])       |               49          |                21            |               177           |           11600
      shape torch.Size([16, 16, 16])      |               49          |                21            |               352           |           23150
      shape torch.Size([32, 16, 16])      |               49          |                21            |               703           |           46300
      shape torch.Size([64, 16, 16])      |               49          |                21            |              1404           |           92700
      shape torch.Size([128, 16, 16])     |               50          |                21            |              2807           |          185000
      shape torch.Size([512, 16, 16])     |               55          |                29            |             11220           |          741700
      shape torch.Size([1024, 16, 16])    |               64          |                42            |             22440           |         1480000
      shape torch.Size([1, 32, 32])       |               55          |                56            |                58           |            1460
      shape torch.Size([2, 32, 32])       |               55          |                57            |               114           |            2920
      shape torch.Size([4, 32, 32])       |               55          |                57            |               225           |            5830
      shape torch.Size([8, 32, 32])       |               55          |                61            |               449           |           11700
      shape torch.Size([16, 32, 32])      |               56          |                61            |               896           |           23300
      shape torch.Size([32, 32, 32])      |               56          |                62            |              1791           |           46600
      shape torch.Size([64, 32, 32])      |               56          |                63            |              3581           |           93100
      shape torch.Size([128, 32, 32])     |               58          |                66            |              7156           |          186000
      shape torch.Size([512, 32, 32])     |              100          |               194            |             28700           |          742400
      shape torch.Size([1024, 32, 32])    |              169          |               335            |             57620           |         1485000
      shape torch.Size([1, 64, 64])       |              224          |               101            |               132           |            1500
      shape torch.Size([2, 64, 64])       |              227          |               100            |               262           |            2951
      shape torch.Size([4, 64, 64])       |              229          |               101            |               523           |            5890
      shape torch.Size([8, 64, 64])       |              231          |               102            |              1040           |           12000
      shape torch.Size([16, 64, 64])      |              237          |               109            |              2088           |           23530
      shape torch.Size([32, 64, 64])      |              242          |               127            |              4171           |           46900
      shape torch.Size([64, 64, 64])      |              247          |               156            |              8330           |           95000
      shape torch.Size([128, 64, 64])     |              293          |               244            |             16710           |          189000
      shape torch.Size([512, 64, 64])     |              685          |              1180            |             67000           |          750900
      shape torch.Size([1024, 64, 64])    |             1300          |              2076            |            134000           |         1505000
      shape torch.Size([1, 128, 128])     |              490          |               309            |               298           |            1560
      shape torch.Size([2, 128, 128])     |              503          |               309            |               594           |            3120
      shape torch.Size([4, 128, 128])     |              515          |               312            |              1185           |            6230
      shape torch.Size([8, 128, 128])     |              523          |               317            |              2370           |           12500
      shape torch.Size([16, 128, 128])    |              547          |               336            |              4734           |           24890
      shape torch.Size([32, 128, 128])    |              596          |               472            |              9491           |           49800
      shape torch.Size([64, 128, 128])    |              700          |               741            |             19000           |          100000
      shape torch.Size([128, 128, 128])   |              930          |              1770            |             37990           |          199000
      shape torch.Size([512, 128, 128])   |             2810          |             11000            |            152000           |          797100
      shape torch.Size([1024, 128, 128])  |             5430          |             22430            |            303900           |         1595000
      shape torch.Size([1, 256, 256])     |             1120          |              1580            |               666           |            1890
      shape torch.Size([2, 256, 256])     |             1160          |              1574            |              1330           |            3784
      shape torch.Size([4, 256, 256])     |             1190          |              1580            |              2658           |            7570
      shape torch.Size([8, 256, 256])     |             1250          |              1613            |              5325           |           15100
      shape torch.Size([16, 256, 256])    |             1394          |              1880            |             10700           |           30260
      shape torch.Size([32, 256, 256])    |             1633          |              3360            |             21300           |           61000
      shape torch.Size([64, 256, 256])    |             2258          |              6730            |             42600           |          120000
      shape torch.Size([128, 256, 256])   |             3639          |             19200            |             85170           |          242200
      shape torch.Size([512, 256, 256])   |            12600          |             87200            |            340600           |          969000
      shape torch.Size([1024, 256, 256])  |            24530          |            176000            |            681300           |         1943000
      shape torch.Size([1, 512, 512])     |             2557          |              9117            |              1724           |            2577
      shape torch.Size([2, 512, 512])     |             2691          |              9209            |              3464           |            5200
      shape torch.Size([4, 512, 512])     |             2853          |              9860            |              6940           |           10000
      shape torch.Size([8, 512, 512])     |             3153          |             11000            |             13900           |           20570
      shape torch.Size([16, 512, 512])    |             3765          |             13000            |             27720           |           41360
      shape torch.Size([32, 512, 512])    |             5500          |             21400            |             55420           |           82000
      shape torch.Size([64, 512, 512])    |             8790          |             44000            |            111000           |          165000
      shape torch.Size([128, 512, 512])   |            15300          |             98000            |            221700           |          329800
      shape torch.Size([512, 512, 512])   |            55400          |            424100            |            886600           |         1325000
      shape torch.Size([1024, 512, 512])  |           110000          |            856200            |           1773000           |         2691000
      shape torch.Size([1, 1024, 1024])   |            10350          |             69290            |              5020           |            5327
      shape torch.Size([2, 1024, 1024])   |            11200          |             74860            |             10040           |           11000
      shape torch.Size([4, 1024, 1024])   |            12200          |             78030            |             20080           |           21290
      shape torch.Size([8, 1024, 1024])   |            14000          |             81200            |             40160           |           42850
      shape torch.Size([16, 1024, 1024])  |            17700          |             96000            |             80300           |           85500
      shape torch.Size([32, 1024, 1024])  |            27740          |            150000            |            160700           |          171000
      shape torch.Size([64, 1024, 1024])  |            45940          |            233400            |            321200           |          344100
      shape torch.Size([1, 2048, 2048])   |            29860          |            579800            |             12920           |           13500
      shape torch.Size([2, 2048, 2048])   |            34000          |            585000            |             25840           |           26840
      shape torch.Size([4, 2048, 2048])   |            39770          |            593900            |             51670           |           54000
      shape torch.Size([8, 2048, 2048])   |            51720          |            632100            |            103000           |          109000
      shape torch.Size([16, 2048, 2048])  |            76900          |            845500            |            206600           |          218400
      shape torch.Size([32, 2048, 2048])  |           130000          |           1058000            |            413900           |          437300

Times are in microseconds (us).


```

</details>
To generate the results below, I put the backend I wanted to test at the beginning of the function `lu_solve_kernel`, followed by a `return;`. Then I run the following script, changing the variable `name`. 
<details>
<summary>
Benchmarking script
</summary>

```python
import torch
import pickle
import itertools
from functools import partial
from torch.utils.benchmark import Timer, Compare

benchmark_name = "linalg.lu_factor CUDA"
name = "magma_looped"
label = "lu_factor_{}".format(name)
shapes = [1, 2, 8, 16, 32, 64, 128, 256, 512, 1024, 2048]
batches = [(1,), (2,), (4,), (8,), (16,), (32,), (64,), (128,), (512,), (1024,)]
results = []
make_arg = partial(torch.randn, dtype=torch.float32, device="cuda")


for n, batch in itertools.product(shapes, batches):
    if n == 1024 and batch[0] >= 128:
        continue
    if n == 2048 and batch[0] >= 64:
        continue
    A = make_arg(batch + (n, n))
    print(A.shape)
    stmt = "torch.linalg.lu_factor_ex(A)"
    timer = Timer(stmt,
                  globals=globals(),
                  label=benchmark_name,
                  description=label,
                  sub_label=f"shape {A.shape}",
                  num_threads=1)
    results.append(timer.blocked_autorange())

compare = Compare(results)
compare.trim_significant_figures()
compare.print()

with open(f"{label}.pickle", 'wb') as f:
    pickle.dump(results, f)
```

</details>

See #72935 (comment) for the script to join the results.

[ghstack-poisoned]
lezcano added a commit that referenced this pull request May 25, 2022
This PR adds getrf_cublas to the functions considered in the heuristics
for lu_solve.

ghstack-source-id: 72febec
Pull Request resolved: #73878
This PR adds getrf_cublas to the functions considered in the heuristics
for `lu_factor`. It also updates the heuristics of the function.

## Benchmark

I'm omitting form the benchmarks the looped versions of the functions as they are much slower than the non-looped ones. The only exception to this is cusolver's looped variant, which is faster when applied to a batch of size one.

<details>
<summary>
Benchmark Results
</summary>

```
[------------------------------------------------- linalg.lu_factor CUDA -------------------------------------------------]                                                                                          
                                          |  lu_factor_heuristic  |  lu_factor_magma_batched  |  lu_factor_cusolver_batched                                                                                          
1 threads: ----------------------------------------------------------------------------------------------------------------                                                                                          
      shape torch.Size([1, 1, 1])         |            26         |              47           |                26                                                                                                    
      shape torch.Size([2, 1, 1])         |            17         |              38           |                17                                                                                                    
      shape torch.Size([4, 1, 1])         |            17         |              38           |                17                                                                                                    
      shape torch.Size([8, 1, 1])         |            20         |              38           |                18                                                                                                    
      shape torch.Size([16, 1, 1])        |            20         |              38           |                17                                                                                                    
      shape torch.Size([32, 1, 1])        |            18         |              38           |                17                                                                                                    
      shape torch.Size([64, 1, 1])        |            18         |              39           |                17                                                                                                    
      shape torch.Size([128, 1, 1])       |            17         |              38           |                17                                                                                                    
      shape torch.Size([512, 1, 1])       |            18         |              39           |                18                                                                                                    
      shape torch.Size([1024, 1, 1])      |            18         |              40           |                18                                                                                                    
      shape torch.Size([1, 2, 2])         |            18         |              38           |                17                                                                                                    
      shape torch.Size([2, 2, 2])         |            17         |              37           |                17                                                                                                    
      shape torch.Size([4, 2, 2])         |            17         |              38           |                17                                                                                                    
      shape torch.Size([8, 2, 2])         |            17         |              38           |                17                                                                                                    
      shape torch.Size([16, 2, 2])        |            17         |              38           |                17                                                                                                    
      shape torch.Size([32, 2, 2])        |            17         |              38           |                17                                                                                                    
      shape torch.Size([64, 2, 2])        |            17         |              38           |                17                                                                                                    
      shape torch.Size([128, 2, 2])       |            17         |              38           |                17                                                                                                    
      shape torch.Size([512, 2, 2])       |            17         |              39           |                17                                                                                                    
      shape torch.Size([1024, 2, 2])      |            17         |              40           |                17                                                                                                    
      shape torch.Size([1, 8, 8])         |            17         |              40           |                17                                                                                                    
      shape torch.Size([2, 8, 8])         |            17         |              40           |                17                                                                                                    
      shape torch.Size([4, 8, 8])         |            17         |              40           |                17                                                                                                    
      shape torch.Size([8, 8, 8])         |            17         |              40           |                17                                                                                                    
      shape torch.Size([16, 8, 8])        |            17         |              41           |                17                                                                                                    
      shape torch.Size([32, 8, 8])        |            17         |              40           |                17                                                                                                    
      shape torch.Size([64, 8, 8])        |            17         |              40           |                17                                                                                                    
      shape torch.Size([128, 8, 8])       |            17         |              40           |                17                                                                                                    
      shape torch.Size([512, 8, 8])       |            17         |              42           |                17                                                                                                    
      shape torch.Size([1024, 8, 8])      |            17         |              44           |                17                                                                                                    
      shape torch.Size([1, 16, 16])       |            24         |              44           |                18                                                                                                    
      shape torch.Size([2, 16, 16])       |            18         |              44           |                18                                                                                                    
      shape torch.Size([4, 16, 16])       |            18         |              45           |                18          
      shape torch.Size([8, 16, 16])       |            19         |              44           |                19          
      shape torch.Size([16, 16, 16])      |            20         |              44           |                20          
      shape torch.Size([32, 16, 16])      |            20         |              45           |                20          
      shape torch.Size([64, 16, 16])      |            20         |              44           |                20          
      shape torch.Size([128, 16, 16])     |            20         |              45           |                20          
      shape torch.Size([512, 16, 16])     |            28         |              50           |                28          
      shape torch.Size([1024, 16, 16])    |            41         |              59           |                41          
      shape torch.Size([1, 32, 32])       |            58         |              50           |                56          
      shape torch.Size([2, 32, 32])       |            56         |              50           |                56          
      shape torch.Size([4, 32, 32])       |            56         |              50           |                57          
      shape torch.Size([8, 32, 32])       |            60         |              50           |                60          
      shape torch.Size([16, 32, 32])      |            60         |              51           |                60          
      shape torch.Size([32, 32, 32])      |           247         |              51           |                61          
      shape torch.Size([64, 32, 32])      |           233         |              51           |                63          
      shape torch.Size([128, 32, 32])     |           236         |              53           |                66          
      shape torch.Size([512, 32, 32])     |           268         |              97           |               193          
      shape torch.Size([1024, 32, 32])    |           317         |             167           |               333          
      shape torch.Size([1, 64, 64])       |           131         |             216           |                99          
      shape torch.Size([2, 64, 64])       |            99         |             220           |                99          
      shape torch.Size([4, 64, 64])       |            99         |             225           |               101          
      shape torch.Size([8, 64, 64])       |           101         |             225           |               102          
      shape torch.Size([16, 64, 64])      |           107         |             230           |               108          
      shape torch.Size([32, 64, 64])      |           440         |             235           |               126          
      shape torch.Size([64, 64, 64])      |           447         |             240           |               155          
      shape torch.Size([128, 64, 64])     |           470         |             289           |               240          
      shape torch.Size([512, 64, 64])     |           793         |             678           |              1180          
      shape torch.Size([1024, 64, 64])    |          1000         |            1300           |              2112          
      shape torch.Size([1, 128, 128])     |           296         |             482           |               309          
      shape torch.Size([2, 128, 128])     |           308         |             499           |               307          
      shape torch.Size([4, 128, 128])     |           311         |             510           |               310          
      shape torch.Size([8, 128, 128])     |           314         |             522           |               314          
      shape torch.Size([16, 128, 128])    |           334         |             541           |               334          
      shape torch.Size([32, 128, 128])    |           770         |             591           |               467          
      shape torch.Size([64, 128, 128])    |           860         |             694           |               733          
      shape torch.Size([128, 128, 128])   |          1040         |             925           |              1980          
      shape torch.Size([512, 128, 128])   |          2883         |            2809           |             11000          
      shape torch.Size([1024, 128, 128])  |          5421         |            5430           |             22360          
      shape torch.Size([1, 256, 256])     |          1310         |            1109           |              1556          
      shape torch.Size([2, 256, 256])     |          1360         |            1150           |              1560          
      shape torch.Size([4, 256, 256])     |          1390         |            1188           |              1569          
      shape torch.Size([8, 256, 256])     |          1440         |            1250           |              1604          
      shape torch.Size([16, 256, 256])    |          1550         |            1390           |              1850          
      shape torch.Size([32, 256, 256])    |          1750         |            1620           |              3332          
      shape torch.Size([64, 256, 256])    |          2327         |            2246           |              6700          
      shape torch.Size([128, 256, 256])   |          3697         |            3638           |             19100          
      shape torch.Size([512, 256, 256])   |         12530         |           12500           |             87300          
      shape torch.Size([1024, 256, 256])  |         24380         |           24420           |            176000          
```

</details>

<details>
<summary>
Benchmark Results all algorithms up to `n=2048`
</summary>

```
[----------------------------------------------------------------- linalg.lu_factor CUDA ------------------------------------------------------------------]
                                          |  lu_factor_magma_batched  |  lu_factor_cusolver_batched  |  lu_factor_cusolver_looped  |  lu_factor_magma_looped
1 threads: -------------------------------------------------------------------------------------------------------------------------------------------------
      shape torch.Size([1, 1, 1])         |               51          |                30            |                27           |            1390
      shape torch.Size([2, 1, 1])         |               42          |                20            |                26           |            2798
      shape torch.Size([4, 1, 1])         |               42          |                20            |                42           |            5589
      shape torch.Size([8, 1, 1])         |               42          |                20            |                72           |           11000
      shape torch.Size([16, 1, 1])        |               42          |                20            |               132           |           22400
      shape torch.Size([32, 1, 1])        |               42          |                20            |               253           |           44620
      shape torch.Size([64, 1, 1])        |               42          |                20            |               496           |           89200
      shape torch.Size([128, 1, 1])       |               42          |                20            |               980           |          180000
      shape torch.Size([512, 1, 1])       |               43          |                20            |              3868           |          714100
      shape torch.Size([1024, 1, 1])      |               44          |                20            |              7800           |         1430000
      shape torch.Size([1, 2, 2])         |               43          |                21            |                19           |            1400
      shape torch.Size([2, 2, 2])         |               42          |                21            |                27           |            2898
      shape torch.Size([4, 2, 2])         |               43          |                21            |                42           |            5800
      shape torch.Size([8, 2, 2])         |               43          |                21            |                73           |           11600
      shape torch.Size([16, 2, 2])        |               43          |                21            |               133           |           23170
      shape torch.Size([32, 2, 2])        |               43          |                21            |               254           |           46290
      shape torch.Size([64, 2, 2])        |               43          |                21            |               500           |           94000
      shape torch.Size([128, 2, 2])       |               43          |                21            |               980           |          190000
      shape torch.Size([512, 2, 2])       |               44          |                21            |              3860           |          741900
      shape torch.Size([1024, 2, 2])      |               44          |                21            |              7640           |         1484000
      shape torch.Size([1, 8, 8])         |               45          |                21            |                19           |            1450
      shape torch.Size([2, 8, 8])         |               45          |                21            |                27           |            2917
      shape torch.Size([4, 8, 8])         |               45          |                21            |                53           |            5800
      shape torch.Size([8, 8, 8])         |               45          |                21            |               105           |           11580
      shape torch.Size([16, 8, 8])        |               45          |                21            |               207           |           23160
      shape torch.Size([32, 8, 8])        |               46          |                21            |               413           |           46400
      shape torch.Size([64, 8, 8])        |               46          |                21            |               824           |           93000
      shape torch.Size([128, 8, 8])       |               46          |                21            |              1645           |          185000
      shape torch.Size([512, 8, 8])       |               47          |                21            |              6574           |          742000
      shape torch.Size([1024, 8, 8])      |               49          |                21            |             13150           |         1481000
      shape torch.Size([1, 16, 16])       |               49          |                21            |                24           |            1460
      shape torch.Size([2, 16, 16])       |               49          |                21            |                46           |            2902
      shape torch.Size([4, 16, 16])       |               49          |                21            |                90           |            5800
      shape torch.Size([8, 16, 16])       |               49          |                21            |               177           |           11600
      shape torch.Size([16, 16, 16])      |               49          |                21            |               352           |           23150
      shape torch.Size([32, 16, 16])      |               49          |                21            |               703           |           46300
      shape torch.Size([64, 16, 16])      |               49          |                21            |              1404           |           92700
      shape torch.Size([128, 16, 16])     |               50          |                21            |              2807           |          185000
      shape torch.Size([512, 16, 16])     |               55          |                29            |             11220           |          741700
      shape torch.Size([1024, 16, 16])    |               64          |                42            |             22440           |         1480000
      shape torch.Size([1, 32, 32])       |               55          |                56            |                58           |            1460
      shape torch.Size([2, 32, 32])       |               55          |                57            |               114           |            2920
      shape torch.Size([4, 32, 32])       |               55          |                57            |               225           |            5830
      shape torch.Size([8, 32, 32])       |               55          |                61            |               449           |           11700
      shape torch.Size([16, 32, 32])      |               56          |                61            |               896           |           23300
      shape torch.Size([32, 32, 32])      |               56          |                62            |              1791           |           46600
      shape torch.Size([64, 32, 32])      |               56          |                63            |              3581           |           93100
      shape torch.Size([128, 32, 32])     |               58          |                66            |              7156           |          186000
      shape torch.Size([512, 32, 32])     |              100          |               194            |             28700           |          742400
      shape torch.Size([1024, 32, 32])    |              169          |               335            |             57620           |         1485000
      shape torch.Size([1, 64, 64])       |              224          |               101            |               132           |            1500
      shape torch.Size([2, 64, 64])       |              227          |               100            |               262           |            2951
      shape torch.Size([4, 64, 64])       |              229          |               101            |               523           |            5890
      shape torch.Size([8, 64, 64])       |              231          |               102            |              1040           |           12000
      shape torch.Size([16, 64, 64])      |              237          |               109            |              2088           |           23530
      shape torch.Size([32, 64, 64])      |              242          |               127            |              4171           |           46900
      shape torch.Size([64, 64, 64])      |              247          |               156            |              8330           |           95000
      shape torch.Size([128, 64, 64])     |              293          |               244            |             16710           |          189000
      shape torch.Size([512, 64, 64])     |              685          |              1180            |             67000           |          750900
      shape torch.Size([1024, 64, 64])    |             1300          |              2076            |            134000           |         1505000
      shape torch.Size([1, 128, 128])     |              490          |               309            |               298           |            1560
      shape torch.Size([2, 128, 128])     |              503          |               309            |               594           |            3120
      shape torch.Size([4, 128, 128])     |              515          |               312            |              1185           |            6230
      shape torch.Size([8, 128, 128])     |              523          |               317            |              2370           |           12500
      shape torch.Size([16, 128, 128])    |              547          |               336            |              4734           |           24890
      shape torch.Size([32, 128, 128])    |              596          |               472            |              9491           |           49800
      shape torch.Size([64, 128, 128])    |              700          |               741            |             19000           |          100000
      shape torch.Size([128, 128, 128])   |              930          |              1770            |             37990           |          199000
      shape torch.Size([512, 128, 128])   |             2810          |             11000            |            152000           |          797100
      shape torch.Size([1024, 128, 128])  |             5430          |             22430            |            303900           |         1595000
      shape torch.Size([1, 256, 256])     |             1120          |              1580            |               666           |            1890
      shape torch.Size([2, 256, 256])     |             1160          |              1574            |              1330           |            3784
      shape torch.Size([4, 256, 256])     |             1190          |              1580            |              2658           |            7570
      shape torch.Size([8, 256, 256])     |             1250          |              1613            |              5325           |           15100
      shape torch.Size([16, 256, 256])    |             1394          |              1880            |             10700           |           30260
      shape torch.Size([32, 256, 256])    |             1633          |              3360            |             21300           |           61000
      shape torch.Size([64, 256, 256])    |             2258          |              6730            |             42600           |          120000
      shape torch.Size([128, 256, 256])   |             3639          |             19200            |             85170           |          242200
      shape torch.Size([512, 256, 256])   |            12600          |             87200            |            340600           |          969000
      shape torch.Size([1024, 256, 256])  |            24530          |            176000            |            681300           |         1943000
      shape torch.Size([1, 512, 512])     |             2557          |              9117            |              1724           |            2577
      shape torch.Size([2, 512, 512])     |             2691          |              9209            |              3464           |            5200
      shape torch.Size([4, 512, 512])     |             2853          |              9860            |              6940           |           10000
      shape torch.Size([8, 512, 512])     |             3153          |             11000            |             13900           |           20570
      shape torch.Size([16, 512, 512])    |             3765          |             13000            |             27720           |           41360
      shape torch.Size([32, 512, 512])    |             5500          |             21400            |             55420           |           82000
      shape torch.Size([64, 512, 512])    |             8790          |             44000            |            111000           |          165000
      shape torch.Size([128, 512, 512])   |            15300          |             98000            |            221700           |          329800
      shape torch.Size([512, 512, 512])   |            55400          |            424100            |            886600           |         1325000
      shape torch.Size([1024, 512, 512])  |           110000          |            856200            |           1773000           |         2691000
      shape torch.Size([1, 1024, 1024])   |            10350          |             69290            |              5020           |            5327
      shape torch.Size([2, 1024, 1024])   |            11200          |             74860            |             10040           |           11000
      shape torch.Size([4, 1024, 1024])   |            12200          |             78030            |             20080           |           21290
      shape torch.Size([8, 1024, 1024])   |            14000          |             81200            |             40160           |           42850
      shape torch.Size([16, 1024, 1024])  |            17700          |             96000            |             80300           |           85500
      shape torch.Size([32, 1024, 1024])  |            27740          |            150000            |            160700           |          171000
      shape torch.Size([64, 1024, 1024])  |            45940          |            233400            |            321200           |          344100
      shape torch.Size([1, 2048, 2048])   |            29860          |            579800            |             12920           |           13500
      shape torch.Size([2, 2048, 2048])   |            34000          |            585000            |             25840           |           26840
      shape torch.Size([4, 2048, 2048])   |            39770          |            593900            |             51670           |           54000
      shape torch.Size([8, 2048, 2048])   |            51720          |            632100            |            103000           |          109000
      shape torch.Size([16, 2048, 2048])  |            76900          |            845500            |            206600           |          218400
      shape torch.Size([32, 2048, 2048])  |           130000          |           1058000            |            413900           |          437300

Times are in microseconds (us).


```

</details>
To generate the results below, I put the backend I wanted to test at the beginning of the function `lu_solve_kernel`, followed by a `return;`. Then I run the following script, changing the variable `name`. 
<details>
<summary>
Benchmarking script
</summary>

```python
import torch
import pickle
import itertools
from functools import partial
from torch.utils.benchmark import Timer, Compare

benchmark_name = "linalg.lu_factor CUDA"
name = "magma_looped"
label = "lu_factor_{}".format(name)
shapes = [1, 2, 8, 16, 32, 64, 128, 256, 512, 1024, 2048]
batches = [(1,), (2,), (4,), (8,), (16,), (32,), (64,), (128,), (512,), (1024,)]
results = []
make_arg = partial(torch.randn, dtype=torch.float32, device="cuda")


for n, batch in itertools.product(shapes, batches):
    if n == 1024 and batch[0] >= 128:
        continue
    if n == 2048 and batch[0] >= 64:
        continue
    A = make_arg(batch + (n, n))
    print(A.shape)
    stmt = "torch.linalg.lu_factor_ex(A)"
    timer = Timer(stmt,
                  globals=globals(),
                  label=benchmark_name,
                  description=label,
                  sub_label=f"shape {A.shape}",
                  num_threads=1)
    results.append(timer.blocked_autorange())

compare = Compare(results)
compare.trim_significant_figures()
compare.print()

with open(f"{label}.pickle", 'wb') as f:
    pickle.dump(results, f)
```

</details>

See #72935 (comment) for the script to join the results.

[ghstack-poisoned]
lezcano added a commit that referenced this pull request Jun 7, 2022
This PR adds getrf_cublas to the functions considered in the heuristics
for lu_solve.

ghstack-source-id: 776519d
Pull Request resolved: #73878
This PR adds getrf_cublas to the functions considered in the heuristics
for `lu_factor`. It also updates the heuristics of the function.

## Benchmark

I'm omitting form the benchmarks the looped versions of the functions as they are much slower than the non-looped ones. The only exception to this is cusolver's looped variant, which is faster when applied to a batch of size one.

<details>
<summary>
Benchmark Results
</summary>

```
[------------------------------------------------- linalg.lu_factor CUDA -------------------------------------------------]                                                                                          
                                          |  lu_factor_heuristic  |  lu_factor_magma_batched  |  lu_factor_cusolver_batched                                                                                          
1 threads: ----------------------------------------------------------------------------------------------------------------                                                                                          
      shape torch.Size([1, 1, 1])         |            26         |              47           |                26                                                                                                    
      shape torch.Size([2, 1, 1])         |            17         |              38           |                17                                                                                                    
      shape torch.Size([4, 1, 1])         |            17         |              38           |                17                                                                                                    
      shape torch.Size([8, 1, 1])         |            20         |              38           |                18                                                                                                    
      shape torch.Size([16, 1, 1])        |            20         |              38           |                17                                                                                                    
      shape torch.Size([32, 1, 1])        |            18         |              38           |                17                                                                                                    
      shape torch.Size([64, 1, 1])        |            18         |              39           |                17                                                                                                    
      shape torch.Size([128, 1, 1])       |            17         |              38           |                17                                                                                                    
      shape torch.Size([512, 1, 1])       |            18         |              39           |                18                                                                                                    
      shape torch.Size([1024, 1, 1])      |            18         |              40           |                18                                                                                                    
      shape torch.Size([1, 2, 2])         |            18         |              38           |                17                                                                                                    
      shape torch.Size([2, 2, 2])         |            17         |              37           |                17                                                                                                    
      shape torch.Size([4, 2, 2])         |            17         |              38           |                17                                                                                                    
      shape torch.Size([8, 2, 2])         |            17         |              38           |                17                                                                                                    
      shape torch.Size([16, 2, 2])        |            17         |              38           |                17                                                                                                    
      shape torch.Size([32, 2, 2])        |            17         |              38           |                17                                                                                                    
      shape torch.Size([64, 2, 2])        |            17         |              38           |                17                                                                                                    
      shape torch.Size([128, 2, 2])       |            17         |              38           |                17                                                                                                    
      shape torch.Size([512, 2, 2])       |            17         |              39           |                17                                                                                                    
      shape torch.Size([1024, 2, 2])      |            17         |              40           |                17                                                                                                    
      shape torch.Size([1, 8, 8])         |            17         |              40           |                17                                                                                                    
      shape torch.Size([2, 8, 8])         |            17         |              40           |                17                                                                                                    
      shape torch.Size([4, 8, 8])         |            17         |              40           |                17                                                                                                    
      shape torch.Size([8, 8, 8])         |            17         |              40           |                17                                                                                                    
      shape torch.Size([16, 8, 8])        |            17         |              41           |                17                                                                                                    
      shape torch.Size([32, 8, 8])        |            17         |              40           |                17                                                                                                    
      shape torch.Size([64, 8, 8])        |            17         |              40           |                17                                                                                                    
      shape torch.Size([128, 8, 8])       |            17         |              40           |                17                                                                                                    
      shape torch.Size([512, 8, 8])       |            17         |              42           |                17                                                                                                    
      shape torch.Size([1024, 8, 8])      |            17         |              44           |                17                                                                                                    
      shape torch.Size([1, 16, 16])       |            24         |              44           |                18                                                                                                    
      shape torch.Size([2, 16, 16])       |            18         |              44           |                18                                                                                                    
      shape torch.Size([4, 16, 16])       |            18         |              45           |                18          
      shape torch.Size([8, 16, 16])       |            19         |              44           |                19          
      shape torch.Size([16, 16, 16])      |            20         |              44           |                20          
      shape torch.Size([32, 16, 16])      |            20         |              45           |                20          
      shape torch.Size([64, 16, 16])      |            20         |              44           |                20          
      shape torch.Size([128, 16, 16])     |            20         |              45           |                20          
      shape torch.Size([512, 16, 16])     |            28         |              50           |                28          
      shape torch.Size([1024, 16, 16])    |            41         |              59           |                41          
      shape torch.Size([1, 32, 32])       |            58         |              50           |                56          
      shape torch.Size([2, 32, 32])       |            56         |              50           |                56          
      shape torch.Size([4, 32, 32])       |            56         |              50           |                57          
      shape torch.Size([8, 32, 32])       |            60         |              50           |                60          
      shape torch.Size([16, 32, 32])      |            60         |              51           |                60          
      shape torch.Size([32, 32, 32])      |           247         |              51           |                61          
      shape torch.Size([64, 32, 32])      |           233         |              51           |                63          
      shape torch.Size([128, 32, 32])     |           236         |              53           |                66          
      shape torch.Size([512, 32, 32])     |           268         |              97           |               193          
      shape torch.Size([1024, 32, 32])    |           317         |             167           |               333          
      shape torch.Size([1, 64, 64])       |           131         |             216           |                99          
      shape torch.Size([2, 64, 64])       |            99         |             220           |                99          
      shape torch.Size([4, 64, 64])       |            99         |             225           |               101          
      shape torch.Size([8, 64, 64])       |           101         |             225           |               102          
      shape torch.Size([16, 64, 64])      |           107         |             230           |               108          
      shape torch.Size([32, 64, 64])      |           440         |             235           |               126          
      shape torch.Size([64, 64, 64])      |           447         |             240           |               155          
      shape torch.Size([128, 64, 64])     |           470         |             289           |               240          
      shape torch.Size([512, 64, 64])     |           793         |             678           |              1180          
      shape torch.Size([1024, 64, 64])    |          1000         |            1300           |              2112          
      shape torch.Size([1, 128, 128])     |           296         |             482           |               309          
      shape torch.Size([2, 128, 128])     |           308         |             499           |               307          
      shape torch.Size([4, 128, 128])     |           311         |             510           |               310          
      shape torch.Size([8, 128, 128])     |           314         |             522           |               314          
      shape torch.Size([16, 128, 128])    |           334         |             541           |               334          
      shape torch.Size([32, 128, 128])    |           770         |             591           |               467          
      shape torch.Size([64, 128, 128])    |           860         |             694           |               733          
      shape torch.Size([128, 128, 128])   |          1040         |             925           |              1980          
      shape torch.Size([512, 128, 128])   |          2883         |            2809           |             11000          
      shape torch.Size([1024, 128, 128])  |          5421         |            5430           |             22360          
      shape torch.Size([1, 256, 256])     |          1310         |            1109           |              1556          
      shape torch.Size([2, 256, 256])     |          1360         |            1150           |              1560          
      shape torch.Size([4, 256, 256])     |          1390         |            1188           |              1569          
      shape torch.Size([8, 256, 256])     |          1440         |            1250           |              1604          
      shape torch.Size([16, 256, 256])    |          1550         |            1390           |              1850          
      shape torch.Size([32, 256, 256])    |          1750         |            1620           |              3332          
      shape torch.Size([64, 256, 256])    |          2327         |            2246           |              6700          
      shape torch.Size([128, 256, 256])   |          3697         |            3638           |             19100          
      shape torch.Size([512, 256, 256])   |         12530         |           12500           |             87300          
      shape torch.Size([1024, 256, 256])  |         24380         |           24420           |            176000          
```

</details>

<details>
<summary>
Benchmark Results all algorithms up to `n=2048`
</summary>

```
[----------------------------------------------------------------- linalg.lu_factor CUDA ------------------------------------------------------------------]
                                          |  lu_factor_magma_batched  |  lu_factor_cusolver_batched  |  lu_factor_cusolver_looped  |  lu_factor_magma_looped
1 threads: -------------------------------------------------------------------------------------------------------------------------------------------------
      shape torch.Size([1, 1, 1])         |               51          |                30            |                27           |            1390
      shape torch.Size([2, 1, 1])         |               42          |                20            |                26           |            2798
      shape torch.Size([4, 1, 1])         |               42          |                20            |                42           |            5589
      shape torch.Size([8, 1, 1])         |               42          |                20            |                72           |           11000
      shape torch.Size([16, 1, 1])        |               42          |                20            |               132           |           22400
      shape torch.Size([32, 1, 1])        |               42          |                20            |               253           |           44620
      shape torch.Size([64, 1, 1])        |               42          |                20            |               496           |           89200
      shape torch.Size([128, 1, 1])       |               42          |                20            |               980           |          180000
      shape torch.Size([512, 1, 1])       |               43          |                20            |              3868           |          714100
      shape torch.Size([1024, 1, 1])      |               44          |                20            |              7800           |         1430000
      shape torch.Size([1, 2, 2])         |               43          |                21            |                19           |            1400
      shape torch.Size([2, 2, 2])         |               42          |                21            |                27           |            2898
      shape torch.Size([4, 2, 2])         |               43          |                21            |                42           |            5800
      shape torch.Size([8, 2, 2])         |               43          |                21            |                73           |           11600
      shape torch.Size([16, 2, 2])        |               43          |                21            |               133           |           23170
      shape torch.Size([32, 2, 2])        |               43          |                21            |               254           |           46290
      shape torch.Size([64, 2, 2])        |               43          |                21            |               500           |           94000
      shape torch.Size([128, 2, 2])       |               43          |                21            |               980           |          190000
      shape torch.Size([512, 2, 2])       |               44          |                21            |              3860           |          741900
      shape torch.Size([1024, 2, 2])      |               44          |                21            |              7640           |         1484000
      shape torch.Size([1, 8, 8])         |               45          |                21            |                19           |            1450
      shape torch.Size([2, 8, 8])         |               45          |                21            |                27           |            2917
      shape torch.Size([4, 8, 8])         |               45          |                21            |                53           |            5800
      shape torch.Size([8, 8, 8])         |               45          |                21            |               105           |           11580
      shape torch.Size([16, 8, 8])        |               45          |                21            |               207           |           23160
      shape torch.Size([32, 8, 8])        |               46          |                21            |               413           |           46400
      shape torch.Size([64, 8, 8])        |               46          |                21            |               824           |           93000
      shape torch.Size([128, 8, 8])       |               46          |                21            |              1645           |          185000
      shape torch.Size([512, 8, 8])       |               47          |                21            |              6574           |          742000
      shape torch.Size([1024, 8, 8])      |               49          |                21            |             13150           |         1481000
      shape torch.Size([1, 16, 16])       |               49          |                21            |                24           |            1460
      shape torch.Size([2, 16, 16])       |               49          |                21            |                46           |            2902
      shape torch.Size([4, 16, 16])       |               49          |                21            |                90           |            5800
      shape torch.Size([8, 16, 16])       |               49          |                21            |               177           |           11600
      shape torch.Size([16, 16, 16])      |               49          |                21            |               352           |           23150
      shape torch.Size([32, 16, 16])      |               49          |                21            |               703           |           46300
      shape torch.Size([64, 16, 16])      |               49          |                21            |              1404           |           92700
      shape torch.Size([128, 16, 16])     |               50          |                21            |              2807           |          185000
      shape torch.Size([512, 16, 16])     |               55          |                29            |             11220           |          741700
      shape torch.Size([1024, 16, 16])    |               64          |                42            |             22440           |         1480000
      shape torch.Size([1, 32, 32])       |               55          |                56            |                58           |            1460
      shape torch.Size([2, 32, 32])       |               55          |                57            |               114           |            2920
      shape torch.Size([4, 32, 32])       |               55          |                57            |               225           |            5830
      shape torch.Size([8, 32, 32])       |               55          |                61            |               449           |           11700
      shape torch.Size([16, 32, 32])      |               56          |                61            |               896           |           23300
      shape torch.Size([32, 32, 32])      |               56          |                62            |              1791           |           46600
      shape torch.Size([64, 32, 32])      |               56          |                63            |              3581           |           93100
      shape torch.Size([128, 32, 32])     |               58          |                66            |              7156           |          186000
      shape torch.Size([512, 32, 32])     |              100          |               194            |             28700           |          742400
      shape torch.Size([1024, 32, 32])    |              169          |               335            |             57620           |         1485000
      shape torch.Size([1, 64, 64])       |              224          |               101            |               132           |            1500
      shape torch.Size([2, 64, 64])       |              227          |               100            |               262           |            2951
      shape torch.Size([4, 64, 64])       |              229          |               101            |               523           |            5890
      shape torch.Size([8, 64, 64])       |              231          |               102            |              1040           |           12000
      shape torch.Size([16, 64, 64])      |              237          |               109            |              2088           |           23530
      shape torch.Size([32, 64, 64])      |              242          |               127            |              4171           |           46900
      shape torch.Size([64, 64, 64])      |              247          |               156            |              8330           |           95000
      shape torch.Size([128, 64, 64])     |              293          |               244            |             16710           |          189000
      shape torch.Size([512, 64, 64])     |              685          |              1180            |             67000           |          750900
      shape torch.Size([1024, 64, 64])    |             1300          |              2076            |            134000           |         1505000
      shape torch.Size([1, 128, 128])     |              490          |               309            |               298           |            1560
      shape torch.Size([2, 128, 128])     |              503          |               309            |               594           |            3120
      shape torch.Size([4, 128, 128])     |              515          |               312            |              1185           |            6230
      shape torch.Size([8, 128, 128])     |              523          |               317            |              2370           |           12500
      shape torch.Size([16, 128, 128])    |              547          |               336            |              4734           |           24890
      shape torch.Size([32, 128, 128])    |              596          |               472            |              9491           |           49800
      shape torch.Size([64, 128, 128])    |              700          |               741            |             19000           |          100000
      shape torch.Size([128, 128, 128])   |              930          |              1770            |             37990           |          199000
      shape torch.Size([512, 128, 128])   |             2810          |             11000            |            152000           |          797100
      shape torch.Size([1024, 128, 128])  |             5430          |             22430            |            303900           |         1595000
      shape torch.Size([1, 256, 256])     |             1120          |              1580            |               666           |            1890
      shape torch.Size([2, 256, 256])     |             1160          |              1574            |              1330           |            3784
      shape torch.Size([4, 256, 256])     |             1190          |              1580            |              2658           |            7570
      shape torch.Size([8, 256, 256])     |             1250          |              1613            |              5325           |           15100
      shape torch.Size([16, 256, 256])    |             1394          |              1880            |             10700           |           30260
      shape torch.Size([32, 256, 256])    |             1633          |              3360            |             21300           |           61000
      shape torch.Size([64, 256, 256])    |             2258          |              6730            |             42600           |          120000
      shape torch.Size([128, 256, 256])   |             3639          |             19200            |             85170           |          242200
      shape torch.Size([512, 256, 256])   |            12600          |             87200            |            340600           |          969000
      shape torch.Size([1024, 256, 256])  |            24530          |            176000            |            681300           |         1943000
      shape torch.Size([1, 512, 512])     |             2557          |              9117            |              1724           |            2577
      shape torch.Size([2, 512, 512])     |             2691          |              9209            |              3464           |            5200
      shape torch.Size([4, 512, 512])     |             2853          |              9860            |              6940           |           10000
      shape torch.Size([8, 512, 512])     |             3153          |             11000            |             13900           |           20570
      shape torch.Size([16, 512, 512])    |             3765          |             13000            |             27720           |           41360
      shape torch.Size([32, 512, 512])    |             5500          |             21400            |             55420           |           82000
      shape torch.Size([64, 512, 512])    |             8790          |             44000            |            111000           |          165000
      shape torch.Size([128, 512, 512])   |            15300          |             98000            |            221700           |          329800
      shape torch.Size([512, 512, 512])   |            55400          |            424100            |            886600           |         1325000
      shape torch.Size([1024, 512, 512])  |           110000          |            856200            |           1773000           |         2691000
      shape torch.Size([1, 1024, 1024])   |            10350          |             69290            |              5020           |            5327
      shape torch.Size([2, 1024, 1024])   |            11200          |             74860            |             10040           |           11000
      shape torch.Size([4, 1024, 1024])   |            12200          |             78030            |             20080           |           21290
      shape torch.Size([8, 1024, 1024])   |            14000          |             81200            |             40160           |           42850
      shape torch.Size([16, 1024, 1024])  |            17700          |             96000            |             80300           |           85500
      shape torch.Size([32, 1024, 1024])  |            27740          |            150000            |            160700           |          171000
      shape torch.Size([64, 1024, 1024])  |            45940          |            233400            |            321200           |          344100
      shape torch.Size([1, 2048, 2048])   |            29860          |            579800            |             12920           |           13500
      shape torch.Size([2, 2048, 2048])   |            34000          |            585000            |             25840           |           26840
      shape torch.Size([4, 2048, 2048])   |            39770          |            593900            |             51670           |           54000
      shape torch.Size([8, 2048, 2048])   |            51720          |            632100            |            103000           |          109000
      shape torch.Size([16, 2048, 2048])  |            76900          |            845500            |            206600           |          218400
      shape torch.Size([32, 2048, 2048])  |           130000          |           1058000            |            413900           |          437300

Times are in microseconds (us).


```

</details>
To generate the results below, I put the backend I wanted to test at the beginning of the function `lu_solve_kernel`, followed by a `return;`. Then I run the following script, changing the variable `name`. 
<details>
<summary>
Benchmarking script
</summary>

```python
import torch
import pickle
import itertools
from functools import partial
from torch.utils.benchmark import Timer, Compare

benchmark_name = "linalg.lu_factor CUDA"
name = "magma_looped"
label = "lu_factor_{}".format(name)
shapes = [1, 2, 8, 16, 32, 64, 128, 256, 512, 1024, 2048]
batches = [(1,), (2,), (4,), (8,), (16,), (32,), (64,), (128,), (512,), (1024,)]
results = []
make_arg = partial(torch.randn, dtype=torch.float32, device="cuda")


for n, batch in itertools.product(shapes, batches):
    if n == 1024 and batch[0] >= 128:
        continue
    if n == 2048 and batch[0] >= 64:
        continue
    A = make_arg(batch + (n, n))
    print(A.shape)
    stmt = "torch.linalg.lu_factor_ex(A)"
    timer = Timer(stmt,
                  globals=globals(),
                  label=benchmark_name,
                  description=label,
                  sub_label=f"shape {A.shape}",
                  num_threads=1)
    results.append(timer.blocked_autorange())

compare = Compare(results)
compare.trim_significant_figures()
compare.print()

with open(f"{label}.pickle", 'wb') as f:
    pickle.dump(results, f)
```

</details>

See #72935 (comment) for the script to join the results.

[ghstack-poisoned]
lezcano added a commit that referenced this pull request Jun 8, 2022
This PR adds getrf_cublas to the functions considered in the heuristics
for lu_solve.

ghstack-source-id: 631d7af
Pull Request resolved: #73878
This PR adds getrf_cublas to the functions considered in the heuristics
for `lu_factor`. It also updates the heuristics of the function.

## Benchmark

I'm omitting form the benchmarks the looped versions of the functions as they are much slower than the non-looped ones. The only exception to this is cusolver's looped variant, which is faster when applied to a batch of size one.

<details>
<summary>
Benchmark Results
</summary>

```
[------------------------------------------------- linalg.lu_factor CUDA -------------------------------------------------]                                                                                          
                                          |  lu_factor_heuristic  |  lu_factor_magma_batched  |  lu_factor_cusolver_batched                                                                                          
1 threads: ----------------------------------------------------------------------------------------------------------------                                                                                          
      shape torch.Size([1, 1, 1])         |            26         |              47           |                26                                                                                                    
      shape torch.Size([2, 1, 1])         |            17         |              38           |                17                                                                                                    
      shape torch.Size([4, 1, 1])         |            17         |              38           |                17                                                                                                    
      shape torch.Size([8, 1, 1])         |            20         |              38           |                18                                                                                                    
      shape torch.Size([16, 1, 1])        |            20         |              38           |                17                                                                                                    
      shape torch.Size([32, 1, 1])        |            18         |              38           |                17                                                                                                    
      shape torch.Size([64, 1, 1])        |            18         |              39           |                17                                                                                                    
      shape torch.Size([128, 1, 1])       |            17         |              38           |                17                                                                                                    
      shape torch.Size([512, 1, 1])       |            18         |              39           |                18                                                                                                    
      shape torch.Size([1024, 1, 1])      |            18         |              40           |                18                                                                                                    
      shape torch.Size([1, 2, 2])         |            18         |              38           |                17                                                                                                    
      shape torch.Size([2, 2, 2])         |            17         |              37           |                17                                                                                                    
      shape torch.Size([4, 2, 2])         |            17         |              38           |                17                                                                                                    
      shape torch.Size([8, 2, 2])         |            17         |              38           |                17                                                                                                    
      shape torch.Size([16, 2, 2])        |            17         |              38           |                17                                                                                                    
      shape torch.Size([32, 2, 2])        |            17         |              38           |                17                                                                                                    
      shape torch.Size([64, 2, 2])        |            17         |              38           |                17                                                                                                    
      shape torch.Size([128, 2, 2])       |            17         |              38           |                17                                                                                                    
      shape torch.Size([512, 2, 2])       |            17         |              39           |                17                                                                                                    
      shape torch.Size([1024, 2, 2])      |            17         |              40           |                17                                                                                                    
      shape torch.Size([1, 8, 8])         |            17         |              40           |                17                                                                                                    
      shape torch.Size([2, 8, 8])         |            17         |              40           |                17                                                                                                    
      shape torch.Size([4, 8, 8])         |            17         |              40           |                17                                                                                                    
      shape torch.Size([8, 8, 8])         |            17         |              40           |                17                                                                                                    
      shape torch.Size([16, 8, 8])        |            17         |              41           |                17                                                                                                    
      shape torch.Size([32, 8, 8])        |            17         |              40           |                17                                                                                                    
      shape torch.Size([64, 8, 8])        |            17         |              40           |                17                                                                                                    
      shape torch.Size([128, 8, 8])       |            17         |              40           |                17                                                                                                    
      shape torch.Size([512, 8, 8])       |            17         |              42           |                17                                                                                                    
      shape torch.Size([1024, 8, 8])      |            17         |              44           |                17                                                                                                    
      shape torch.Size([1, 16, 16])       |            24         |              44           |                18                                                                                                    
      shape torch.Size([2, 16, 16])       |            18         |              44           |                18                                                                                                    
      shape torch.Size([4, 16, 16])       |            18         |              45           |                18          
      shape torch.Size([8, 16, 16])       |            19         |              44           |                19          
      shape torch.Size([16, 16, 16])      |            20         |              44           |                20          
      shape torch.Size([32, 16, 16])      |            20         |              45           |                20          
      shape torch.Size([64, 16, 16])      |            20         |              44           |                20          
      shape torch.Size([128, 16, 16])     |            20         |              45           |                20          
      shape torch.Size([512, 16, 16])     |            28         |              50           |                28          
      shape torch.Size([1024, 16, 16])    |            41         |              59           |                41          
      shape torch.Size([1, 32, 32])       |            58         |              50           |                56          
      shape torch.Size([2, 32, 32])       |            56         |              50           |                56          
      shape torch.Size([4, 32, 32])       |            56         |              50           |                57          
      shape torch.Size([8, 32, 32])       |            60         |              50           |                60          
      shape torch.Size([16, 32, 32])      |            60         |              51           |                60          
      shape torch.Size([32, 32, 32])      |           247         |              51           |                61          
      shape torch.Size([64, 32, 32])      |           233         |              51           |                63          
      shape torch.Size([128, 32, 32])     |           236         |              53           |                66          
      shape torch.Size([512, 32, 32])     |           268         |              97           |               193          
      shape torch.Size([1024, 32, 32])    |           317         |             167           |               333          
      shape torch.Size([1, 64, 64])       |           131         |             216           |                99          
      shape torch.Size([2, 64, 64])       |            99         |             220           |                99          
      shape torch.Size([4, 64, 64])       |            99         |             225           |               101          
      shape torch.Size([8, 64, 64])       |           101         |             225           |               102          
      shape torch.Size([16, 64, 64])      |           107         |             230           |               108          
      shape torch.Size([32, 64, 64])      |           440         |             235           |               126          
      shape torch.Size([64, 64, 64])      |           447         |             240           |               155          
      shape torch.Size([128, 64, 64])     |           470         |             289           |               240          
      shape torch.Size([512, 64, 64])     |           793         |             678           |              1180          
      shape torch.Size([1024, 64, 64])    |          1000         |            1300           |              2112          
      shape torch.Size([1, 128, 128])     |           296         |             482           |               309          
      shape torch.Size([2, 128, 128])     |           308         |             499           |               307          
      shape torch.Size([4, 128, 128])     |           311         |             510           |               310          
      shape torch.Size([8, 128, 128])     |           314         |             522           |               314          
      shape torch.Size([16, 128, 128])    |           334         |             541           |               334          
      shape torch.Size([32, 128, 128])    |           770         |             591           |               467          
      shape torch.Size([64, 128, 128])    |           860         |             694           |               733          
      shape torch.Size([128, 128, 128])   |          1040         |             925           |              1980          
      shape torch.Size([512, 128, 128])   |          2883         |            2809           |             11000          
      shape torch.Size([1024, 128, 128])  |          5421         |            5430           |             22360          
      shape torch.Size([1, 256, 256])     |          1310         |            1109           |              1556          
      shape torch.Size([2, 256, 256])     |          1360         |            1150           |              1560          
      shape torch.Size([4, 256, 256])     |          1390         |            1188           |              1569          
      shape torch.Size([8, 256, 256])     |          1440         |            1250           |              1604          
      shape torch.Size([16, 256, 256])    |          1550         |            1390           |              1850          
      shape torch.Size([32, 256, 256])    |          1750         |            1620           |              3332          
      shape torch.Size([64, 256, 256])    |          2327         |            2246           |              6700          
      shape torch.Size([128, 256, 256])   |          3697         |            3638           |             19100          
      shape torch.Size([512, 256, 256])   |         12530         |           12500           |             87300          
      shape torch.Size([1024, 256, 256])  |         24380         |           24420           |            176000          
```

</details>

<details>
<summary>
Benchmark Results all algorithms up to `n=2048`
</summary>

```
[----------------------------------------------------------------- linalg.lu_factor CUDA ------------------------------------------------------------------]
                                          |  lu_factor_magma_batched  |  lu_factor_cusolver_batched  |  lu_factor_cusolver_looped  |  lu_factor_magma_looped
1 threads: -------------------------------------------------------------------------------------------------------------------------------------------------
      shape torch.Size([1, 1, 1])         |               51          |                30            |                27           |            1390
      shape torch.Size([2, 1, 1])         |               42          |                20            |                26           |            2798
      shape torch.Size([4, 1, 1])         |               42          |                20            |                42           |            5589
      shape torch.Size([8, 1, 1])         |               42          |                20            |                72           |           11000
      shape torch.Size([16, 1, 1])        |               42          |                20            |               132           |           22400
      shape torch.Size([32, 1, 1])        |               42          |                20            |               253           |           44620
      shape torch.Size([64, 1, 1])        |               42          |                20            |               496           |           89200
      shape torch.Size([128, 1, 1])       |               42          |                20            |               980           |          180000
      shape torch.Size([512, 1, 1])       |               43          |                20            |              3868           |          714100
      shape torch.Size([1024, 1, 1])      |               44          |                20            |              7800           |         1430000
      shape torch.Size([1, 2, 2])         |               43          |                21            |                19           |            1400
      shape torch.Size([2, 2, 2])         |               42          |                21            |                27           |            2898
      shape torch.Size([4, 2, 2])         |               43          |                21            |                42           |            5800
      shape torch.Size([8, 2, 2])         |               43          |                21            |                73           |           11600
      shape torch.Size([16, 2, 2])        |               43          |                21            |               133           |           23170
      shape torch.Size([32, 2, 2])        |               43          |                21            |               254           |           46290
      shape torch.Size([64, 2, 2])        |               43          |                21            |               500           |           94000
      shape torch.Size([128, 2, 2])       |               43          |                21            |               980           |          190000
      shape torch.Size([512, 2, 2])       |               44          |                21            |              3860           |          741900
      shape torch.Size([1024, 2, 2])      |               44          |                21            |              7640           |         1484000
      shape torch.Size([1, 8, 8])         |               45          |                21            |                19           |            1450
      shape torch.Size([2, 8, 8])         |               45          |                21            |                27           |            2917
      shape torch.Size([4, 8, 8])         |               45          |                21            |                53           |            5800
      shape torch.Size([8, 8, 8])         |               45          |                21            |               105           |           11580
      shape torch.Size([16, 8, 8])        |               45          |                21            |               207           |           23160
      shape torch.Size([32, 8, 8])        |               46          |                21            |               413           |           46400
      shape torch.Size([64, 8, 8])        |               46          |                21            |               824           |           93000
      shape torch.Size([128, 8, 8])       |               46          |                21            |              1645           |          185000
      shape torch.Size([512, 8, 8])       |               47          |                21            |              6574           |          742000
      shape torch.Size([1024, 8, 8])      |               49          |                21            |             13150           |         1481000
      shape torch.Size([1, 16, 16])       |               49          |                21            |                24           |            1460
      shape torch.Size([2, 16, 16])       |               49          |                21            |                46           |            2902
      shape torch.Size([4, 16, 16])       |               49          |                21            |                90           |            5800
      shape torch.Size([8, 16, 16])       |               49          |                21            |               177           |           11600
      shape torch.Size([16, 16, 16])      |               49          |                21            |               352           |           23150
      shape torch.Size([32, 16, 16])      |               49          |                21            |               703           |           46300
      shape torch.Size([64, 16, 16])      |               49          |                21            |              1404           |           92700
      shape torch.Size([128, 16, 16])     |               50          |                21            |              2807           |          185000
      shape torch.Size([512, 16, 16])     |               55          |                29            |             11220           |          741700
      shape torch.Size([1024, 16, 16])    |               64          |                42            |             22440           |         1480000
      shape torch.Size([1, 32, 32])       |               55          |                56            |                58           |            1460
      shape torch.Size([2, 32, 32])       |               55          |                57            |               114           |            2920
      shape torch.Size([4, 32, 32])       |               55          |                57            |               225           |            5830
      shape torch.Size([8, 32, 32])       |               55          |                61            |               449           |           11700
      shape torch.Size([16, 32, 32])      |               56          |                61            |               896           |           23300
      shape torch.Size([32, 32, 32])      |               56          |                62            |              1791           |           46600
      shape torch.Size([64, 32, 32])      |               56          |                63            |              3581           |           93100
      shape torch.Size([128, 32, 32])     |               58          |                66            |              7156           |          186000
      shape torch.Size([512, 32, 32])     |              100          |               194            |             28700           |          742400
      shape torch.Size([1024, 32, 32])    |              169          |               335            |             57620           |         1485000
      shape torch.Size([1, 64, 64])       |              224          |               101            |               132           |            1500
      shape torch.Size([2, 64, 64])       |              227          |               100            |               262           |            2951
      shape torch.Size([4, 64, 64])       |              229          |               101            |               523           |            5890
      shape torch.Size([8, 64, 64])       |              231          |               102            |              1040           |           12000
      shape torch.Size([16, 64, 64])      |              237          |               109            |              2088           |           23530
      shape torch.Size([32, 64, 64])      |              242          |               127            |              4171           |           46900
      shape torch.Size([64, 64, 64])      |              247          |               156            |              8330           |           95000
      shape torch.Size([128, 64, 64])     |              293          |               244            |             16710           |          189000
      shape torch.Size([512, 64, 64])     |              685          |              1180            |             67000           |          750900
      shape torch.Size([1024, 64, 64])    |             1300          |              2076            |            134000           |         1505000
      shape torch.Size([1, 128, 128])     |              490          |               309            |               298           |            1560
      shape torch.Size([2, 128, 128])     |              503          |               309            |               594           |            3120
      shape torch.Size([4, 128, 128])     |              515          |               312            |              1185           |            6230
      shape torch.Size([8, 128, 128])     |              523          |               317            |              2370           |           12500
      shape torch.Size([16, 128, 128])    |              547          |               336            |              4734           |           24890
      shape torch.Size([32, 128, 128])    |              596          |               472            |              9491           |           49800
      shape torch.Size([64, 128, 128])    |              700          |               741            |             19000           |          100000
      shape torch.Size([128, 128, 128])   |              930          |              1770            |             37990           |          199000
      shape torch.Size([512, 128, 128])   |             2810          |             11000            |            152000           |          797100
      shape torch.Size([1024, 128, 128])  |             5430          |             22430            |            303900           |         1595000
      shape torch.Size([1, 256, 256])     |             1120          |              1580            |               666           |            1890
      shape torch.Size([2, 256, 256])     |             1160          |              1574            |              1330           |            3784
      shape torch.Size([4, 256, 256])     |             1190          |              1580            |              2658           |            7570
      shape torch.Size([8, 256, 256])     |             1250          |              1613            |              5325           |           15100
      shape torch.Size([16, 256, 256])    |             1394          |              1880            |             10700           |           30260
      shape torch.Size([32, 256, 256])    |             1633          |              3360            |             21300           |           61000
      shape torch.Size([64, 256, 256])    |             2258          |              6730            |             42600           |          120000
      shape torch.Size([128, 256, 256])   |             3639          |             19200            |             85170           |          242200
      shape torch.Size([512, 256, 256])   |            12600          |             87200            |            340600           |          969000
      shape torch.Size([1024, 256, 256])  |            24530          |            176000            |            681300           |         1943000
      shape torch.Size([1, 512, 512])     |             2557          |              9117            |              1724           |            2577
      shape torch.Size([2, 512, 512])     |             2691          |              9209            |              3464           |            5200
      shape torch.Size([4, 512, 512])     |             2853          |              9860            |              6940           |           10000
      shape torch.Size([8, 512, 512])     |             3153          |             11000            |             13900           |           20570
      shape torch.Size([16, 512, 512])    |             3765          |             13000            |             27720           |           41360
      shape torch.Size([32, 512, 512])    |             5500          |             21400            |             55420           |           82000
      shape torch.Size([64, 512, 512])    |             8790          |             44000            |            111000           |          165000
      shape torch.Size([128, 512, 512])   |            15300          |             98000            |            221700           |          329800
      shape torch.Size([512, 512, 512])   |            55400          |            424100            |            886600           |         1325000
      shape torch.Size([1024, 512, 512])  |           110000          |            856200            |           1773000           |         2691000
      shape torch.Size([1, 1024, 1024])   |            10350          |             69290            |              5020           |            5327
      shape torch.Size([2, 1024, 1024])   |            11200          |             74860            |             10040           |           11000
      shape torch.Size([4, 1024, 1024])   |            12200          |             78030            |             20080           |           21290
      shape torch.Size([8, 1024, 1024])   |            14000          |             81200            |             40160           |           42850
      shape torch.Size([16, 1024, 1024])  |            17700          |             96000            |             80300           |           85500
      shape torch.Size([32, 1024, 1024])  |            27740          |            150000            |            160700           |          171000
      shape torch.Size([64, 1024, 1024])  |            45940          |            233400            |            321200           |          344100
      shape torch.Size([1, 2048, 2048])   |            29860          |            579800            |             12920           |           13500
      shape torch.Size([2, 2048, 2048])   |            34000          |            585000            |             25840           |           26840
      shape torch.Size([4, 2048, 2048])   |            39770          |            593900            |             51670           |           54000
      shape torch.Size([8, 2048, 2048])   |            51720          |            632100            |            103000           |          109000
      shape torch.Size([16, 2048, 2048])  |            76900          |            845500            |            206600           |          218400
      shape torch.Size([32, 2048, 2048])  |           130000          |           1058000            |            413900           |          437300

Times are in microseconds (us).


```

</details>
To generate the results below, I put the backend I wanted to test at the beginning of the function `lu_solve_kernel`, followed by a `return;`. Then I run the following script, changing the variable `name`. 
<details>
<summary>
Benchmarking script
</summary>

```python
import torch
import pickle
import itertools
from functools import partial
from torch.utils.benchmark import Timer, Compare

benchmark_name = "linalg.lu_factor CUDA"
name = "magma_looped"
label = "lu_factor_{}".format(name)
shapes = [1, 2, 8, 16, 32, 64, 128, 256, 512, 1024, 2048]
batches = [(1,), (2,), (4,), (8,), (16,), (32,), (64,), (128,), (512,), (1024,)]
results = []
make_arg = partial(torch.randn, dtype=torch.float32, device="cuda")


for n, batch in itertools.product(shapes, batches):
    if n == 1024 and batch[0] >= 128:
        continue
    if n == 2048 and batch[0] >= 64:
        continue
    A = make_arg(batch + (n, n))
    print(A.shape)
    stmt = "torch.linalg.lu_factor_ex(A)"
    timer = Timer(stmt,
                  globals=globals(),
                  label=benchmark_name,
                  description=label,
                  sub_label=f"shape {A.shape}",
                  num_threads=1)
    results.append(timer.blocked_autorange())

compare = Compare(results)
compare.trim_significant_figures()
compare.print()

with open(f"{label}.pickle", 'wb') as f:
    pickle.dump(results, f)
```

</details>

See #72935 (comment) for the script to join the results.

[ghstack-poisoned]
lezcano added a commit that referenced this pull request Jun 9, 2022
This PR adds getrf_cublas to the functions considered in the heuristics
for lu_solve.

ghstack-source-id: c2138db
Pull Request resolved: #73878
Copy link
Copy Markdown
Collaborator

@mruberry mruberry left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Stamped!

@lezcano
Copy link
Copy Markdown
Collaborator Author

lezcano commented Jun 10, 2022

@pytorchbot merge

@pytorchmergebot
Copy link
Copy Markdown
Collaborator

@pytorchbot successfully started a merge job. Check the current status here

@pytorchmergebot
Copy link
Copy Markdown
Collaborator

Merge failed due to Command git -C /home/runner/actions-runner/_work/pytorch/pytorch cherry-pick -x a7a8b2ecc09d3e8cdad8dcb6f3c37f83f9387218 returned non-zero exit code 1

Auto-merging aten/src/ATen/native/BatchLinearAlgebra.cpp
Auto-merging aten/src/ATen/native/cuda/linalg/BatchLinearAlgebra.cpp
Auto-merging aten/src/ATen/native/cuda/linalg/BatchLinearAlgebraLib.h
Auto-merging test/test_linalg.py
CONFLICT (content): Merge conflict in test/test_linalg.py
error: could not apply a7a8b2ecc0... Update and improve the heuristics for linalg.lu_solve
hint: After resolving the conflicts, mark them with
hint: "git add/rm <pathspec>", then run
hint: "git cherry-pick --continue".
hint: You can instead skip this commit with "git cherry-pick --skip".
hint: To abort and get back to the state before "git cherry-pick",
hint: run "git cherry-pick --abort".

Raised by https://github.com/pytorch/pytorch/actions/runs/2476746342

This PR adds getrf_cublas to the functions considered in the heuristics
for `lu_factor`. It also updates the heuristics of the function.

## Benchmark

I'm omitting form the benchmarks the looped versions of the functions as they are much slower than the non-looped ones. The only exception to this is cusolver's looped variant, which is faster when applied to a batch of size one.

<details>
<summary>
Benchmark Results
</summary>

```
[------------------------------------------------- linalg.lu_factor CUDA -------------------------------------------------]                                                                                          
                                          |  lu_factor_heuristic  |  lu_factor_magma_batched  |  lu_factor_cusolver_batched                                                                                          
1 threads: ----------------------------------------------------------------------------------------------------------------                                                                                          
      shape torch.Size([1, 1, 1])         |            26         |              47           |                26                                                                                                    
      shape torch.Size([2, 1, 1])         |            17         |              38           |                17                                                                                                    
      shape torch.Size([4, 1, 1])         |            17         |              38           |                17                                                                                                    
      shape torch.Size([8, 1, 1])         |            20         |              38           |                18                                                                                                    
      shape torch.Size([16, 1, 1])        |            20         |              38           |                17                                                                                                    
      shape torch.Size([32, 1, 1])        |            18         |              38           |                17                                                                                                    
      shape torch.Size([64, 1, 1])        |            18         |              39           |                17                                                                                                    
      shape torch.Size([128, 1, 1])       |            17         |              38           |                17                                                                                                    
      shape torch.Size([512, 1, 1])       |            18         |              39           |                18                                                                                                    
      shape torch.Size([1024, 1, 1])      |            18         |              40           |                18                                                                                                    
      shape torch.Size([1, 2, 2])         |            18         |              38           |                17                                                                                                    
      shape torch.Size([2, 2, 2])         |            17         |              37           |                17                                                                                                    
      shape torch.Size([4, 2, 2])         |            17         |              38           |                17                                                                                                    
      shape torch.Size([8, 2, 2])         |            17         |              38           |                17                                                                                                    
      shape torch.Size([16, 2, 2])        |            17         |              38           |                17                                                                                                    
      shape torch.Size([32, 2, 2])        |            17         |              38           |                17                                                                                                    
      shape torch.Size([64, 2, 2])        |            17         |              38           |                17                                                                                                    
      shape torch.Size([128, 2, 2])       |            17         |              38           |                17                                                                                                    
      shape torch.Size([512, 2, 2])       |            17         |              39           |                17                                                                                                    
      shape torch.Size([1024, 2, 2])      |            17         |              40           |                17                                                                                                    
      shape torch.Size([1, 8, 8])         |            17         |              40           |                17                                                                                                    
      shape torch.Size([2, 8, 8])         |            17         |              40           |                17                                                                                                    
      shape torch.Size([4, 8, 8])         |            17         |              40           |                17                                                                                                    
      shape torch.Size([8, 8, 8])         |            17         |              40           |                17                                                                                                    
      shape torch.Size([16, 8, 8])        |            17         |              41           |                17                                                                                                    
      shape torch.Size([32, 8, 8])        |            17         |              40           |                17                                                                                                    
      shape torch.Size([64, 8, 8])        |            17         |              40           |                17                                                                                                    
      shape torch.Size([128, 8, 8])       |            17         |              40           |                17                                                                                                    
      shape torch.Size([512, 8, 8])       |            17         |              42           |                17                                                                                                    
      shape torch.Size([1024, 8, 8])      |            17         |              44           |                17                                                                                                    
      shape torch.Size([1, 16, 16])       |            24         |              44           |                18                                                                                                    
      shape torch.Size([2, 16, 16])       |            18         |              44           |                18                                                                                                    
      shape torch.Size([4, 16, 16])       |            18         |              45           |                18          
      shape torch.Size([8, 16, 16])       |            19         |              44           |                19          
      shape torch.Size([16, 16, 16])      |            20         |              44           |                20          
      shape torch.Size([32, 16, 16])      |            20         |              45           |                20          
      shape torch.Size([64, 16, 16])      |            20         |              44           |                20          
      shape torch.Size([128, 16, 16])     |            20         |              45           |                20          
      shape torch.Size([512, 16, 16])     |            28         |              50           |                28          
      shape torch.Size([1024, 16, 16])    |            41         |              59           |                41          
      shape torch.Size([1, 32, 32])       |            58         |              50           |                56          
      shape torch.Size([2, 32, 32])       |            56         |              50           |                56          
      shape torch.Size([4, 32, 32])       |            56         |              50           |                57          
      shape torch.Size([8, 32, 32])       |            60         |              50           |                60          
      shape torch.Size([16, 32, 32])      |            60         |              51           |                60          
      shape torch.Size([32, 32, 32])      |           247         |              51           |                61          
      shape torch.Size([64, 32, 32])      |           233         |              51           |                63          
      shape torch.Size([128, 32, 32])     |           236         |              53           |                66          
      shape torch.Size([512, 32, 32])     |           268         |              97           |               193          
      shape torch.Size([1024, 32, 32])    |           317         |             167           |               333          
      shape torch.Size([1, 64, 64])       |           131         |             216           |                99          
      shape torch.Size([2, 64, 64])       |            99         |             220           |                99          
      shape torch.Size([4, 64, 64])       |            99         |             225           |               101          
      shape torch.Size([8, 64, 64])       |           101         |             225           |               102          
      shape torch.Size([16, 64, 64])      |           107         |             230           |               108          
      shape torch.Size([32, 64, 64])      |           440         |             235           |               126          
      shape torch.Size([64, 64, 64])      |           447         |             240           |               155          
      shape torch.Size([128, 64, 64])     |           470         |             289           |               240          
      shape torch.Size([512, 64, 64])     |           793         |             678           |              1180          
      shape torch.Size([1024, 64, 64])    |          1000         |            1300           |              2112          
      shape torch.Size([1, 128, 128])     |           296         |             482           |               309          
      shape torch.Size([2, 128, 128])     |           308         |             499           |               307          
      shape torch.Size([4, 128, 128])     |           311         |             510           |               310          
      shape torch.Size([8, 128, 128])     |           314         |             522           |               314          
      shape torch.Size([16, 128, 128])    |           334         |             541           |               334          
      shape torch.Size([32, 128, 128])    |           770         |             591           |               467          
      shape torch.Size([64, 128, 128])    |           860         |             694           |               733          
      shape torch.Size([128, 128, 128])   |          1040         |             925           |              1980          
      shape torch.Size([512, 128, 128])   |          2883         |            2809           |             11000          
      shape torch.Size([1024, 128, 128])  |          5421         |            5430           |             22360          
      shape torch.Size([1, 256, 256])     |          1310         |            1109           |              1556          
      shape torch.Size([2, 256, 256])     |          1360         |            1150           |              1560          
      shape torch.Size([4, 256, 256])     |          1390         |            1188           |              1569          
      shape torch.Size([8, 256, 256])     |          1440         |            1250           |              1604          
      shape torch.Size([16, 256, 256])    |          1550         |            1390           |              1850          
      shape torch.Size([32, 256, 256])    |          1750         |            1620           |              3332          
      shape torch.Size([64, 256, 256])    |          2327         |            2246           |              6700          
      shape torch.Size([128, 256, 256])   |          3697         |            3638           |             19100          
      shape torch.Size([512, 256, 256])   |         12530         |           12500           |             87300          
      shape torch.Size([1024, 256, 256])  |         24380         |           24420           |            176000          
```

</details>

<details>
<summary>
Benchmark Results all algorithms up to `n=2048`
</summary>

```
[----------------------------------------------------------------- linalg.lu_factor CUDA ------------------------------------------------------------------]
                                          |  lu_factor_magma_batched  |  lu_factor_cusolver_batched  |  lu_factor_cusolver_looped  |  lu_factor_magma_looped
1 threads: -------------------------------------------------------------------------------------------------------------------------------------------------
      shape torch.Size([1, 1, 1])         |               51          |                30            |                27           |            1390
      shape torch.Size([2, 1, 1])         |               42          |                20            |                26           |            2798
      shape torch.Size([4, 1, 1])         |               42          |                20            |                42           |            5589
      shape torch.Size([8, 1, 1])         |               42          |                20            |                72           |           11000
      shape torch.Size([16, 1, 1])        |               42          |                20            |               132           |           22400
      shape torch.Size([32, 1, 1])        |               42          |                20            |               253           |           44620
      shape torch.Size([64, 1, 1])        |               42          |                20            |               496           |           89200
      shape torch.Size([128, 1, 1])       |               42          |                20            |               980           |          180000
      shape torch.Size([512, 1, 1])       |               43          |                20            |              3868           |          714100
      shape torch.Size([1024, 1, 1])      |               44          |                20            |              7800           |         1430000
      shape torch.Size([1, 2, 2])         |               43          |                21            |                19           |            1400
      shape torch.Size([2, 2, 2])         |               42          |                21            |                27           |            2898
      shape torch.Size([4, 2, 2])         |               43          |                21            |                42           |            5800
      shape torch.Size([8, 2, 2])         |               43          |                21            |                73           |           11600
      shape torch.Size([16, 2, 2])        |               43          |                21            |               133           |           23170
      shape torch.Size([32, 2, 2])        |               43          |                21            |               254           |           46290
      shape torch.Size([64, 2, 2])        |               43          |                21            |               500           |           94000
      shape torch.Size([128, 2, 2])       |               43          |                21            |               980           |          190000
      shape torch.Size([512, 2, 2])       |               44          |                21            |              3860           |          741900
      shape torch.Size([1024, 2, 2])      |               44          |                21            |              7640           |         1484000
      shape torch.Size([1, 8, 8])         |               45          |                21            |                19           |            1450
      shape torch.Size([2, 8, 8])         |               45          |                21            |                27           |            2917
      shape torch.Size([4, 8, 8])         |               45          |                21            |                53           |            5800
      shape torch.Size([8, 8, 8])         |               45          |                21            |               105           |           11580
      shape torch.Size([16, 8, 8])        |               45          |                21            |               207           |           23160
      shape torch.Size([32, 8, 8])        |               46          |                21            |               413           |           46400
      shape torch.Size([64, 8, 8])        |               46          |                21            |               824           |           93000
      shape torch.Size([128, 8, 8])       |               46          |                21            |              1645           |          185000
      shape torch.Size([512, 8, 8])       |               47          |                21            |              6574           |          742000
      shape torch.Size([1024, 8, 8])      |               49          |                21            |             13150           |         1481000
      shape torch.Size([1, 16, 16])       |               49          |                21            |                24           |            1460
      shape torch.Size([2, 16, 16])       |               49          |                21            |                46           |            2902
      shape torch.Size([4, 16, 16])       |               49          |                21            |                90           |            5800
      shape torch.Size([8, 16, 16])       |               49          |                21            |               177           |           11600
      shape torch.Size([16, 16, 16])      |               49          |                21            |               352           |           23150
      shape torch.Size([32, 16, 16])      |               49          |                21            |               703           |           46300
      shape torch.Size([64, 16, 16])      |               49          |                21            |              1404           |           92700
      shape torch.Size([128, 16, 16])     |               50          |                21            |              2807           |          185000
      shape torch.Size([512, 16, 16])     |               55          |                29            |             11220           |          741700
      shape torch.Size([1024, 16, 16])    |               64          |                42            |             22440           |         1480000
      shape torch.Size([1, 32, 32])       |               55          |                56            |                58           |            1460
      shape torch.Size([2, 32, 32])       |               55          |                57            |               114           |            2920
      shape torch.Size([4, 32, 32])       |               55          |                57            |               225           |            5830
      shape torch.Size([8, 32, 32])       |               55          |                61            |               449           |           11700
      shape torch.Size([16, 32, 32])      |               56          |                61            |               896           |           23300
      shape torch.Size([32, 32, 32])      |               56          |                62            |              1791           |           46600
      shape torch.Size([64, 32, 32])      |               56          |                63            |              3581           |           93100
      shape torch.Size([128, 32, 32])     |               58          |                66            |              7156           |          186000
      shape torch.Size([512, 32, 32])     |              100          |               194            |             28700           |          742400
      shape torch.Size([1024, 32, 32])    |              169          |               335            |             57620           |         1485000
      shape torch.Size([1, 64, 64])       |              224          |               101            |               132           |            1500
      shape torch.Size([2, 64, 64])       |              227          |               100            |               262           |            2951
      shape torch.Size([4, 64, 64])       |              229          |               101            |               523           |            5890
      shape torch.Size([8, 64, 64])       |              231          |               102            |              1040           |           12000
      shape torch.Size([16, 64, 64])      |              237          |               109            |              2088           |           23530
      shape torch.Size([32, 64, 64])      |              242          |               127            |              4171           |           46900
      shape torch.Size([64, 64, 64])      |              247          |               156            |              8330           |           95000
      shape torch.Size([128, 64, 64])     |              293          |               244            |             16710           |          189000
      shape torch.Size([512, 64, 64])     |              685          |              1180            |             67000           |          750900
      shape torch.Size([1024, 64, 64])    |             1300          |              2076            |            134000           |         1505000
      shape torch.Size([1, 128, 128])     |              490          |               309            |               298           |            1560
      shape torch.Size([2, 128, 128])     |              503          |               309            |               594           |            3120
      shape torch.Size([4, 128, 128])     |              515          |               312            |              1185           |            6230
      shape torch.Size([8, 128, 128])     |              523          |               317            |              2370           |           12500
      shape torch.Size([16, 128, 128])    |              547          |               336            |              4734           |           24890
      shape torch.Size([32, 128, 128])    |              596          |               472            |              9491           |           49800
      shape torch.Size([64, 128, 128])    |              700          |               741            |             19000           |          100000
      shape torch.Size([128, 128, 128])   |              930          |              1770            |             37990           |          199000
      shape torch.Size([512, 128, 128])   |             2810          |             11000            |            152000           |          797100
      shape torch.Size([1024, 128, 128])  |             5430          |             22430            |            303900           |         1595000
      shape torch.Size([1, 256, 256])     |             1120          |              1580            |               666           |            1890
      shape torch.Size([2, 256, 256])     |             1160          |              1574            |              1330           |            3784
      shape torch.Size([4, 256, 256])     |             1190          |              1580            |              2658           |            7570
      shape torch.Size([8, 256, 256])     |             1250          |              1613            |              5325           |           15100
      shape torch.Size([16, 256, 256])    |             1394          |              1880            |             10700           |           30260
      shape torch.Size([32, 256, 256])    |             1633          |              3360            |             21300           |           61000
      shape torch.Size([64, 256, 256])    |             2258          |              6730            |             42600           |          120000
      shape torch.Size([128, 256, 256])   |             3639          |             19200            |             85170           |          242200
      shape torch.Size([512, 256, 256])   |            12600          |             87200            |            340600           |          969000
      shape torch.Size([1024, 256, 256])  |            24530          |            176000            |            681300           |         1943000
      shape torch.Size([1, 512, 512])     |             2557          |              9117            |              1724           |            2577
      shape torch.Size([2, 512, 512])     |             2691          |              9209            |              3464           |            5200
      shape torch.Size([4, 512, 512])     |             2853          |              9860            |              6940           |           10000
      shape torch.Size([8, 512, 512])     |             3153          |             11000            |             13900           |           20570
      shape torch.Size([16, 512, 512])    |             3765          |             13000            |             27720           |           41360
      shape torch.Size([32, 512, 512])    |             5500          |             21400            |             55420           |           82000
      shape torch.Size([64, 512, 512])    |             8790          |             44000            |            111000           |          165000
      shape torch.Size([128, 512, 512])   |            15300          |             98000            |            221700           |          329800
      shape torch.Size([512, 512, 512])   |            55400          |            424100            |            886600           |         1325000
      shape torch.Size([1024, 512, 512])  |           110000          |            856200            |           1773000           |         2691000
      shape torch.Size([1, 1024, 1024])   |            10350          |             69290            |              5020           |            5327
      shape torch.Size([2, 1024, 1024])   |            11200          |             74860            |             10040           |           11000
      shape torch.Size([4, 1024, 1024])   |            12200          |             78030            |             20080           |           21290
      shape torch.Size([8, 1024, 1024])   |            14000          |             81200            |             40160           |           42850
      shape torch.Size([16, 1024, 1024])  |            17700          |             96000            |             80300           |           85500
      shape torch.Size([32, 1024, 1024])  |            27740          |            150000            |            160700           |          171000
      shape torch.Size([64, 1024, 1024])  |            45940          |            233400            |            321200           |          344100
      shape torch.Size([1, 2048, 2048])   |            29860          |            579800            |             12920           |           13500
      shape torch.Size([2, 2048, 2048])   |            34000          |            585000            |             25840           |           26840
      shape torch.Size([4, 2048, 2048])   |            39770          |            593900            |             51670           |           54000
      shape torch.Size([8, 2048, 2048])   |            51720          |            632100            |            103000           |          109000
      shape torch.Size([16, 2048, 2048])  |            76900          |            845500            |            206600           |          218400
      shape torch.Size([32, 2048, 2048])  |           130000          |           1058000            |            413900           |          437300

Times are in microseconds (us).


```

</details>
To generate the results below, I put the backend I wanted to test at the beginning of the function `lu_solve_kernel`, followed by a `return;`. Then I run the following script, changing the variable `name`. 
<details>
<summary>
Benchmarking script
</summary>

```python
import torch
import pickle
import itertools
from functools import partial
from torch.utils.benchmark import Timer, Compare

benchmark_name = "linalg.lu_factor CUDA"
name = "magma_looped"
label = "lu_factor_{}".format(name)
shapes = [1, 2, 8, 16, 32, 64, 128, 256, 512, 1024, 2048]
batches = [(1,), (2,), (4,), (8,), (16,), (32,), (64,), (128,), (512,), (1024,)]
results = []
make_arg = partial(torch.randn, dtype=torch.float32, device="cuda")


for n, batch in itertools.product(shapes, batches):
    if n == 1024 and batch[0] >= 128:
        continue
    if n == 2048 and batch[0] >= 64:
        continue
    A = make_arg(batch + (n, n))
    print(A.shape)
    stmt = "torch.linalg.lu_factor_ex(A)"
    timer = Timer(stmt,
                  globals=globals(),
                  label=benchmark_name,
                  description=label,
                  sub_label=f"shape {A.shape}",
                  num_threads=1)
    results.append(timer.blocked_autorange())

compare = Compare(results)
compare.trim_significant_figures()
compare.print()

with open(f"{label}.pickle", 'wb') as f:
    pickle.dump(results, f)
```

</details>

See #72935 (comment) for the script to join the results.

[ghstack-poisoned]
lezcano added a commit that referenced this pull request Jun 10, 2022
This PR adds getrf_cublas to the functions considered in the heuristics
for lu_solve.

ghstack-source-id: d3ed8a0
Pull Request resolved: #73878
@lezcano
Copy link
Copy Markdown
Collaborator Author

lezcano commented Jun 11, 2022

@pytorchbot merge

@pytorchmergebot
Copy link
Copy Markdown
Collaborator

@pytorchbot successfully started a merge job. Check the current status here

@github-actions
Copy link
Copy Markdown
Contributor

Hey @lezcano.
You've committed this PR, but it does not have both a 'release notes: ...' and 'topics: ...' label. Please add one of each to the PR. The 'release notes: ...' label should represent the part of PyTorch that this PR changes (fx, autograd, distributed, etc) and the 'topics: ...' label should represent the kind of PR it is (not user facing, new feature, bug fix, perf improvement, etc). The list of valid labels can be found here for the 'release notes: ...' and here for the 'topics: ...'.
For changes that are 'topic: not user facing' there is no need for a release notes label.

@lezcano lezcano added the topic: not user facing topic category label Jun 11, 2022
@facebook-github-bot facebook-github-bot deleted the gh/Lezcano/52/head branch June 14, 2022 14:16
facebook-github-bot pushed a commit that referenced this pull request Jun 14, 2022
Summary:
This PR adds getrf_cublas to the functions considered in the heuristics
for lu_solve.

Pull Request resolved: #73878

Approved by: https://github.com/nikitaved, https://github.com/IvanYashchuk, https://github.com/mruberry

Test Plan: contbuild & OSS CI, see https://hud.pytorch.org/commit/pytorch/pytorch/9fc2518a8a24e07fe3a9c049ece6dc5bfc6d8641

Reviewed By: osalpekar

Differential Revision: D37089127

Pulled By: osalpekar

fbshipit-source-id: b9346418f87bd0c41c325f992e0f5675a85828fc
laurentdupin pushed a commit to laurentdupin/pytorch that referenced this pull request Apr 25, 2026
This PR adds getrf_cublas to the functions considered in the heuristics
for lu_solve.

Pull Request resolved: pytorch#73878

Approved by: https://github.com/nikitaved, https://github.com/IvanYashchuk, https://github.com/mruberry
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

cla signed Merged module: linear algebra Issues related to specialized linear algebra operations in PyTorch; includes matrix multiply matmul module: performance Issues related to performance, either of kernel code or framework glue open source topic: not user facing topic category

Projects

None yet

Development

Successfully merging this pull request may close these issues.

8 participants