Add cuBLAS path for batched torch.geqrf#56253
Add cuBLAS path for batched torch.geqrf#56253IvanYashchuk wants to merge 12 commits intogh/ivanyashchuk/14/basefrom
Conversation
`geqrfBatched` from cuBLAS is used if ``` (input.size(-2) <= 256 && batchCount(input) >= std::max<int64_t>(2, input.size(-2) / 16)) ``` [ghstack-poisoned]
💊 CI failures summary and remediationsAs of commit 7d029cf (more details on the Dr. CI page): 💚 💚 Looks good so far! There are no failures yet. 💚 💚 This comment was automatically generated by Dr. CI (expand for details).Follow this link to opt-out of these comments for your Pull Requests.Please report bugs/suggestions to the (internal) Dr. CI Users group. |
`geqrfBatched` from cuBLAS is used if ``` (input.size(-2) <= 256 && batchCount(input) >= std::max<int64_t>(2, input.size(-2) / 16)) ``` ghstack-source-id: e4add76 Pull Request resolved: pytorch#56253
`geqrfBatched` from cuBLAS is used if ``` (input.size(-2) <= 256 && batchCount(input) >= std::max<int64_t>(2, input.size(-2) / 16)) ``` Ref. #47953 [ghstack-poisoned]
`geqrfBatched` from cuBLAS is used if ``` (input.size(-2) <= 256 && batchCount(input) >= std::max<int64_t>(2, input.size(-2) / 16)) ``` ghstack-source-id: 88039ff Pull Request resolved: pytorch#56253
xwang233
left a comment
There was a problem hiding this comment.
Thanks for the PR! This overall looks good. I have left some comments.
|
|
||
| // cuBLAS batched geqrf requires input to be the device array of pointers to device single matrices | ||
| Tensor input_ptr_array = get_device_pointers<scalar_t>(input); | ||
| Tensor tau_ptr_array = get_device_pointers<scalar_t>(tau.unsqueeze(-1)); |
There was a problem hiding this comment.
Hmm, why is there an unsqueeze call?
There was a problem hiding this comment.
That's a workaround as get_device_pointers works only for array of matrices and tau is a vector.
| at::cuda::blas::geqrfBatched(handle, m, n, input_ptr_array_data, lda, tau_ptr_array_data, &info, batch_size); | ||
|
|
||
| // info only indicates wrong arguments to geqrfBatched call | ||
| TORCH_INTERNAL_ASSERT(info == 0); |
There was a problem hiding this comment.
Maybe add a comment here saying info is a host variable?
There was a problem hiding this comment.
Sure, will do.
`geqrfBatched` from cuBLAS is used if ``` (input.size(-2) <= 256 && batchCount(input) >= std::max<int64_t>(2, input.size(-2) / 16)) ``` Ref. #47953 [ghstack-poisoned]
`geqrfBatched` from cuBLAS is used if ``` (input.size(-2) <= 256 && batchCount(input) >= std::max<int64_t>(2, input.size(-2) / 16)) ``` Ref. #47953 [ghstack-poisoned]
`geqrfBatched` from cuBLAS is used if ``` (input.size(-2) <= 256 && batchCount(input) >= std::max<int64_t>(2, input.size(-2) / 16)) ``` Ref. #47953 [ghstack-poisoned]
`geqrfBatched` from cuBLAS is used if ``` (input.size(-2) <= 256 && batchCount(input) >= std::max<int64_t>(2, input.size(-2) / 16)) ``` Ref. #47953 [ghstack-poisoned]
`geqrfBatched` from cuBLAS is used if ``` (input.size(-2) <= 256 && batchCount(input) >= std::max<int64_t>(2, input.size(-2) / 16)) ``` Ref. #47953 [ghstack-poisoned]
`geqrfBatched` from cuBLAS is used if ``` (input.size(-2) <= 256 && batchCount(input) >= std::max<int64_t>(2, input.size(-2) / 16)) ``` ghstack-source-id: b5c8e4f Pull Request resolved: pytorch#56253
`geqrfBatched` from cuBLAS is used if ``` (input.size(-2) <= 256 && batchCount(input) >= std::max<int64_t>(2, input.size(-2) / 16)) ``` Ref. #47953 [ghstack-poisoned]
`geqrfBatched` from cuBLAS is used if ``` (input.size(-2) <= 256 && batchCount(input) >= std::max<int64_t>(2, input.size(-2) / 16)) ``` ghstack-source-id: f42d3c3 Pull Request resolved: pytorch#56253
`geqrfBatched` from cuBLAS is used if ``` (input.size(-2) <= 256 && batchCount(input) >= std::max<int64_t>(2, input.size(-2) / 16)) ``` [ghstack-poisoned]
`geqrfBatched` from cuBLAS is used if ``` (input.size(-2) <= 256 && batchCount(input) >= std::max<int64_t>(2, input.size(-2) / 16)) ``` [ghstack-poisoned]
`geqrfBatched` from cuBLAS is used if ``` (input.size(-2) <= 256 && batchCount(input) >= std::max<int64_t>(2, input.size(-2) / 16)) ``` ghstack-source-id: 4c5917b Pull Request resolved: pytorch#56253
|
|
||
| template <class Dtype> | ||
| void geqrfBatched(CUDABLAS_GEQRF_BATCHED_ARGTYPES(Dtype)) { | ||
| TORCH_CHECK( |
There was a problem hiding this comment.
TORCH_INTERNAL_ASSERT? Otherwise the function name should be user-facing
mruberry
left a comment
There was a problem hiding this comment.
Just one small macro/error message nit
`geqrfBatched` from cuBLAS is used if ``` (input.size(-2) <= 256 && batchCount(input) >= std::max<int64_t>(2, input.size(-2) / 16)) ``` Differential Revision: [D27960156](https://our.internmc.facebook.com/intern/diff/D27960156) [ghstack-poisoned]
`geqrfBatched` from cuBLAS is used if ``` (input.size(-2) <= 256 && batchCount(input) >= std::max<int64_t>(2, input.size(-2) / 16)) ``` Differential Revision: [D27960156](https://our.internmc.facebook.com/intern/diff/D27960156) [ghstack-poisoned]
`geqrfBatched` from cuBLAS is used if ``` (input.size(-2) <= 256 && batchCount(input) >= std::max<int64_t>(2, input.size(-2) / 16)) ``` ghstack-source-id: 59d9c99 Pull Request resolved: pytorch#56253
| void geqrf_kernel(const Tensor& input, const Tensor& tau, int64_t m, int64_t n) { | ||
| // if number of rows is smaller than 32 batched is always faster for batch size > 1 | ||
| // for larger number of rows number of batches condition | ||
| if (input.size(-2) <= 256 && batchCount(input) >= std::max<int64_t>(2, input.size(-2) / 16)) { | ||
| return geqrf_batched(input, tau, m, n); | ||
| } else { | ||
| return geqrf_looped(input, tau, m, n); | ||
| } | ||
| } |
There was a problem hiding this comment.
Hi @IvanYashchuk , do you have a benchmark table for this heuristic? I can report to cusolver team to check the performance of cublas<T>geqrfBatched comparing with looped cusolver geqrf.
There was a problem hiding this comment.
Sure, here are the comparison tables. All results are for float64.
Batches of 2-256 for sizes 2x2 to 1024x1024
Times are in microseconds (us).
| | cuBLAS | cuSOLVER |
|-------------------------------|-----------|-----------|
| torch.Size([2, 2]) | 37.5 | 44.1 |
| torch.Size([2, 2, 2]) | 24.1 | 30.3 |
| torch.Size([32, 2, 2]) | 24.8 | 421.3 |
| torch.Size([64, 2, 2]) | 25.2 | 840.2 |
| torch.Size([128, 2, 2]) | 24.7 | 1675.9 |
| torch.Size([256, 2, 2]) | 24.3 | 3348.5 |
| torch.Size([8, 8]) | 44.4 | 62.9 |
| torch.Size([2, 8, 8]) | 44.9 | 122.7 |
| torch.Size([32, 8, 8]) | 60.5 | 1909.0 |
| torch.Size([64, 8, 8]) | 60.5 | 3813.4 |
| torch.Size([128, 8, 8]) | 60.5 | 7625.6 |
| torch.Size([256, 8, 8]) | 60.6 | 15240.0 |
| torch.Size([16, 16]) | 156.5 | 126.5 |
| torch.Size([2, 16, 16]) | 157.6 | 249.4 |
| torch.Size([32, 16, 16]) | 206.0 | 3929.7 |
| torch.Size([64, 16, 16]) | 206.1 | 7856.0 |
| torch.Size([128, 16, 16]) | 206.2 | 15708.0 |
| torch.Size([256, 16, 16]) | 206.7 | 31410.1 |
| torch.Size([32, 32]) | 619.6 | 257.9 |
| torch.Size([2, 32, 32]) | 629.4 | 512.1 |
| torch.Size([32, 32, 32]) | 784.6 | 8136.3 |
| torch.Size([64, 32, 32]) | 784.9 | 16266.3 |
| torch.Size([128, 32, 32]) | 786.4 | 32655.6 |
| torch.Size([256, 32, 32]) | 789.3 | 65579.5 |
| torch.Size([64, 64]) | 2381.0 | 661.9 |
| torch.Size([2, 64, 64]) | 2448.6 | 1319.5 |
| torch.Size([32, 64, 64]) | 3096.4 | 21052.8 |
| torch.Size([64, 64, 64]) | 3101.0 | 42113.2 |
| torch.Size([128, 64, 64]) | 3115.0 | 84217.7 |
| torch.Size([256, 64, 64]) | 3466.1 | 168483.9 |
| torch.Size([128, 128]) | 11816.0 | 1572.4 |
| torch.Size([2, 128, 128]) | 12990.9 | 3140.5 |
| torch.Size([32, 128, 128]) | 14793.2 | 50225.2 |
| torch.Size([64, 128, 128]) | 16164.0 | 100416.3 |
| torch.Size([128, 128, 128]) | 17637.3 | 200814.2 |
| torch.Size([256, 128, 128]) | 18670.9 | 401571.3 |
| torch.Size([256, 256]) | 56018.7 | 2377.1 |
| torch.Size([2, 256, 256]) | 63733.5 | 4744.8 |
| torch.Size([32, 256, 256]) | 84046.0 | 76058.5 |
| torch.Size([64, 256, 256]) | 87702.1 | 152151.9 |
| torch.Size([128, 256, 256]) | 92019.4 | 304297.9 |
| torch.Size([256, 256, 256]) | 99104.9 | 609614.2 |
| torch.Size([512, 512]) | 302886.9 | 5869.6 |
| torch.Size([2, 512, 512]) | 352219.2 | 11742.1 |
| torch.Size([32, 512, 512]) | 556534.2 | 187686.5 |
| torch.Size([64, 512, 512]) | 571038.7 | 375346.6 |
| torch.Size([128, 512, 512]) | 597674.4 | 750685.9 |
| torch.Size([256, 512, 512]) | 679707.7 | 1501277.7 |
| torch.Size([1024, 1024]) | 1842582.2 | 15375.0 |
| torch.Size([2, 1024, 1024]) | 2598219.3 | 30746.9 |
| torch.Size([32, 1024, 1024]) | 4038294.7 | 491840.1 |
| torch.Size([64, 1024, 1024]) | 4096143.4 | 983641.5 |
| torch.Size([128, 1024, 1024]) | 4216505.4 | 1978362.1 |
| torch.Size([256, 1024, 1024]) | 5407082.8 | 3957738.8 |
Comparison for batches of 512x512
Times are in milliseconds (ms).
| | cuBLAS | cuSOLVER |
|-----------------------------|--------|----------|
| torch.Size([2, 512, 512]) | 347.3 | 13.7 |
| torch.Size([4, 512, 512]) | 393.9 | 23.3 |
| torch.Size([8, 512, 512]) | 523.6 | 46.5 |
| torch.Size([16, 512, 512]) | 536.5 | 93.0 |
| torch.Size([32, 512, 512]) | 549.9 | 186.0 |
| torch.Size([64, 512, 512]) | 564.3 | 374.8 |
| torch.Size([96, 512, 512]) | 576.9 | 562.1 |
| torch.Size([128, 512, 512]) | 594.0 | 749.3 |
| torch.Size([256, 512, 512]) | 676.1 | 1498.5 |
| torch.Size([96, 512, 512]) | 573.2 | 555.0 |
| torch.Size([97, 512, 512]) | 574.3 | 560.7 |
| torch.Size([98, 512, 512]) | 575.2 | 566.5 |
| torch.Size([99, 512, 512]) | 576.2 | 572.3 |
| torch.Size([100, 512, 512]) | 576.0 | 581.7 |
| torch.Size([101, 512, 512]) | 577.0 | 587.3 |
| torch.Size([102, 512, 512]) | 576.9 | 593.1 |
| torch.Size([103, 512, 512]) | 589.2 | 598.9 |
| torch.Size([104, 512, 512]) | 579.9 | 605.0 |
| torch.Size([105, 512, 512]) | 580.3 | 610.8 |
| torch.Size([106, 512, 512]) | 581.0 | 616.6 |
| torch.Size([107, 512, 512]) | 581.6 | 622.4 |
| torch.Size([108, 512, 512]) | 581.8 | 628.1 |
| torch.Size([109, 512, 512]) | 582.4 | 634.0 |
Comparison for batches of 256x256
Times are in milliseconds (ms).
| | cuBLAS | cuSOLVER |
|----------------------------|--------|----------|
| torch.Size([2, 256, 256]) | 72.7 | 6.5 |
| torch.Size([3, 256, 256]) | 63.4 | 7.0 |
| torch.Size([4, 256, 256]) | 63.9 | 9.4 |
| torch.Size([5, 256, 256]) | 65.1 | 11.7 |
| torch.Size([6, 256, 256]) | 66.9 | 14.0 |
| torch.Size([7, 256, 256]) | 69.5 | 16.4 |
| torch.Size([8, 256, 256]) | 71.6 | 18.8 |
| torch.Size([9, 256, 256]) | 71.5 | 21.1 |
| torch.Size([10, 256, 256]) | 70.8 | 23.5 |
| torch.Size([11, 256, 256]) | 71.7 | 26.0 |
| torch.Size([12, 256, 256]) | 72.0 | 28.4 |
| torch.Size([13, 256, 256]) | 74.6 | 30.7 |
| torch.Size([14, 256, 256]) | 76.4 | 33.1 |
| torch.Size([15, 256, 256]) | 77.3 | 35.4 |
| torch.Size([16, 256, 256]) | 78.1 | 37.8 |
| torch.Size([17, 256, 256]) | 78.5 | 40.2 |
| torch.Size([18, 256, 256]) | 79.2 | 42.5 |
| torch.Size([19, 256, 256]) | 79.6 | 44.9 |
| torch.Size([20, 256, 256]) | 80.4 | 47.2 |
| torch.Size([21, 256, 256]) | 80.6 | 49.6 |
| torch.Size([22, 256, 256]) | 81.5 | 52.0 |
| torch.Size([23, 256, 256]) | 81.8 | 54.3 |
| torch.Size([24, 256, 256]) | 81.8 | 56.7 |
| torch.Size([25, 256, 256]) | 82.4 | 59.0 |
| torch.Size([26, 256, 256]) | 81.7 | 61.4 |
| torch.Size([27, 256, 256]) | 82.1 | 63.7 |
| torch.Size([28, 256, 256]) | 81.7 | 66.1 |
| torch.Size([29, 256, 256]) | 82.8 | 68.4 |
| torch.Size([30, 256, 256]) | 82.8 | 70.8 |
| torch.Size([31, 256, 256]) | 83.2 | 73.2 |
| torch.Size([32, 256, 256]) | 83.5 | 75.6 |
| torch.Size([33, 256, 256]) | 83.5 | 77.3 |
| torch.Size([34, 256, 256]) | 82.9 | 79.5 |
| torch.Size([35, 256, 256]) | 83.0 | 81.9 |
| torch.Size([36, 256, 256]) | 83.1 | 84.3 |
| torch.Size([37, 256, 256]) | 83.3 | 86.6 |
| torch.Size([38, 256, 256]) | 83.4 | 88.9 |
| torch.Size([39, 256, 256]) | 83.5 | 91.3 |
| torch.Size([40, 256, 256]) | 84.2 | 93.6 |
| torch.Size([41, 256, 256]) | 84.0 | 95.9 |
| torch.Size([42, 256, 256]) | 84.0 | 98.3 |
| torch.Size([43, 256, 256]) | 84.0 | 100.6 |
| torch.Size([44, 256, 256]) | 84.1 | 103.4 |
| torch.Size([45, 256, 256]) | 84.4 | 106.1 |
| torch.Size([46, 256, 256]) | 84.9 | 108.5 |
| torch.Size([47, 256, 256]) | 85.2 | 110.8 |
| torch.Size([48, 256, 256]) | 85.3 | 113.3 |
| torch.Size([49, 256, 256]) | 85.3 | 115.6 |
Comparison for batches of 128x128
Times are in milliseconds (ms).
| | cuBLAS | cuSOLVER |
|----------------------------|--------|----------|
| torch.Size([2, 128, 128]) | 17.7 | 4.3 |
| torch.Size([3, 128, 128]) | 12.9 | 4.6 |
| torch.Size([4, 128, 128]) | 13.2 | 6.2 |
| torch.Size([5, 128, 128]) | 13.5 | 7.7 |
| torch.Size([6, 128, 128]) | 13.7 | 9.3 |
| torch.Size([7, 128, 128]) | 14.1 | 10.8 |
| torch.Size([8, 128, 128]) | 14.5 | 12.4 |
| torch.Size([9, 128, 128]) | 14.5 | 13.9 |
| torch.Size([10, 128, 128]) | 14.5 | 15.5 |
| torch.Size([11, 128, 128]) | 14.5 | 17.0 |
| torch.Size([12, 128, 128]) | 14.5 | 18.6 |
| torch.Size([13, 128, 128]) | 14.5 | 20.1 |
| torch.Size([14, 128, 128]) | 14.5 | 21.6 |
| torch.Size([15, 128, 128]) | 14.5 | 23.2 |
There was a problem hiding this comment.
Thanks! That's very helpful.
Summary: Pull Request resolved: pytorch#56253 `geqrfBatched` from cuBLAS is used if ``` (input.size(-2) <= 256 && batchCount(input) >= std::max<int64_t>(2, input.size(-2) / 16)) ``` Test Plan: Imported from OSS Reviewed By: ngimel Differential Revision: D27960156 Pulled By: mruberry fbshipit-source-id: 3e438eff01cbf7c7e075fb7aef709b97698a4650
Summary: Pull Request resolved: pytorch#56253 `geqrfBatched` from cuBLAS is used if ``` (input.size(-2) <= 256 && batchCount(input) >= std::max<int64_t>(2, input.size(-2) / 16)) ``` Test Plan: Imported from OSS Reviewed By: ngimel Differential Revision: D27960156 Pulled By: mruberry fbshipit-source-id: 3e438eff01cbf7c7e075fb7aef709b97698a4650
Summary: Pull Request resolved: pytorch#56253 `geqrfBatched` from cuBLAS is used if ``` (input.size(-2) <= 256 && batchCount(input) >= std::max<int64_t>(2, input.size(-2) / 16)) ``` Test Plan: Imported from OSS Reviewed By: ngimel Differential Revision: D27960156 Pulled By: mruberry fbshipit-source-id: 3e438eff01cbf7c7e075fb7aef709b97698a4650
Stack from ghstack:
geqrfBatchedfrom cuBLAS is used ifDifferential Revision: D27960156