[DataPipes] Add group support to the sharding_filter#88424
[DataPipes] Add group support to the sharding_filter#88424VitalyFedyunin wants to merge 7 commits intogh/VitalyFedyunin/116/basefrom
Conversation
[ghstack-poisoned]
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/88424
Note: Links to docs will display an error until the docs builds have been completed. ✅ No FailuresAs of commit 4a1181d: This comment was automatically generated by Dr. CI and updates every 15 minutes. |
[ghstack-poisoned]
[ghstack-poisoned]
[ghstack-poisoned]
ejguan
left a comment
There was a problem hiding this comment.
LGTM with a few nit comments.
|
|
||
| def __init__(self, source_datapipe: IterDataPipe, sharding_group_filter=None): | ||
| self.source_datapipe = source_datapipe | ||
| self.sharding_group_filter = sharding_group_filter |
There was a problem hiding this comment.
Do we need an extra API to set sharding_group_filter?
Based on the implementation, it seems sharding_group_filter is an integer, could we change it set or list to support multiple filters?
There was a problem hiding this comment.
I have no use-cases for it, but it will be trivial to change later if we require to.
| if self.sharding_group_filter is None: | ||
| sorted_sharding_groups.append(self.groups[key]) | ||
| else: | ||
| if key == self.sharding_group_filter: | ||
| sorted_sharding_groups.append(self.groups[key]) |
There was a problem hiding this comment.
nit:
if self.sharding_group_filter is None or key == self.sharding_group_filter:
sorted_sharding_groups.append(self.groups[key])| def apply_sharding(self, num_of_instances, instance_id): | ||
| self.num_of_instances = num_of_instances | ||
| self.instance_id = instance_id | ||
| def apply_sharding(self, num_of_instances, instance_id, sharding_group=SHARDING_PRIORITIES.DEFAULT): |
There was a problem hiding this comment.
Super nit: Could we add a validation that instance_id < num_of_instances?
|
|
||
| with self.assertRaises(Exception): | ||
| dp.apply_sharding(2, 1, sharding_group=SHARDING_PRIORITIES.DEFAULT) | ||
| dp.apply_sharding(5, 3, sharding_group=SHARDING_PRIORITIES.MULTIPROCESSING) |
There was a problem hiding this comment.
Could we add a separate context of self.assertRaises for the second Error?
[ghstack-poisoned]
|
@VitalyFedyunin has imported this pull request. If you are a Meta employee, you can view this diff on Phabricator. |
Differential Revision: [D41006747](https://our.internmc.facebook.com/intern/diff/D41006747) [ghstack-poisoned]
Differential Revision: [D41006747](https://our.internmc.facebook.com/intern/diff/D41006747) [ghstack-poisoned]
|
@VitalyFedyunin has imported this pull request. If you are a Meta employee, you can view this diff on Phabricator. |
|
/easycla |
1 similar comment
|
/easycla |
|
@pytorchbot merge (Initiating merge automatically since Phabricator Diff has merged) |
Merge startedYour change will be merged once all checks pass (ETA 0-4 Hours). Learn more about merging in the wiki. Questions? Feedback? Please reach out to the PyTorch DevX Team |
Summary: After pytorch/pytorch#88424 is landed, we are able to invoke `apply_sharding` by sharding levels (distributed or multiprocessing). Then, we are able to give fine-control on sharding by `ReadingService`. - For `DistributedReadingService`, we will only set sharding on the distributed level - For `PrototypeMPReadingService`, we will set distributed sharding in the main process and set mp sharding in the worker processes. Previously, we set sharding in each worker process based on both distributed and mp information. - `worker_init_fn` doesn't need `DistInfo` anymore. As, the `DataPipe` has been distributed sharded in the main process. - Combine `DistInfo` and `ExtraInfo` for `worker_reset_fn` to synchronize the distributed seeds across distributed workers and set worker-local seeds based on both distributed and mp information. Pull Request resolved: #916 Reviewed By: mingyuzh Differential Revision: D41776719 Pulled By: ejguan fbshipit-source-id: 6042da09f5e83019d536696237028ea20e67d110
Differential Revision: [D41006747](https://our.internmc.facebook.com/intern/diff/D41006747) Pull Request resolved: pytorch#88424 Approved by: https://github.com/ejguan
| raise RuntimeError('This implementation of sharding can be only applied once per instance of DataPipeline.', | ||
| 'Already applied to', already_applied_to, 'while trying to apply to', pipe) | ||
| pipe.apply_sharding(num_of_instances, instance_id) | ||
| pipe.apply_sharding(num_of_instances, instance_id, sharding_group=sharding_group) |
There was a problem hiding this comment.
noob question: does is_shardable and apply_sharding only exists in ShardingFilterIterDataPipe ?
Also if there is no ShardingFilterIterDataPipe, look like no sharding will happen? (shall we error in that case? ) :)
cc @ejguan
Stack from ghstack (oldest at bottom):
Differential Revision: D41006747