Make DistributedSampler stateful#1315
Conversation
|
AI Store test can be safely ignored for now |
andrewkho
left a comment
There was a problem hiding this comment.
Looks pretty good, but would like to simplify the code a bit and move the tests around as well
| ls[i].append(next(its[i])) | ||
| self.assertEqual(ls[0], ls[1]) | ||
|
|
||
| def test_initialization_StatefulDistributedSampler(self): |
There was a problem hiding this comment.
Let's move all of these tests out to a new file called test_sampler.py. You can update https://github.com/pytorch/data/blob/main/.github/workflows/stateful_dataloader_ci.yml to call it in an additional step
There was a problem hiding this comment.
| from torchdata.stateful_dataloader.sampler import StatefulDistributedSampler | ||
|
|
||
| dataset = self.dataset | ||
| sampler = StatefulDistributedSampler(dataset, num_replicas=10, rank=0, shuffle=False, seed=42, drop_last=False) |
There was a problem hiding this comment.
For testing state_dict, let's have most of the tests set up with passing sampler + dataset to StatefulDataLoader so we can test that it works end-to-end
There was a problem hiding this comment.
You might need to use a dummy Collate function to easily inspect elements, check the test_state_dict.py file for examples
There was a problem hiding this comment.
| self.next_yielded = None | ||
|
|
||
| def __iter__(self): | ||
|
|
There was a problem hiding this comment.
Is it possible to fork the DistributedSampler.__iter__ code here instead and just update, instead of having a separate Iterator class?
There was a problem hiding this comment.
| if self.sampler.shuffle: | ||
| # deterministically shuffle based on epoch and seed | ||
| g = torch.Generator() | ||
| g.manual_seed(self.sampler.seed + self.sampler.epoch) | ||
| indices = torch.randperm(len(self.sampler.dataset), generator=g).tolist() # type: ignore[arg-type] | ||
| else: | ||
| indices = list(range(len(self.sampler.dataset))) # type: ignore[arg-type] | ||
|
|
||
| if not self.sampler.drop_last: | ||
| # add extra samples to make it evenly divisible | ||
| padding_size = self.sampler.total_size - len(indices) | ||
| if padding_size <= len(indices): | ||
| indices += indices[:padding_size] | ||
| else: | ||
| indices += (indices * math.ceil(padding_size / len(indices)))[:padding_size] | ||
| else: | ||
| # remove tail of data to make it evenly divisible. | ||
| indices = indices[: self.sampler.total_size] | ||
| assert len(indices) == self.sampler.total_size | ||
|
|
||
| # subsample | ||
| indices = indices[self.sampler.rank : self.sampler.total_size : self.sampler.num_replicas] | ||
| assert len(indices) == self.sampler.num_samples | ||
|
|
||
| self.parent_iterator = iter(indices) | ||
| self.indices = list(self.parent_iterator) | ||
| self.current_index = 0 |
There was a problem hiding this comment.
Is there a way to call the original code instead of forking it here?
| def state_dict(self) -> Dict[str, Any]: | ||
| return self.sampler.state_dict() | ||
|
|
||
| def load_state_dict(self, state_dict: Dict[str, Any]) -> None: | ||
| self.sampler.load_state_dict(state_dict) |
There was a problem hiding this comment.
I don't think we need this both here and in the main sampler class, can we consolidate to have this in just one place?
andrewkho
left a comment
There was a problem hiding this comment.
Couple of suggestions, but looks great! very nice test suite.
When you're done making changes, please run the fbcode CI for media_dataloader
Co-authored-by: Andrew Ho <andrewkh@meta.com>
|
@ramanishsingh has imported this pull request. If you are a Meta employee, you can view this diff on Phabricator. |
|
@ramanishsingh has imported this pull request. If you are a Meta employee, you can view this diff on Phabricator. |
|
This pull request was exported from Phabricator. Differential Revision: D61772177 |
Fixes #1269
Changes
torchdata/stateful_dataloader/sampler.py: Added new classesStatefulDistributedSamplerand_StatefulDistributedSamplerIteratortest/stateful_dataloader/test_dataloader.pynew tests forStatefulDistributedSampler