Skip to content

Support TPU v4 with new PyTorch/XLA TPU runtime#1393

Merged
sgugger merged 8 commits intohuggingface:mainfrom
will-cromar:wcromar/pjrt-v4
May 8, 2023
Merged

Support TPU v4 with new PyTorch/XLA TPU runtime#1393
sgugger merged 8 commits intohuggingface:mainfrom
will-cromar:wcromar/pjrt-v4

Conversation

@will-cromar
Copy link
Copy Markdown
Contributor

I've been working on migrating PyTorch/XLA from our legacy XRT runtime to PJRT. We have detailed documentation on the differences and changes here: https://github.com/pytorch/xla/blob/master/docs/pjrt.md

This PR collects all of the fixes that I've made so far to support XRT and PJRT interchangeably through Accelerate. In order of significance:

  1. Update implementation of synchronize_rng_types to use xm.collective_broadcast to broadcast the RNG state tensor. Collective operations in general should be called from the main thread of each process to avoid unpredictable behaviors in XLA. But, our implementation of MpDeviceLoader calls accelerate.DataLoaderShard's __iter__ (which synchronizes the RNG) in all of the preloading threads. To ensure that the RNG is synchronized once from the main thread, call it in the MpDeviceLoaderWrapper's __iter__ instead.
  2. In general, you should call xm.mark_step to finish any remaining steps before checkpointing. MpDeviceLoader is responsible for calling xm.mark_step at the beginning of each new step and at the end of the dataset iterator. If you checkpoint in the middle of iteration, replica 0 won't reach the mark_step at the beginning of the next iteration before it tries to checkpoint. To avoid making the user call mark_step themselves, call it for them on replica 0 before checkpointing in save_state.
  3. Use XLA's all_gather in gather instead of mesh_reduce + torch.cat. With PJRT, rendezvous and mesh_reduce both use XLA collective ops to broadcast pickled data (docs) and move it back to the CPU. Batching all of the recursive all_gather calls together and calling xm.mark_step once reduces the number of transfers between host and device.
  4. Don't override XLA_USE_BF16 when mixed_precision is not set. This isn't really mixed precision as much as it is a mechanism to silently convert all torch.floats to BF16 on the TPU. Real mixed precision support in PT/XLA is still a WIP, and we can update back here when it's stable.
  5. Remove obsolete/unused _mp_fn in test script.

Tested:

Accelerate will not work on TPU v2 and v3 with this PR, because both of them use use multithreading due to TPU design constraints. I'll follow up with the remaining fixes for TPU v2 and v3 in #1385.

cc @sgugger @JackCaoG

@will-cromar will-cromar marked this pull request as ready for review May 5, 2023 21:15
@HuggingFaceDocBuilderDev
Copy link
Copy Markdown

HuggingFaceDocBuilderDev commented May 6, 2023

The documentation is not available anymore as the PR was closed or merged.

Copy link
Copy Markdown
Collaborator

@sgugger sgugger left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for adding support for TPU v4! All the changes LGTM! Can you just run make style on your branch so that the quality check passes?

@will-cromar will-cromar requested a review from sgugger May 8, 2023 17:05
@sgugger sgugger merged commit 145fca5 into huggingface:main May 8, 2023
@sgugger
Copy link
Copy Markdown
Collaborator

sgugger commented May 8, 2023

Thanks again!

@JackCaoG
Copy link
Copy Markdown

JackCaoG commented May 8, 2023

Thanks both, this is super exciting!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants