Skip to content

cuSOLVER path for LU factorization in CUDA.#56887

Closed
v0dro wants to merge 30 commits intomasterfrom
ci-all/v0dro-cusolver-lu
Closed

cuSOLVER path for LU factorization in CUDA.#56887
v0dro wants to merge 30 commits intomasterfrom
ci-all/v0dro-cusolver-lu

Conversation

@v0dro
Copy link
Copy Markdown
Contributor

@v0dro v0dro commented Apr 25, 2021

This PR adds cuSOLVER path for torch.lu.

Performance comparison results: #53879 (comment)

Code for reproducing performance results: #56887 (comment)

The following heuristics are used for choosing cuSOLVER over MAGMA:

  • If batch size == 1 OR (batch size <= 8 AND shape <= 16), choose cuSOLVER over MAGMA.
  • For all other cases use MAGMA.

See also #47953.

Following are the performance results between the MASTER branch and the current changes:

Details
[-------------------------- LU factorization (ATen) torch.float64 ---------------------------]
                                     |  lu_factorize CURRENT |  lu_factorize MASTER
1 threads: -----------------------------------------------------------------------------------
      torch.Size([1, 1, 1])          |              363.9          |             284.1        
      torch.Size([2, 1, 1])          |              354.8          |             271.8        
      torch.Size([4, 1, 1])          |              393.7          |             278.0        
      torch.Size([8, 1, 1])          |              459.3          |             279.1        
      torch.Size([16, 1, 1])         |              524.2          |             288.9        
      torch.Size([32, 1, 1])         |              525.1          |             281.2        
      torch.Size([64, 1, 1])         |              524.5          |             281.7        
      torch.Size([128, 1, 1])        |              522.8          |             285.2        
      torch.Size([1, 2, 2])          |              360.4          |             277.7        
      torch.Size([2, 2, 2])          |              372.9          |             279.2        
      torch.Size([4, 2, 2])          |              419.4          |             278.3        
      torch.Size([8, 2, 2])          |              475.7          |             279.2        
      torch.Size([16, 2, 2])         |              530.0          |             299.5        
      torch.Size([32, 2, 2])         |              530.0          |             294.5        
      torch.Size([64, 2, 2])         |              531.0          |             291.5        
      torch.Size([128, 2, 2])        |              544.4          |             292.3        
      torch.Size([1, 8, 8])          |              372.6          |             292.8        
      torch.Size([2, 8, 8])          |              380.9          |             296.2        
      torch.Size([4, 8, 8])          |              420.0          |             293.4        
      torch.Size([8, 8, 8])          |              490.6          |             294.6        
      torch.Size([16, 8, 8])         |              535.6          |             296.5        
      torch.Size([32, 8, 8])         |              534.7          |             302.1        
      torch.Size([64, 8, 8])         |              539.1          |             305.5        
      torch.Size([128, 8, 8])        |              540.7          |             296.5        
      torch.Size([1, 16, 16])        |              345.0          |             303.2        
      torch.Size([2, 16, 16])        |              405.0          |             306.3        
      torch.Size([4, 16, 16])        |              482.8          |             305.6        
      torch.Size([8, 16, 16])        |              596.3          |             305.9        
      torch.Size([16, 16, 16])       |              539.6          |             304.4        
      torch.Size([32, 16, 16])       |              542.2          |             305.8        
      torch.Size([64, 16, 16])       |              556.1          |             311.0        
      torch.Size([128, 16, 16])      |              545.1          |             308.1        
      torch.Size([1, 32, 32])        |              432.7          |             342.4        
      torch.Size([2, 32, 32])        |              582.6          |             341.8        
      torch.Size([4, 32, 32])        |              580.4          |             344.4        
      torch.Size([8, 32, 32])        |              586.5          |             343.8        
      torch.Size([16, 32, 32])       |              582.9          |             346.0        
      torch.Size([32, 32, 32])       |              574.4          |             343.7        
      torch.Size([64, 32, 32])       |              562.8          |             350.8        
      torch.Size([128, 32, 32])      |              568.3          |             349.8        
      torch.Size([1, 64, 64])        |              537.1          |             518.4        
      torch.Size([2, 64, 64])        |              766.5          |             539.1        
      torch.Size([4, 64, 64])        |              771.6          |             551.9        
      torch.Size([8, 64, 64])        |              783.4          |             556.0        
      torch.Size([16, 64, 64])       |              798.8          |             555.3        
      torch.Size([32, 64, 64])       |              795.6          |             548.6        
      torch.Size([64, 64, 64])       |              804.2          |             580.4        
      torch.Size([128, 64, 64])      |              837.6          |             616.9        
      torch.Size([1, 128, 128])      |              844.7          |             848.9        
      torch.Size([2, 128, 128])      |             1096.7          |             873.3        
      torch.Size([4, 128, 128])      |             1117.9          |             884.8        
      torch.Size([8, 128, 128])      |             1138.1          |             903.6        
      torch.Size([16, 128, 128])     |             1169.1          |             943.9        
      torch.Size([32, 128, 128])     |             1204.8          |             981.4        
      torch.Size([64, 128, 128])     |             1336.6          |            1105.8        
      torch.Size([128, 128, 128])    |             1639.4          |            1473.3        
      torch.Size([1, 512, 512])      |             3714.3          |            3928.6        
      torch.Size([2, 512, 512])      |             4388.3          |            4179.7        
      torch.Size([4, 512, 512])      |             4765.4          |            4536.9        
      torch.Size([8, 512, 512])      |             5615.2          |            5441.1        
      torch.Size([16, 512, 512])     |             7203.6          |            7130.2        
      torch.Size([32, 512, 512])     |            10580.5          |           10503.9        
      torch.Size([64, 512, 512])     |            17374.8          |           17349.6        
      torch.Size([128, 512, 512])    |            32542.3          |           32548.8        
      torch.Size([1, 1024, 1024])    |            10041.5          |           14292.3        
      torch.Size([2, 1024, 1024])    |            17126.6          |           16971.0        
      torch.Size([4, 1024, 1024])    |            20591.0          |           20490.8        
      torch.Size([8, 1024, 1024])    |            27682.8          |           27560.7        
      torch.Size([16, 1024, 1024])   |            41035.2          |           41035.8        
      torch.Size([32, 1024, 1024])   |            67091.8          |           67345.9        
      torch.Size([64, 1024, 1024])   |           119612.3          |          119782.3        
      torch.Size([128, 1024, 1024])  |           230095.5          |          230766.2        

