Skip to content

Support TPU v2 and v3 on new PyTorch/XLA TPU runtime#1385

Merged
sgugger merged 6 commits intohuggingface:mainfrom
will-cromar:wcromar/pjrt
May 9, 2023
Merged

Support TPU v2 and v3 on new PyTorch/XLA TPU runtime#1385
sgugger merged 6 commits intohuggingface:mainfrom
will-cromar:wcromar/pjrt

Conversation

@will-cromar
Copy link
Copy Markdown
Contributor

@will-cromar will-cromar commented May 3, 2023

Depends on #1393

I've been working on migrating PyTorch/XLA from our legacy XRT runtime to PJRT. We have detailed documentation on the differences and changes here: https://github.com/pytorch/xla/blob/master/docs/pjrt.md

Due to TPU design constraints, PJRT must use multithreading on TPU v2 and v3 (docs).

  • This means we cannot use global state for AcceleratorState, since two replicas would end up sharing the same accelerator.device and accelerator.process_index. I implemented a descriptor that wraps the global _shared_state dicts in state.py and optionally makes it thread-local when using XLA.
  • Use numpy Generators in the accelerate test script instead of the global numpy.random.seed. This makes the test thread-safe for TPU v2 and v3. Numpy recommends using Generators as a best practice in their docs: https://numpy.org/doc/stable/reference/random/generated/numpy.random.seed.html

Tested:

  • accelerate test on TPU v2-8 with XRT and PJRT

@HuggingFaceDocBuilderDev
Copy link
Copy Markdown

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint.

@JackCaoG
Copy link
Copy Markdown

JackCaoG commented May 3, 2023

@muellerzr We are trying to enable the Accelerate with new TPUVM and the PJRT runtime. Would someone in your team be able to review this change? We are working on deprecating the old XRT runtime and TPU node architure, so it is important for us to keep the hf support going forward.

Copy link
Copy Markdown
Collaborator

@sgugger sgugger left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for opening this PR! If possible let's have it focused on TPUs only and leave the changes non-related to that for other PRs, so it's easier to inspect potential regressions with git blames in the future.

Note that we use four spaces as indentation and not two.

From what I understand in your comments, a user will need to have PyTorch 2.0+ to use TPUs after this is merged? So let's have a proper error raised if we detect an earlier version telling the user to upgrade, what do you think?

Comment thread src/accelerate/state.py Outdated
Comment thread src/accelerate/state.py Outdated
class ThreadLocalSharedDict(GlobalSharedDict, threading.local):
pass

SharedDict = GlobalSharedDict if not is_tpu_available(check_device=False) else ThreadLocalSharedDict
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's not break the existing ang keep using a regular dict when not on TPU.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I wrote it this way because descriptors have slightly different semantics than just a dict on its own.

PartialState(...)._shared_state and PartialState._shared_state (instance vs class) give the same value: the underlying dict. Likewise, PartialState(...)._shared_state = {...} overrides the dict inside the descriptor as you would expect. However, PartialState._shared_state = {} actually replaces the descriptor object with a dict instead, which is then no longer thread-local. This is why I switched over to clear() below instead of assignment.

This is an obscure wart in Python, and I was planning to put a note in the comments when I took another cleanup pass. You can hack around this by using metaclasses to set the descriptor on both the class and instances of the class separately as suggested here: https://stackoverflow.com/a/51278141

I had this metaclass in an earlier version of this commit, but I decided that my change was getting too complicated for a simple use case like this. Using a dict for the global case will work as long as you call clear to empty it, but future uses of class._shared_state = {} will behave differently for the thread-local descriptor (ie will break the multithreaded case).

Let me know how you would like to handle it.

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Leaving a dict and calling clear seems easier to me, as long as it works. We don't have to worry about the state being multithreaded usually (as it's just launched on different processes) so I believe this is the first time it's coming up.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.

Comment thread src/accelerate/state.py
Comment thread src/accelerate/state.py
Comment thread src/accelerate/state.py
Comment thread src/accelerate/test_utils/scripts/test_script.py Outdated
set_seed(42)
generator.manual_seed(42)
train_set = RegressionDataset(length=length)
train_set = RegressionDataset(length=length, seed=42)
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This shouldn't be necessary as everything was seeded before. If set_seed missed something, the fix should probably be in set_seed.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

See #1385 (comment)

Each thread needs to make a Generator with a common seed, rather than using the global random state.

Comment on lines -24 to +27
if seed is not None:
np.random.seed(seed)
rng = np.random.default_rng(seed)
self.length = length
self.x = np.random.normal(size=(length,)).astype(np.float32)
self.y = a * self.x + b + np.random.normal(scale=0.1, size=(length,)).astype(np.float32)
self.x = rng.normal(size=(length,)).astype(np.float32)
self.y = a * self.x + b + rng.normal(scale=0.1, size=(length,)).astype(np.float32)
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This should be left as is.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This test actually can't use numpy's singleton RNG state with multithreading. Otherwise, the two threads will set the global seed at about the same time, then start concurrently pulling from that RNG. This will lead to inconsistent results between threads.

The only way to guarantee consistent results is to introduce a lock such that one thread updates the global seed and calls np.random.normal, then the next thread resets the seed and calls np.random.normal after the first is completely done. I used the Generator instead since it's simpler than introducing a lock.

Copy link
Copy Markdown
Contributor Author

