Skip to content

Enable BFloat support for gemms on arch other than ampere#50442

Closed
zasdfgbnm wants to merge 20 commits intomasterfrom
ci-all/matmul-bf16-non-ampere
Closed

Enable BFloat support for gemms on arch other than ampere#50442
zasdfgbnm wants to merge 20 commits intomasterfrom
ci-all/matmul-bf16-non-ampere

Conversation

@zasdfgbnm
Copy link
Copy Markdown
Collaborator

Fixes #{issue number}

@zasdfgbnm zasdfgbnm changed the title Enable BFloat support for gemms on arch other than ampere [WIP]Enable BFloat support for gemms on arch other than ampere Jan 12, 2021
@facebook-github-bot
Copy link
Copy Markdown
Contributor

facebook-github-bot commented Jan 12, 2021

💊 CI failures summary and remediations

As of commit c33a608 (more details on the Dr. CI page):


  • 2/2 failures possibly* introduced in this PR
    • 1/2 non-CircleCI failure(s)

🕵️ 1 new failure recognized by patterns

The following CI failures do not appear to be due to upstream breakages:

See CircleCI build pytorch_windows_vs2019_py36_cuda11.1_test2 (1/1)

Step: "Test" (full log | diagnosis details | 🔁 rerun)

AssertionError: "Simulate error" does not match "grad can be implicitly created only for scalar outputs"

Traceback (most recent call last):
  File "C:\Users\circleci\project\build\win_tmp\build\torch\testing\_internal\common_device_type.py", line 290, in instantiated_test
    result = test_fn(self, *args)
  File "C:\Users\circleci\project\build\win_tmp\build\torch\testing\_internal\common_device_type.py", line 687, in only_fn
    return fn(slf, device, *args, **kwargs)
  File "test_autograd.py", line 6652, in test_reentrant_parent_error_on_cpu
    self._test_reentrant_parent_error_on_cpu(device)
  File "test_autograd.py", line 6638, in _test_reentrant_parent_error_on_cpu
    torch.autograd.backward([t5.sum(), t7.sum()])
AssertionError: "Simulate error" does not match "grad can be implicitly created only for scalar outputs"

----------------------------------------------------------------------
Ran 2794 tests in 2826.455s

FAILED (failures=1, skipped=23, expected failures=1)

Generating XML reports...
Generated XML report: test-reports\python-unittest\TEST-TestAutograd-20210125173827.xml
Generated XML report: test-reports\python-unittest\TEST-TestAutogradComplex-20210125173827.xml
Generated XML report: test-reports\python-unittest\TEST-TestAutogradDeviceTypeCPU-20210125173827.xml

This comment was automatically generated by Dr. CI (expand for details).Follow this link to opt-out of these comments for your Pull Requests.

Please report bugs/suggestions to the (internal) Dr. CI Users group.

@zasdfgbnm zasdfgbnm changed the title [WIP]Enable BFloat support for gemms on arch other than ampere Enable BFloat support for gemms on arch other than ampere Jan 14, 2021
@zasdfgbnm zasdfgbnm requested review from mruberry and ngimel January 14, 2021 18:19
@zasdfgbnm
Copy link
Copy Markdown
Collaborator Author

This should be ready. Test failures are unrelated.

@mrshenli mrshenli added module: bfloat16 triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module labels Jan 15, 2021
} else {
TORCH_CHECK(false, "BFloat16 gemm in CUDA requires Ampere or later GPU");
}
TORCH_CUDABLAS_CHECK(cublasGemmEx(
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

setting and resetting cublas MathMode is not required if you specify CUBLAS_GEMM_DFALT_TENSOR_OP?

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

According to https://docs.nvidia.com/cuda/cublas/index.html#cublasmath_t

CUBLAS_DEFAULT_MATH    This is the default and highest-performance mode that uses compute and intermediate storage precisions with at least the same number of mantissa and exponent bits as requested. Tensor Cores will be used whenever possible.
CUBLAS_TENSOR_OP_MATH    This mode is deprecated and will be removed in a future release. Allows the library to use Tensor Core operations whenever possible. For single precision GEMM routines cuBLAS will use the CUBLAS_COMPUTE_32F_FAST_16F compute type.

Comment thread test/test_linalg.py Outdated
Comment thread test/test_linalg.py Outdated
b1 = torch.randn(num_batches, M, N, device=device).to(dtype)
b2 = torch.randn(num_batches, N, O, device=device).to(dtype)
self.assertRaisesRegex(RuntimeError, "type|Type|not implemented|Ampere", lambda: torch.bmm(b1, b2))
if not is_cuda_bfloat:
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is_supported=False, is_cuda_bfloat=False is an impossible situation?

Copy link
Copy Markdown
Collaborator Author

@zasdfgbnm zasdfgbnm Jan 19, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Some ops are supported on SM52, and some are not. I don't think it worth the maintenance effort to write a clear list on which is supported on which SM. So what I implemented here is:

SM >= 53 ---> supported
SM < 53 ---> undefined behavior

Comment thread torch/testing/_internal/common_cuda.py Outdated
Comment thread aten/src/ATen/cuda/CUDABlas.cpp Outdated
Comment thread test/test_linalg.py Outdated
Comment thread test/test_linalg.py Outdated
Comment thread test/test_linalg.py Outdated
@codecov
Copy link
Copy Markdown

codecov Bot commented Jan 20, 2021

Codecov Report

Merging #50442 (79a68e3) into master (8e9ed27) will decrease coverage by 0.00%.
The diff coverage is 100.00%.

@@            Coverage Diff             @@
##           master   #50442      +/-   ##
==========================================
- Coverage   81.00%   81.00%   -0.01%     
==========================================
  Files        1916     1916              
  Lines      209481   209484       +3     
==========================================
+ Hits       169690   169692       +2     
- Misses      39791    39792       +1     

@zasdfgbnm
Copy link
Copy Markdown
Collaborator Author

@ngimel @mruberry I think I have resolved all review comments, and all tests pass.

Copy link
Copy Markdown
Collaborator

@mruberry mruberry left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Cool! Thanks @zasdfgbnm!

Would you just rebase this? Sorry PyTorch is especially popular these days.

@zasdfgbnm
Copy link
Copy Markdown
Collaborator Author

@mruberry rebased

Copy link
Copy Markdown
Contributor

@facebook-github-bot facebook-github-bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@mruberry has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.

@mruberry
Copy link
Copy Markdown
Collaborator

Internal builds are failing with:

    from torch.testing._internal.common_cuda import _get_torch_cuda_version
  File "/data/sandcastle/boxes/eden-trunk-hg-fbcode-fbsource/fbcode/buck-out/dev/gen/caffe2/caffe2/fb/high_perf_models/pytorch/torchscript/test/test_ir_bench#binary,link-tree/torch/testing/_internal/common_cuda.py", line 18, in <module>
    CUDA11OrLater = torch.version.cuda and float(torch.version.cuda) >= 11
ValueError: could not convert string to float: '9.2.0'

We typically use LooseVersion for version comparisons. See

active_if=LooseVersion(scipy.__version__) < "1.4.0"),

for an example.

@zasdfgbnm
Copy link
Copy Markdown
Collaborator Author

@mruberry fixed

Copy link
Copy Markdown
Contributor

@facebook-github-bot facebook-github-bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@mruberry has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.

@facebook-github-bot
Copy link
Copy Markdown
Contributor

@mruberry merged this pull request in b822aba.

@zasdfgbnm zasdfgbnm deleted the ci-all/matmul-bf16-non-ampere branch January 26, 2021 21:50
laurentdupin pushed a commit to laurentdupin/pytorch that referenced this pull request Apr 24, 2026
)

Summary:
Fixes #{issue number}

Pull Request resolved: pytorch#50442

Reviewed By: bdhirsh

Differential Revision: D26044981

Pulled By: mruberry

fbshipit-source-id: 65c42f2c1de8d24e4852a1b5bd8f4b1735b2230e
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

cla signed Merged module: bfloat16 open source triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module

Projects

None yet

Development

Successfully merging this pull request may close these issues.

6 participants