Times are in microseconds (us).

The main reason why a performance regression can be seen is related to this issue (#55122) and there seems to be no easy way to fix this (atleast in this PR).

@facebook-github-bot
Copy link
Copy Markdown
Contributor

facebook-github-bot commented Apr 25, 2021

💊 CI failures summary and remediations

As of commit 661be93 (more details on the Dr. CI page and at hud.pytorch.org/pr/56887):



🕵️ 1 new failure recognized by patterns

The following CI failures do not appear to be due to upstream breakages:

See CircleCI build binary_windows_wheel_3_7_cu102_build (1/1)

Step: "Build" (full log | diagnosis details | 🔁 rerun)

ModuleNotFoundError: No module named 'tools'
    File "C:\w\b\windows\conda\envs\py37\lib\site-packages\pip\_vendor\pep517\in_process\_in_process.py", line 114, in get_requires_for_build_wheel
      return hook(config_settings)
    File "C:\Users\circleci\AppData\Local\Temp\pip-build-env-_wn7htq4\overlay\Lib\site-packages\setuptools\build_meta.py", line 155, in get_requires_for_build_wheel
      config_settings, requirements=['wheel'])
    File "C:\Users\circleci\AppData\Local\Temp\pip-build-env-_wn7htq4\overlay\Lib\site-packages\setuptools\build_meta.py", line 135, in _get_build_requires
      self.run_setup()
    File "C:\Users\circleci\AppData\Local\Temp\pip-build-env-_wn7htq4\overlay\Lib\site-packages\setuptools\build_meta.py", line 150, in run_setup
      exec(compile(code, __file__, 'exec'), locals())
    File "setup.py", line 219, in <module>
      from tools.build_pytorch_libs import build_caffe2
  ModuleNotFoundError: No module named 'tools'
  Getting requirements to build wheel: finished with status 'error'
WARNING: Discarding file:///C:/w/b/windows/pytorch. Command errored out with exit status 1: 'C:\w\b\windows\conda\envs\py37\python.exe' 'C:\w\b\windows\conda\envs\py37\lib\site-packages\pip\_vendor\pep517\in_process\_in_process.py' get_requires_for_build_wheel 'C:\Users\circleci\AppData\Local\Temp\tmpb8ke54hw' Check the logs for full command output.
ERROR: Command errored out with exit status 1: 'C:\w\b\windows\conda\envs\py37\python.exe' 'C:\w\b\windows\conda\envs\py37\lib\site-packages\pip\_vendor\pep517\in_process\_in_process.py' get_requires_for_build_wheel 'C:\Users\circleci\AppData\Local\Temp\tmpb8ke54hw' Check the logs for full command output.
Exception information:
Traceback (most recent call last):
  File "C:\w\b\windows\conda\envs\py37\lib\site-packages\pip\_internal\cli\base_command.py", line 180, in _main
    status = self.run(options, args)
  File "C:\w\b\windows\conda\envs\py37\lib\site-packages\pip\_internal\cli\req_command.py", line 204, in wrapper
    return func(self, options, args)
  File "C:\w\b\windows\conda\envs\py37\lib\site-packages\pip\_internal\commands\wheel.py", line 143, in run

1 failure not recognized by patterns:

Job Step Action
CircleCI pytorch_linux_xenial_cuda11_3_cudnn8_py3_gcc7_test Report results 🔁 rerun

1 job timed out:

  • pytorch_linux_xenial_cuda11_3_cudnn8_py3_gcc7_test

🚧 1 fixed upstream failure:

These were probably caused by upstream breakages that were already fixed.

Please rebase on the viable/strict branch (expand for instructions)

If your commit is older than viable/strict, run these commands:

git fetch https://github.com/pytorch/pytorch viable/strict
git rebase FETCH_HEAD

Preview docs built from this PR

This comment was automatically generated by Dr. CI (expand for details).Follow this link to opt-out of these comments for your Pull Requests.

Please report bugs/suggestions to the (internal) Dr. CI Users group.

Click here to manually regenerate this comment.

@codecov
Copy link
Copy Markdown

codecov Bot commented Apr 25, 2021

Codecov Report

Merging #56887 (6622f5f) into master (44cc873) will increase coverage by 27.19%.
The diff coverage is n/a.

❗ Current head 6622f5f differs from pull request most recent head 3318c33. Consider uploading reports for the commit 3318c33 to get more accurate results

@@             Coverage Diff             @@
##           master   #56887       +/-   ##
===========================================
+ Coverage   50.49%   77.68%   +27.19%     
===========================================
  Files         596     1954     +1358     
  Lines       76023   194549   +118526     
===========================================
+ Hits        38387   151134   +112747     
- Misses      37636    43415     +5779     

@IvanYashchuk IvanYashchuk self-requested a review April 29, 2021 14:58
@v0dro
Copy link
Copy Markdown
Contributor Author

v0dro commented May 1, 2021

