[AMP] Support XLA:TPU#96370
Conversation
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/96370
Note: Links to docs will display an error until the docs builds have been completed. ✅ No FailuresAs of commit 366470b: This comment was automatically generated by Dr. CI and updates every 15 minutes. |
|
|
|
|
||
| TORCH_LIBRARY_IMPL(_, AutocastXLA, m) { | ||
| m.fallback(torch::CppFunction::makeFallthrough()); | ||
| } |
There was a problem hiding this comment.
@cowanmeg all of these dispatcher registrations probably don't have to live in core - can we move them into the pytorch/xla repo?
There was a problem hiding this comment.
Moved into pytorch/xla. Note, I moved the CastPolicy enum into autocast_mode.h so it could be included.
| // Naughtily, AutocastCUDA is also being used for XLA. In the terminal state, | ||
| // it probably should get its own Autocast key | ||
| AutocastXLA, | ||
| // AutocastXLA is only being used for TPUs. XLA GPUs continue to use AutocastCUDA. |
There was a problem hiding this comment.
@cowanmeg can you describe how this works a bit more - what's the UX here? Is the user expected to use torch.cuda.autocast() when using XLA with gpu's, and torch.xla.autocast()` when using tpu's?
There was a problem hiding this comment.
Correct. Updated the summary for clarity.
| struct AutocastContext { | ||
| bool gpu_enabled = false; | ||
| bool cpu_enabled = false; | ||
| bool xla_enabled = false; |
There was a problem hiding this comment.
cc @davidberard98 - do you mind reviewing the JIT changes in this file? I'm not too familiar with them.
There was a problem hiding this comment.
looks good other than the two other comments (static runtime & bc-breaking)
There was a problem hiding this comment.
decided to take the jit changes out since it's not used often in pytorch/xla
| const auto cpu_enabled = p_node->Input(2).toBool(); | ||
| const auto cuda_dtype = p_node->Input(3).toScalarType(); | ||
| const auto cpu_dtype = p_node->Input(4).toScalarType(); | ||
| const auto xla_enabled = p_node->Input(3).toBool(); |
There was a problem hiding this comment.
@tenpercent can you take a look at this? should we just leave static runtime out?
| variants: function | ||
|
|
||
| - func: _autocast_to_reduced_precision(Tensor(a) self, bool cuda_enabled, bool cpu_enabled, ScalarType cuda_dtype, ScalarType cpu_dtype) -> Tensor(a) | ||
| - func: _autocast_to_reduced_precision(Tensor(a) self, bool cuda_enabled, bool cpu_enabled, bool xla_enabled, ScalarType cuda_dtype, ScalarType cpu_dtype, ScalarType xla_dtype) -> Tensor(a) |
There was a problem hiding this comment.
I think this is bc-breaking? not very familiar with how to do this, but I think we'd need to add upgraders for jit, right?
| struct AutocastContext { | ||
| bool gpu_enabled = false; | ||
| bool cpu_enabled = false; | ||
| bool xla_enabled = false; |
There was a problem hiding this comment.
looks good other than the two other comments (static runtime & bc-breaking)
|
@pytorchbot merge |
Merge startedYour change will be merged once all checks pass (ETA 0-4 Hours). Learn more about merging in the wiki. Questions? Feedback? Please reach out to the PyTorch DevX Team |
Merge failedReason: This PR is too stale; the last push date was more than 3 days ago. Please rebase and try again. You can rebase and merge by leaving the following comment on this PR: Details for Dev Infra teamRaised by workflow job |
|
@pytorchbot merge |
Merge startedYour change will be merged once all checks pass (ETA 0-4 Hours). Learn more about merging in the wiki. Questions? Feedback? Please reach out to the PyTorch DevX Team |
Merge failedReason: 3 mandatory check(s) failed. The first few are: Dig deeper by viewing the failures on hud |
|
I think this failing inductor test is unrelated: |
|
@cowanmeg hmm, I don't think I see that failure in CI on the main branch https://hud.pytorch.org/, and it's a bit hard to tell immediately if it's flaky/unrelated, since that tests E2E logic, and appears to be running with autocast enabled. Can you try rebasing? |
|
@pytorchbot merge |
Merge startedYour change will be merged once all checks pass (ETA 0-4 Hours). Learn more about merging in the wiki. Questions? Feedback? Please reach out to the PyTorch DevX Team |
Merge failedReason: 3 mandatory check(s) failed. The first few are: Dig deeper by viewing the failures on hud |
|
I think the BC lint failure is a cancellation? |
|
@pytorchbot merge |
Merge startedYour change will be merged once all checks pass (ETA 0-4 Hours). Learn more about merging in the wiki. Questions? Feedback? Please reach out to the PyTorch DevX Team |
With pytorch/xla#5148, pytorch/xla#4740
With these changes
XLA:GPU users should use
torch.cuda.amp.autocast()for AMP with float16XLA:TPU users should use
torch.amp.autocast('xla')for AMP with bfloat16cc @mcarilli @ptrblck @leslie-fang-intel @jgong5