Skip to content

Cusolver handle may decrease MAGMA performance on GPU #55122

@xwang233

Description

@xwang233

🐛 Bug

There are some weird interactions between cusolver handle and MAGMA performance. 🤔

I have uploaded related files to https://github.com/xwang233/code-snippet/tree/master/magma-cusolver-handle. All scripts were tested in NGC pytorch 20.12 environment with cuda 11.1

NGC pytorch 20.12 is using MAGMA 2.5.2, while my local environment is using MAGMA 2.5.3

What happened?

Currently in pytorch, matrix inverse of batched matrix with batch size <= 2 is dispatched to cusolver; matrix inverse of batched matrix with batch size > 2 is dispatched to magma. This heuristic is added to pick the better performance between the two backends for different matrix sizes.

Cusolver handles are created when a cusolver call is needed. After the cusolver call, the handle is released into a handle pool instead of being destroyed. This is because cusolver handle creation and destroy take very long time, so reuse them gives better performance in most cases. Cublas, cusparse, cudnn handles in pytorch follow the same "handle pool" implementation.

Ideally, the two backends shouldn't affect performance of each other. However, it turns out that if someone runs matrix inverse in the following way:

magma inverse (A), then cusolver inverse (B), then magma inverse (C)

They will see that magma inverse performance in (C) is worse than that of (A), even for exactly the same matrix and exactly the same python statements.

Run this script (https://github.com/xwang233/code-snippet/blob/master/magma-cusolver-handle/a.py) to profile torch.inverse, and see the results yourself! You may need pytorch version >= 1.7.0 to see it. Cusolver inverse was initially added in #42403.

Example output (per iteration time in ms, smaller is faster)

1.9.0a0+gitfbaad8c
time_magma_1 =  0.184
time_magma_2 =  0.411

The MAGMA inverse call after cusolver is regressed for about 0.23 ms per iteration.

Why this happened?

I've checked the performance with Nvidia Nsight Systems, and here are the two different MAGMA inverse reports:

Magma inverse call in (A):

first magma

Magma inverse call in (C):

second magma

It is found that the cudaMalloc and cudaFree in second magma (C) call takes longer time than first magma (A) call. It is also found that there are those extra OSRT library calls ioctl. Sum of the 4 ioctl blocks gives a time of ~210 us, which is in the same order of magnitude of the regressed performance. (However, we don't know if they have causal relations.)

Where does this ioctl come from?

See stack trace

ioctl stacktrace

It's probably from a constructor call here

magma_queue_create_from_cuda(
device_id,
at::cuda::getCurrentCUDAStream(),
handle,
at::cuda::getCurrentCUDASparseHandle(),
&magma_queue_);
}

FYI, the third argument here is a cublas handle. So magma queue creation takes cuda stream, cublas handle, cusparse handle, and has nothing to do with cusolver handle.

Why is cusolver related to this MAGMA queue creation?

Good question, and a tough one. To save you some time on debug, I wrote a short C++ repro program

https://github.com/xwang233/code-snippet/blob/dcc269de7d1fba131ffd747ba15b1feb5f09a527/magma-cusolver-handle/cpp/main.cpp#L274-L283

The program simply runs and profiles MAGMA magma_sgetrf_batched + magma_sgetri_outofplace_batched calls in three different "environments"

  • normal
  • a cusolver handle is created (not even used in any cusolver functions)
  • the cusolver handle is destroyed

The three environments correspond to the (A), (C), (A?) cases above. Here are sample outputs

magma time elapsed 0.063466 ms
magma time elapsed 0.272094 ms
magma time elapsed 0.063150 ms

It shows that, somehow, creation of a cusolver handle may decrease MAGMA performance.

Which pytorch operators are affected?

  • inverse
  • cholesky
  • cholesky_solve (PR ongoing)
  • cholesky_inverse (PR ongoing, it calls cholesky_solve as its implementation)

There may be other linear algebra operators, that use MAGMA backend, affected, but I haven't check benchmark of them yet. Also note that the process (B) doesn't have to be the same linear algebra call as (A) or (C). It could be any cusolver call that creates a cusolver handle.

This regression seems to be a fixed duration of ~0.2 ms on my machine, which means it's mostly negligible for large matrix operations.

Additional context

See also #42403 #42666 #47953

cc @ngimel @jianyuh @nikitaved @pearu @mruberry @heitorschueroff @walterddr @IvanYashchuk @VitalyFedyunin @ptrblck

Metadata

Metadata

Assignees

No one assigned

    Labels

    module: cudaRelated to torch.cuda, and CUDA support in generalmodule: linear algebraIssues related to specialized linear algebra operations in PyTorch; includes matrix multiply matmulmodule: magmarelated to magma linear algebra cuda supportmodule: performanceIssues related to performance, either of kernel code or framework gluetriagedThis issue has been looked at a team member, and triaged and prioritized into an appropriate module

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions