Support skip_first_batches for XLA#2966
Conversation
muellerzr
left a comment
There was a problem hiding this comment.
Thanks! Overall this looks fine, just one suggestion :)
src/accelerate/data_loader.py
Outdated
| Creates a `torch.utils.data.DataLoader` that will efficiently skip the first `num_batches`. | ||
| """ | ||
| is_xla_dataloader = False | ||
| if is_torch_xla_available() and isinstance(dataloader, MpDeviceLoaderWrapper): |
There was a problem hiding this comment.
At this point I believe we can use PartialState().distributed_type == DistributedType.XLA
There was a problem hiding this comment.
OK, I change to AcceleratorState because this class has already been imported and it aligns with the usage of the prepare_data_loader method. I believe that the distributed_type of these two states is shared.
There was a problem hiding this comment.
PartialState is safer and better for these types of situations.
There was a problem hiding this comment.
Thank you for your correction. I have already changed it to PartialState. I'm not very familiar with these states :)
|
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update. |
|
hi, @muellerzr , can you merge this pr? The error in tests seems unrelated to this PR. |
|
@yitongh yes indeed! |
What does this PR do?
At present, when using the resume_from_checkpoint feature in the Transformers Trainer, it results in an error because
skip_first_batchesdoes not supportMpDeviceLoaderWrapperofXLA. This PR supports this feature.Before submitting
Pull Request section?
to it if that's the case.
documentation guidelines, and
here are tips on formatting docstrings.
Who can review?
@muellerzr