Enable BFloat support for gemms on arch other than ampere#50442
Enable BFloat support for gemms on arch other than ampere#50442
Conversation
💊 CI failures summary and remediationsAs of commit c33a608 (more details on the Dr. CI page):
🕵️ 1 new failure recognized by patternsThe following CI failures do not appear to be due to upstream breakages:
|
…torch into ci-all/matmul-bf16-non-ampere
…l-bf16-non-ampere
|
This should be ready. Test failures are unrelated. |
| } else { | ||
| TORCH_CHECK(false, "BFloat16 gemm in CUDA requires Ampere or later GPU"); | ||
| } | ||
| TORCH_CUDABLAS_CHECK(cublasGemmEx( |
There was a problem hiding this comment.
setting and resetting cublas MathMode is not required if you specify CUBLAS_GEMM_DFALT_TENSOR_OP?
There was a problem hiding this comment.
According to https://docs.nvidia.com/cuda/cublas/index.html#cublasmath_t
CUBLAS_DEFAULT_MATH This is the default and highest-performance mode that uses compute and intermediate storage precisions with at least the same number of mantissa and exponent bits as requested. Tensor Cores will be used whenever possible.
CUBLAS_TENSOR_OP_MATH This mode is deprecated and will be removed in a future release. Allows the library to use Tensor Core operations whenever possible. For single precision GEMM routines cuBLAS will use the CUBLAS_COMPUTE_32F_FAST_16F compute type.
| b1 = torch.randn(num_batches, M, N, device=device).to(dtype) | ||
| b2 = torch.randn(num_batches, N, O, device=device).to(dtype) | ||
| self.assertRaisesRegex(RuntimeError, "type|Type|not implemented|Ampere", lambda: torch.bmm(b1, b2)) | ||
| if not is_cuda_bfloat: |
There was a problem hiding this comment.
is_supported=False, is_cuda_bfloat=False is an impossible situation?
There was a problem hiding this comment.
Some ops are supported on SM52, and some are not. I don't think it worth the maintenance effort to write a clear list on which is supported on which SM. So what I implemented here is:
SM >= 53 ---> supported
SM < 53 ---> undefined behavior
…l-bf16-non-ampere
Codecov Report
@@ Coverage Diff @@
## master #50442 +/- ##
==========================================
- Coverage 81.00% 81.00% -0.01%
==========================================
Files 1916 1916
Lines 209481 209484 +3
==========================================
+ Hits 169690 169692 +2
- Misses 39791 39792 +1 |
mruberry
left a comment
There was a problem hiding this comment.
Cool! Thanks @zasdfgbnm!
Would you just rebase this? Sorry PyTorch is especially popular these days.
|
@mruberry rebased |
facebook-github-bot
left a comment
There was a problem hiding this comment.
@mruberry has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.
|
Internal builds are failing with: We typically use LooseVersion for version comparisons. See for an example. |
…l-bf16-non-ampere
|
@mruberry fixed |
facebook-github-bot
left a comment
There was a problem hiding this comment.
@mruberry has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.
) Summary: Fixes #{issue number} Pull Request resolved: pytorch#50442 Reviewed By: bdhirsh Differential Revision: D26044981 Pulled By: mruberry fbshipit-source-id: 65c42f2c1de8d24e4852a1b5bd8f4b1735b2230e
Fixes #{issue number}