Skip to content

[AMP] Support XLA:TPU#96370

Closed
cowanmeg wants to merge 32 commits intopytorch:mainfrom
cowanmeg:amp-tpu
Closed

[AMP] Support XLA:TPU#96370
cowanmeg wants to merge 32 commits intopytorch:mainfrom
cowanmeg:amp-tpu

Conversation

@cowanmeg
Copy link
Copy Markdown
Contributor

@cowanmeg cowanmeg commented Mar 9, 2023

With pytorch/xla#5148, pytorch/xla#4740

With these changes
XLA:GPU users should use torch.cuda.amp.autocast() for AMP with float16
XLA:TPU users should use torch.amp.autocast('xla') for AMP with bfloat16

cc @mcarilli @ptrblck @leslie-fang-intel @jgong5

@pytorch-bot
Copy link
Copy Markdown

pytorch-bot Bot commented Mar 9, 2023

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/96370

Note: Links to docs will display an error until the docs builds have been completed.

✅ No Failures

As of commit 366470b:
💚 Looks good so far! There are no failures yet. 💚

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@linux-foundation-easycla
Copy link
Copy Markdown

linux-foundation-easycla Bot commented Mar 15, 2023

CLA Signed

The committers listed above are authorized under a signed CLA.

@albanD albanD requested a review from bdhirsh March 17, 2023 21:38
@soulitzer soulitzer removed their request for review March 22, 2023 20:50
Comment thread aten/src/ATen/autocast_mode.cpp Outdated

TORCH_LIBRARY_IMPL(_, AutocastXLA, m) {
m.fallback(torch::CppFunction::makeFallthrough());
}
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@cowanmeg all of these dispatcher registrations probably don't have to live in core - can we move them into the pytorch/xla repo?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Moved into pytorch/xla. Note, I moved the CastPolicy enum into autocast_mode.h so it could be included.

Comment thread c10/core/DispatchKey.h Outdated
// Naughtily, AutocastCUDA is also being used for XLA. In the terminal state,
// it probably should get its own Autocast key
AutocastXLA,
// AutocastXLA is only being used for TPUs. XLA GPUs continue to use AutocastCUDA.
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@cowanmeg can you describe how this works a bit more - what's the UX here? Is the user expected to use torch.cuda.autocast() when using XLA with gpu's, and torch.xla.autocast()` when using tpu's?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Correct. Updated the summary for clarity.

Comment thread third_party/kineto
Comment thread torch/__init__.py Outdated
Comment thread torch/csrc/jit/passes/autocast.cpp Outdated
struct AutocastContext {
bool gpu_enabled = false;
bool cpu_enabled = false;
bool xla_enabled = false;
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

cc @davidberard98 - do you mind reviewing the JIT changes in this file? I'm not too familiar with them.

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

looks good other than the two other comments (static runtime & bc-breaking)

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

decided to take the jit changes out since it's not used often in pytorch/xla

const auto cpu_enabled = p_node->Input(2).toBool();
const auto cuda_dtype = p_node->Input(3).toScalarType();
const auto cpu_dtype = p_node->Input(4).toScalarType();
const auto xla_enabled = p_node->Input(3).toBool();
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@tenpercent can you take a look at this? should we just leave static runtime out?

variants: function

- func: _autocast_to_reduced_precision(Tensor(a) self, bool cuda_enabled, bool cpu_enabled, ScalarType cuda_dtype, ScalarType cpu_dtype) -> Tensor(a)
- func: _autocast_to_reduced_precision(Tensor(a) self, bool cuda_enabled, bool cpu_enabled, bool xla_enabled, ScalarType cuda_dtype, ScalarType cpu_dtype, ScalarType xla_dtype) -> Tensor(a)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this is bc-breaking? not very familiar with how to do this, but I think we'd need to add upgraders for jit, right?

Comment thread torch/csrc/jit/passes/autocast.cpp Outdated
struct AutocastContext {
bool gpu_enabled = false;
bool cpu_enabled = false;
bool xla_enabled = false;
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

looks good other than the two other comments (static runtime & bc-breaking)

@cowanmeg cowanmeg requested a review from bdhirsh April 13, 2023 16:38
@cowanmeg
Copy link
Copy Markdown
Contributor Author

@pytorchbot merge

@pytorchmergebot
Copy link
Copy Markdown
Collaborator

Merge started

Your change will be merged once all checks pass (ETA 0-4 Hours).

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

@pytorchmergebot
Copy link
Copy Markdown
Collaborator

Merge failed

Reason: This PR is too stale; the last push date was more than 3 days ago. Please rebase and try again. You can rebase and merge by leaving the following comment on this PR:
@pytorchbot merge -r
Or just rebase by leaving @pytorchbot rebase comment

Details for Dev Infra team Raised by workflow job

@cowanmeg
Copy link
Copy Markdown
Contributor Author

@pytorchbot merge

@pytorchmergebot
Copy link
Copy Markdown
Collaborator

Merge started

Your change will be merged once all checks pass (ETA 0-4 Hours).

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

@pytorchmergebot
Copy link
Copy Markdown
Collaborator

Merge failed

Reason: 3 mandatory check(s) failed. The first few are:

Dig deeper by viewing the failures on hud

Details for Dev Infra team Raised by workflow job

Failing merge rule: Core Maintainers

@cowanmeg
Copy link
Copy Markdown
Contributor Author

I think this failing inductor test is unrelated:

2023-06-20T21:08:10.1240937Z [2023-06-20 21:08:10,123] torch._dynamo.utils: [ERROR] RMSE (res-fp64): 0.00213, (ref-fp64): 0.00064 and shape=torch.Size([256])
2023-06-20T21:08:10.1245080Z [2023-06-20 21:08:10,123] torch._dynamo.utils: [ERROR] Accuracy failed for key name backbone.fpn.layer_blocks.2.0.bias.grad
2023-06-20T21:08:10.1296323Z fail_accuracy

@bdhirsh
Copy link
Copy Markdown
Collaborator

bdhirsh commented Jun 21, 2023

@cowanmeg hmm, I don't think I see that failure in CI on the main branch https://hud.pytorch.org/, and it's a bit hard to tell immediately if it's flaky/unrelated, since that tests E2E logic, and appears to be running with autocast enabled. Can you try rebasing?

@cowanmeg
Copy link
Copy Markdown
Contributor Author

@pytorchbot merge

@pytorchmergebot
Copy link
Copy Markdown
Collaborator

Merge started

Your change will be merged once all checks pass (ETA 0-4 Hours).

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

@pytorchmergebot
Copy link
Copy Markdown
Collaborator

Merge failed

Reason: 3 mandatory check(s) failed. The first few are:

Dig deeper by viewing the failures on hud

Details for Dev Infra team Raised by workflow job

Failing merge rule: Core Maintainers

@cowanmeg
Copy link
Copy Markdown
Contributor Author

I think the BC lint failure is a cancellation?

@kit1980
Copy link
Copy Markdown
Contributor

kit1980 commented Jun 23, 2023

@pytorchbot merge

@pytorchmergebot
Copy link
Copy Markdown
Collaborator

Merge started

Your change will be merged once all checks pass (ETA 0-4 Hours).

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ciflow/inductor ciflow/trunk Trigger trunk jobs on your pull request Merged module: amp (automated mixed precision) autocast open source release notes: jit release notes category triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module

Projects

None yet

Development

Successfully merging this pull request may close these issues.

8 participants