@will-cromar will-cromar left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for taking a look!

I can break this down into 3 reviews:

  1. Suppress FileNotFound when deleting the test file. This is trivial and not directly related to XLA.
  2. Changes to get Accelerate working on TPU v4 and XLA:GPU (which don't use multithreading). These are straightforward refactors of XLA-specific code.
  3. Changes to support TPU v2 and v3 with multithreading. These change some common code paths and are more likely to need a rollback.

Unless I missed something, I don't think I added a dependency on anything introduced in PT/XLA 2.0. You can switch to PJRT by setting PJRT_DEVICE=TPU in the environment, which will have no effect before PJRT was added. I did add a call to xm.collective_broadcast, which is in the 1.13 release branch.

Let me know if this works for you. I can re-request your review and take this PR out of draft when I'm done shuffling commits and fixing code style issues.

Comment thread src/accelerate/state.py Outdated
class ThreadLocalSharedDict(GlobalSharedDict, threading.local):
pass

SharedDict = GlobalSharedDict if not is_tpu_available(check_device=False) else ThreadLocalSharedDict
Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I wrote it this way because descriptors have slightly different semantics than just a dict on its own.

PartialState(...)._shared_state and PartialState._shared_state (instance vs class) give the same value: the underlying dict. Likewise, PartialState(...)._shared_state = {...} overrides the dict inside the descriptor as you would expect. However, PartialState._shared_state = {} actually replaces the descriptor object with a dict instead, which is then no longer thread-local. This is why I switched over to clear() below instead of assignment.

This is an obscure wart in Python, and I was planning to put a note in the comments when I took another cleanup pass. You can hack around this by using metaclasses to set the descriptor on both the class and instances of the class separately as suggested here: https://stackoverflow.com/a/51278141

I had this metaclass in an earlier version of this commit, but I decided that my change was getting too complicated for a simple use case like this. Using a dict for the global case will work as long as you call clear to empty it, but future uses of class._shared_state = {} will behave differently for the thread-local descriptor (ie will break the multithreaded case).

Let me know how you would like to handle it.

Comment thread src/accelerate/state.py
Comment on lines -24 to +27
if seed is not None:
np.random.seed(seed)
rng = np.random.default_rng(seed)
self.length = length
self.x = np.random.normal(size=(length,)).astype(np.float32)
self.y = a * self.x + b + np.random.normal(scale=0.1, size=(length,)).astype(np.float32)
self.x = rng.normal(size=(length,)).astype(np.float32)
self.y = a * self.x + b + rng.normal(scale=0.1, size=(length,)).astype(np.float32)
Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This test actually can't use numpy's singleton RNG state with multithreading. Otherwise, the two threads will set the global seed at about the same time, then start concurrently pulling from that RNG. This will lead to inconsistent results between threads.

The only way to guarantee consistent results is to introduce a lock such that one thread updates the global seed and calls np.random.normal, then the next thread resets the seed and calls np.random.normal after the first is completely done. I used the Generator instead since it's simpler than introducing a lock.

@will-cromar will-cromar changed the title Fixes and updates for new PyTorch/XLA TPU runtime Support TPU v2 and v3 on new PyTorch/XLA TPU runtime May 5, 2023
@will-cromar
Copy link
Copy Markdown
Contributor Author

(1) is already fixed on main (thanks!), so dropping that commit here.

Sent the commits to make TPU v4 work (2) in #1393. We can review the last set of changes (for TPU v2 and v3) in this PR since we've already started the conversation.

I don't think GitHub has a nice way for me to make this PR diff against #1393, since my changes are coming from a fork. The relevant changes here are in training.py, test_script.py, and state.py.

@will-cromar will-cromar requested a review from sgugger May 5, 2023 21:24
@will-cromar
Copy link
Copy Markdown
Contributor Author

Rebased after #1393 and re-tested with XRT and PJRT on TPU v2-8.

@will-cromar will-cromar marked this pull request as ready for review May 8, 2023 19:29
Comment on lines 23 to +27
def __init__(self, a=2, b=3, length=64, seed=None):
if seed is not None:
np.random.seed(seed)
rng = np.random.default_rng(seed)
self.length = length
self.x = np.random.normal(size=(length,)).astype(np.float32)
self.y = a * self.x + b + np.random.normal(scale=0.1, size=(length,)).astype(np.float32)
self.x = rng.normal(size=(length,)).astype(np.float32)
self.y = a * self.x + b + rng.normal(scale=0.1, size=(length,)).astype(np.float32)
Copy link
Copy Markdown
Contributor

@muellerzr muellerzr May 9, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Alternative suggestion here, what if we modified set_seed to give it a default_rng bool flag, and when True will call np.random.default_rng instead and return the object. (Default False would thus return None)

Then, RegressionDataset checks if seed is an int or a generator, and then modifies accordingly towards what we want to do.

E.g. something like:

rng = np.random
if seed is not None:
    if isinstance(seed, int):
        np.random.seed(seed)
    else:
        rng = seed

wdyt @sgugger? This way we maintain the original behavior, and allow for this when needed (aka, only TPUs)

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That doesn't quite work as set_seed sets multiple seeds. Why would it return a Numpy RNG and not a torch rng?

Copy link
Copy Markdown
Collaborator

@sgugger sgugger left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for iterating on this, LGTM!

@sgugger sgugger merged commit d95d68e into huggingface:main May 9, 2023
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants