Skip to content

Enable xla:gpu autocast for bfloat16 if not restricted#5570

Merged
will-cromar merged 11 commits intomasterfrom
spmd_amp_gpu
Sep 15, 2023
Merged

Enable xla:gpu autocast for bfloat16 if not restricted#5570
will-cromar merged 11 commits intomasterfrom
spmd_amp_gpu

Conversation

@yeounoh
Copy link
Copy Markdown
Contributor

@yeounoh yeounoh commented Sep 14, 2023

torch_xla.amp.autocast was not using XLA:GPU but torch.cuda (GPU) backend. We should use XLA devices for XLA tensors, and for running torch_xla.amp.autocast.

Clarification 1:

AMP executes on XLA:GPU, unless it's bfloat16 in which case, it has to fall back to cuda and requires USE_CUDA=1. We should address it.

This pytorch/pytorch#109302 to update the dispatcher comment is not needed. This also addresses #5497.

@yeounoh yeounoh self-assigned this Sep 14, 2023
@yeounoh yeounoh changed the title Enable autocast for XLA:GPU Enable xla autocast for XLA:GPU Sep 14, 2023
@JackCaoG
Copy link
Copy Markdown
Collaborator

there were some histories behind this. We used upstream autocase for cuda because we want to reuse the GPU AMP rules. Are you saying currently AMP does not execute on XLA:GPU when enabled?

@JackCaoG
Copy link
Copy Markdown
Collaborator

@JackCaoG
Copy link
Copy Markdown
Collaborator

I am pretty sure, at least in 2.0 that our autocast actually worked with XLA:GPU and executed on XLA:GPU

@yeounoh
Copy link
Copy Markdown
Contributor Author

yeounoh commented Sep 14, 2023

I will investigate why our tests didn't catch this issue and modify or add proper tests, before merging.

@yeounoh yeounoh marked this pull request as draft September 14, 2023 19:26
@yeounoh yeounoh removed the request for review from cowanmeg September 14, 2023 19:27
@yeounoh
Copy link
Copy Markdown
Contributor Author

yeounoh commented Sep 14, 2023

I am pretty sure, at least in 2.0 that our autocast actually worked with XLA:GPU and executed on XLA:GPU

Yea, it did. I think this issue was surfaced becasue of bfloat16, and we were hard-coding the dtype float16 for GPU back then.

@yeounoh
Copy link
Copy Markdown
Contributor Author

yeounoh commented Sep 14, 2023

there were some histories behind this. We used upstream autocase for cuda because we want to reuse the GPU AMP rules. Are you saying currently AMP does not execute on XLA:GPU when enabled?

Got it, AMP executes on XLA:GPU, unless it's bfloat16 in which case, it has to fall back to cuda and requires USE_CUDA=1. We should address it.

@baoleai
Copy link
Copy Markdown
Contributor

baoleai commented Sep 15, 2023

Got it, AMP executes on XLA:GPU, unless it's bfloat16 in which case, it has to fall back to cuda and requires USE_CUDA=1. We should address it.

How is this issue resolved?

@yeounoh
Copy link
Copy Markdown
Contributor Author

yeounoh commented Sep 15, 2023

Got it, AMP executes on XLA:GPU, unless it's bfloat16 in which case, it has to fall back to cuda and requires USE_CUDA=1. We should address it.

How is this issue resolved?

Hi @baoleai it was resolved, but making changes to our approach. Will let you know when the PR is ready.

@yeounoh yeounoh force-pushed the spmd_amp_gpu branch 4 times, most recently from 17d614e to 0a0365b Compare September 15, 2023 05:14
@yeounoh yeounoh force-pushed the spmd_amp_gpu branch 8 times, most recently from 9cf9a21 to 98dec3d Compare September 15, 2023 07:36
@yeounoh yeounoh marked this pull request as ready for review September 15, 2023 07:40
@yeounoh yeounoh changed the title Enable xla autocast for XLA:GPU Enable xla:gpu autocast for bfloat16 if not restricted Sep 15, 2023
@yeounoh
Copy link
Copy Markdown
Contributor Author

yeounoh commented Sep 15, 2023

GPU test CI run timed out at the end.... increasing the timeout. I was able to pass all the GPU tests locally, on my GPU/TPU boxes. Will wait for the new CI run.

@ManfeiBai
Copy link
Copy Markdown
Collaborator

GPU test CI run timed out at the end.... increasing the timeout. I was able to pass all the GPU tests locally, on my GPU/TPU boxes. Will wait for the new CI run.

Hi, does the GPU CI test timed out at a test? didn't found the time out error in the last two commits' log

Pin update PR met similar issue and fixed by increased the timeout, do we want to try this too? 55a0823

Comment thread torch_xla/runtime.py
"""Returns whether torch.bfloat16 is supported on this environment.
"""
try:
torch.tensor([1.], dtype=torch.bfloat16, device=xm.xla_device())
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

lol this is smart

@will-cromar will-cromar merged commit 64bbf15 into master Sep 15, 2023
will-cromar pushed a commit that referenced this pull request Sep 15, 2023
* Enable autocast for XLA:GPU

* linter fix

* XLA autocast test for GPU and TPU

* linter fix

* Ensure that xla autocast is properly enabled for GPU and does not crash when torch cuda is not available.

* linter fix

* Add tests

* Support bf16

* linter fix

* exclude unsupported test cases

* increase GPU test timeout to 300
@yeounoh
Copy link
Copy Markdown
Contributor Author

yeounoh commented Sep 15, 2023

GPU test CI run timed out at the end.... increasing the timeout. I was able to pass all the GPU tests locally, on my GPU/TPU boxes. Will wait for the new CI run.

