Skip to content

[PyTorch ] Thread parallel bmm across batch dim#59596

Closed
kimishpatel wants to merge 6 commits intogh/kimishpatel/69/basefrom
gh/kimishpatel/69/head
Closed

[PyTorch ] Thread parallel bmm across batch dim#59596
kimishpatel wants to merge 6 commits intogh/kimishpatel/69/basefrom
gh/kimishpatel/69/head

Conversation

@kimishpatel
Copy link
Contributor

@kimishpatel kimishpatel commented Jun 7, 2021

Stack from ghstack:

Parallelize batch matmul across batch dim. This was found to improve perf for
some usecases on mobile.

Benchmarking results:
BMM benchmark via operator benchmark on Samsung S8 US phone:

B M N K   No Threading Threads=4
2 8 256 16 8448 0.062 0.095
2 8 256 32 16896 0.077 0.099
2 8 16 16 768 0.048 0.07
2 8 16 32 1536 0.048 0.075
2 256 256 16 16384 0.476 0.27
2 256 256 32 32768 0.79 0.477
2 256 16 16 8704 0.08 0.101
2 256 16 32 17408 0.11 0.134
100 8 256 16 422400 2.269 0.6945
100 8 256 32 844800 3.111 0.92
100 8 16 16 38400 0.936 0.425
100 8 16 32 76800 0.987 0.45
100 256 256 16 819200 27.601 10.7545
100 256 256 32 1638400 43.088 13.756
100 256 16 16 435200 2.734 0.802
100 256 16 32 870400 3.9125 1.174
             
4 32 32 16 4096 0.082 0.08
4 32 32 32 8192 0.091 0.082
4 32 64 16 6144 0.092 0.101
4 32 64 32 12288 0.114 0.093
4 64 32 16 6144 0.0985 0.0845
4 64 32 32 12288 0.116 0.101
4 64 64 16 8192 0.119 0.093
4 64 64 32 16384 0.161 0.127
8 32 32 16 8192 0.129 0.1
8 32 32 32 16384 0.153 0.103
8 32 64 16 12288 0.155 0.117
8 32 64 32 24576 0.199 0.1465
8 64 32 16 12288 0.159 0.116
8 64 32 32 24576 0.204 0.138
8 64 64 16 16384 0.209 0.156
8 64 64 32 32768 0.292 0.256

Seems that even at smaller size we are almost on par or little better. Only small batch size (2) seems to be worse.

B M N K   No Threading Threads=4 Speedup
4 4 16 16 1280 0.068 0.065 1.04615385
4 4 32 32 4608 0.068 0.07 0.97142857
4 4 32 64 9216 0.072 0.072 1
4 4 64 32 8704 0.073 0.076 0.96052632
4 4 64 64 17408 0.083 0.08 1.0375
4 8 32 32 5120 0.083 0.072 1.15277778
4 8 32 64 10240 0.079 0.079 1
4 8 64 32 9216 0.091 0.075 1.21333333
4 8 64 64 18432 0.14 0.075 1.86666667

Differential Revision: D26833417

Parallelize batch matmul across batch dim. This was found to improve perf for
some usecases on mobile.

