Skip to content

Enables bfloat16 x [float16, complex64, complex128] type promotion#43324

Closed
mruberry wants to merge 3 commits intomasterfrom
bfloat16_float16
Closed

Enables bfloat16 x [float16, complex64, complex128] type promotion#43324
mruberry wants to merge 3 commits intomasterfrom
bfloat16_float16

Conversation

@mruberry
Copy link
Copy Markdown
Collaborator

@mruberry mruberry commented Aug 20, 2020

Implements bfloat16 type promotion consistent with JAX (see https://jax.readthedocs.io/en/latest/type_promotion.html), addressing issue #43049.

  • bfloat16 x float16 -> float32
  • bfloat16 x complex64 -> complex64
  • bfloat16 x complex128 -> complex128

Existing tests, after updates, are sufficient to validate the new behavior.

cc @xuhdev

@mruberry mruberry requested a review from gchanan August 20, 2020 10:55
self.assertEqual((bf + scalar).dtype, torch.bfloat16)
self.assertEqual((scalar + bf).dtype, torch.bfloat16)
self.assertEqual(scalar + bf, bf + scalar)
with self.assertRaises(RuntimeError):
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can you add a complex scalar to the for loop above? That seems to capture the intent of this test.

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good idea. I also simplified the loop.

self.assertEqual(bf + t, t + bf)
if dtype in (torch.float16, torch.float32, torch.float64, torch.cfloat, torch.cdouble):
# Handles bfloat16 x float16 -> float32 promotion
expected_dtype = dtype if dtype != torch.half else torch.float32
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can you rationalize?

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Rationalized.

Copy link
Copy Markdown
Contributor

@facebook-github-bot facebook-github-bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@mruberry has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.

@dr-ci
Copy link
Copy Markdown

dr-ci Bot commented Aug 21, 2020

💊 CI failures summary and remediations

As of commit f4922f4 (more details on the Dr. CI page):


  • 1/1 failures introduced in this PR

🕵️ 1 new failure recognized by patterns

The following CI failures do not appear to be due to upstream breakages:

See CircleCI build pytorch_macos_10_13_py3_test (1/1)

Step: "Test" (full log | diagnosis details | 🔁 rerun)

Aug 21 00:36:07 [E request_callback_no_python.cpp:618] Received error while processing request type 2: RuntimeError: Can not pickle torch.futures.Future
Aug 21 00:36:07 At: 
Aug 21 00:36:07   /Users/distiller/workspace/miniconda3/lib/python3.7/site-packages/torch/distributed/rpc/internal.py(93): serialize 
Aug 21 00:36:07   /Users/distiller/workspace/miniconda3/lib/python3.7/site-packages/torch/distributed/rpc/internal.py(145): serialize 
Aug 21 00:36:07  
Aug 21 00:36:07 [E request_callback_no_python.cpp:618] Received error while processing request type 2: RuntimeError: Can not pickle torch.futures.Future 
Aug 21 00:36:07  
Aug 21 00:36:07 At: 
Aug 21 00:36:07   /Users/distiller/workspace/miniconda3/lib/python3.7/site-packages/torch/distributed/rpc/internal.py(93): serialize 
Aug 21 00:36:07   /Users/distiller/workspace/miniconda3/lib/python3.7/site-packages/torch/distributed/rpc/internal.py(145): serialize 
Aug 21 00:36:07  
Aug 21 00:36:07 [E request_callback_no_python.cpp:618] Received error while processing request type 2: RuntimeError: Can not pickle torch.futures.Future 
Aug 21 00:36:07  
Aug 21 00:36:07 At: 
Aug 21 00:36:07   /Users/distiller/workspace/miniconda3/lib/python3.7/site-packages/torch/distributed/rpc/internal.py(93): serialize 
Aug 21 00:36:07   /Users/distiller/workspace/miniconda3/lib/python3.7/site-packages/torch/distributed/rpc/internal.py(145): serialize 
Aug 21 00:36:07  
Aug 21 00:36:07 ok (1.340s) 
Aug 21 00:36:09   test_return_future_remote (__main__.ProcessGroupRpcTestWithSpawn) ... ok (1.283s) 
Aug 21 00:36:10   test_return_local_rrefs (__main__.ProcessGroupRpcTestWithSpawn) ... ok (1.365s) 
Aug 21 00:36:11   test_rpc_return_rref (__main__.ProcessGroupRpcTestWithSpawn) ... ok (1.289s) 
Aug 21 00:36:19   test_rpc_timeouts (__main__.ProcessGroupRpcTestWithSpawn) ... ok (7.873s) 

This comment was automatically generated by Dr. CI (expand for details).Follow this link to opt-out of these comments for your Pull Requests.

Please report bugs/suggestions on the GitHub issue tracker or post in the (internal) Dr. CI Users group.

See how this bot performed.

This comment has been revised 1 time.

@facebook-github-bot
Copy link
Copy Markdown
Contributor

@mruberry merged this pull request in 3aec118.

@mruberry mruberry deleted the bfloat16_float16 branch August 21, 2020 19:37
laurentdupin pushed a commit to laurentdupin/pytorch that referenced this pull request Apr 24, 2026
…ytorch#43324)

Summary:
Implements bfloat16 type promotion consistent with JAX (see https://jax.readthedocs.io/en/latest/type_promotion.html), addressing issue pytorch#43049.

- bfloat16 x float16 -> float32
- bfloat16 x complex64 -> complex64
- bfloat16 x complex128 -> complex128

Existing tests, after updates, are sufficient to validate the new behavior.

cc xuhdev

Pull Request resolved: pytorch#43324

Reviewed By: albanD

Differential Revision: D23259823

Pulled By: mruberry

fbshipit-source-id: ca9c2c7d0325faced1f884f3c37edf8fa8c8b089
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants