Add worker_init_fn and worker_reset_fn to ProtoMPRS#907
Add worker_init_fn and worker_reset_fn to ProtoMPRS#907ejguan wants to merge 4 commits intometa-pytorch:mainfrom
Conversation
|
@ejguan has imported this pull request. If you are a Meta employee, you can view this diff on Phabricator. |
1 similar comment
|
@ejguan has imported this pull request. If you are a Meta employee, you can view this diff on Phabricator. |
| ) -> DataPipe: | ||
| global_worker_id = worker_info.worker_id * dist_info.world_size + dist_info.rank | ||
| total_num_workers = worker_info.num_workers * dist_info.world_size | ||
| torch.utils.data.graph_settings.apply_sharding(datapipe, total_num_workers, global_worker_id) |
There was a problem hiding this comment.
curious what does this "apply_sharding" do? ~
There was a problem hiding this comment.
apply_sharding will set sharding_filter DataPipe to yield every n-th element of the original DataPipe. n is calcuated based on instance id across distributed workers.
| protocol_type = communication.protocol.MapDataPipeQueueProtocolServer # type: ignore[assignment] | ||
| else: | ||
| raise Exception("Only supports IterDataPipe or MapDataPipe, got", source_datapipe) | ||
| if call_on_process_init is not None: |
There was a problem hiding this comment.
curious: why make call_on_process_init call later? ~
There was a problem hiding this comment.
I moved it to make sure the Error is raised earlier before calling call_on_process_init.
| Otherwise, method should be 'fork', 'spawn'. | ||
| prefetch_worker: (int, 10 by default): Number of data will be prefetched at | ||
| the end of each worker process. | ||
| prefetch_mainloop: (int, 10 by default): Number of data will be prefetched |
There was a problem hiding this comment.
curious: what's the relationship between prefetch_mainloop and prefetch_worker? is it expected that prefetch_mainloop == prefetch_worker * num_workers?
There was a problem hiding this comment.
prefetch_mainloop is another prefetch from main process to prefetch data from mp.Queue. So, overall the pipeline is prefetching total prefetch_mainloop + prefetch_worker * num_workers.
| if np_seed < 0: | ||
| np_seed = 2 ** 32 + np_seed | ||
| numpy.random.seed(np_seed) | ||
| self._dist_info = DistInfo(1, 0) |
There was a problem hiding this comment.
do we plan to support world_size > 1 in the future? (understand it's a prototype now~)
There was a problem hiding this comment.
It's supporting world_size > 1. And, in some cases like lightning, they initialize distributed environment after DataLoader is constructed. So, I have to make sure distributed call moved to __iter__.
| self, | ||
| num_workers: int = 0, | ||
| multiprocessing_context=None, | ||
| multiprocessing_context: Optional[str] = None, |
There was a problem hiding this comment.
is this multiprocessing_context expected to be one of the 'fork', 'spawn' and 'forkserver' ?
There was a problem hiding this comment.
Not always. For windows, only spawn can be supported. That's the reason I checked the validity of mp_context using mp.get_all_start_methods.
See: https://docs.python.org/3/library/multiprocessing.html#multiprocessing.get_all_start_methods
There was a problem hiding this comment.
Right, i mean it cannot be out of these three options~ (although not all of them are supported~)
| worker_seed_generator, | ||
| ) | ||
| # Set different seeds across distributed workers | ||
| global_worker_id = worker_info.worker_id * dist_info.world_size + dist_info.rank |
There was a problem hiding this comment.
If I understand correctly, Say we have 2 distributed training process (world_size=2) and 3 data loading subprocess per training process (num_workers=3). the global_worker_id assignment is like following:
rank 0 1
------------------
worker0 0 1
worker1 2 3
worker2 4 5
Essentially the global worker id with rank 0 are (0, 2, 4) while the global worker id with rank 1 are (1, 3, 5). Wondering why not make global worker id with the same rank continuous? ~
There was a problem hiding this comment.
Essentially the global worker id with rank 0 are (0, 2, 4) while the global worker id with rank 1 are (1, 3, 5). Wondering why not make global worker id with the same rank continuous? ~
Aha, you found a tricky problem we found earlier. We can't make the same rank continuous because we want to make sure data are sharded as evenly as possible. Using the same example with world_size=2, num_workers=3, if the total number of data is 8 and we do continuous global worker id per rank, the length of data in rank 0 will be 5 but rank 1 will be 3. And, if I make the global_worker_id as the current implementation, both ranks will be able to fetch 4 data
There was a problem hiding this comment.
@ejguan I see. Because the sharding mechanism is something like tuple_id % total_worker == global_worker_id? So when it's not divisible, lower "global_worker_id" will get the "extra one row". Result in more rows in low rank worker?~
There was a problem hiding this comment.
I am not 100% sure that I understand.
We will try to fill in low workers for all nodes from low rank to high rank.
| _world_size = dist.get_world_size() | ||
| _rank = dist.get_rank() | ||
| self._dist_info = DistInfo(_world_size, _rank) |
There was a problem hiding this comment.
cc: @wenleix Here is the place fetching distributed information lazily.
| shared_seed: int | ||
|
|
||
|
|
||
| def process_init_fn( |
There was a problem hiding this comment.
We have to keep process_init_fn separately because apply_sharding can be called once per sharding-level cc: @NivekT
|
@ejguan has imported this pull request. If you are a Meta employee, you can view this diff on Phabricator. |
| worker_prefetch_cnt: int = 10, | ||
| main_prefetch_cnt: int = 10, |
There was a problem hiding this comment.
Let me know if you have different opinion on those names
NivekT
left a comment
There was a problem hiding this comment.
I was wondering if we will need something similar for DistributedRS but it seems like it may not be necessary. And we can do that when needed.
Agree. Looks like if we just want to have "co-located data reading and training" (a.k.a "onbox"), this multi-process reading service should be good enough~ By |
@wenleix |
Changes
worker_init_fnandworker_reset_fntoPrototypeMultiprocessingReadingService. This should helpSequentialRSto extendworker_reset_fnper epoch (still need to hack by assigningworker_reset_fnmanually)_process_init_fnand_process_reset_fnand move random number/tensor generation totorchdata.dataloader2.utilsinit_fnorreset_fn, addDistInfoandWorkerInfoto pass relevant information to worker. This should help extensibility to add extra information to thoseInfoclass