Hi, does the GPU CI test timed out at a test? didn't found the time out error in the last two commits' log

Pin update PR met similar issue and fixed by increased the timeout, do we want to try this too? 55a0823

I increased timeout and was able to pass as well. I am going to address the gpu test time sometime soon.

JackCaoG pushed a commit that referenced this pull request Sep 15, 2023
* Enable autocast for XLA:GPU

* linter fix

* XLA autocast test for GPU and TPU

* linter fix

* Ensure that xla autocast is properly enabled for GPU and does not crash when torch cuda is not available.

* linter fix

* Add tests

* Support bf16

* linter fix

* exclude unsupported test cases

* increase GPU test timeout to 300

Co-authored-by: Yeounoh Chung <yeounoh@google.com>
# XLA:GPU with bfloat16 should run on `xla` backend
# unless torch.autocast is compiled with cuda.
backend = 'xla'
else:
Copy link
Copy Markdown
Contributor

@baoleai baoleai Sep 15, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why does not support bfloat16 now? Previously torch_xla was able to support bfloat16 with torch.cuda.is_available = True(there will be bfloat16 in HLO). Do you mean we have to compile PyTorch with USE_CUDA=0 after this commit? And in most scenarios, our PyTorch needs to support both XLA and GPUs simultaneously, i.e. with USE_CUDA=1.

Copy link
Copy Markdown
Contributor Author

@yeounoh yeounoh Sep 16, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Notice the if condition there, if torch is already compiled with USE_CUDA=1 and torch.cuda.is_available == True then it would run with cuda backend and supports bfloat16 as the cuda backend allows (note that for some platforms this wouldn't be the case).

For some additional context regarding the XLA:GPU torch.cuda.is_available == False case, we guide users to compile torch with USE_CUDA=0 to run ops with XLA only, not torch cuda backend -- gpu guide. But yea, I see that people would actually want to set both USE_CUDA and XLA_CUDA and run things simultaneously -- which should be fine?

For XLA:GPU xla backend, this should also provide the bfloat16 support #5598

Copy link
Copy Markdown
Contributor

@baoleai baoleai Sep 18, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

then it would run with cuda backend and supports bfloat16 as the cuda backend allows (note that for some platforms this wouldn't be the case).

But at line 41, the type is forced to torch.float16, even if I set dtype=torch.bfloat16.

will-cromar added a commit that referenced this pull request Sep 18, 2023
* Enable autocast for XLA:GPU

* linter fix

* XLA autocast test for GPU and TPU

* linter fix

* Ensure that xla autocast is properly enabled for GPU and does not crash when torch cuda is not available.

* linter fix

* Add tests

* Support bf16

* linter fix

* exclude unsupported test cases

* increase GPU test timeout to 300

Co-authored-by: Yeounoh Chung <yeounoh@google.com>
will-cromar added a commit that referenced this pull request Sep 19, 2023
* Handle dynamo function without input (#5565) (#5577)

* Make cpu tensor on XLA dynamo backend a warning instead of error (#5549) (#5576)

* [author: jluntamazon] Adding more explicit HLO lowering control by exposing LoweringContext… (#5431) (#5580)

* Adding more explicit HLO lowering control by exposing LoweringContext (and utilities) to python for Neuron

* fixing linter issues

* fixing spacing

* apply comments and fix compilation errors

* add test for new apis

* fix linter

* update test

* update test

* modify test

* reverse back to GetIrValue()

* update test inputs with random numbers

* skip unittest because it only fails in CI

---------

Co-authored-by: aws-kingrj <78175353+aws-kingrj@users.noreply.github.com>
Co-authored-by: Ubuntu <ubuntu@ip-172-31-3-186.us-west-2.compute.internal>
Co-authored-by: seanlatias <seanlatias@gmail.com>

* fixing num_local_processes typo (#5573) (#5579)

Co-authored-by: aws-kingrj <78175353+aws-kingrj@users.noreply.github.com>

* Move where clear pending IR is called to avoid crash (#5552) (#5582)

* Move where clear pending IR is called to avoid crash

* fix CI

* fix CI and add some debugging messages

* Fix release branch and tag patterns for GitHub Actions (#5587) (#5590)

* Improve bernoulli rng-bit-generation memory footprint (#5581) (#5589)

* Allow downcasting RngUniform genenration for Bernoulli

Co-authored-by: Yeounoh Chung <yeounoh@google.com>

* Enable xla:gpu autocast for bfloat16 if not restricted (#5570) (#5591)

* Enable autocast for XLA:GPU

* linter fix

* XLA autocast test for GPU and TPU

* linter fix

* Ensure that xla autocast is properly enabled for GPU and does not crash when torch cuda is not available.

* linter fix

* Add tests

* Support bf16

* linter fix

* exclude unsupported test cases

* increase GPU test timeout to 300

Co-authored-by: Yeounoh Chung <yeounoh@google.com>

* Cherry-pick: Don't trigger CI build on release tag push (#5595)

Copy of #5594 on release branch

* formatting

---------

Co-authored-by: JackCaoG <59073027+JackCaoG@users.noreply.github.com>
Co-authored-by: Wonjoo Lee <wonjoo@google.com>
Co-authored-by: aws-kingrj <78175353+aws-kingrj@users.noreply.github.com>
Co-authored-by: Ubuntu <ubuntu@ip-172-31-3-186.us-west-2.compute.internal>
Co-authored-by: seanlatias <seanlatias@gmail.com>
Co-authored-by: Manfei <41607353+ManfeiBai@users.noreply.github.com>
Co-authored-by: Yeounoh Chung <yeounoh@google.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants