[PyTorch ] Thread parallel bmm across batch dim#59596
[PyTorch ] Thread parallel bmm across batch dim#59596kimishpatel wants to merge 6 commits intogh/kimishpatel/69/basefrom
Conversation
Parallelize batch matmul across batch dim. This was found to improve perf for some usecases on mobile. Differential Revision: [D26833417](https://our.internmc.facebook.com/intern/diff/D26833417/) [ghstack-poisoned]
💊 CI failures summary and remediationsAs of commit 53d36ed (more details on the Dr. CI page): 💚 💚 Looks good so far! There are no failures yet. 💚 💚 This comment was automatically generated by Dr. CI (expand for details).Follow this link to opt-out of these comments for your Pull Requests.Please report bugs/suggestions to the (internal) Dr. CI Users group. |
Parallelize batch matmul across batch dim. This was found to improve perf for some usecases on mobile. Differential Revision: [D26833417](https://our.internmc.facebook.com/intern/diff/D26833417/) ghstack-source-id: 130770723 Pull Request resolved: #59596
ailzhang
left a comment
There was a problem hiding this comment.
Stamping since this PR guarantees correctness, although we should keep in mind at::parallel_for in its current shape doesn't support preserving TLS state. So this hack should be discouraged in new call sites.
| auto r = self_or_result.select(0, b); | ||
| addmm_impl_cpu_(r, r, batch1.select(0, b), batch2.select(0, b), 0, 1); | ||
| /* | ||
| * Inference mode multithreading is done because various thread local |
There was a problem hiding this comment.
Would you mind also opening an issue and documenting what's the new API would do there? :D
There was a problem hiding this comment.
Also I'd appreciate a big warning here: if someone wants to add another instance of this they should do it properly by using the new API. :D
|
The BLAS library that is called inside Also is this issue due to the fact that the BLAS used on mobile is not multithreaded? If that's the case, wouldn't we prefer to add this to this BLAS only and not at such a high level? |
@albanD is this for oss build? If so can you point me to cmake config? We can disable this if you want for server side. Also we most likely use MKL path. On the other hand however having BLAS have its own threadpool along with pytorch its own threadpool will already be detrimental to performance.
On mobile side yes we dont use BLAS's multithreading for the same reason that too many threads can cause perf regression as threads may compete for scheduling. |
At least yes, I don't know which BLAS lib is used internally.
I am not sure what you're looking for. The user control that parallelism with
That depends on what you compiled with.
That's a know issue. But in practice, the pytorch threadpool is rarely used and so the BLAS threadpool is the only one doing actual work.
Would you be able to make this change inside the mobile blas binding directly to avoid having divergent callpath for mobile/non-mobile in this higher level function? |
@albanD I dont follow this. In the case of mobile 1) we dont want to have multithreaded blas lib and 2) we want to use the threadpool that is used by other ops, specifically convs/linear layers. |
|
Ho sorry, I was wondering if there isn't a better place to put this code in this stack. In particular, if you do it lower (when there are no more Tensor ops), you don't need all the special logic to handle the fact that parallel_for does not propagate TLS. |
I tried implementation which does not require propagation of TLS in D28873227. But it excludes other things that can be parallelized. For example select op etc. This results in cutting the benefit in half. Also from the discussion it seems that we dont want paralllelization on non-mobile builds. |
I don't think we do, but that would still simplify the mobile only code :) |
So I can make this mobile only. But what would still simplify the mobile only code? As I said doing this at lower level is not gonna be perf win as much. |
Parallelize batch matmul across batch dim. This was found to improve perf for some usecases on mobile. Differential Revision: [D26833417](https://our.internmc.facebook.com/intern/diff/D26833417/) [ghstack-poisoned]
Pull Request resolved: #59596 Parallelize batch matmul across batch dim. This was found to improve perf for some usecases on mobile. ghstack-source-id: 130899667 Differential Revision: [D26833417](https://our.internmc.facebook.com/intern/diff/D26833417/)
Parallelize batch matmul across batch dim. This was found to improve perf for some usecases on mobile. Benchmarking results: BMM benchmark via operator benchmark on Samsung S8 US phone: B | M | N | K | | No Threading | Threads=4 -- | -- | -- | -- | -- | -- | -- 2 | 8 | 256 | 16 | 8448 | 0.062 | 0.095 2 | 8 | 256 | 32 | 16896 | 0.077 | 0.099 2 | 8 | 16 | 16 | 768 | 0.048 | 0.07 2 | 8 | 16 | 32 | 1536 | 0.048 | 0.075 2 | 256 | 256 | 16 | 16384 | 0.476 | 0.27 2 | 256 | 256 | 32 | 32768 | 0.79 | 0.477 2 | 256 | 16 | 16 | 8704 | 0.08 | 0.101 2 | 256 | 16 | 32 | 17408 | 0.11 | 0.134 100 | 8 | 256 | 16 | 422400 | 2.269 | 0.6945 100 | 8 | 256 | 32 | 844800 | 3.111 | 0.92 100 | 8 | 16 | 16 | 38400 | 0.936 | 0.425 100 | 8 | 16 | 32 | 76800 | 0.987 | 0.45 100 | 256 | 256 | 16 | 819200 | 27.601 | 10.7545 100 | 256 | 256 | 32 | 1638400 | 43.088 | 13.756 100 | 256 | 16 | 16 | 435200 | 2.734 | 0.802 100 | 256 | 16 | 32 | 870400 | 3.9125 | 1.174 | | | | | | 4 | 32 | 32 | 16 | 4096 | 0.082 | 0.08 4 | 32 | 32 | 32 | 8192 | 0.091 | 0.082 4 | 32 | 64 | 16 | 6144 | 0.092 | 0.101 4 | 32 | 64 | 32 | 12288 | 0.114 | 0.093 4 | 64 | 32 | 16 | 6144 | 0.0985 | 0.0845 4 | 64 | 32 | 32 | 12288 | 0.116 | 0.101 4 | 64 | 64 | 16 | 8192 | 0.119 | 0.093 4 | 64 | 64 | 32 | 16384 | 0.161 | 0.127 8 | 32 | 32 | 16 | 8192 | 0.129 | 0.1 8 | 32 | 32 | 32 | 16384 | 0.153 | 0.103 8 | 32 | 64 | 16 | 12288 | 0.155 | 0.117 8 | 32 | 64 | 32 | 24576 | 0.199 | 0.1465 8 | 64 | 32 | 16 | 12288 | 0.159 | 0.116 8 | 64 | 32 | 32 | 24576 | 0.204 | 0.138 8 | 64 | 64 | 16 | 16384 | 0.209 | 0.156 8 | 64 | 64 | 32 | 32768 | 0.292 | 0.256 Differential Revision: [D26833417](https://our.internmc.facebook.com/intern/diff/D26833417/) [ghstack-poisoned]
Pull Request resolved: #59596 Parallelize batch matmul across batch dim. This was found to improve perf for some usecases on mobile. ghstack-source-id: 130946745 Differential Revision: [D26833417](https://our.internmc.facebook.com/intern/diff/D26833417/)
albanD
left a comment
There was a problem hiding this comment.
Thanks for the update.
I think we can remove a bit of code duplication but the overall idea (except the &=) looks ok to me.
test/test_linalg.py
Outdated
|
|
||
| @onlyCPU | ||
| @torch.inference_mode() | ||
| def test_bmm_multithreaded(self, device): |
There was a problem hiding this comment.
Out of curiosity, how long does this test takes to run?
There was a problem hiding this comment.
I didnt quite check that. Can report back once I do.
There was a problem hiding this comment.
If it takes more than 15s, we should consider marking it as a slow test.
| // bmm_test: operator benchmark under | ||
| // benchmarks/operator_benchmarks/pt/bmm_test.py Ran this benchmark for | ||
| // various matrix sizes on Samsung S8U | ||
| enable_multithreaded_bmm = |
There was a problem hiding this comment.
You want to use &= here no?
There was a problem hiding this comment.
Oh yes. Thanks for the catch.
test/test_linalg.py
Outdated
|
|
||
| @onlyCPU | ||
| @torch.inference_mode() | ||
| def test_bmm_multithreaded(self, device): |
There was a problem hiding this comment.
This file is actually only about torch.linalg.* functions. Could you move this test to test_torch.py?
There was a problem hiding this comment.
Why does it make sense for it to be in test_torch?
There was a problem hiding this comment.
Everything that has no better place goes there. In particular testing for all operators that doesn't fit in any better category goes there.
| bool enable_multithreaded_bmm{false}; | ||
| #ifdef C10_MOBILE | ||
| /* | ||
| * Inference mode multithreading is done because various thread local |
There was a problem hiding this comment.
nit: What you mean here is "We only do multithreading when Inference Mode is enabled"? That could be clarified a bit I think.
There was a problem hiding this comment.
Not sure what you mean. It already clarifies why it is enabled in inference mode. Or do you mean to capture the fact that having autograd engine running in multi-threading can be an issue?
There was a problem hiding this comment.
It is more that "Inference mode multithreading is done because" is not clear to me at first read. I would replace it with "We only do multithreading when Inference Mode is enabled because".
| && self_or_result.is_contiguous()) { | ||
| at::native::_baddbmm_mkl_(self_or_result, batch1, batch2, beta, alpha); | ||
| } else { // split along batch dimension | ||
| bool enable_multithreaded_bmm{false}; |
There was a problem hiding this comment.
Could you move this into the #else of the macro and make it const?
That would allow the compiler to completely prune the code in the conditions below: https://godbolt.org/z/8894Mdxd1
There was a problem hiding this comment.
Oh interesting. I did not realize that without const compiler would not 'const propagate'. Seems strange. But sure I can make the change.
There was a problem hiding this comment.
If you do "-Os" then it simplifies a lot. Nonetheless I will make the change.
There was a problem hiding this comment.
It does with optimization enabled. But not for debug versions from what I see.
Parallelize batch matmul across batch dim. This was found to improve perf for some usecases on mobile. Benchmarking results: BMM benchmark via operator benchmark on Samsung S8 US phone: B | M | N | K | | No Threading | Threads=4 -- | -- | -- | -- | -- | -- | -- 2 | 8 | 256 | 16 | 8448 | 0.062 | 0.095 2 | 8 | 256 | 32 | 16896 | 0.077 | 0.099 2 | 8 | 16 | 16 | 768 | 0.048 | 0.07 2 | 8 | 16 | 32 | 1536 | 0.048 | 0.075 2 | 256 | 256 | 16 | 16384 | 0.476 | 0.27 2 | 256 | 256 | 32 | 32768 | 0.79 | 0.477 2 | 256 | 16 | 16 | 8704 | 0.08 | 0.101 2 | 256 | 16 | 32 | 17408 | 0.11 | 0.134 100 | 8 | 256 | 16 | 422400 | 2.269 | 0.6945 100 | 8 | 256 | 32 | 844800 | 3.111 | 0.92 100 | 8 | 16 | 16 | 38400 | 0.936 | 0.425 100 | 8 | 16 | 32 | 76800 | 0.987 | 0.45 100 | 256 | 256 | 16 | 819200 | 27.601 | 10.7545 100 | 256 | 256 | 32 | 1638400 | 43.088 | 13.756 100 | 256 | 16 | 16 | 435200 | 2.734 | 0.802 100 | 256 | 16 | 32 | 870400 | 3.9125 | 1.174 | | | | | | 4 | 32 | 32 | 16 | 4096 | 0.082 | 0.08 4 | 32 | 32 | 32 | 8192 | 0.091 | 0.082 4 | 32 | 64 | 16 | 6144 | 0.092 | 0.101 4 | 32 | 64 | 32 | 12288 | 0.114 | 0.093 4 | 64 | 32 | 16 | 6144 | 0.0985 | 0.0845 4 | 64 | 32 | 32 | 12288 | 0.116 | 0.101 4 | 64 | 64 | 16 | 8192 | 0.119 | 0.093 4 | 64 | 64 | 32 | 16384 | 0.161 | 0.127 8 | 32 | 32 | 16 | 8192 | 0.129 | 0.1 8 | 32 | 32 | 32 | 16384 | 0.153 | 0.103 8 | 32 | 64 | 16 | 12288 | 0.155 | 0.117 8 | 32 | 64 | 32 | 24576 | 0.199 | 0.1465 8 | 64 | 32 | 16 | 12288 | 0.159 | 0.116 8 | 64 | 32 | 32 | 24576 | 0.204 | 0.138 8 | 64 | 64 | 16 | 16384 | 0.209 | 0.156 8 | 64 | 64 | 32 | 32768 | 0.292 | 0.256 Differential Revision: [D26833417](https://our.internmc.facebook.com/intern/diff/D26833417/) [ghstack-poisoned]
albanD
left a comment
There was a problem hiding this comment.
LGTM
We might want to add the slowtest on top of that if it takes too long to run though.
Parallelize batch matmul across batch dim. This was found to improve perf for some usecases on mobile. Benchmarking results: BMM benchmark via operator benchmark on Samsung S8 US phone: B | M | N | K | | No Threading | Threads=4 -- | -- | -- | -- | -- | -- | -- 2 | 8 | 256 | 16 | 8448 | 0.062 | 0.095 2 | 8 | 256 | 32 | 16896 | 0.077 | 0.099 2 | 8 | 16 | 16 | 768 | 0.048 | 0.07 2 | 8 | 16 | 32 | 1536 | 0.048 | 0.075 2 | 256 | 256 | 16 | 16384 | 0.476 | 0.27 2 | 256 | 256 | 32 | 32768 | 0.79 | 0.477 2 | 256 | 16 | 16 | 8704 | 0.08 | 0.101 2 | 256 | 16 | 32 | 17408 | 0.11 | 0.134 100 | 8 | 256 | 16 | 422400 | 2.269 | 0.6945 100 | 8 | 256 | 32 | 844800 | 3.111 | 0.92 100 | 8 | 16 | 16 | 38400 | 0.936 | 0.425 100 | 8 | 16 | 32 | 76800 | 0.987 | 0.45 100 | 256 | 256 | 16 | 819200 | 27.601 | 10.7545 100 | 256 | 256 | 32 | 1638400 | 43.088 | 13.756 100 | 256 | 16 | 16 | 435200 | 2.734 | 0.802 100 | 256 | 16 | 32 | 870400 | 3.9125 | 1.174 | | | | | | 4 | 32 | 32 | 16 | 4096 | 0.082 | 0.08 4 | 32 | 32 | 32 | 8192 | 0.091 | 0.082 4 | 32 | 64 | 16 | 6144 | 0.092 | 0.101 4 | 32 | 64 | 32 | 12288 | 0.114 | 0.093 4 | 64 | 32 | 16 | 6144 | 0.0985 | 0.0845 4 | 64 | 32 | 32 | 12288 | 0.116 | 0.101 4 | 64 | 64 | 16 | 8192 | 0.119 | 0.093 4 | 64 | 64 | 32 | 16384 | 0.161 | 0.127 8 | 32 | 32 | 16 | 8192 | 0.129 | 0.1 8 | 32 | 32 | 32 | 16384 | 0.153 | 0.103 8 | 32 | 64 | 16 | 12288 | 0.155 | 0.117 8 | 32 | 64 | 32 | 24576 | 0.199 | 0.1465 8 | 64 | 32 | 16 | 12288 | 0.159 | 0.116 8 | 64 | 32 | 32 | 24576 | 0.204 | 0.138 8 | 64 | 64 | 16 | 16384 | 0.209 | 0.156 8 | 64 | 64 | 32 | 32768 | 0.292 | 0.256 Differential Revision: [D26833417](https://our.internmc.facebook.com/intern/diff/D26833417/) [ghstack-poisoned]
Pull Request resolved: #59596 Parallelize batch matmul across batch dim. This was found to improve perf for some usecases on mobile. ghstack-source-id: 130966058 Differential Revision: [D26833417](https://our.internmc.facebook.com/intern/diff/D26833417/)
Parallelize batch matmul across batch dim. This was found to improve perf for some usecases on mobile. Benchmarking results: BMM benchmark via operator benchmark on Samsung S8 US phone: B | M | N | K | | No Threading | Threads=4 -- | -- | -- | -- | -- | -- | -- 2 | 8 | 256 | 16 | 8448 | 0.062 | 0.095 2 | 8 | 256 | 32 | 16896 | 0.077 | 0.099 2 | 8 | 16 | 16 | 768 | 0.048 | 0.07 2 | 8 | 16 | 32 | 1536 | 0.048 | 0.075 2 | 256 | 256 | 16 | 16384 | 0.476 | 0.27 2 | 256 | 256 | 32 | 32768 | 0.79 | 0.477 2 | 256 | 16 | 16 | 8704 | 0.08 | 0.101 2 | 256 | 16 | 32 | 17408 | 0.11 | 0.134 100 | 8 | 256 | 16 | 422400 | 2.269 | 0.6945 100 | 8 | 256 | 32 | 844800 | 3.111 | 0.92 100 | 8 | 16 | 16 | 38400 | 0.936 | 0.425 100 | 8 | 16 | 32 | 76800 | 0.987 | 0.45 100 | 256 | 256 | 16 | 819200 | 27.601 | 10.7545 100 | 256 | 256 | 32 | 1638400 | 43.088 | 13.756 100 | 256 | 16 | 16 | 435200 | 2.734 | 0.802 100 | 256 | 16 | 32 | 870400 | 3.9125 | 1.174 | | | | | | 4 | 32 | 32 | 16 | 4096 | 0.082 | 0.08 4 | 32 | 32 | 32 | 8192 | 0.091 | 0.082 4 | 32 | 64 | 16 | 6144 | 0.092 | 0.101 4 | 32 | 64 | 32 | 12288 | 0.114 | 0.093 4 | 64 | 32 | 16 | 6144 | 0.0985 | 0.0845 4 | 64 | 32 | 32 | 12288 | 0.116 | 0.101 4 | 64 | 64 | 16 | 8192 | 0.119 | 0.093 4 | 64 | 64 | 32 | 16384 | 0.161 | 0.127 8 | 32 | 32 | 16 | 8192 | 0.129 | 0.1 8 | 32 | 32 | 32 | 16384 | 0.153 | 0.103 8 | 32 | 64 | 16 | 12288 | 0.155 | 0.117 8 | 32 | 64 | 32 | 24576 | 0.199 | 0.1465 8 | 64 | 32 | 16 | 12288 | 0.159 | 0.116 8 | 64 | 32 | 32 | 24576 | 0.204 | 0.138 8 | 64 | 64 | 16 | 16384 | 0.209 | 0.156 8 | 64 | 64 | 32 | 32768 | 0.292 | 0.256 Differential Revision: [D26833417](https://our.internmc.facebook.com/intern/diff/D26833417/) [ghstack-poisoned]
Pull Request resolved: #59596 Parallelize batch matmul across batch dim. This was found to improve perf for some usecases on mobile. ghstack-source-id: 130989569 Differential Revision: [D26833417](https://our.internmc.facebook.com/intern/diff/D26833417/)
|
This pull request has been merged in 4f79270. |
Stack from ghstack:
Parallelize batch matmul across batch dim. This was found to improve perf for
some usecases on mobile.
Benchmarking results:
BMM benchmark via operator benchmark on Samsung S8 US phone:
Seems that even at smaller size we are almost on par or little better. Only small batch size (2) seems to be worse.
Differential Revision: D26833417