Migrate addmm, addbmm and THBlas_gemm to ATen#40927
Migrate addmm, addbmm and THBlas_gemm to ATen#40927peterbell10 wants to merge 5 commits intopytorch:masterfrom
Conversation
💊 CI failures summary and remediationsAs of commit 0f8a3f0 (more details on the Dr. CI page):
❄️ 1 failure tentatively classified as flakybut reruns have not yet been triggered to confirm:
|
|
Does |
|
|
|
I guess |
|
Sorry, I got confused. I see that you are referring to But for >3d shaped tensors, F.linear would have to still to reshape input and output. Basically, we don't have for now a simple [B1xB2x...xC] @ [CxZ] + [Z] fused matmul+bias method. Matmul works with arbitrary shapes, but doesn't fuse with bias. baddbmm fuses with bias, but only works for 3d tensors. |
|
I suggest we leave this enhancement for a follow up PR. |
|
|
||
| # TODO: update this once torch.addmm is supported for complex | ||
| if dtype.is_complex: | ||
| if dtype.is_complex and device != 'cpu': |
| @@ -0,0 +1,250 @@ | |||
| #include <ATen/native/CPUBlas.h> | |||
There was a problem hiding this comment.
Cross-reference with aten/src/TH/generic/THBlas.cpp
|
It looks we got a little extra feature bonus, which is that complex gemm now works on CPU |
| namespace cpublas { | ||
| namespace { | ||
|
|
||
| void normalize_last_dims( |
There was a problem hiding this comment.
New helper function refactored out of THBlas_(gemm) and related functions
| } | ||
| } | ||
|
|
||
| bool use_blas_gemm( |
There was a problem hiding this comment.
| switch (trans) { | ||
| case Transpose: return 't'; | ||
| case NoTranspose: return 'n'; | ||
| // case ConjTranspose: return 'c'; |
There was a problem hiding this comment.
cc @anjali411 we're probably going to want to expose this at some point
|
|
||
| DEFINE_DISPATCH(gemm_stub); | ||
|
|
||
| void gemm( |
There was a problem hiding this comment.
| @@ -0,0 +1,197 @@ | |||
| #include <ATen/Dispatch.h> | |||
| #include <ATen/native/CPUBlas.h> | |||
There was a problem hiding this comment.
Are you sure you actually want to vectorize these fallbacks? They're so naive I'm not sure they're worth the binary size to compile them with AVX/etc.
There was a problem hiding this comment.
Do you know if we have any test coverage for this code?
There was a problem hiding this comment.
BFloat16, Half and ints < 64 all unconditionally exercise this code path. Only BFloat16 and Half seem to actually have test coverage though.
Are you sure you actually want to vectorize these fallbacks? They're so naive I'm not sure they're worth the binary size to compile them with AVX/etc.
There's certainly a lot of room for improvement but at the very least, I think a.T @ b and a @ b.T should vectorize very well and be reasonably efficient.
|
|
||
|
|
||
| template <typename scalar_t> | ||
| void gemm_notrans_( |
There was a problem hiding this comment.
| namespace { | ||
|
|
||
| template <typename scalar_t> | ||
| void scale_(int64_t m, int64_t n, scalar_t alpha, scalar_t *a, int64_t lda) { |
There was a problem hiding this comment.
The renaming of beta to alpha here is confusing (no action necessary)
| // c *= beta | ||
| scale_(m, n, beta, c, ldc); | ||
|
|
||
| // c += alpha * (a @ b.T) |
There was a problem hiding this comment.
These comments are great, thanks!
| Tensor &result, const Tensor &self, Tensor m1, Tensor m2, Scalar beta, Scalar alpha) { | ||
| TORCH_CHECK(self.dim() == 2, "input must be a matrix"); | ||
| TORCH_CHECK(m1.dim() == 2, "m1 must be a matrix"); | ||
| TORCH_CHECK(m2.dim() == 2, "m2 must be a matrix"); |
There was a problem hiding this comment.
This is a slight pessimization over the old error checking code, which would tell you what the dimension of input/m1/m2 were (btw, we should use names consistent with the Python documentation, which are input, mat1 and mat2
| TORCH_CHECK( | ||
| self.size(0) == m1.size(0) && self.size(1) == m2.size(1), | ||
| "input shape is incompatible with matrix multiplication (", | ||
| m1.size(0), "x", m1.size(1), " and ", m2.size(0), "x", m2.size(1), ")"); |
There was a problem hiding this comment.
This is bad. You need to report self size.
1 similar comment
|
Diff appears to have regressed performance in prod, unlanding. I'm asking the reporter for more information. |
|
@peterbell10 In the mean time, if you could run some before and after benchmarks on these functions, that may also be helpful in pinning down the regression. |
|
I've run the operator benchmarks for Some information that would be useful if possible:
|
|
@peterbell10 we've recently merged a new benchmarking utility that would allow you to generate random inputs for your functions to get a better coverage of performance. PR #38338 comes with example of benchmarking "before" and "after" builds by creating separate environments, and other comprehensive examples, please take a look. |
Their profile is a little hard to read, but it looks like a 3x end-to-end slowdown. Should be really obvious. |
|
Here are the top five mm/matmul shape sizes on the relevant benchmark: |
|
I've now tried fuzzing the tensor shapes, as well as using the exact shapes given. Neither show any performance regression on my system. I even tried all the combinations of C and Fortran-contiguous inputs, neither of which made any significant difference. Certainly not a 3x performance drop. I also tried with different dtypes and the only thing I see is a slight performance improvement for the non-BLAS cases. |
|
Yeah, the performance regressions we are seeing on op bench internally (e.g. for 128x128x128) are disastrous, looks like something goes wrong and we don't use blas. |
|
@ngimel identified the problem as an fbcode specific problem, and has relanded the diff. |
Summary: Resubmit #40927 Closes #24679, closes #24678 `addbmm` depends on `addmm` so needed to be ported at the same time. I also removed `THTensor_(baddbmm)` which I noticed had already been ported so was just dead code. After having already written this code, I had to fix merge conflicts with #40354 which revealed there was already an established place for cpu blas routines in ATen. However, the version there doesn't make use of ATen's AVX dispatching so thought I'd wait for comment before migrating this into that style. Pull Request resolved: #40927 Reviewed By: ezyang Differential Revision: D22468490 Pulled By: ngimel fbshipit-source-id: f8a22be3216f67629420939455e31a88af20201d
|
@peterbell10, out internal runs of op bench flagged regressions on 1x1x1 matmuls (linear_N1_IN1_OUT1_cpu_Eager and matmul_M1_N1_K1_trans_aTrue_trans_bFalse_cpu_Eager) which means that overhead increased after migration to ATen. Can you please check if you can reproduce these regressions? |
Summary: Fixes the overhead reported by ngimel in #40927 (comment) As it turns out, `Tensor.size(n)` has more overhead than `Tensor.sizes()[n]`. Since addmm does a lot of introspection of the input matrix sizes and strides, this added up to a noticeable (~1 us) constant time overhead. With this change, a 1x1 matmul takes 2.85 us on my machine compared to 2.90 us on pytorch 1.5. Pull Request resolved: #41374 Reviewed By: ailzhang Differential Revision: D22519924 Pulled By: ngimel fbshipit-source-id: b29504bee7de79ce42e5e50f91523dde42b073b7
Summary: I noticed that `TensorIteratorDynamicCasting.h` defines a helper meta-function `CPPTypeToScalarType` which does exactly the same thing as the `c10::CppTypeToScalarType` meta-function I added in gh-40927. No need for two identical definitions. Pull Request resolved: #42640 Reviewed By: malfet Differential Revision: D22969708 Pulled By: ezyang fbshipit-source-id: 8303c7f4a75ae248f393a4811ae9d2bcacab44ff
Summary: Closes pytorch#24679, closes pytorch#24678 `addbmm` depends on `addmm` so needed to be ported at the same time. I also removed `THTensor_(baddbmm)` which I noticed had already been ported so was just dead code. After having already written this code, I had to fix merge conflicts with pytorch#40354 which revealed there was already an established place for cpu blas routines in ATen. However, the version there doesn't make use of ATen's AVX dispatching so thought I'd wait for comment before migrating this into that style. Pull Request resolved: pytorch#40927 Differential Revision: D22418756 Pulled By: ezyang fbshipit-source-id: 44e7bb5964263d73ae8cc6adc5f6d4e966476ae6
Summary: Resubmit pytorch#40927 Closes pytorch#24679, closes pytorch#24678 `addbmm` depends on `addmm` so needed to be ported at the same time. I also removed `THTensor_(baddbmm)` which I noticed had already been ported so was just dead code. After having already written this code, I had to fix merge conflicts with pytorch#40354 which revealed there was already an established place for cpu blas routines in ATen. However, the version there doesn't make use of ATen's AVX dispatching so thought I'd wait for comment before migrating this into that style. Pull Request resolved: pytorch#40927 Reviewed By: ezyang Differential Revision: D22468490 Pulled By: ngimel fbshipit-source-id: f8a22be3216f67629420939455e31a88af20201d
Summary: Fixes the overhead reported by ngimel in pytorch#40927 (comment) As it turns out, `Tensor.size(n)` has more overhead than `Tensor.sizes()[n]`. Since addmm does a lot of introspection of the input matrix sizes and strides, this added up to a noticeable (~1 us) constant time overhead. With this change, a 1x1 matmul takes 2.85 us on my machine compared to 2.90 us on pytorch 1.5. Pull Request resolved: pytorch#41374 Reviewed By: ailzhang Differential Revision: D22519924 Pulled By: ngimel fbshipit-source-id: b29504bee7de79ce42e5e50f91523dde42b073b7
Summary: I noticed that `TensorIteratorDynamicCasting.h` defines a helper meta-function `CPPTypeToScalarType` which does exactly the same thing as the `c10::CppTypeToScalarType` meta-function I added in pytorchgh-40927. No need for two identical definitions. Pull Request resolved: pytorch#42640 Reviewed By: malfet Differential Revision: D22969708 Pulled By: ezyang fbshipit-source-id: 8303c7f4a75ae248f393a4811ae9d2bcacab44ff
Closes #24679, closes #24678
addbmmdepends onaddmmso needed to be ported at the same time. I also removedTHTensor_(baddbmm)which I noticed had already been ported so was just dead code.After having already written this code, I had to fix merge conflicts with #40354 which revealed there was already an established place for cpu blas routines in ATen. However, the version there doesn't make use of ATen's AVX dispatching so thought I'd wait for comment before migrating this into that style.