Support TPU v4 with new PyTorch/XLA TPU runtime#1393
Merged
sgugger merged 8 commits intohuggingface:mainfrom May 8, 2023
Merged
Support TPU v4 with new PyTorch/XLA TPU runtime#1393sgugger merged 8 commits intohuggingface:mainfrom
sgugger merged 8 commits intohuggingface:mainfrom
Conversation
|
The documentation is not available anymore as the PR was closed or merged. |
sgugger
approved these changes
May 6, 2023
Collaborator
|
Thanks again! |
|
Thanks both, this is super exciting! |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
I've been working on migrating PyTorch/XLA from our legacy XRT runtime to PJRT. We have detailed documentation on the differences and changes here: https://github.com/pytorch/xla/blob/master/docs/pjrt.md
This PR collects all of the fixes that I've made so far to support XRT and PJRT interchangeably through Accelerate. In order of significance:
synchronize_rng_typesto usexm.collective_broadcastto broadcast the RNG state tensor. Collective operations in general should be called from the main thread of each process to avoid unpredictable behaviors in XLA. But, our implementation ofMpDeviceLoadercallsaccelerate.DataLoaderShard's__iter__(which synchronizes the RNG) in all of the preloading threads. To ensure that the RNG is synchronized once from the main thread, call it in theMpDeviceLoaderWrapper's__iter__instead.xm.mark_stepto finish any remaining steps before checkpointing.MpDeviceLoaderis responsible for callingxm.mark_stepat the beginning of each new step and at the end of the dataset iterator. If you checkpoint in the middle of iteration, replica 0 won't reach themark_stepat the beginning of the next iteration before it tries to checkpoint. To avoid making the user callmark_stepthemselves, call it for them on replica 0 before checkpointing insave_state.all_gatheringatherinstead ofmesh_reduce+torch.cat. With PJRT,rendezvousandmesh_reduceboth use XLA collective ops to broadcast pickled data (docs) and move it back to the CPU. Batching all of the recursiveall_gathercalls together and callingxm.mark_steponce reduces the number of transfers between host and device.XLA_USE_BF16whenmixed_precisionis not set. This isn't really mixed precision as much as it is a mechanism to silently convert alltorch.floatsto BF16 on the TPU. Real mixed precision support in PT/XLA is still a WIP, and we can update back here when it's stable._mp_fnin test script.Tested:
accelerate teston v4-8 with XRT and PJRTdiffusersStable Diffusion fine-tuning example on TPU v4-8Accelerate will not work on TPU v2 and v3 with this PR, because both of them use use multithreading due to TPU design constraints. I'll follow up with the remaining fixes for TPU v2 and v3 in #1385.
cc @sgugger @JackCaoG