Skip to content

Allow running with bfloat16 on XLA:GPU autocast#5598

Merged
yeounoh merged 19 commits intomasterfrom
spmd_amp_gpu
Sep 18, 2023
Merged

Allow running with bfloat16 on XLA:GPU autocast#5598
yeounoh merged 19 commits intomasterfrom
spmd_amp_gpu

Conversation

@yeounoh
Copy link
Copy Markdown
Contributor

@yeounoh yeounoh commented Sep 16, 2023

This is to follow up on #5570 , which enabled xla autocast with bfloat16 type on XLA:GPU, but restricted/casted bfloat16 to float16 and float32. Supporting bfloat16 on eligible HW platforms should bring in perf improvements for XLA:GPU. This is tested with the same set of ops that torch.autocast('cuda') uses for f16 autocasting.

Comment thread torch_xla/amp/autocast_mode.py Outdated
# XLA:GPU with bfloat16 should run on `xla` backend
# unless torch.autocast is compiled with cuda.
backend = 'xla'
self._cuda_bfloat16 = True
Copy link
Copy Markdown
Contributor

@baoleai baoleai Sep 18, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does PyTorch not support bfloat16 when running on xla backend?
I'm still confused when dtype is set to bfloat16:
When torch.cuda.is_available() returns False, torch_xla uses the XLA:GPU backend. However, in this patch it seems to fall back to use the GPU backend.
Similarly, when torch.cuda.is_available() returns True, torch_xla uses the GPU backend, but the dtype is still forced to float16 instead of bfloat16.

Copy link
Copy Markdown
Contributor Author

@yeounoh yeounoh Sep 18, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi @baoleai good catch -- I meant to do what the err msg says. We don't need to force torch.float16 on cuda backend.

Thanks, will update the PR.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok, let me go over -- let me know if I am missing anything.

  1. "When torch.cuda.is_available() returns False, torch_xla uses the XLA:GPU backend. However, in this patch it seems to fall back to use the GPU backend."

When cuda is not available (False), we go into this code path

if xr.is_bf16_supported() and not torch.cuda.is_available():
          # XLA:GPU with bfloat16 should run on `xla` backend
          # unless torch.autocast is compiled with cuda.
          backend = 'xla'
          self._cuda_bfloat16 = True

We uses xla backend for autocast.

  1. Similarly, when torch.cuda.is_available() returns True, torch_xla uses the GPU backend, but the dtype is still forced to float16 instead of bfloat16.

This is the part I need to address, instead of else we want elif {cuda is not available}, to implement the intention laid out in the err msg.

Copy link
Copy Markdown
Contributor Author

@yeounoh yeounoh Sep 18, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So something like this, commit 50c1d2f1c598c2ec4ada8f7177861cc05056ff36

if dtype is None:
        dtype = torch.float16
elif dtype == torch.bfloat16 and not torch.cuda.is_available():
        if xr.is_bf16_supported():
          # XLA:GPU with bfloat16 should run on `xla` backend
          # unless torch.autocast is compiled with cuda.
          backend = 'xla'
          self._cuda_bfloat16 = True
        else:
          # This has been the default behavior for unsupported bfloat16 dtype
          dtype = torch.float16
          error_message = "In XLA:GPU autocast, but bfloat16 is not supported on this HW.\n"
          error_message += ("Using the default cuda autocast dtype float16.")

Comment thread torch_xla/amp/autocast_mode.py Outdated
self._xla_device = xm.xla_device_hw(device)
if self._xla_device == 'GPU':
backend = 'cuda'
self._cuda_bfloat16 = False
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am very confuse what this _cuda_bfloat16 actually means. Do you mind adding a comment above line 28 to explain all possible combinations and the expected behaviors?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

+1, realized that this variable should be called _xla_bfloat16. Let me add a brief comment, too.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Something like this,

self._xla_bfloat16 = False  # True if xla backend with bfloat16 dtype.
if dtype is None:
   dtype = torch.float16
elif dtype == torch.bfloat16 and not torch.cuda.is_available():
   if xr.is_bf16_supported():
       # XLA:GPU with bfloat16 should run on `xla` backend
       # unless torch.autocast is compiled with cuda.
      backend = 'xla'
      self._xla_bfloat16 = True

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

commit dc9224336af6cf6313e60d197dab47323e8e509d (HEAD -> spmd_amp_gpu, origin/spmd_amp_gpu)
Author: Yeounoh Chung <yeounoh@google.com>
Date:   Mon Sep 18 12:02:15 2023 -0700

    Rename _cuda_bfloat16 to _xla_bfloat16 since it is set when xla backend is used for bfloat16.
 

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Another confusing question is why there is only a special treatment for torch.cuda.is_available() is False and dtype=bfloat16 here, is it because XLA:GPU itself doesn't support bfloat16? But when torch.cuda.is_available() is True, there is no special treatment for bfloat16, then the following is still used

torch.set_autocast_xla_enabled(self.prev)
torch.set_autocast_xla_dtype(self.prev_dtype)

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi @baoleai XLA:GPU supports bfloat16, but we were using the cuda backend for autocast, when torch.cuda.is_available() was false. Instead we want to use xla backend.

When torch.cuda.is_available() is true, we use cuda backend for autocast, and since we are entering/existing the autocast context via torch_xla.amp.autocast, we still need to set

torch.set_autocast_xla_enabled(self.prev)
torch.set_autocast_xla_dtype(self.prev_dtype)

Comment on lines +77 to +80
if self._xla_bfloat16:
torch.set_autocast_enabled(self._enabled)
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

if _xla_bfloat16 , shouldn't we use set_autocast_xla_enabled instead?

Copy link
Copy Markdown
Contributor Author

@yeounoh yeounoh Sep 18, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

it will be set by torch.autocast, we need to set both -- since we are wrapping and calling torch.autocast. So the xla autocast is enabled if one is using torch_xla.amp.autocast, and as torch autocast is enabled we are also calling the torch.autocast with cuda or xla backend.

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we leave a comment here, it is really confusing to see that if we are using xla_bf16, we need setup upstream autocast to enabled.

Comment on lines +88 to +93
if self._xla_bfloat16:
torch.set_autocast_enabled(self.prev)
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

same question

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ditto

Copy link
Copy Markdown
Collaborator

@JackCaoG JackCaoG left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM if comment can be added. Chatted with @yeounoh offline I am going to approve to unblock.

@yeounoh
Copy link
Copy Markdown
Contributor Author

yeounoh commented Sep 18, 2023

The GPU tests had passed already. Merging after adding the comments.

@yeounoh yeounoh merged commit b214c10 into master Sep 18, 2023
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants