Allow running with bfloat16 on XLA:GPU autocast#5598
Conversation
cbddc5d to
1383493
Compare
| # XLA:GPU with bfloat16 should run on `xla` backend | ||
| # unless torch.autocast is compiled with cuda. | ||
| backend = 'xla' | ||
| self._cuda_bfloat16 = True |
There was a problem hiding this comment.
Does PyTorch not support bfloat16 when running on xla backend?
I'm still confused when dtype is set to bfloat16:
When torch.cuda.is_available() returns False, torch_xla uses the XLA:GPU backend. However, in this patch it seems to fall back to use the GPU backend.
Similarly, when torch.cuda.is_available() returns True, torch_xla uses the GPU backend, but the dtype is still forced to float16 instead of bfloat16.
There was a problem hiding this comment.
Hi @baoleai good catch -- I meant to do what the err msg says. We don't need to force torch.float16 on cuda backend.
Thanks, will update the PR.
There was a problem hiding this comment.
Ok, let me go over -- let me know if I am missing anything.
-
"When torch.cuda.is_available() returns False, torch_xla uses the XLA:GPU backend. However, in this patch it seems to fall back to use the GPU backend."
When cuda is not available (False), we go into this code path
if xr.is_bf16_supported() and not torch.cuda.is_available():
# XLA:GPU with bfloat16 should run on `xla` backend
# unless torch.autocast is compiled with cuda.
backend = 'xla'
self._cuda_bfloat16 = True
We uses xla backend for autocast.
-
Similarly, when torch.cuda.is_available() returns True, torch_xla uses the GPU backend, but the dtype is still forced to float16 instead of bfloat16.
This is the part I need to address, instead of else we want elif {cuda is not available}, to implement the intention laid out in the err msg.
There was a problem hiding this comment.
So something like this, commit 50c1d2f1c598c2ec4ada8f7177861cc05056ff36
if dtype is None:
dtype = torch.float16
elif dtype == torch.bfloat16 and not torch.cuda.is_available():
if xr.is_bf16_supported():
# XLA:GPU with bfloat16 should run on `xla` backend
# unless torch.autocast is compiled with cuda.
backend = 'xla'
self._cuda_bfloat16 = True
else:
# This has been the default behavior for unsupported bfloat16 dtype
dtype = torch.float16
error_message = "In XLA:GPU autocast, but bfloat16 is not supported on this HW.\n"
error_message += ("Using the default cuda autocast dtype float16.")
| self._xla_device = xm.xla_device_hw(device) | ||
| if self._xla_device == 'GPU': | ||
| backend = 'cuda' | ||
| self._cuda_bfloat16 = False |
There was a problem hiding this comment.
I am very confuse what this _cuda_bfloat16 actually means. Do you mind adding a comment above line 28 to explain all possible combinations and the expected behaviors?
There was a problem hiding this comment.
+1, realized that this variable should be called _xla_bfloat16. Let me add a brief comment, too.
There was a problem hiding this comment.
Something like this,
self._xla_bfloat16 = False # True if xla backend with bfloat16 dtype.
if dtype is None:
dtype = torch.float16
elif dtype == torch.bfloat16 and not torch.cuda.is_available():
if xr.is_bf16_supported():
# XLA:GPU with bfloat16 should run on `xla` backend
# unless torch.autocast is compiled with cuda.
backend = 'xla'
self._xla_bfloat16 = True
There was a problem hiding this comment.
commit dc9224336af6cf6313e60d197dab47323e8e509d (HEAD -> spmd_amp_gpu, origin/spmd_amp_gpu)
Author: Yeounoh Chung <yeounoh@google.com>
Date: Mon Sep 18 12:02:15 2023 -0700
Rename _cuda_bfloat16 to _xla_bfloat16 since it is set when xla backend is used for bfloat16.
There was a problem hiding this comment.
Another confusing question is why there is only a special treatment for torch.cuda.is_available() is False and dtype=bfloat16 here, is it because XLA:GPU itself doesn't support bfloat16? But when torch.cuda.is_available() is True, there is no special treatment for bfloat16, then the following is still used
torch.set_autocast_xla_enabled(self.prev)
torch.set_autocast_xla_dtype(self.prev_dtype)
There was a problem hiding this comment.
Hi @baoleai XLA:GPU supports bfloat16, but we were using the cuda backend for autocast, when torch.cuda.is_available() was false. Instead we want to use xla backend.
When torch.cuda.is_available() is true, we use cuda backend for autocast, and since we are entering/existing the autocast context via torch_xla.amp.autocast, we still need to set
torch.set_autocast_xla_enabled(self.prev)
torch.set_autocast_xla_dtype(self.prev_dtype)
| if self._xla_bfloat16: | ||
| torch.set_autocast_enabled(self._enabled) |
There was a problem hiding this comment.
if _xla_bfloat16 , shouldn't we use set_autocast_xla_enabled instead?
There was a problem hiding this comment.
it will be set by torch.autocast, we need to set both -- since we are wrapping and calling torch.autocast. So the xla autocast is enabled if one is using torch_xla.amp.autocast, and as torch autocast is enabled we are also calling the torch.autocast with cuda or xla backend.
There was a problem hiding this comment.
Can we leave a comment here, it is really confusing to see that if we are using xla_bf16, we need setup upstream autocast to enabled.
| if self._xla_bfloat16: | ||
| torch.set_autocast_enabled(self.prev) |
5c317bb to
8358095
Compare
|
The GPU tests had passed already. Merging after adding the comments. |
This is to follow up on #5570 , which enabled xla autocast with bfloat16 type on XLA:GPU, but restricted/casted bfloat16 to float16 and float32. Supporting bfloat16 on eligible HW platforms should bring in perf improvements for XLA:GPU. This is tested with the same set of ops that
torch.autocast('cuda')uses for f16 autocasting.