Conversation
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/172945
Note: Links to docs will display an error until the docs builds have been completed. ✅ No FailuresAs of commit 8dd24fd with merge base 24e0e50 ( This comment was automatically generated by Dr. CI and updates every 15 minutes. |
|
Unknown label
|
|
@pytorchbot label "ciflow/linux-aarch64" |
|
@pytorchbot rebase |
|
@pytorchbot started a rebase job onto refs/remotes/origin/viable/strict. Check the current status here |
|
Rebase failed due to Command Raised by https://github.com/pytorch/pytorch/actions/runs/21246174215 |
c7cd14c to
9bd68e7
Compare
9bd68e7 to
95f34c7
Compare
|
@pytorchbot label "ciflow/linux-aarch64" |
There was a problem hiding this comment.
Great work thank you!
I added a few minor comments.
Could you please also make it explicit in the PR description that you update OpenBLAS version and that the new version contains BGEMM kernels and how they're different from SBGEMM etc.
Could you also attach the benchmark script you ran and the speedups achieved with this PR?
cmake/Modules/FindBLAS.cmake
Outdated
| ENDIF(BLAS_HAS_SBGEMM) | ||
| set(CMAKE_REQUIRED_LIBRARIES) |
There was a problem hiding this comment.
NIT: why does this need to be changed
| bool tf32_usable = std::is_same_v<scalar_t, float> && use_mkldnn_tf32_matmul(); | ||
| if ( !(bf16_usable || fp16_usable || bf32_usable || tf32_usable) || | ||
| if (bf16_usable) { | ||
| // New BF16-only heuristic |
There was a problem hiding this comment.
NIT: let's have a better comment, maybe something on the lines of: "for these cases oneDNN is better than OpenBLAS"
| @@ -20,7 +19,5 @@ CFLAGS=-O3 | |||
| BUILD_BFLOAT16=1 | |||
| " | |||
|
|
|||
| make -j8 ${OPENBLAS_BUILD_FLAGS} -C $OPENBLAS_CHECKOUT_DIR | |||
| sudo make install -C $OPENBLAS_CHECKOUT_DIR | |||
|
|
|||
| rm -rf $OPENBLAS_CHECKOUT_DIR No newline at end of file | |||
| make -j8 ${OPENBLAS_BUILD_FLAGS} -C ${OPENBLAS_CHECKOUT_DIR} | |||
| make -j8 ${OPENBLAS_BUILD_FLAGS} install -C ${OPENBLAS_CHECKOUT_DIR} No newline at end of file | |||
There was a problem hiding this comment.
apart from the OpenBLAS version update, why are we modifying this?
aten/src/ATen/native/CPUBlas.cpp
Outdated
| } | ||
| #endif | ||
| #if AT_BUILD_WITH_BLAS() && defined(BLAS_HAS_SBGEMM) | ||
| #if AT_BUILD_WITH_BLAS() && (defined(BLAS_HAS_SBGEMM) || defined(BLAS_HAS_BGEMM)) |
There was a problem hiding this comment.
Is the || defined(BLAS_HAS_BGEMM) redundant here?
will you ever have BLAS_HAS_BGEMM without BLAS_HAS_SBGEMM
There was a problem hiding this comment.
Thanks for the suggestion. I kept both BLAS_HAS_SBGEMM and BLAS_HAS_BGEMM checks just to stay safe across OpenBLAS versions, since some older builds may only expose SBGEMM while newer ones define BGEMM.
I’m happy to simplify it if we think SBGEMM will always be there going forward , do you think it’s safe to rely on SBGEMM-only in future versions?
aten/src/ATen/native/CPUBlas.cpp
Outdated
| c[j * ldc_ + i] = c10::convert<at::BFloat16>(float_v[j * m_ + i]); | ||
| } | ||
| } | ||
| #endif // |
There was a problem hiding this comment.
NIT: // defined(BLAS_HAS_BGEMM)
95f34c7 to
3e723ba
Compare
|
@pytorchbot label "ciflow/linux-aarch64" |
|
The failing CI seems unrelated... but rebase didn't seem to fix it.... |
3e723ba to
ea21fe0
Compare
6e20e67 to
50da97f
Compare
|
@pytorchbot rebase |
|
@pytorchbot started a rebase job onto refs/remotes/origin/viable/strict. Check the current status here |
|
Successfully rebased |
50da97f to
1d11997
Compare
1d11997 to
5e37303
Compare
|
@pytorchbot label "ciflow/linux-aarch64" |
OpenBLAS v0.3.31 adds support for BGEMM on SVE128, SVE256 machines and general optimizations for SBGEMM/BGEMM: OpenMathLib/OpenBLAS#5419, OpenMathLib/OpenBLAS#5399 among other things. OpenBLAS v0.3.32 accelerates SBGEMM/BGEMM on SVE128 machines by ~20%: OpenMathLib/OpenBLAS#5667 This accelerates SDPA, and will be capitalized on by #172945 further to accelerate linear,mm, bmm, etc PS: BGEMM means bf16 x bf16 -> bf16 and SBGEMM means: bf16 x bf16 -> fp32 ghstack-source-id: cf38a01 Pull-Request: #177012
OpenBLAS v0.3.31 adds support for BGEMM on SVE128, SVE256 machines and general optimizations for SBGEMM/BGEMM: OpenMathLib/OpenBLAS#5419, OpenMathLib/OpenBLAS#5399 among other things. OpenBLAS v0.3.32 accelerates SBGEMM/BGEMM on SVE128 machines by ~20%: OpenMathLib/OpenBLAS#5667 This accelerates SDPA, and will be capitalized on by #172945 further to accelerate linear,mm, bmm, etc PS: BGEMM means bf16 x bf16 -> bf16 and SBGEMM means: bf16 x bf16 -> fp32 ghstack-source-id: 952fd9e Pull-Request: #177012
OpenBLAS v0.3.31 adds support for BGEMM on SVE128, SVE256 machines and general optimizations for SBGEMM/BGEMM: OpenMathLib/OpenBLAS#5419, OpenMathLib/OpenBLAS#5399 among other things. OpenBLAS v0.3.32 accelerates SBGEMM/BGEMM on SVE128 machines by ~20%: OpenMathLib/OpenBLAS#5667 This accelerates SDPA, and will be capitalized on by #172945 further to accelerate linear,mm, bmm, etc PS: BGEMM means bf16 x bf16 -> bf16 and SBGEMM means: bf16 x bf16 -> fp32 ghstack-source-id: 596be25 Pull-Request: #177012
OpenBLAS v0.3.31 adds support for BGEMM on SVE128, SVE256 machines and general optimizations for SBGEMM/BGEMM: OpenMathLib/OpenBLAS#5419, OpenMathLib/OpenBLAS#5399 among other things. OpenBLAS v0.3.32 accelerates SBGEMM/BGEMM on SVE128 machines by ~20%: OpenMathLib/OpenBLAS#5667 This accelerates SDPA, and will be capitalized on by #172945 further to accelerate linear,mm, bmm, etc PS: BGEMM means bf16 x bf16 -> bf16 and SBGEMM means: bf16 x bf16 -> fp32 ghstack-source-id: 545189c Pull-Request: #177012
|
@pytorchbot rebase |
|
@pytorchbot started a rebase job onto refs/remotes/origin/viable/strict. Check the current status here |
|
Successfully rebased |
5e37303 to
8dd24fd
Compare
|
Hi @Anallear - nice speedups for llama! |
|
btw, it seems like the new version of OpenBLAS v0.3.32 will get released by the end of the week: OpenMathLib/OpenBLAS#5682 I raised a separate PR to update OpenBLAS to that version #177012 |
| if (bf16_usable) { | ||
| // BF16 heuristic: use BGEMM for GEMV-like or small shapes, | ||
| // otherwise prefer oneDNN for larger workloads. | ||
| if ((m == 1 || n == 1) || (m * n * k <= 786432)) { | ||
| return false; | ||
| } |
There was a problem hiding this comment.
change the rule here might be troublesome for different platforms.
if you care about the performance on arm platform, i suggest that change this condition (whether to use oneDNN or not) only for arm.
This PR introduces:
BGEMM backend integration for BF16 GEMM
A data-driven decision-tree heuristic for selecting BGEMM vs oneDNN
Significant CPU inference improvements
OpenBLAS update
Benchmark Setup
• Self CPU total
Short prompt: 486.337s → 97.390s (4.99x faster)
Long prompt (repeated 512x): 951.762s → 486.337s (1.96x faster)
• aten::mm self CPU
Short prompt: 329.574s → 69.056s (4.77x faster)
Long prompt: 771.995s → 329.574s (2.34x faster)
The majority of gains come from improved BF16 GEMM selection.
Methodology
Benchmark script
Benchmark.py
cc @jgong5 @mingfeima @XiaobingSuper @sanchitintel @ashokei @jingxu10 @jerryzh168 @aditew01