Skip to content

Migrate addmm, addbmm and THBlas_gemm to ATen#40927

Closed
peterbell10 wants to merge 5 commits intopytorch:masterfrom
peterbell10:addmm-aten
Closed

Migrate addmm, addbmm and THBlas_gemm to ATen#40927
peterbell10 wants to merge 5 commits intopytorch:masterfrom
peterbell10:addmm-aten

Conversation

@peterbell10
Copy link
Copy Markdown
Collaborator

Closes #24679, closes #24678

addbmm depends on addmm so needed to be ported at the same time. I also removed THTensor_(baddbmm) which I noticed had already been ported so was just dead code.

After having already written this code, I had to fix merge conflicts with #40354 which revealed there was already an established place for cpu blas routines in ATen. However, the version there doesn't make use of ATen's AVX dispatching so thought I'd wait for comment before migrating this into that style.

@dr-ci
Copy link
Copy Markdown

dr-ci Bot commented Jul 2, 2020

💊 CI failures summary and remediations

As of commit 0f8a3f0 (more details on the Dr. CI page):



❄️ 1 failure tentatively classified as flaky

but reruns have not yet been triggered to confirm:

See CircleCI build binary_linux_libtorch_3_7m_cpu_gcc5_4_cxx11-abi_shared-with-deps_build (1/1)

Step: "Install unbuffer and ts" (full log | diagnosis details | 🔁 rerun) ❄️

E: Failed to fetch https://developer.download.nvidia.com/compute/cuda/repos/ubuntu1604/x86_64/Packages Writing more data than expected (1196016 > 1190062)
0% [Working] 0% [8 InRelease gpgv 247 kB] [Waiting for headers]                                                    Ign:9 https://developer.download.nvidia.com/compute/cuda/repos/ubuntu1604/x86_64  Packages 
 0% [8 InRelease gpgv 247 kB] [Waiting for headers]                                                    Get:9 https://developer.download.nvidia.com/compute/cuda/repos/ubuntu1604/x86_64  Packages [276 kB] 
 0% [8 InRelease gpgv 247 kB] [Waiting for headers] [9 Packages 0 B/298 kB 0%]                                                                               Ign:9 https://developer.download.nvidia.com/compute/cuda/repos/ubuntu1604/x86_64  Packages 
                                                                              0% [8 InRelease gpgv 247 kB] [Waiting for headers]                                                    Hit:10 http://archive.ubuntu.com/ubuntu xenial-updates InRelease 
0% [Waiting for headers] 0% [10 InRelease gpgv 109 kB] [Waiting for headers]                                                     Get:9 https://developer.download.nvidia.com/compute/cuda/repos/ubuntu1604/x86_64  Packages [1,190 kB] 
 0% [10 InRelease gpgv 109 kB] [Waiting for headers] [9 Packages 16.4 kB/1,262 k                                                                                 Err:9 https://developer.download.nvidia.com/compute/cuda/repos/ubuntu1604/x86_64  Packages 
  Writing more data than expected (1196016 > 1190062) 
                                                    0% [Waiting for headers]                          Hit:11 http://archive.ubuntu.com/ubuntu xenial-backports InRelease 
                              20% [Working]               Fetched 836 B in 0s (1,354 B/s) 
Reading package lists... 99%  Reading package lists... Done  
E: Failed to fetch https://developer.download.nvidia.com/compute/cuda/repos/ubuntu1604/x86_64/Packages  Writing more data than expected (1196016 > 1190062) 
E: Some index files failed to download. They have been ignored, or old ones used instead. 

Extra GitHub checks: 1 failed


This comment was automatically generated by Dr. CI (expand for details).Follow this link to opt-out of these comments for your Pull Requests.

Please report bugs/suggestions on the GitHub issue tracker or post in the (internal) Dr. CI Users group.

See how this bot performed.

This comment has been revised 15 times.

@vadimkantorov
Copy link
Copy Markdown
Contributor

vadimkantorov commented Jul 3, 2020

Does addmm support 3d (4d, 5d,...) input tensors? The usecase is #39661

@peterbell10
Copy link
Copy Markdown
Collaborator Author

addmm will raise if the inputs are not 2d tensors, the same as the TH version. Isn't 3d (batched) gemm what torch.baddbmm does already?

@vadimkantorov
Copy link
Copy Markdown
Contributor

vadimkantorov commented Jul 4, 2020

I guess addbmm requires both input and the matrix have the same batch dimension which breaks that usecase. In that usecase, input tensor should have arbitrary batch dimension, and the transform should be just a matrix.

@vadimkantorov
Copy link
Copy Markdown
Contributor

Sorry, I got confused. I see that you are referring to baddbmm. I need to check if broadcasting works well for this case.

But for >3d shaped tensors, F.linear would have to still to reshape input and output. Basically, we don't have for now a simple [B1xB2x...xC] @ [CxZ] + [Z] fused matmul+bias method. Matmul works with arbitrary shapes, but doesn't fuse with bias. baddbmm fuses with bias, but only works for 3d tensors.

@ezyang
Copy link
Copy Markdown
Contributor

ezyang commented Jul 6, 2020

I suggest we leave this enhancement for a follow up PR.

@ngimel ngimel self-requested a review July 7, 2020 00:27
@ngimel ngimel added the triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module label Jul 7, 2020
Comment thread test/test_torch.py

# TODO: update this once torch.addmm is supported for complex
if dtype.is_complex:
if dtype.is_complex and device != 'cpu':
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@@ -0,0 +1,250 @@
#include <ATen/native/CPUBlas.h>
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Cross-reference with aten/src/TH/generic/THBlas.cpp

@ezyang
Copy link
Copy Markdown
Contributor

ezyang commented Jul 7, 2020

It looks we got a little extra feature bonus, which is that complex gemm now works on CPU

namespace cpublas {
namespace {

void normalize_last_dims(
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

New helper function refactored out of THBlas_(gemm) and related functions

}
}

bool use_blas_gemm(
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

switch (trans) {
case Transpose: return 't';
case NoTranspose: return 'n';
// case ConjTranspose: return 'c';
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

cc @anjali411 we're probably going to want to expose this at some point

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yeah good point! noted


DEFINE_DISPATCH(gemm_stub);

void gemm(
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Comment thread aten/src/ATen/native/CPUBlas.cpp Outdated
Comment thread aten/src/ATen/native/CPUBlas.cpp Outdated
@@ -0,0 +1,197 @@
#include <ATen/Dispatch.h>
#include <ATen/native/CPUBlas.h>
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Are you sure you actually want to vectorize these fallbacks? They're so naive I'm not sure they're worth the binary size to compile them with AVX/etc.

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do you know if we have any test coverage for this code?

Copy link
Copy Markdown
Collaborator Author

@peterbell10 peterbell10 Jul 7, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

BFloat16, Half and ints < 64 all unconditionally exercise this code path. Only BFloat16 and Half seem to actually have test coverage though.

Are you sure you actually want to vectorize these fallbacks? They're so naive I'm not sure they're worth the binary size to compile them with AVX/etc.

There's certainly a lot of room for improvement but at the very least, I think a.T @ b and a @ b.T should vectorize very well and be reasonably efficient.



template <typename scalar_t>
void gemm_notrans_(
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

namespace {

template <typename scalar_t>
void scale_(int64_t m, int64_t n, scalar_t alpha, scalar_t *a, int64_t lda) {
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The renaming of beta to alpha here is confusing (no action necessary)

// c *= beta
scale_(m, n, beta, c, ldc);

// c += alpha * (a @ b.T)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These comments are great, thanks!

Comment thread aten/src/ATen/native/LinearAlgebra.cpp Outdated
Tensor &result, const Tensor &self, Tensor m1, Tensor m2, Scalar beta, Scalar alpha) {
TORCH_CHECK(self.dim() == 2, "input must be a matrix");
TORCH_CHECK(m1.dim() == 2, "m1 must be a matrix");
TORCH_CHECK(m2.dim() == 2, "m2 must be a matrix");
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is a slight pessimization over the old error checking code, which would tell you what the dimension of input/m1/m2 were (btw, we should use names consistent with the Python documentation, which are input, mat1 and mat2

Comment thread aten/src/ATen/native/LinearAlgebra.cpp Outdated
TORCH_CHECK(
self.size(0) == m1.size(0) && self.size(1) == m2.size(1),
"input shape is incompatible with matrix multiplication (",
m1.size(0), "x", m1.size(1), " and ", m2.size(0), "x", m2.size(1), ")");
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is bad. You need to report self size.

@facebook-github-bot
Copy link
Copy Markdown
Contributor

@ezyang merged this pull request in 6725c03.

1 similar comment
@facebook-github-bot
Copy link
Copy Markdown
Contributor

@ezyang merged this pull request in 6725c03.

@ezyang
Copy link
Copy Markdown
Contributor

ezyang commented Jul 9, 2020

Diff appears to have regressed performance in prod, unlanding. I'm asking the reporter for more information.

@ezyang
Copy link
Copy Markdown
Contributor

ezyang commented Jul 9, 2020

@peterbell10 In the mean time, if you could run some before and after benchmarks on these functions, that may also be helpful in pinning down the regression.

@peterbell10
Copy link
Copy Markdown
Collaborator Author

I've run the operator benchmarks for addmm and there is no obvious regression. Perhaps a us here and there, hard to discern from the noise floor. However, that benchmark doesn't really cover all the edge cases so I could be missing it.

Some information that would be useful if possible:

  • How big of a slow down are we talking? Maybe BLAS isn't getting called (disaster), or is it just a small overhead introduced that's unacceptable for such a common operator.
  • Are they using addmm, addbmm or matmul?
  • Which dtype, is it one of the BLAS accelerated ones or not?
  • What shape are their tensors, and are they contiguous?

@ngimel
Copy link
Copy Markdown
Collaborator

ngimel commented Jul 9, 2020

@peterbell10 we've recently merged a new benchmarking utility that would allow you to generate random inputs for your functions to get a better coverage of performance. PR #38338 comes with example of benchmarking "before" and "after" builds by creating separate environments, and other comprehensive examples, please take a look.

@ezyang
Copy link
Copy Markdown
Contributor

ezyang commented Jul 9, 2020

Maybe BLAS isn't getting called (disaster), or is it just a small overhead introduced that's unacceptable for such a common operator.

Their profile is a little hard to read, but it looks like a 3x end-to-end slowdown. Should be really obvious.

@ezyang
Copy link
Copy Markdown
Contributor

ezyang commented Jul 9, 2020

Here are the top five mm/matmul shape sizes on the relevant benchmark:

 mm | [[256,2048],[2048,512]]
 mm | [[256,512],[512,2048]]
 mm | [[256,512],[512,512]]
 mm | [[1,540],[540,1024]]
 matmul | [[256,1,512],[512,2048]]

@peterbell10
Copy link
Copy Markdown
Collaborator Author

peterbell10 commented Jul 9, 2020

I've now tried fuzzing the tensor shapes, as well as using the exact shapes given. Neither show any performance regression on my system. I even tried all the combinations of C and Fortran-contiguous inputs, neither of which made any significant difference. Certainly not a 3x performance drop. I also tried with different dtypes and the only thing I see is a slight performance improvement for the non-BLAS cases.

@ngimel
Copy link
Copy Markdown
Collaborator

ngimel commented Jul 9, 2020

Yeah, the performance regressions we are seeing on op bench internally (e.g. for 128x128x128) are disastrous, looks like something goes wrong and we don't use blas.

@ezyang
Copy link
Copy Markdown
Contributor

ezyang commented Jul 10, 2020

@ngimel identified the problem as an fbcode specific problem, and has relanded the diff.

facebook-github-bot pushed a commit that referenced this pull request Jul 10, 2020
Summary:
Resubmit #40927
Closes #24679, closes #24678

`addbmm` depends on `addmm` so needed to be ported at the same time. I also removed `THTensor_(baddbmm)` which I noticed had already been ported so was just dead code.

After having already written this code, I had to fix merge conflicts with #40354 which revealed there was already an established place for cpu blas routines in ATen. However, the version there doesn't make use of ATen's AVX dispatching so thought I'd wait for comment before migrating this into that style.

Pull Request resolved: #40927

Reviewed By: ezyang

Differential Revision: D22468490

Pulled By: ngimel

fbshipit-source-id: f8a22be3216f67629420939455e31a88af20201d
@ngimel
Copy link
Copy Markdown
Collaborator

ngimel commented Jul 13, 2020

@peterbell10, out internal runs of op bench flagged regressions on 1x1x1 matmuls (linear_N1_IN1_OUT1_cpu_Eager and matmul_M1_N1_K1_trans_aTrue_trans_bFalse_cpu_Eager) which means that overhead increased after migration to ATen. Can you please check if you can reproduce these regressions?

facebook-github-bot pushed a commit that referenced this pull request Jul 14, 2020
Summary:
Fixes the overhead reported by ngimel in #40927 (comment)

As it turns out, `Tensor.size(n)` has more overhead than `Tensor.sizes()[n]`. Since addmm does a lot of introspection of the input matrix sizes and strides, this added up to a noticeable (~1 us) constant time overhead.

With this change, a 1x1 matmul takes 2.85 us on my machine compared to 2.90 us on pytorch 1.5.

Pull Request resolved: #41374

Reviewed By: ailzhang

Differential Revision: D22519924

Pulled By: ngimel

fbshipit-source-id: b29504bee7de79ce42e5e50f91523dde42b073b7
facebook-github-bot pushed a commit that referenced this pull request Aug 7, 2020
Summary:
I noticed that `TensorIteratorDynamicCasting.h` defines a helper meta-function `CPPTypeToScalarType` which does exactly the same thing as the `c10::CppTypeToScalarType` meta-function I added in gh-40927. No need for two identical definitions.

Pull Request resolved: #42640

Reviewed By: malfet

Differential Revision: D22969708

Pulled By: ezyang

fbshipit-source-id: 8303c7f4a75ae248f393a4811ae9d2bcacab44ff
laurentdupin pushed a commit to laurentdupin/pytorch that referenced this pull request Apr 24, 2026
Summary:
Closes pytorch#24679, closes pytorch#24678

`addbmm` depends on `addmm` so needed to be ported at the same time. I also removed `THTensor_(baddbmm)` which I noticed had already been ported so was just dead code.

After having already written this code, I had to fix merge conflicts with pytorch#40354 which revealed there was already an established place for cpu blas routines in ATen. However, the version there doesn't make use of ATen's AVX dispatching so thought I'd wait for comment before migrating this into that style.

Pull Request resolved: pytorch#40927

Differential Revision: D22418756

Pulled By: ezyang

fbshipit-source-id: 44e7bb5964263d73ae8cc6adc5f6d4e966476ae6
laurentdupin pushed a commit to laurentdupin/pytorch that referenced this pull request Apr 24, 2026
Summary:
Resubmit pytorch#40927
Closes pytorch#24679, closes pytorch#24678

`addbmm` depends on `addmm` so needed to be ported at the same time. I also removed `THTensor_(baddbmm)` which I noticed had already been ported so was just dead code.

After having already written this code, I had to fix merge conflicts with pytorch#40354 which revealed there was already an established place for cpu blas routines in ATen. However, the version there doesn't make use of ATen's AVX dispatching so thought I'd wait for comment before migrating this into that style.

Pull Request resolved: pytorch#40927

Reviewed By: ezyang

Differential Revision: D22468490

Pulled By: ngimel

fbshipit-source-id: f8a22be3216f67629420939455e31a88af20201d
laurentdupin pushed a commit to laurentdupin/pytorch that referenced this pull request Apr 24, 2026
Summary:
Fixes the overhead reported by ngimel in pytorch#40927 (comment)

As it turns out, `Tensor.size(n)` has more overhead than `Tensor.sizes()[n]`. Since addmm does a lot of introspection of the input matrix sizes and strides, this added up to a noticeable (~1 us) constant time overhead.

With this change, a 1x1 matmul takes 2.85 us on my machine compared to 2.90 us on pytorch 1.5.

Pull Request resolved: pytorch#41374

Reviewed By: ailzhang

Differential Revision: D22519924

Pulled By: ngimel

fbshipit-source-id: b29504bee7de79ce42e5e50f91523dde42b073b7
laurentdupin pushed a commit to laurentdupin/pytorch that referenced this pull request Apr 24, 2026
Summary:
I noticed that `TensorIteratorDynamicCasting.h` defines a helper meta-function `CPPTypeToScalarType` which does exactly the same thing as the `c10::CppTypeToScalarType` meta-function I added in pytorchgh-40927. No need for two identical definitions.

Pull Request resolved: pytorch#42640

Reviewed By: malfet

Differential Revision: D22969708

Pulled By: ezyang

fbshipit-source-id: 8303c7f4a75ae248f393a4811ae9d2bcacab44ff
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

Merged open source triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module

Projects

None yet

Development

Successfully merging this pull request may close these issues.

Migrate addmm and addmm_ from the TH to Aten (CPU) Migrate addbmm and addbmm_ from the TH to Aten (CPU)

7 participants