Support TPU v2 and v3 on new PyTorch/XLA TPU runtime#1385
Support TPU v2 and v3 on new PyTorch/XLA TPU runtime#1385sgugger merged 6 commits intohuggingface:mainfrom
Conversation
|
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. |
|
@muellerzr We are trying to enable the Accelerate with new TPUVM and the PJRT runtime. Would someone in your team be able to review this change? We are working on deprecating the old XRT runtime and TPU node architure, so it is important for us to keep the hf support going forward. |
sgugger
left a comment
There was a problem hiding this comment.
Thanks for opening this PR! If possible let's have it focused on TPUs only and leave the changes non-related to that for other PRs, so it's easier to inspect potential regressions with git blames in the future.
Note that we use four spaces as indentation and not two.
From what I understand in your comments, a user will need to have PyTorch 2.0+ to use TPUs after this is merged? So let's have a proper error raised if we detect an earlier version telling the user to upgrade, what do you think?
| class ThreadLocalSharedDict(GlobalSharedDict, threading.local): | ||
| pass | ||
|
|
||
| SharedDict = GlobalSharedDict if not is_tpu_available(check_device=False) else ThreadLocalSharedDict |
There was a problem hiding this comment.
Let's not break the existing ang keep using a regular dict when not on TPU.
There was a problem hiding this comment.
I wrote it this way because descriptors have slightly different semantics than just a dict on its own.
PartialState(...)._shared_state and PartialState._shared_state (instance vs class) give the same value: the underlying dict. Likewise, PartialState(...)._shared_state = {...} overrides the dict inside the descriptor as you would expect. However, PartialState._shared_state = {} actually replaces the descriptor object with a dict instead, which is then no longer thread-local. This is why I switched over to clear() below instead of assignment.
This is an obscure wart in Python, and I was planning to put a note in the comments when I took another cleanup pass. You can hack around this by using metaclasses to set the descriptor on both the class and instances of the class separately as suggested here: https://stackoverflow.com/a/51278141
I had this metaclass in an earlier version of this commit, but I decided that my change was getting too complicated for a simple use case like this. Using a dict for the global case will work as long as you call clear to empty it, but future uses of class._shared_state = {} will behave differently for the thread-local descriptor (ie will break the multithreaded case).
Let me know how you would like to handle it.
There was a problem hiding this comment.
Leaving a dict and calling clear seems easier to me, as long as it works. We don't have to worry about the state being multithreaded usually (as it's just launched on different processes) so I believe this is the first time it's coming up.
| set_seed(42) | ||
| generator.manual_seed(42) | ||
| train_set = RegressionDataset(length=length) | ||
| train_set = RegressionDataset(length=length, seed=42) |
There was a problem hiding this comment.
This shouldn't be necessary as everything was seeded before. If set_seed missed something, the fix should probably be in set_seed.
There was a problem hiding this comment.
See #1385 (comment)
Each thread needs to make a Generator with a common seed, rather than using the global random state.
| if seed is not None: | ||
| np.random.seed(seed) | ||
| rng = np.random.default_rng(seed) | ||
| self.length = length | ||
| self.x = np.random.normal(size=(length,)).astype(np.float32) | ||
| self.y = a * self.x + b + np.random.normal(scale=0.1, size=(length,)).astype(np.float32) | ||
| self.x = rng.normal(size=(length,)).astype(np.float32) | ||
| self.y = a * self.x + b + rng.normal(scale=0.1, size=(length,)).astype(np.float32) |
There was a problem hiding this comment.
This should be left as is.
There was a problem hiding this comment.
This test actually can't use numpy's singleton RNG state with multithreading. Otherwise, the two threads will set the global seed at about the same time, then start concurrently pulling from that RNG. This will lead to inconsistent results between threads.
The only way to guarantee consistent results is to introduce a lock such that one thread updates the global seed and calls np.random.normal, then the next thread resets the seed and calls np.random.normal after the first is completely done. I used the Generator instead since it's simpler than introducing a lock.
will-cromar
left a comment
There was a problem hiding this comment.
Thanks for taking a look!
I can break this down into 3 reviews:
- Suppress
FileNotFoundwhen deleting the test file. This is trivial and not directly related to XLA. - Changes to get Accelerate working on TPU v4 and XLA:GPU (which don't use multithreading). These are straightforward refactors of XLA-specific code.
- Changes to support TPU v2 and v3 with multithreading. These change some common code paths and are more likely to need a rollback.
Unless I missed something, I don't think I added a dependency on anything introduced in PT/XLA 2.0. You can switch to PJRT by setting PJRT_DEVICE=TPU in the environment, which will have no effect before PJRT was added. I did add a call to xm.collective_broadcast, which is in the 1.13 release branch.
Let me know if this works for you. I can re-request your review and take this PR out of draft when I'm done shuffling commits and fixing code style issues.
| class ThreadLocalSharedDict(GlobalSharedDict, threading.local): | ||
| pass | ||
|
|
||
| SharedDict = GlobalSharedDict if not is_tpu_available(check_device=False) else ThreadLocalSharedDict |
There was a problem hiding this comment.
I wrote it this way because descriptors have slightly different semantics than just a dict on its own.
PartialState(...)._shared_state and PartialState._shared_state (instance vs class) give the same value: the underlying dict. Likewise, PartialState(...)._shared_state = {...} overrides the dict inside the descriptor as you would expect. However, PartialState._shared_state = {} actually replaces the descriptor object with a dict instead, which is then no longer thread-local. This is why I switched over to clear() below instead of assignment.
This is an obscure wart in Python, and I was planning to put a note in the comments when I took another cleanup pass. You can hack around this by using metaclasses to set the descriptor on both the class and instances of the class separately as suggested here: https://stackoverflow.com/a/51278141
I had this metaclass in an earlier version of this commit, but I decided that my change was getting too complicated for a simple use case like this. Using a dict for the global case will work as long as you call clear to empty it, but future uses of class._shared_state = {} will behave differently for the thread-local descriptor (ie will break the multithreaded case).
Let me know how you would like to handle it.
| if seed is not None: | ||
| np.random.seed(seed) | ||
| rng = np.random.default_rng(seed) | ||
| self.length = length | ||
| self.x = np.random.normal(size=(length,)).astype(np.float32) | ||
| self.y = a * self.x + b + np.random.normal(scale=0.1, size=(length,)).astype(np.float32) | ||
| self.x = rng.normal(size=(length,)).astype(np.float32) | ||
| self.y = a * self.x + b + rng.normal(scale=0.1, size=(length,)).astype(np.float32) |
There was a problem hiding this comment.
This test actually can't use numpy's singleton RNG state with multithreading. Otherwise, the two threads will set the global seed at about the same time, then start concurrently pulling from that RNG. This will lead to inconsistent results between threads.
The only way to guarantee consistent results is to introduce a lock such that one thread updates the global seed and calls np.random.normal, then the next thread resets the seed and calls np.random.normal after the first is completely done. I used the Generator instead since it's simpler than introducing a lock.
|
(1) is already fixed on Sent the commits to make TPU v4 work (2) in #1393. We can review the last set of changes (for TPU v2 and v3) in this PR since we've already started the conversation. I don't think GitHub has a nice way for me to make this PR diff against #1393, since my changes are coming from a fork. The relevant changes here are in |
|
Rebased after #1393 and re-tested with XRT and PJRT on TPU v2-8. |
| def __init__(self, a=2, b=3, length=64, seed=None): | ||
| if seed is not None: | ||
| np.random.seed(seed) | ||
| rng = np.random.default_rng(seed) | ||
| self.length = length | ||
| self.x = np.random.normal(size=(length,)).astype(np.float32) | ||
| self.y = a * self.x + b + np.random.normal(scale=0.1, size=(length,)).astype(np.float32) | ||
| self.x = rng.normal(size=(length,)).astype(np.float32) | ||
| self.y = a * self.x + b + rng.normal(scale=0.1, size=(length,)).astype(np.float32) |
There was a problem hiding this comment.
Alternative suggestion here, what if we modified set_seed to give it a default_rng bool flag, and when True will call np.random.default_rng instead and return the object. (Default False would thus return None)
Then, RegressionDataset checks if seed is an int or a generator, and then modifies accordingly towards what we want to do.
E.g. something like:
rng = np.random
if seed is not None:
if isinstance(seed, int):
np.random.seed(seed)
else:
rng = seedwdyt @sgugger? This way we maintain the original behavior, and allow for this when needed (aka, only TPUs)
There was a problem hiding this comment.
That doesn't quite work as set_seed sets multiple seeds. Why would it return a Numpy RNG and not a torch rng?
sgugger
left a comment
There was a problem hiding this comment.
Thanks for iterating on this, LGTM!
Depends on #1393
I've been working on migrating PyTorch/XLA from our legacy XRT runtime to PJRT. We have detailed documentation on the differences and changes here: https://github.com/pytorch/xla/blob/master/docs/pjrt.md
Due to TPU design constraints, PJRT must use multithreading on TPU v2 and v3 (docs).
AcceleratorState, since two replicas would end up sharing the sameaccelerator.deviceandaccelerator.process_index. I implemented a descriptor that wraps the global_shared_statedicts instate.pyand optionally makes it thread-local when using XLA.Generators in theaccelerate testscript instead of the globalnumpy.random.seed. This makes the test thread-safe for TPU v2 and v3. Numpy recommends usingGeneratorsas a best practice in their docs: https://numpy.org/doc/stable/reference/random/generated/numpy.random.seed.htmlTested:
accelerate teston TPU v2-8 with XRT and PJRT