@IvanYashchuk I'm leaving out the 64 bit cusolver API in this since it malfunctions for singular matrices (different results than MAGMA/scipy) and pivot arrays used by other methods such as lu_solve need to 32 bit for now. Will try again after adding 64 bit functionality for lu_solve. I'll use ghstack.

@IvanYashchuk
Copy link
Copy Markdown
Collaborator

Sure, 64-bit functions is not critical to use now. We can add them later after 1.9 release. From my experience there is no performance difference in 32 vs 64 bit functions for other functions, and 32-bit functions are not deprecated even yet. So no need to worry about it for a while, could you please add this to the issue where we track cusolver & cublas bugs #53879

@v0dro v0dro force-pushed the ci-all/v0dro-cusolver-lu branch from 38b2d32 to 7eebf16 Compare May 1, 2021 13:15
start updating cusolver LU solve functions

update cuSOLVER routine

convert nan to 0

updating kernel to work with complex dtypes

add 64bit cusolver API

fix minor problems in CUDA solver 64 bit API

remove 64 bit usage of cusolver

add 64 bit CUSOLVER interface

Add cuSOLVER 64 bit API for single batched LU factorization.

Add cuSOLVER 64 bit API for single batched LU factorization.

remove 64 bit API due to wrong pivot arrays being returned for singular matrix and breaking other routines

Add cuSOLVER path to single batch LU factorization.

update cuSOLVER routine

resolve cherry pick

resolve cherry pick conflict

update files for cherry pick

update commits

remove 64 bit usage of cusolver

add 64 bit CUSOLVER interface

remove 64 bit API due to wrong pivot arrays being returned for singular matrix and breaking other routines
@v0dro v0dro force-pushed the ci-all/v0dro-cusolver-lu branch from 7eebf16 to 3f9acc9 Compare May 1, 2021 13:19
@v0dro v0dro marked this pull request as ready for review May 1, 2021 15:22
@v0dro v0dro changed the title Ci all/v0dro cusolver lu cuSOLVER path for LU factorization in CUDA. May 1, 2021
@anjali411 anjali411 added triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module module: linear algebra Issues related to specialized linear algebra operations in PyTorch; includes matrix multiply matmul labels May 3, 2021
@anjali411 anjali411 requested a review from mruberry May 3, 2021 16:20
@rgommers
Copy link
Copy Markdown
Collaborator

rgommers commented May 4, 2021

@v0dro can you please add benchmarks?

@IvanYashchuk
Copy link
Copy Markdown
Collaborator

The results are in pdf linked in this comment #53879 (comment).

@rgommers
Copy link
Copy Markdown
Collaborator

rgommers commented May 4, 2021

Thanks Ivan. That's 8 figures, no code to produce them. It would be really useful to put the figures plus the code to produce them in the PR description here. That's not only helpful for review, but also for future follow-up - no one is going to be able to find back a pdf in a comment on an issue not linked to this PR if there's a regression detected post-merge.

Copy link
Copy Markdown
Collaborator

@IvanYashchuk IvanYashchuk left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@v0dro, thank you for this pull request and for working on comparing the MAGMA/cuBLAS/cuSOLVER implementations.

I left a few suggestions inline.

As Ralf mentioned could you please add a comment with the code to produce the performance results, and link it from the main description of the PR.
Could you please also edit the description of the PR with the chosen conditions when cuSOLVER is used and when batched MAGMA is used.

Comment thread aten/src/ATen/native/cuda/BatchLinearAlgebra.cu Outdated
Comment thread aten/src/ATen/native/cuda/BatchLinearAlgebra.cu Outdated
Comment thread aten/src/ATen/native/cuda/BatchLinearAlgebra.cu Outdated
Comment thread aten/src/ATen/native/cuda/BatchLinearAlgebraLib.cu Outdated
}

void apply_lu_cusolver_looped(Tensor& self, Tensor& pivots, Tensor& infos, bool get_pivots) {
auto infos_data = infos.data_ptr<int>();
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we move this inside the lambda function?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This makes it more dry and readable IMO?

Comment thread aten/src/ATen/native/cuda/BatchLinearAlgebraLib.cu Outdated
Comment thread aten/src/ATen/native/cuda/BatchLinearAlgebraLib.cu Outdated
Comment thread test/test_linalg.py Outdated
Comment thread aten/src/ATen/native/cuda/BatchLinearAlgebraLib.cu Outdated
Comment thread aten/src/ATen/cuda/CUDASolver.cpp Outdated
Copy link
Copy Markdown
Collaborator

@xwang233 xwang233 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the work on cusolver @v0dro ! I've left some comments on the PR.

Would it be convenient to post a benchmark table with the numbers comparison of run time from magma, cusolver, cublas (if needed) of getrf? With that, we can refer to it in the future. Also, I can report this to the cusolver team and let them check and improve the performance to benefit us all. Thank you!

Comment thread aten/src/ATen/native/cuda/BatchLinearAlgebra.cu Outdated
Comment thread aten/src/ATen/native/cuda/BatchLinearAlgebra.cu Outdated
Comment thread aten/src/ATen/native/cuda/BatchLinearAlgebraLib.cu Outdated
Comment thread aten/src/ATen/native/cuda/BatchLinearAlgebra.cu Outdated
self_data + batch * self_stride,
lda,
pivots_data + batch * pivots_matrix_stride,
infos_data + batch
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If both pivots_tensor and infos_tensor are of shape req_size, do they have the same stride, which is 1 here?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

pivots_tensor has a size of N (matrix numrows) and infos_tensor has size batch_size. Each LU factorization returns a single info. If you perform an LU of a scalar then yes they will have size 1.

Comment thread aten/src/ATen/native/cuda/BatchLinearAlgebraLib.cu Outdated
@v0dro
Copy link
Copy Markdown
Contributor Author

v0dro commented May 7, 2021

Thank you for your comments. I will get back to you soon with a review.

@v0dro v0dro force-pushed the ci-all/v0dro-cusolver-lu branch from 13859d7 to 2401be7 Compare June 11, 2021 10:49
@v0dro
Copy link
Copy Markdown
Contributor Author

v0dro commented Jun 11, 2021

@jithunnair-amd can you check now?

@v0dro v0dro marked this pull request as ready for review June 12, 2021 13:01
@mruberry mruberry requested a review from lezcano June 13, 2021 05:51
Comment thread aten/src/ATen/native/cuda/ScatterGatherKernel.cu
Comment thread test/test_linalg.py
Comment thread aten/src/ATen/native/cuda/BatchLinearAlgebraLib.cu Outdated
Comment thread aten/src/ATen/native/BatchLinearAlgebra.cpp Outdated
@v0dro v0dro force-pushed the ci-all/v0dro-cusolver-lu branch from 21ba864 to 3355359 Compare June 14, 2021 16:11
@mruberry
Copy link
Copy Markdown
Collaborator

@jithunnair-amd can you check now?

Ping @jithunnair-amd -- is your issue resolved?

@v0dro
Copy link
Copy Markdown
Contributor Author

v0dro commented Jun 24, 2021

@jithunnair-amd ping!

@v0dro
Copy link
Copy Markdown
Contributor Author

v0dro commented Jun 25, 2021

I'm getting this failure on XLA:

n 24 16:38:58 ======================================================================
Jun 24 16:38:58 FAIL [0.015s]: test_nll_loss_invalid_weights_xla (__main__.TestNNDeviceTypeXLA)
Jun 24 16:38:58 ----------------------------------------------------------------------
Jun 24 16:38:58 RuntimeError: /var/lib/jenkins/workspace/xla/third_party/tensorflow/bazel-tensorflow/tensorflow/compiler/xla/xla_client/debug_macros.h:27 : Check failed: status.status() == ::tensorflow::Status::OK() (Invalid argument: Input dimension should be either 1 or equal to the output dimension it is broadcasting into; the 0th operand dimension is 4, the 1th output dimension is 3. vs. OK)
Jun 24 16:38:58 *** Begin stack trace ***
Jun 24 16:38:58 	tensorflow::CurrentStackTrace[abi:cxx11]()
Jun 24 16:38:58 	xla::Shape const* ConsumeValue<xla::Shape const*>(tensorflow::StatusOr<xla::Shape const*>&&)
Jun 24 16:38:58 	torch_xla::XlaHelpers::ShapeOfXlaOp(xla::XlaOp)
Jun 24 16:38:58 	torch_xla::ir::ops::InferOutputShape(absl::lts_20210324::Span<xla::Shape const>, std::function<xla::XlaOp (absl::lts_20210324::Span<xla::XlaOp const>)> const&)
Jun 24 16:38:58 	
Jun 24 16:38:58 	torch_xla::ir::Node::GetOpShape(std::function<xla::Shape ()> const&) const
Jun 24 16:38:58 	torch_xla::ir::Node::Node(torch_xla::ir::OpKind, absl::lts_20210324::Span<torch_xla::ir::Value const>, std::function<xla::Shape ()> const&, unsigned long, absl::lts_20210324::uint128)
Jun 24 16:38:58 	torch_xla::ir::ops::NllLoss::NllLoss(torch_xla::ir::Value const&, torch_xla::ir::Value const&, absl::lts_20210324::optional<torch_xla::ir::Value> const&, torch_xla::ReductionMode, int)
Jun 24 16:38:58 	torch_xla::XLATensor::nll_loss(torch_xla::XLATensor const&, torch_xla::XLATensor const&, torch_xla::XLATensor const&, long, int)
Jun 24 16:38:58 	torch_xla::nll_loss_forward(at::Tensor const&, at::Tensor const&, c10::optional<at::Tensor> const&, long, long)
Jun 24 16:38:58 	
Jun 24 16:38:58 	
Jun 24 16:38:58 	at::redispatch::nll_loss_forward(c10::DispatchKeySet, at::Tensor const&, at::Tensor const&, c10::optional<at::Tensor> const&, long, long)
Jun 24 16:38:58 	
Jun 24 16:38:58 	at::nll_loss_forward(at::Tensor const&, at::Tensor const&, c10::optional<at::Tensor> const&, long, long)
Jun 24 16:38:58 	at::native::nll_loss(at::Tensor const&, at::Tensor const&, c10::optional<at::Tensor> const&, long, long)
Jun 24 16:38:58 	
Jun 24 16:38:58 	at::nll_loss(at::Tensor const&, at::Tensor const&, c10::optional<at::Tensor> const&, long, long)
Jun 24 16:38:58 	at::native::nll_loss_nd(at::Tensor const&, at::Tensor const&, c10::optional<at::Tensor> const&, long, long)
Jun 24 16:38:58 	
Jun 24 16:38:58 	at::nll_loss_nd(at::Tensor const&, at::Tensor const&, c10::optional<at::Tensor> const&, long, long)
Jun 24 16:38:58 	
Jun 24 16:38:58 	
Jun 24 16:38:58 	_PyEval_EvalFrameDefault
Jun 24 16:38:58 	
Jun 24 16:38:58 	
Jun 24 16:38:58 	
Jun 24 16:38:58 	_PyEval_EvalFrameDefault
Jun 24 16:38:58 	
Jun 24 16:38:58 	
Jun 24 16:38:58 	PyObject_Call
Jun 24 16:38:58 	_PyEval_EvalFrameDefault
Jun 24 16:38:58 	
Jun 24 16:38:58 	
Jun 24 16:38:58 	
Jun 24 16:38:58 	_PyEval_EvalFrameDefault
Jun 24 16:38:58 	
Jun 24 16:38:58 	
Jun 24 16:38:58 	
Jun 24 16:38:58 	_PyEval_EvalFrameDefault
Jun 24 16:38:58 	
Jun 24 16:38:58 	
Jun 24 16:38:58 	
Jun 24 16:38:58 	_PyEval_EvalFrameDefault
Jun 24 16:38:58 	
Jun 24 16:38:58 	
Jun 24 16:38:58 	PyObject_Call
Jun 24 16:38:58 	_PyEval_EvalFrameDefault
Jun 24 16:38:58 	
Jun 24 16:38:58 	
Jun 24 16:38:58 	PyObject_Call
Jun 24 16:38:58 	
Jun 24 16:38:58 	_PyObject_FastCallDict
Jun 24 16:38:58 	
Jun 24 16:38:58 	_PyEval_EvalFrameDefault
Jun 24 16:38:58 	
Jun 24 16:38:58 	
Jun 24 16:38:58 	PyObject_Call
Jun 24 16:38:58 	_PyEval_EvalFrameDefault
Jun 24 16:38:58 	
Jun 24 16:38:58 	
Jun 24 16:38:58 	PyObject_Call
Jun 24 16:38:58 	
Jun 24 16:38:58 	_PyObject_FastCallDict
Jun 24 16:38:58 	
Jun 24 16:38:58 	_PyEval_EvalFrameDefault
Jun 24 16:38:58 	
Jun 24 16:38:58 	
Jun 24 16:38:58 	PyObject_Call
Jun 24 16:38:58 	_PyEval_EvalFrameDefault
Jun 24 16:38:58 	
Jun 24 16:38:58 	
Jun 24 16:38:58 	PyObject_Call
Jun 24 16:38:58 	
Jun 24 16:38:58 	_PyObject_FastCallDict
Jun 24 16:38:58 	
Jun 24 16:38:58 	_PyEval_EvalFrameDefault
Jun 24 16:38:58 	
Jun 24 16:38:58 	
Jun 24 16:38:58 	_PyEval_EvalFrameDefault
Jun 24 16:38:58 	
Jun 24 16:38:58 	
Jun 24 16:38:58 	_PyEval_EvalFrameDefault
Jun 24 16:38:58 	
Jun 24 16:38:58 	
Jun 24 16:38:58 	PyObject_Call
Jun 24 16:38:58 	
Jun 24 16:38:58 	
Jun 24 16:38:58 	_PyObject_FastCallDict
Jun 24 16:38:58 	_PyObject_FastCallKeywords
Jun 24 16:38:58 	
Jun 24 16:38:58 	_PyEval_EvalFrameDefault
Jun 24 16:38:58 	
Jun 24 16:38:58 	
Jun 24 16:38:58 	_PyEval_EvalFrameDefault
Jun 24 16:38:58 	
Jun 24 16:38:58 	PyEval_EvalCode
Jun 24 16:38:58 	
Jun 24 16:38:58 	PyRun_FileExFlags
Jun 24 16:38:58 	PyRun_SimpleFileExFlags
Jun 24 16:38:58 	Py_Main
Jun 24 16:38:58 	main
Jun 24 16:38:58 	__libc_start_main
Jun 24 16:38:58 	
Jun 24 16:38:58 *** End stack trace ***
Jun 24 16:38:58 
Jun 24 16:38:58 
Jun 24 16:38:58 During handling of the above exception, another exception occurred:
Jun 24 16:38:58 
Jun 24 16:38:58 Traceback (most recent call last):
Jun 24 16:38:58   File "/opt/conda/lib/python3.6/site-packages/torch/testing/_internal/common_device_type.py", line 397, in instantiated_test
Jun 24 16:38:58     result = test_fn(self, *args)
Jun 24 16:38:58   File "/var/lib/jenkins/workspace/xla/test/../../test/test_nn.py", line 16007, in test_nll_loss_invalid_weights
Jun 24 16:38:58     F.nll_loss(x, t, weight=weight)
Jun 24 16:38:58 AssertionError: "weight tensor should be defined either for all 3 classes or no classes" does not match "/var/lib/jenkins/workspace/xla/third_party/tensorflow/bazel-tensorflow/tensorflow/compiler/xla/xla_client/debug_macros.h:27 : Check failed: status.status() == ::tensorflow::Status::OK() (Invalid argument: Input dimension should be either 1 or equal to the output dimension it is broadcasting into; the 0th operand dimension is 4, the 1th output dimension is 3. vs. OK)
Jun 24 16:38:58 *** Begin stack trace ***
Jun 24 16:38:58 	tensorflow::CurrentStackTrace[abi:cxx11]()
Jun 24 16:38:58 	xla::Shape const* ConsumeValue<xla::Shape const*>(tensorflow::StatusOr<xla::Shape const*>&&)
Jun 24 16:38:58 	torch_xla::XlaHelpers::ShapeOfXlaOp(xla::XlaOp)
Jun 24 16:38:58 	torch_xla::ir::ops::InferOutputShape(absl::lts_20210324::Span<xla::Shape const>, std::function<xla::XlaOp (absl::lts_20210324::Span<xla::XlaOp const>)> const&)
Jun 24 16:38:58 	
Jun 24 16:38:58 	torch_xla::ir::Node::GetOpShape(std::function<xla::Shape ()> const&) const
Jun 24 16:38:58 	torch_xla::ir::Node::Node(torch_xla::ir::OpKind, absl::lts_20210324::Span<torch_xla::ir::Value const>, std::function<xla::Shape ()> const&, unsigned long, absl::lts_20210324::uint128)
Jun 24 16:38:58 	torch_xla::ir::ops::NllLoss::NllLoss(torch_xla::ir::Value const&, torch_xla::ir::Value const&, absl::lts_20210324::optional<torch_xla::ir::Value> const&, torch_xla::ReductionMode, int)
Jun 24 16:38:58 	torch_xla::XLATensor::nll_loss(torch_xla::XLATensor const&, torch_xla::XLATensor const&, torch_xla::XLATensor const&, long, int)
Jun 24 16:38:58 	torch_xla::nll_loss_forward(at::Tensor const&, at::Tensor const&, c10::optional<at::Tensor> const&, long, long)
Jun 24 16:38:58 	
Jun 24 16:38:58 	
Jun 24 16:38:58 	at::redispatch::nll_loss_forward(c10::DispatchKeySet, at::Tensor const&, at::Tensor const&, c10::optional<at::Tensor> const&, long, long)
Jun 24 16:38:58 	
Jun 24 16:38:58 	at::nll_loss_forward(at::Tensor const&, at::Tensor const&, c10::optional<at::Tensor> const&, long, long)
Jun 24 16:38:58 	at::native::nll_loss(at::Tensor const&, at::Tensor const&, c10::optional<at::Tensor> const&, long, long)
Jun 24 16:38:58 	
Jun 24 16:38:58 	at::nll_loss(at::Tensor const&, at::Tensor const&, c10::optional<at::Tensor> const&, long, long)
Jun 24 16:38:58 	at::native::nll_loss_nd(at::Tensor const&, at::Tensor const&, c10::optional<at::Tensor> const&, long, long)
Jun 24 16:38:58 	
Jun 24 16:38:58 	at::nll_loss_nd(at::Tensor const&, at::Tensor const&, c10::optional<at::Tensor> const&, long, long)
Jun 24 16:38:58 	
Jun 24 16:38:58 	
Jun 24 16:38:58 	_PyEval_EvalFrameDefault
Jun 24 16:38:58 	
Jun 24 16:38:58 	
Jun 24 16:38:58 	
Jun 24 16:38:58 	_PyEval_EvalFrameDefault
Jun 24 16:38:58 	
Jun 24 16:38:58 	
Jun 24 16:38:58 	PyObject_Call
Jun 24 16:38:58 	_PyEval_EvalFrameDefault
Jun 24 16:38:58 	
Jun 24 16:38:58 	
Jun 24 16:38:58 	
Jun 24 16:38:58 	_PyEval_EvalFrameDefault
Jun 24 16:38:58 	
Jun 24 16:38:58 	
Jun 24 16:38:58 	
Jun 24 16:38:58 	_PyEval_EvalFrameDefault
Jun 24 16:38:58 	
Jun 24 16:38:58 	
Jun 24 16:38:58 	
Jun 24 16:38:58 	_PyEval_EvalFrameDefault
Jun 24 16:38:58 	
Jun 24 16:38:58 	
Jun 24 16:38:58 	PyObject_Call
Jun 24 16:38:58 	_PyEval_EvalFrameDefault
Jun 24 16:38:58 	
Jun 24 16:38:58 	
Jun 24 16:38:58 	PyObject_Call
Jun 24 16:38:58 	
Jun 24 16:38:58 	_PyObject_FastCallDict
Jun 24 16:38:58 	
Jun 24 16:38:58 	_PyEval_EvalFrameDefault
Jun 24 16:38:58 	
Jun 24 16:38:58 	
Jun 24 16:38:58 	PyObject_Call
Jun 24 16:38:58 	_PyEval_EvalFrameDefault
Jun 24 16:38:58 	
Jun 24 16:38:58 	
Jun 24 16:38:58 	PyObject_Call
Jun 24 16:38:58 	
Jun 24 16:38:58 	_PyObject_FastCallDict
Jun 24 16:38:58 	
Jun 24 16:38:58 	_PyEval_EvalFrameDefault
Jun 24 16:38:58 	
Jun 24 16:38:58 	
Jun 24 16:38:58 	PyObject_Call
Jun 24 16:38:58 	_PyEval_EvalFrameDefault
Jun 24 16:38:58 	
Jun 24 16:38:58 	
Jun 24 16:38:58 	PyObject_Call
Jun 24 16:38:58 	
Jun 24 16:38:58 	_PyObject_FastCallDict
Jun 24 16:38:58 	
Jun 24 16:38:58 	_PyEval_EvalFrameDefault
Jun 24 16:38:58 	
Jun 24 16:38:58 	
Jun 24 16:38:58 	_PyEval_EvalFrameDefault
Jun 24 16:38:58 	
Jun 24 16:38:58 	
Jun 24 16:38:58 	_PyEval_EvalFrameDefault
Jun 24 16:38:58 	
Jun 24 16:38:58 	
Jun 24 16:38:58 	PyObject_Call
Jun 24 16:38:58 	
Jun 24 16:38:58 	
Jun 24 16:38:58 	_PyObject_FastCallDict
Jun 24 16:38:58 	_PyObject_FastCallKeywords
Jun 24 16:38:58 	
Jun 24 16:38:58 	_PyEval_EvalFrameDefault
Jun 24 16:38:58 	
Jun 24 16:38:58 	
Jun 24 16:38:58 	_PyEval_EvalFrameDefault
Jun 24 16:38:58 	
Jun 24 16:38:58 	PyEval_EvalCode
Jun 24 16:38:58 	
Jun 24 16:38:58 	PyRun_FileExFlags
Jun 24 16:38:58 	PyRun_SimpleFileExFlags
Jun 24 16:38:58 	Py_Main
Jun 24 16:38:58 	main
Jun 24 16:38:58 	__libc_start_main
Jun 24 16:38:58 	
Jun 24 16:38:58 *** End stack trace ***
Jun 24 16:38:58 "
Jun 24 16:38:58 
Jun 24 16:38:58 ----------------------------------------------------------------------
Jun 24 16:38:58 Ran 243 tests in 1028.999s
Jun 24 16:38:58 
Jun 24 16:38:58 FAILED (failures=2, skipped=135)
Jun 24 16:38:58 
Jun 24 16:38:58 Generating XML reports...
Jun 24 16:38:58 Generated XML report: test-reports/python-unittest/test.......test.test_nn/TEST-TestNNDeviceTypeXLA-20210624162149.xml
Jun 24 16:38:58 + cleanup
Jun 24 16:38:58 + retcode=1
Jun 24 16:38:58 + set +x
Jun 24 16:38:58 =================== sccache compilation log ===================
Jun 24 16:38:58 =========== If your build fails, please take a look at the log above for possible reasons ===========
Jun 24 16:38:58 Compile requests                      0
Jun 24 16:38:58 Compile requests executed             0
Jun 24 16:38:58 Cache hits                            0
Jun 24 16:38:58 Cache misses                          0
Jun 24 16:38:58 Cache timeouts                        0
Jun 24 16:38:58 Cache read errors                     0
Jun 24 16:38:58 Forced recaches                       0
Jun 24 16:38:58 Cache write errors                    0
Jun 24 16:38:58 Compilation failures                  0
Jun 24 16:38:58 Cache errors                          0
Jun 24 16:38:58 Non-cacheable compilations            0
Jun 24 16:38:58 Non-cacheable calls                   0
Jun 24 16:38:58 Non-compilation calls                 0
Jun 24 16:38:58 Unsupported compiler calls            0
Jun 24 16:38:58 Average cache write               0.000 s
Jun 24 16:38:58 Average cache read miss           0.000 s
Jun 24 16:38:58 Average cache read hit            0.000 s
Jun 24 16:38:58 Failed distributed compilations       0
Jun 24 16:38:58 Cache location                  S3, bucket: Bucket(name=ossci-compiler-cache-circleci-v2, base_url=http://ossci-compiler-cache-circleci-v2.s3.amazonaws.com/)
Jun 24 16:38:58 Stopping sccache server...
Jun 24 16:38:58 Compile requests                      0
Jun 24 16:38:58 Compile requests executed             0
Jun 24 16:38:58 Cache hits                            0
Jun 24 16:38:58 Cache misses                          0
Jun 24 16:38:58 Cache timeouts                        0
Jun 24 16:38:58 Cache read errors                     0
Jun 24 16:38:58 Forced recaches                       0
Jun 24 16:38:58 Cache write errors                    0
Jun 24 16:38:58 Compilation failures                  0
Jun 24 16:38:58 Cache errors                          0
Jun 24 16:38:58 Non-cacheable compilations            0
Jun 24 16:38:58 Non-cacheable calls                   0
Jun 24 16:38:58 Non-compilation calls                 0
Jun 24 16:38:58 Unsupported compiler calls            0
Jun 24 16:38:58 Average cache write               0.000 s
Jun 24 16:38:58 Average cache read miss           0.000 s
Jun 24 16:38:58 Average cache read hit            0.000 s
Jun 24 16:38:58 Failed distributed compilations       0
Jun 24 16:38:58 Cache location                  S3, bucket: Bucket(name=ossci-compiler-cache-circleci-v2, base_url=http://ossci-compiler-cache-circleci-v2.s3.amazonaws.com/)

Exited with code exit status 1

Might this be related to the issue that AMD is facing?

Copy link
Copy Markdown
Collaborator

@lezcano lezcano left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Arrived late to the party, but just wanted to say that this LGTM!

[&self,
&pivots,
&infos,
&get_pivots]() {
Copy link
Copy Markdown
Collaborator

@lezcano lezcano Jun 25, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
&get_pivots]() {
get_pivots]() {

nit. Given that we are doing explicit capture, better to pass by value a bool.

@IvanYashchuk
Copy link
Copy Markdown
Collaborator

The XLA failure is not related to your PR, and is probably already fixed on master.
And ROCm CI is green, so the problem with it is not present anymore.

@IvanYashchuk
Copy link
Copy Markdown
Collaborator

@mruberry, could you please import this PR?

@facebook-github-bot
Copy link
Copy Markdown
Contributor

@mruberry has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.

@mruberry mruberry self-requested a review June 30, 2021 08:25
Copy link
Copy Markdown
Collaborator

@mruberry mruberry left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Cool! Thanks @v0dro! And thanks @lezcano and @IvanYashchuk for reviewing.

@jithunnair-amd
Copy link
Copy Markdown
Collaborator

@v0dro @mruberry Apologies for being MIA. Yes, since ROCm CI is green, this change resolves my concerns. Thanks for checking!

@facebook-github-bot
Copy link
Copy Markdown
Contributor

@mruberry merged this pull request in e133801.

@github-actions github-actions Bot deleted the ci-all/v0dro-cusolver-lu branch February 12, 2024 21:15
laurentdupin pushed a commit to laurentdupin/pytorch that referenced this pull request Apr 25, 2026
Summary:
This PR adds cuSOLVER path for `torch.lu`.

Performance comparison results: pytorch#53879 (comment)

Code for reproducing performance results: pytorch#56887 (comment)

The following heuristics are used for choosing cuSOLVER over MAGMA:
* If batch size == 1 OR (batch size <= 8 AND shape <= 16), choose cuSOLVER over MAGMA.
* For all other cases use MAGMA.

See also pytorch#47953.

Following are the performance results between the MASTER branch and the current changes:

<details>

```
[-------------------------- LU factorization (ATen) torch.float64 ---------------------------]
                                     |  lu_factorize CURRENT |  lu_factorize MASTER
1 threads: -----------------------------------------------------------------------------------
      torch.Size([1, 1, 1])          |              363.9          |             284.1
      torch.Size([2, 1, 1])          |              354.8          |             271.8
      torch.Size([4, 1, 1])          |              393.7          |             278.0
      torch.Size([8, 1, 1])          |              459.3          |             279.1
      torch.Size([16, 1, 1])         |              524.2          |             288.9
      torch.Size([32, 1, 1])         |              525.1          |             281.2
      torch.Size([64, 1, 1])         |              524.5          |             281.7
      torch.Size([128, 1, 1])        |              522.8          |             285.2
      torch.Size([1, 2, 2])          |              360.4          |             277.7
      torch.Size([2, 2, 2])          |              372.9          |             279.2
      torch.Size([4, 2, 2])          |              419.4          |             278.3
      torch.Size([8, 2, 2])          |              475.7          |             279.2
      torch.Size([16, 2, 2])         |              530.0          |             299.5
      torch.Size([32, 2, 2])         |              530.0          |             294.5
      torch.Size([64, 2, 2])         |              531.0          |             291.5
      torch.Size([128, 2, 2])        |              544.4          |             292.3
      torch.Size([1, 8, 8])          |              372.6          |             292.8
      torch.Size([2, 8, 8])          |              380.9          |             296.2
      torch.Size([4, 8, 8])          |              420.0          |             293.4
      torch.Size([8, 8, 8])          |              490.6          |             294.6
      torch.Size([16, 8, 8])         |              535.6          |             296.5
      torch.Size([32, 8, 8])         |              534.7          |             302.1
      torch.Size([64, 8, 8])         |              539.1          |             305.5
      torch.Size([128, 8, 8])        |              540.7          |             296.5
      torch.Size([1, 16, 16])        |              345.0          |             303.2
      torch.Size([2, 16, 16])        |              405.0          |             306.3
      torch.Size([4, 16, 16])        |              482.8          |             305.6
      torch.Size([8, 16, 16])        |              596.3          |             305.9
      torch.Size([16, 16, 16])       |              539.6          |             304.4
      torch.Size([32, 16, 16])       |              542.2          |             305.8
      torch.Size([64, 16, 16])       |              556.1          |             311.0
      torch.Size([128, 16, 16])      |              545.1          |             308.1
      torch.Size([1, 32, 32])        |              432.7          |             342.4
      torch.Size([2, 32, 32])        |              582.6          |             341.8
      torch.Size([4, 32, 32])        |              580.4          |             344.4
      torch.Size([8, 32, 32])        |              586.5          |             343.8
      torch.Size([16, 32, 32])       |              582.9          |             346.0
      torch.Size([32, 32, 32])       |              574.4          |             343.7
      torch.Size([64, 32, 32])       |              562.8          |             350.8
      torch.Size([128, 32, 32])      |              568.3          |             349.8
      torch.Size([1, 64, 64])        |              537.1          |             518.4
      torch.Size([2, 64, 64])        |              766.5          |             539.1
      torch.Size([4, 64, 64])        |              771.6          |             551.9
      torch.Size([8, 64, 64])        |              783.4          |             556.0
      torch.Size([16, 64, 64])       |              798.8          |             555.3
      torch.Size([32, 64, 64])       |              795.6          |             548.6
      torch.Size([64, 64, 64])       |              804.2          |             580.4
      torch.Size([128, 64, 64])      |              837.6          |             616.9
      torch.Size([1, 128, 128])      |              844.7          |             848.9
      torch.Size([2, 128, 128])      |             1096.7          |             873.3
      torch.Size([4, 128, 128])      |             1117.9          |             884.8
      torch.Size([8, 128, 128])      |             1138.1          |             903.6
      torch.Size([16, 128, 128])     |             1169.1          |             943.9
      torch.Size([32, 128, 128])     |             1204.8          |             981.4
      torch.Size([64, 128, 128])     |             1336.6          |            1105.8
      torch.Size([128, 128, 128])    |             1639.4          |            1473.3
      torch.Size([1, 512, 512])      |             3714.3          |            3928.6
      torch.Size([2, 512, 512])      |             4388.3          |            4179.7
      torch.Size([4, 512, 512])      |             4765.4          |            4536.9
      torch.Size([8, 512, 512])      |             5615.2          |            5441.1
      torch.Size([16, 512, 512])     |             7203.6          |            7130.2
      torch.Size([32, 512, 512])     |            10580.5          |           10503.9
      torch.Size([64, 512, 512])     |            17374.8          |           17349.6
      torch.Size([128, 512, 512])    |            32542.3          |           32548.8
      torch.Size([1, 1024, 1024])    |            10041.5          |           14292.3
      torch.Size([2, 1024, 1024])    |            17126.6          |           16971.0
      torch.Size([4, 1024, 1024])    |            20591.0          |           20490.8
      torch.Size([8, 1024, 1024])    |            27682.8          |           27560.7
      torch.Size([16, 1024, 1024])   |            41035.2          |           41035.8
      torch.Size([32, 1024, 1024])   |            67091.8          |           67345.9
      torch.Size([64, 1024, 1024])   |           119612.3          |          119782.3
      torch.Size([128, 1024, 1024])  |           230095.5          |          230766.2

Times are in microseconds (us).

```
</details>

The main reason why a performance regression can be seen is related to this issue (pytorch#55122) and there seems to be no easy way to fix this (atleast in this PR).

Pull Request resolved: pytorch#56887

Reviewed By: ngimel

Differential Revision: D29482342

Pulled By: mruberry

fbshipit-source-id: 4fdedf21b0d5597b289e168dff61d5f5d7727fb1
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

cla signed Merged module: linear algebra Issues related to specialized linear algebra operations in PyTorch; includes matrix multiply matmul open source triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module

Projects

None yet

Development

Successfully merging this pull request may close these issues.

10 participants