Differential Revision: [D26833417](https://our.internmc.facebook.com/intern/diff/D26833417/)

[ghstack-poisoned]
@facebook-github-bot
Copy link
Contributor

facebook-github-bot commented Jun 7, 2021

💊 CI failures summary and remediations

As of commit 53d36ed (more details on the Dr. CI page):


💚 💚 Looks good so far! There are no failures yet. 💚 💚


This comment was automatically generated by Dr. CI (expand for details).Follow this link to opt-out of these comments for your Pull Requests.

Please report bugs/suggestions to the (internal) Dr. CI Users group.

Click here to manually regenerate this comment.

kimishpatel added a commit that referenced this pull request Jun 7, 2021
Parallelize batch matmul across batch dim. This was found to improve perf for
some usecases on mobile.

Differential Revision: [D26833417](https://our.internmc.facebook.com/intern/diff/D26833417/)

ghstack-source-id: 130770723
Pull Request resolved: #59596
@kimishpatel kimishpatel requested review from ailzhang and ezyang June 7, 2021 21:59
Copy link
Contributor

@ailzhang ailzhang left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Stamping since this PR guarantees correctness, although we should keep in mind at::parallel_for in its current shape doesn't support preserving TLS state. So this hack should be discouraged in new call sites.

auto r = self_or_result.select(0, b);
addmm_impl_cpu_(r, r, batch1.select(0, b), batch2.select(0, b), 0, 1);
/*
* Inference mode multithreading is done because various thread local
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would you mind also opening an issue and documenting what's the new API would do there? :D

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also I'd appreciate a big warning here: if someone wants to add another instance of this they should do it properly by using the new API. :D

@albanD
Copy link
Collaborator

albanD commented Jun 8, 2021

The BLAS library that is called inside addmm_impl_cpu_ is already thread parallelism on CPU. Is it possible that this would lead to oversubscription of threads on machines with multithreaded BLAS (and so big slow down)?

Also is this issue due to the fact that the BLAS used on mobile is not multithreaded? If that's the case, wouldn't we prefer to add this to this BLAS only and not at such a high level?

@kimishpatel
Copy link
Contributor Author

The BLAS library that is called inside addmm_impl_cpu_ is already thread parallelism on CPU. Is it possible that this would lead to oversubscription of threads on machines with multithreaded BLAS (and so big slow down)?

@albanD is this for oss build? If so can you point me to cmake config? We can disable this if you want for server side. Also we most likely use MKL path. On the other hand however having BLAS have its own threadpool along with pytorch its own threadpool will already be detrimental to performance.

Also is this issue due to the fact that the BLAS used on mobile is not multithreaded? If that's the case, wouldn't we prefer to add this to this BLAS only and not at such a high level?

On mobile side yes we dont use BLAS's multithreading for the same reason that too many threads can cause perf regression as threads may compete for scheduling.

@albanD
Copy link
Collaborator

albanD commented Jun 8, 2021

is this for oss build?

At least yes, I don't know which BLAS lib is used internally.

If so can you point me to cmake config?

I am not sure what you're looking for. The user control that parallelism with OMP_NUM_THREADS env variable or https://pytorch.org/docs/stable/generated/torch.set_num_threads.html?highlight=set_num_threads#torch.set_num_threads.

Also we most likely use MKL path.

That depends on what you compiled with.
But even when MKL is used, the codeblock above will only be used in a very narrow case and the code you are changing will be called otherwise (calling again into the non-batched version of MKL).

On the other hand however having BLAS have its own threadpool along with pytorch its own threadpool will already be detrimental to performance.

That's a know issue. But in practice, the pytorch threadpool is rarely used and so the BLAS threadpool is the only one doing actual work.

On mobile side yes we dont use BLAS's multithreading for the same reason that too many threads can cause perf regression as threads may compete for scheduling.

Would you be able to make this change inside the mobile blas binding directly to avoid having divergent callpath for mobile/non-mobile in this higher level function?

@kimishpatel
Copy link
Contributor Author

Would you be able to make this change inside the mobile blas binding directly to avoid having divergent callpath for mobile/non-mobile in this higher level function?

@albanD I dont follow this. In the case of mobile 1) we dont want to have multithreaded blas lib and 2) we want to use the threadpool that is used by other ops, specifically convs/linear layers.
In the case of mobile most of the parallelization across threads happens via pthreadpool which pytorch owns.

@albanD
Copy link
Collaborator

albanD commented Jun 8, 2021

Ho sorry,
While checking the code, the addmm_impl_cpu_ function that you call end up callingat::native::cpublas::gemm, that calls into dgemm_ from whatever was the blas library was available at compile time (or a manual implementation if nothing is provided).

I was wondering if there isn't a better place to put this code in this stack. In particular, if you do it lower (when there are no more Tensor ops), you don't need all the special logic to handle the fact that parallel_for does not propagate TLS.

@kimishpatel
Copy link
Contributor Author

Ho sorry,
While checking the code, the addmm_impl_cpu_ function that you call end up callingat::native::cpublas::gemm, that calls into dgemm_ from whatever was the blas library was available at compile time (or a manual implementation if nothing is provided).

I was wondering if there isn't a better place to put this code in this stack. In particular, if you do it lower (when there are no more Tensor ops), you don't need all the special logic to handle the fact that parallel_for does not propagate TLS.

I tried implementation which does not require propagation of TLS in D28873227. But it excludes other things that can be parallelized. For example select op etc. This results in cutting the benefit in half.

Also from the discussion it seems that we dont want paralllelization on non-mobile builds.

@albanD
Copy link
Collaborator

albanD commented Jun 8, 2021

Also from the discussion it seems that we dont want paralllelization on non-mobile builds.

I don't think we do, but that would still simplify the mobile only code :)

@kimishpatel
Copy link
Contributor Author

kimishpatel commented Jun 8, 2021

Also from the discussion it seems that we dont want paralllelization on non-mobile builds.

I don't think we do, but that would still simplify the mobile only code :)

So I can make this mobile only. But what would still simplify the mobile only code? As I said doing this at lower level is not gonna be perf win as much.

Parallelize batch matmul across batch dim. This was found to improve perf for
some usecases on mobile.

Differential Revision: [D26833417](https://our.internmc.facebook.com/intern/diff/D26833417/)

[ghstack-poisoned]
kimishpatel added a commit that referenced this pull request Jun 8, 2021
Pull Request resolved: #59596


Parallelize batch matmul across batch dim. This was found to improve perf for
some usecases on mobile.
ghstack-source-id: 130899667

Differential Revision: [D26833417](https://our.internmc.facebook.com/intern/diff/D26833417/)
Parallelize batch matmul across batch dim. This was found to improve perf for
some usecases on mobile.

Benchmarking results:
BMM benchmark via operator benchmark on Samsung S8 US phone:



B | M | N | K |   | No Threading | Threads=4
-- | -- | -- | -- | -- | -- | --
2 | 8 | 256 | 16 | 8448 | 0.062 | 0.095
2 | 8 | 256 | 32 | 16896 | 0.077 | 0.099
2 | 8 | 16 | 16 | 768 | 0.048 | 0.07
2 | 8 | 16 | 32 | 1536 | 0.048 | 0.075
2 | 256 | 256 | 16 | 16384 | 0.476 | 0.27
2 | 256 | 256 | 32 | 32768 | 0.79 | 0.477
2 | 256 | 16 | 16 | 8704 | 0.08 | 0.101
2 | 256 | 16 | 32 | 17408 | 0.11 | 0.134
100 | 8 | 256 | 16 | 422400 | 2.269 | 0.6945
100 | 8 | 256 | 32 | 844800 | 3.111 | 0.92
100 | 8 | 16 | 16 | 38400 | 0.936 | 0.425
100 | 8 | 16 | 32 | 76800 | 0.987 | 0.45
100 | 256 | 256 | 16 | 819200 | 27.601 | 10.7545
100 | 256 | 256 | 32 | 1638400 | 43.088 | 13.756
100 | 256 | 16 | 16 | 435200 | 2.734 | 0.802
100 | 256 | 16 | 32 | 870400 | 3.9125 | 1.174
  |   |   |   |   |   |  
4 | 32 | 32 | 16 | 4096 | 0.082 | 0.08
4 | 32 | 32 | 32 | 8192 | 0.091 | 0.082
4 | 32 | 64 | 16 | 6144 | 0.092 | 0.101
4 | 32 | 64 | 32 | 12288 | 0.114 | 0.093
4 | 64 | 32 | 16 | 6144 | 0.0985 | 0.0845
4 | 64 | 32 | 32 | 12288 | 0.116 | 0.101
4 | 64 | 64 | 16 | 8192 | 0.119 | 0.093
4 | 64 | 64 | 32 | 16384 | 0.161 | 0.127
8 | 32 | 32 | 16 | 8192 | 0.129 | 0.1
8 | 32 | 32 | 32 | 16384 | 0.153 | 0.103
8 | 32 | 64 | 16 | 12288 | 0.155 | 0.117
8 | 32 | 64 | 32 | 24576 | 0.199 | 0.1465
8 | 64 | 32 | 16 | 12288 | 0.159 | 0.116
8 | 64 | 32 | 32 | 24576 | 0.204 | 0.138
8 | 64 | 64 | 16 | 16384 | 0.209 | 0.156
8 | 64 | 64 | 32 | 32768 | 0.292 | 0.256








Differential Revision: [D26833417](https://our.internmc.facebook.com/intern/diff/D26833417/)

[ghstack-poisoned]
kimishpatel added a commit that referenced this pull request Jun 9, 2021
Pull Request resolved: #59596


Parallelize batch matmul across batch dim. This was found to improve perf for
some usecases on mobile.
ghstack-source-id: 130946745

Differential Revision: [D26833417](https://our.internmc.facebook.com/intern/diff/D26833417/)
Copy link
Collaborator

@albanD albanD left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the update.
I think we can remove a bit of code duplication but the overall idea (except the &=) looks ok to me.


@onlyCPU
@torch.inference_mode()
def test_bmm_multithreaded(self, device):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Out of curiosity, how long does this test takes to run?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I didnt quite check that. Can report back once I do.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If it takes more than 15s, we should consider marking it as a slow test.

// bmm_test: operator benchmark under
// benchmarks/operator_benchmarks/pt/bmm_test.py Ran this benchmark for
// various matrix sizes on Samsung S8U
enable_multithreaded_bmm =
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You want to use &= here no?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh yes. Thanks for the catch.


@onlyCPU
@torch.inference_mode()
def test_bmm_multithreaded(self, device):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This file is actually only about torch.linalg.* functions. Could you move this test to test_torch.py?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why does it make sense for it to be in test_torch?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Everything that has no better place goes there. In particular testing for all operators that doesn't fit in any better category goes there.

bool enable_multithreaded_bmm{false};
#ifdef C10_MOBILE
/*
* Inference mode multithreading is done because various thread local
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: What you mean here is "We only do multithreading when Inference Mode is enabled"? That could be clarified a bit I think.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not sure what you mean. It already clarifies why it is enabled in inference mode. Or do you mean to capture the fact that having autograd engine running in multi-threading can be an issue?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It is more that "Inference mode multithreading is done because" is not clear to me at first read. I would replace it with "We only do multithreading when Inference Mode is enabled because".

&& self_or_result.is_contiguous()) {
at::native::_baddbmm_mkl_(self_or_result, batch1, batch2, beta, alpha);
} else { // split along batch dimension
bool enable_multithreaded_bmm{false};
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you move this into the #else of the macro and make it const?
That would allow the compiler to completely prune the code in the conditions below: https://godbolt.org/z/8894Mdxd1

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh interesting. I did not realize that without const compiler would not 'const propagate'. Seems strange. But sure I can make the change.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If you do "-Os" then it simplifies a lot. Nonetheless I will make the change.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It does with optimization enabled. But not for debug versions from what I see.

Parallelize batch matmul across batch dim. This was found to improve perf for
some usecases on mobile.

Benchmarking results:
BMM benchmark via operator benchmark on Samsung S8 US phone:



B | M | N | K |   | No Threading | Threads=4
-- | -- | -- | -- | -- | -- | --
2 | 8 | 256 | 16 | 8448 | 0.062 | 0.095
2 | 8 | 256 | 32 | 16896 | 0.077 | 0.099
2 | 8 | 16 | 16 | 768 | 0.048 | 0.07
2 | 8 | 16 | 32 | 1536 | 0.048 | 0.075
2 | 256 | 256 | 16 | 16384 | 0.476 | 0.27
2 | 256 | 256 | 32 | 32768 | 0.79 | 0.477
2 | 256 | 16 | 16 | 8704 | 0.08 | 0.101
2 | 256 | 16 | 32 | 17408 | 0.11 | 0.134
100 | 8 | 256 | 16 | 422400 | 2.269 | 0.6945
100 | 8 | 256 | 32 | 844800 | 3.111 | 0.92
100 | 8 | 16 | 16 | 38400 | 0.936 | 0.425
100 | 8 | 16 | 32 | 76800 | 0.987 | 0.45
100 | 256 | 256 | 16 | 819200 | 27.601 | 10.7545
100 | 256 | 256 | 32 | 1638400 | 43.088 | 13.756
100 | 256 | 16 | 16 | 435200 | 2.734 | 0.802
100 | 256 | 16 | 32 | 870400 | 3.9125 | 1.174
  |   |   |   |   |   |  
4 | 32 | 32 | 16 | 4096 | 0.082 | 0.08
4 | 32 | 32 | 32 | 8192 | 0.091 | 0.082
4 | 32 | 64 | 16 | 6144 | 0.092 | 0.101
4 | 32 | 64 | 32 | 12288 | 0.114 | 0.093
4 | 64 | 32 | 16 | 6144 | 0.0985 | 0.0845
4 | 64 | 32 | 32 | 12288 | 0.116 | 0.101
4 | 64 | 64 | 16 | 8192 | 0.119 | 0.093
4 | 64 | 64 | 32 | 16384 | 0.161 | 0.127
8 | 32 | 32 | 16 | 8192 | 0.129 | 0.1
8 | 32 | 32 | 32 | 16384 | 0.153 | 0.103
8 | 32 | 64 | 16 | 12288 | 0.155 | 0.117
8 | 32 | 64 | 32 | 24576 | 0.199 | 0.1465
8 | 64 | 32 | 16 | 12288 | 0.159 | 0.116
8 | 64 | 32 | 32 | 24576 | 0.204 | 0.138
8 | 64 | 64 | 16 | 16384 | 0.209 | 0.156
8 | 64 | 64 | 32 | 32768 | 0.292 | 0.256








Differential Revision: [D26833417](https://our.internmc.facebook.com/intern/diff/D26833417/)

[ghstack-poisoned]
Copy link
Collaborator

@albanD albanD left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

We might want to add the slowtest on top of that if it takes too long to run though.

Parallelize batch matmul across batch dim. This was found to improve perf for
some usecases on mobile.

Benchmarking results:
BMM benchmark via operator benchmark on Samsung S8 US phone:



B | M | N | K |   | No Threading | Threads=4
-- | -- | -- | -- | -- | -- | --
2 | 8 | 256 | 16 | 8448 | 0.062 | 0.095
2 | 8 | 256 | 32 | 16896 | 0.077 | 0.099
2 | 8 | 16 | 16 | 768 | 0.048 | 0.07
2 | 8 | 16 | 32 | 1536 | 0.048 | 0.075
2 | 256 | 256 | 16 | 16384 | 0.476 | 0.27
2 | 256 | 256 | 32 | 32768 | 0.79 | 0.477
2 | 256 | 16 | 16 | 8704 | 0.08 | 0.101
2 | 256 | 16 | 32 | 17408 | 0.11 | 0.134
100 | 8 | 256 | 16 | 422400 | 2.269 | 0.6945
100 | 8 | 256 | 32 | 844800 | 3.111 | 0.92
100 | 8 | 16 | 16 | 38400 | 0.936 | 0.425
100 | 8 | 16 | 32 | 76800 | 0.987 | 0.45
100 | 256 | 256 | 16 | 819200 | 27.601 | 10.7545
100 | 256 | 256 | 32 | 1638400 | 43.088 | 13.756
100 | 256 | 16 | 16 | 435200 | 2.734 | 0.802
100 | 256 | 16 | 32 | 870400 | 3.9125 | 1.174
  |   |   |   |   |   |  
4 | 32 | 32 | 16 | 4096 | 0.082 | 0.08
4 | 32 | 32 | 32 | 8192 | 0.091 | 0.082
4 | 32 | 64 | 16 | 6144 | 0.092 | 0.101
4 | 32 | 64 | 32 | 12288 | 0.114 | 0.093
4 | 64 | 32 | 16 | 6144 | 0.0985 | 0.0845
4 | 64 | 32 | 32 | 12288 | 0.116 | 0.101
4 | 64 | 64 | 16 | 8192 | 0.119 | 0.093
4 | 64 | 64 | 32 | 16384 | 0.161 | 0.127
8 | 32 | 32 | 16 | 8192 | 0.129 | 0.1
8 | 32 | 32 | 32 | 16384 | 0.153 | 0.103
8 | 32 | 64 | 16 | 12288 | 0.155 | 0.117
8 | 32 | 64 | 32 | 24576 | 0.199 | 0.1465
8 | 64 | 32 | 16 | 12288 | 0.159 | 0.116
8 | 64 | 32 | 32 | 24576 | 0.204 | 0.138
8 | 64 | 64 | 16 | 16384 | 0.209 | 0.156
8 | 64 | 64 | 32 | 32768 | 0.292 | 0.256








Differential Revision: [D26833417](https://our.internmc.facebook.com/intern/diff/D26833417/)

[ghstack-poisoned]
kimishpatel added a commit that referenced this pull request Jun 9, 2021
Pull Request resolved: #59596


Parallelize batch matmul across batch dim. This was found to improve perf for
some usecases on mobile.
ghstack-source-id: 130966058

Differential Revision: [D26833417](https://our.internmc.facebook.com/intern/diff/D26833417/)
Parallelize batch matmul across batch dim. This was found to improve perf for
some usecases on mobile.

Benchmarking results:
BMM benchmark via operator benchmark on Samsung S8 US phone:



B | M | N | K |   | No Threading | Threads=4
-- | -- | -- | -- | -- | -- | --
2 | 8 | 256 | 16 | 8448 | 0.062 | 0.095
2 | 8 | 256 | 32 | 16896 | 0.077 | 0.099
2 | 8 | 16 | 16 | 768 | 0.048 | 0.07
2 | 8 | 16 | 32 | 1536 | 0.048 | 0.075
2 | 256 | 256 | 16 | 16384 | 0.476 | 0.27
2 | 256 | 256 | 32 | 32768 | 0.79 | 0.477
2 | 256 | 16 | 16 | 8704 | 0.08 | 0.101
2 | 256 | 16 | 32 | 17408 | 0.11 | 0.134
100 | 8 | 256 | 16 | 422400 | 2.269 | 0.6945
100 | 8 | 256 | 32 | 844800 | 3.111 | 0.92
100 | 8 | 16 | 16 | 38400 | 0.936 | 0.425
100 | 8 | 16 | 32 | 76800 | 0.987 | 0.45
100 | 256 | 256 | 16 | 819200 | 27.601 | 10.7545
100 | 256 | 256 | 32 | 1638400 | 43.088 | 13.756
100 | 256 | 16 | 16 | 435200 | 2.734 | 0.802
100 | 256 | 16 | 32 | 870400 | 3.9125 | 1.174
  |   |   |   |   |   |  
4 | 32 | 32 | 16 | 4096 | 0.082 | 0.08
4 | 32 | 32 | 32 | 8192 | 0.091 | 0.082
4 | 32 | 64 | 16 | 6144 | 0.092 | 0.101
4 | 32 | 64 | 32 | 12288 | 0.114 | 0.093
4 | 64 | 32 | 16 | 6144 | 0.0985 | 0.0845
4 | 64 | 32 | 32 | 12288 | 0.116 | 0.101
4 | 64 | 64 | 16 | 8192 | 0.119 | 0.093
4 | 64 | 64 | 32 | 16384 | 0.161 | 0.127
8 | 32 | 32 | 16 | 8192 | 0.129 | 0.1
8 | 32 | 32 | 32 | 16384 | 0.153 | 0.103
8 | 32 | 64 | 16 | 12288 | 0.155 | 0.117
8 | 32 | 64 | 32 | 24576 | 0.199 | 0.1465
8 | 64 | 32 | 16 | 12288 | 0.159 | 0.116
8 | 64 | 32 | 32 | 24576 | 0.204 | 0.138
8 | 64 | 64 | 16 | 16384 | 0.209 | 0.156
8 | 64 | 64 | 32 | 32768 | 0.292 | 0.256








Differential Revision: [D26833417](https://our.internmc.facebook.com/intern/diff/D26833417/)

[ghstack-poisoned]
kimishpatel added a commit that referenced this pull request Jun 9, 2021
Pull Request resolved: #59596


Parallelize batch matmul across batch dim. This was found to improve perf for
some usecases on mobile.
ghstack-source-id: 130989569

Differential Revision: [D26833417](https://our.internmc.facebook.com/intern/diff/D26833417/)
@facebook-github-bot
Copy link
Contributor

This pull request has been merged in 4f79270.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants