Skip to content

Add cuBLAS path for batched torch.geqrf#56253

Closed
IvanYashchuk wants to merge 12 commits intogh/ivanyashchuk/14/basefrom
gh/ivanyashchuk/14/head
Closed

Add cuBLAS path for batched torch.geqrf#56253
IvanYashchuk wants to merge 12 commits intogh/ivanyashchuk/14/basefrom
gh/ivanyashchuk/14/head

Conversation

@IvanYashchuk
Copy link
Copy Markdown
Collaborator

@IvanYashchuk IvanYashchuk commented Apr 16, 2021

Stack from ghstack:

geqrfBatched from cuBLAS is used if

(input.size(-2) <= 256 && batchCount(input) >= std::max<int64_t>(2, input.size(-2) / 16))

Differential Revision: D27960156

`geqrfBatched` from cuBLAS is used if
```
(input.size(-2) <= 256 && batchCount(input) >= std::max<int64_t>(2, input.size(-2) / 16))
```

[ghstack-poisoned]
@facebook-github-bot
Copy link
Copy Markdown
Contributor

facebook-github-bot commented Apr 16, 2021

💊 CI failures summary and remediations

As of commit 7d029cf (more details on the Dr. CI page):


💚 💚 Looks good so far! There are no failures yet. 💚 💚


This comment was automatically generated by Dr. CI (expand for details).Follow this link to opt-out of these comments for your Pull Requests.

Please report bugs/suggestions to the (internal) Dr. CI Users group.

IvanYashchuk added a commit to IvanYashchuk/pytorch that referenced this pull request Apr 16, 2021
`geqrfBatched` from cuBLAS is used if
```
(input.size(-2) <= 256 && batchCount(input) >= std::max<int64_t>(2, input.size(-2) / 16))
```

ghstack-source-id: e4add76
Pull Request resolved: pytorch#56253
@IvanYashchuk IvanYashchuk requested a review from xwang233 April 16, 2021 10:02
@IvanYashchuk IvanYashchuk added the module: linear algebra Issues related to specialized linear algebra operations in PyTorch; includes matrix multiply matmul label Apr 16, 2021
@IvanYashchuk IvanYashchuk requested a review from mruberry April 16, 2021 10:02
`geqrfBatched` from cuBLAS is used if
```
(input.size(-2) <= 256 && batchCount(input) >= std::max<int64_t>(2, input.size(-2) / 16))
```

Ref. #47953

[ghstack-poisoned]
IvanYashchuk added a commit to IvanYashchuk/pytorch that referenced this pull request Apr 16, 2021
`geqrfBatched` from cuBLAS is used if
```
(input.size(-2) <= 256 && batchCount(input) >= std::max<int64_t>(2, input.size(-2) / 16))
```

ghstack-source-id: 88039ff
Pull Request resolved: pytorch#56253
Copy link
Copy Markdown
Collaborator

@xwang233 xwang233 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the PR! This overall looks good. I have left some comments.


// cuBLAS batched geqrf requires input to be the device array of pointers to device single matrices
Tensor input_ptr_array = get_device_pointers<scalar_t>(input);
Tensor tau_ptr_array = get_device_pointers<scalar_t>(tau.unsqueeze(-1));
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmm, why is there an unsqueeze call?

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That's a workaround as get_device_pointers works only for array of matrices and tau is a vector.

at::cuda::blas::geqrfBatched(handle, m, n, input_ptr_array_data, lda, tau_ptr_array_data, &info, batch_size);

// info only indicates wrong arguments to geqrfBatched call
TORCH_INTERNAL_ASSERT(info == 0);
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe add a comment here saying info is a host variable?

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sure, will do.

`geqrfBatched` from cuBLAS is used if
```
(input.size(-2) <= 256 && batchCount(input) >= std::max<int64_t>(2, input.size(-2) / 16))
```

Ref. #47953

[ghstack-poisoned]
`geqrfBatched` from cuBLAS is used if
```
(input.size(-2) <= 256 && batchCount(input) >= std::max<int64_t>(2, input.size(-2) / 16))
```

Ref. #47953

[ghstack-poisoned]
`geqrfBatched` from cuBLAS is used if
```
(input.size(-2) <= 256 && batchCount(input) >= std::max<int64_t>(2, input.size(-2) / 16))
```

Ref. #47953

[ghstack-poisoned]
`geqrfBatched` from cuBLAS is used if
```
(input.size(-2) <= 256 && batchCount(input) >= std::max<int64_t>(2, input.size(-2) / 16))
```

Ref. #47953

[ghstack-poisoned]
`geqrfBatched` from cuBLAS is used if
```
(input.size(-2) <= 256 && batchCount(input) >= std::max<int64_t>(2, input.size(-2) / 16))
```

Ref. #47953

[ghstack-poisoned]
IvanYashchuk added a commit to IvanYashchuk/pytorch that referenced this pull request Apr 19, 2021
`geqrfBatched` from cuBLAS is used if
```
(input.size(-2) <= 256 && batchCount(input) >= std::max<int64_t>(2, input.size(-2) / 16))
```

ghstack-source-id: b5c8e4f
Pull Request resolved: pytorch#56253
`geqrfBatched` from cuBLAS is used if
```
(input.size(-2) <= 256 && batchCount(input) >= std::max<int64_t>(2, input.size(-2) / 16))
```

Ref. #47953

[ghstack-poisoned]
IvanYashchuk added a commit to IvanYashchuk/pytorch that referenced this pull request Apr 19, 2021
`geqrfBatched` from cuBLAS is used if
```
(input.size(-2) <= 256 && batchCount(input) >= std::max<int64_t>(2, input.size(-2) / 16))
```

ghstack-source-id: f42d3c3
Pull Request resolved: pytorch#56253
`geqrfBatched` from cuBLAS is used if
```
(input.size(-2) <= 256 && batchCount(input) >= std::max<int64_t>(2, input.size(-2) / 16))
```



[ghstack-poisoned]
`geqrfBatched` from cuBLAS is used if
```
(input.size(-2) <= 256 && batchCount(input) >= std::max<int64_t>(2, input.size(-2) / 16))
```



[ghstack-poisoned]
IvanYashchuk added a commit to IvanYashchuk/pytorch that referenced this pull request Apr 20, 2021
`geqrfBatched` from cuBLAS is used if
```
(input.size(-2) <= 256 && batchCount(input) >= std::max<int64_t>(2, input.size(-2) / 16))
```

ghstack-source-id: 4c5917b
Pull Request resolved: pytorch#56253
Comment thread aten/src/ATen/cuda/CUDABlas.h Outdated

template <class Dtype>
void geqrfBatched(CUDABLAS_GEQRF_BATCHED_ARGTYPES(Dtype)) {
TORCH_CHECK(
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

TORCH_INTERNAL_ASSERT? Otherwise the function name should be user-facing

Copy link
Copy Markdown
Collaborator

@mruberry mruberry left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just one small macro/error message nit

`geqrfBatched` from cuBLAS is used if
```
(input.size(-2) <= 256 && batchCount(input) >= std::max<int64_t>(2, input.size(-2) / 16))
```

Differential Revision: [D27960156](https://our.internmc.facebook.com/intern/diff/D27960156)

[ghstack-poisoned]
`geqrfBatched` from cuBLAS is used if
```
(input.size(-2) <= 256 && batchCount(input) >= std::max<int64_t>(2, input.size(-2) / 16))
```

Differential Revision: [D27960156](https://our.internmc.facebook.com/intern/diff/D27960156)

[ghstack-poisoned]
IvanYashchuk added a commit to IvanYashchuk/pytorch that referenced this pull request Apr 26, 2021
`geqrfBatched` from cuBLAS is used if
```
(input.size(-2) <= 256 && batchCount(input) >= std::max<int64_t>(2, input.size(-2) / 16))
```

ghstack-source-id: 59d9c99
Pull Request resolved: pytorch#56253
@facebook-github-bot
Copy link
Copy Markdown
Contributor

@mruberry merged this pull request in 5b1f0ef.

Comment on lines +2068 to +2076
void geqrf_kernel(const Tensor& input, const Tensor& tau, int64_t m, int64_t n) {
// if number of rows is smaller than 32 batched is always faster for batch size > 1
// for larger number of rows number of batches condition
if (input.size(-2) <= 256 && batchCount(input) >= std::max<int64_t>(2, input.size(-2) / 16)) {
return geqrf_batched(input, tau, m, n);
} else {
return geqrf_looped(input, tau, m, n);
}
}
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi @IvanYashchuk , do you have a benchmark table for this heuristic? I can report to cusolver team to check the performance of cublas<T>geqrfBatched comparing with looped cusolver geqrf.

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sure, here are the comparison tables. All results are for float64.

Batches of 2-256 for sizes 2x2 to 1024x1024
Times are in microseconds (us).
|                               | cuBLAS    | cuSOLVER  |
|-------------------------------|-----------|-----------|
| torch.Size([2, 2])            | 37.5      | 44.1      |
| torch.Size([2, 2, 2])         | 24.1      | 30.3      |
| torch.Size([32, 2, 2])        | 24.8      | 421.3     |
| torch.Size([64, 2, 2])        | 25.2      | 840.2     |
| torch.Size([128, 2, 2])       | 24.7      | 1675.9    |
| torch.Size([256, 2, 2])       | 24.3      | 3348.5    |
| torch.Size([8, 8])            | 44.4      | 62.9      |
| torch.Size([2, 8, 8])         | 44.9      | 122.7     |
| torch.Size([32, 8, 8])        | 60.5      | 1909.0    |
| torch.Size([64, 8, 8])        | 60.5      | 3813.4    |
| torch.Size([128, 8, 8])       | 60.5      | 7625.6    |
| torch.Size([256, 8, 8])       | 60.6      | 15240.0   |
| torch.Size([16, 16])          | 156.5     | 126.5     |
| torch.Size([2, 16, 16])       | 157.6     | 249.4     |
| torch.Size([32, 16, 16])      | 206.0     | 3929.7    |
| torch.Size([64, 16, 16])      | 206.1     | 7856.0    |
| torch.Size([128, 16, 16])     | 206.2     | 15708.0   |
| torch.Size([256, 16, 16])     | 206.7     | 31410.1   |
| torch.Size([32, 32])          | 619.6     | 257.9     |
| torch.Size([2, 32, 32])       | 629.4     | 512.1     |
| torch.Size([32, 32, 32])      | 784.6     | 8136.3    |
| torch.Size([64, 32, 32])      | 784.9     | 16266.3   |
| torch.Size([128, 32, 32])     | 786.4     | 32655.6   |
| torch.Size([256, 32, 32])     | 789.3     | 65579.5   |
| torch.Size([64, 64])          | 2381.0    | 661.9     |
| torch.Size([2, 64, 64])       | 2448.6    | 1319.5    |
| torch.Size([32, 64, 64])      | 3096.4    | 21052.8   |
| torch.Size([64, 64, 64])      | 3101.0    | 42113.2   |
| torch.Size([128, 64, 64])     | 3115.0    | 84217.7   |
| torch.Size([256, 64, 64])     | 3466.1    | 168483.9  |
| torch.Size([128, 128])        | 11816.0   | 1572.4    |
| torch.Size([2, 128, 128])     | 12990.9   | 3140.5    |
| torch.Size([32, 128, 128])    | 14793.2   | 50225.2   |
| torch.Size([64, 128, 128])    | 16164.0   | 100416.3  |
| torch.Size([128, 128, 128])   | 17637.3   | 200814.2  |
| torch.Size([256, 128, 128])   | 18670.9   | 401571.3  |
| torch.Size([256, 256])        | 56018.7   | 2377.1    |
| torch.Size([2, 256, 256])     | 63733.5   | 4744.8    |
| torch.Size([32, 256, 256])    | 84046.0   | 76058.5   |
| torch.Size([64, 256, 256])    | 87702.1   | 152151.9  |
| torch.Size([128, 256, 256])   | 92019.4   | 304297.9  |
| torch.Size([256, 256, 256])   | 99104.9   | 609614.2  |
| torch.Size([512, 512])        | 302886.9  | 5869.6    |
| torch.Size([2, 512, 512])     | 352219.2  | 11742.1   |
| torch.Size([32, 512, 512])    | 556534.2  | 187686.5  |
| torch.Size([64, 512, 512])    | 571038.7  | 375346.6  |
| torch.Size([128, 512, 512])   | 597674.4  | 750685.9  |
| torch.Size([256, 512, 512])   | 679707.7  | 1501277.7 |
| torch.Size([1024, 1024])      | 1842582.2 | 15375.0   |
| torch.Size([2, 1024, 1024])   | 2598219.3 | 30746.9   |
| torch.Size([32, 1024, 1024])  | 4038294.7 | 491840.1  |
| torch.Size([64, 1024, 1024])  | 4096143.4 | 983641.5  |
| torch.Size([128, 1024, 1024]) | 4216505.4 | 1978362.1 |
| torch.Size([256, 1024, 1024]) | 5407082.8 | 3957738.8 |
Comparison for batches of 512x512
Times are in milliseconds (ms).
|                             | cuBLAS | cuSOLVER |
|-----------------------------|--------|----------|
| torch.Size([2, 512, 512])   |  347.3 | 13.7     |
| torch.Size([4, 512, 512])   | 393.9  | 23.3     |
| torch.Size([8, 512, 512])   | 523.6  | 46.5     |
| torch.Size([16, 512, 512])  | 536.5  | 93.0     |
| torch.Size([32, 512, 512])  | 549.9  | 186.0    |
| torch.Size([64, 512, 512])  | 564.3  | 374.8    |
| torch.Size([96, 512, 512])  | 576.9  | 562.1    |
| torch.Size([128, 512, 512]) | 594.0  | 749.3    |
| torch.Size([256, 512, 512]) | 676.1  | 1498.5   |
| torch.Size([96, 512, 512])  | 573.2  | 555.0    |
| torch.Size([97, 512, 512])  | 574.3  | 560.7    |
| torch.Size([98, 512, 512])  | 575.2  | 566.5    |
| torch.Size([99, 512, 512])  | 576.2  | 572.3    |
| torch.Size([100, 512, 512]) | 576.0  | 581.7    |
| torch.Size([101, 512, 512]) | 577.0  | 587.3    |
| torch.Size([102, 512, 512]) | 576.9  | 593.1    |
| torch.Size([103, 512, 512]) | 589.2  | 598.9    |
| torch.Size([104, 512, 512]) | 579.9  | 605.0    |
| torch.Size([105, 512, 512]) | 580.3  | 610.8    |
| torch.Size([106, 512, 512]) | 581.0  | 616.6    |
| torch.Size([107, 512, 512]) | 581.6  | 622.4    |
| torch.Size([108, 512, 512]) | 581.8  | 628.1    |
| torch.Size([109, 512, 512]) | 582.4  | 634.0    |
Comparison for batches of 256x256
Times are in milliseconds (ms).

|                            | cuBLAS | cuSOLVER |
|----------------------------|--------|----------|
| torch.Size([2, 256, 256])  | 72.7   | 6.5      |
| torch.Size([3, 256, 256])  | 63.4   | 7.0      |
| torch.Size([4, 256, 256])  | 63.9   | 9.4      |
| torch.Size([5, 256, 256])  | 65.1   | 11.7     |
| torch.Size([6, 256, 256])  | 66.9   | 14.0     |
| torch.Size([7, 256, 256])  | 69.5   | 16.4     |
| torch.Size([8, 256, 256])  | 71.6   | 18.8     |
| torch.Size([9, 256, 256])  | 71.5   | 21.1     |
| torch.Size([10, 256, 256]) | 70.8   | 23.5     |
| torch.Size([11, 256, 256]) | 71.7   | 26.0     |
| torch.Size([12, 256, 256]) | 72.0   | 28.4     |
| torch.Size([13, 256, 256]) | 74.6   | 30.7     |
| torch.Size([14, 256, 256]) | 76.4   | 33.1     |
| torch.Size([15, 256, 256]) | 77.3   | 35.4     |
| torch.Size([16, 256, 256]) | 78.1   | 37.8     |
| torch.Size([17, 256, 256]) | 78.5   | 40.2     |
| torch.Size([18, 256, 256]) | 79.2   | 42.5     |
| torch.Size([19, 256, 256]) | 79.6   | 44.9     |
| torch.Size([20, 256, 256]) | 80.4   | 47.2     |
| torch.Size([21, 256, 256]) | 80.6   | 49.6     |
| torch.Size([22, 256, 256]) | 81.5   | 52.0     |
| torch.Size([23, 256, 256]) | 81.8   | 54.3     |
| torch.Size([24, 256, 256]) | 81.8   | 56.7     |
| torch.Size([25, 256, 256]) | 82.4   | 59.0     |
| torch.Size([26, 256, 256]) | 81.7   | 61.4     |
| torch.Size([27, 256, 256]) | 82.1   | 63.7     |
| torch.Size([28, 256, 256]) | 81.7   | 66.1     |
| torch.Size([29, 256, 256]) | 82.8   | 68.4     |
| torch.Size([30, 256, 256]) | 82.8   | 70.8     |
| torch.Size([31, 256, 256]) | 83.2   | 73.2     |
| torch.Size([32, 256, 256]) | 83.5   | 75.6     |
| torch.Size([33, 256, 256]) | 83.5   | 77.3     |
| torch.Size([34, 256, 256]) | 82.9   | 79.5     |
| torch.Size([35, 256, 256]) | 83.0   | 81.9     |
| torch.Size([36, 256, 256]) | 83.1   | 84.3     |
| torch.Size([37, 256, 256]) | 83.3   | 86.6     |
| torch.Size([38, 256, 256]) | 83.4   | 88.9     |
| torch.Size([39, 256, 256]) | 83.5   | 91.3     |
| torch.Size([40, 256, 256]) | 84.2   | 93.6     |
| torch.Size([41, 256, 256]) | 84.0   | 95.9     |
| torch.Size([42, 256, 256]) | 84.0   | 98.3     |
| torch.Size([43, 256, 256]) | 84.0   | 100.6    |
| torch.Size([44, 256, 256]) | 84.1   | 103.4    |
| torch.Size([45, 256, 256]) | 84.4   | 106.1    |
| torch.Size([46, 256, 256]) | 84.9   | 108.5    |
| torch.Size([47, 256, 256]) | 85.2   | 110.8    |
| torch.Size([48, 256, 256]) | 85.3   | 113.3    |
| torch.Size([49, 256, 256]) | 85.3   | 115.6    |
Comparison for batches of 128x128
Times are in milliseconds (ms).

|                            | cuBLAS | cuSOLVER |
|----------------------------|--------|----------|
| torch.Size([2, 128, 128])  | 17.7   | 4.3      |
| torch.Size([3, 128, 128])  | 12.9   | 4.6      |
| torch.Size([4, 128, 128])  | 13.2   | 6.2      |
| torch.Size([5, 128, 128])  | 13.5   | 7.7      |
| torch.Size([6, 128, 128])  | 13.7   | 9.3      |
| torch.Size([7, 128, 128])  | 14.1   | 10.8     |
| torch.Size([8, 128, 128])  | 14.5   | 12.4     |
| torch.Size([9, 128, 128])  | 14.5   | 13.9     |
| torch.Size([10, 128, 128]) | 14.5   | 15.5     |
| torch.Size([11, 128, 128]) | 14.5   | 17.0     |
| torch.Size([12, 128, 128]) | 14.5   | 18.6     |
| torch.Size([13, 128, 128]) | 14.5   | 20.1     |
| torch.Size([14, 128, 128]) | 14.5   | 21.6     |
| torch.Size([15, 128, 128]) | 14.5   | 23.2     |

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks! That's very helpful.

@facebook-github-bot facebook-github-bot deleted the gh/ivanyashchuk/14/head branch April 30, 2021 14:16
crcrpar pushed a commit to crcrpar/pytorch that referenced this pull request May 7, 2021
Summary:
Pull Request resolved: pytorch#56253

`geqrfBatched` from cuBLAS is used if
```
(input.size(-2) <= 256 && batchCount(input) >= std::max<int64_t>(2, input.size(-2) / 16))
```

Test Plan: Imported from OSS

Reviewed By: ngimel

Differential Revision: D27960156

Pulled By: mruberry

fbshipit-source-id: 3e438eff01cbf7c7e075fb7aef709b97698a4650
krshrimali pushed a commit to krshrimali/pytorch that referenced this pull request May 19, 2021
Summary:
Pull Request resolved: pytorch#56253

`geqrfBatched` from cuBLAS is used if
```
(input.size(-2) <= 256 && batchCount(input) >= std::max<int64_t>(2, input.size(-2) / 16))
```

Test Plan: Imported from OSS

Reviewed By: ngimel

Differential Revision: D27960156

Pulled By: mruberry

fbshipit-source-id: 3e438eff01cbf7c7e075fb7aef709b97698a4650
laurentdupin pushed a commit to laurentdupin/pytorch that referenced this pull request Apr 25, 2026
Summary:
Pull Request resolved: pytorch#56253

`geqrfBatched` from cuBLAS is used if
```
(input.size(-2) <= 256 && batchCount(input) >= std::max<int64_t>(2, input.size(-2) / 16))
```

Test Plan: Imported from OSS

Reviewed By: ngimel

Differential Revision: D27960156

Pulled By: mruberry

fbshipit-source-id: 3e438eff01cbf7c7e075fb7aef709b97698a4650
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

cla signed Merged module: linear algebra Issues related to specialized linear algebra operations in PyTorch; includes matrix multiply matmul open source

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants