[ReadingService] Add round robin sharding to support non-replicable DataPipe for Multiprocessing#919
[ReadingService] Add round robin sharding to support non-replicable DataPipe for Multiprocessing#919ejguan wants to merge 18 commits intometa-pytorch:mainfrom
Conversation
| if not isinstance(response, communication.messages.ResetIteratorResponse): | ||
| raise Exception("Invalid response received") | ||
|
|
||
| def get_response_reset_epoch(self, block=False): |
There was a problem hiding this comment.
Not sure why we didn't have this response before
There was a problem hiding this comment.
Non-blocking but - hmmmm.... did this cause any bug or unhandled messages? Do you happen to know why?
I know you were looking into unusually messages and responses
There was a problem hiding this comment.
When new iteration starts, reset will be called for _IterateQueueDataPipes . And, an extra get_respondse_next is invoked to drop / response. So, all requests are served.
https://github.com/pytorch/data/blob/0a0ae5d35be439c786c9ab798ea1596238d8640a/torchdata/dataloader2/reading_service.py#L152-L154
I can give a try to remove this part.
There was a problem hiding this comment.
It seems removing is fine
There was a problem hiding this comment.
noob question: any docs / entry pointer to understand how these dataloader2/communication design / work? ~
There was a problem hiding this comment.
Unfortunately no doc. I can talk about the components based on my understanding:
- ProtocolClient remains in the main process to pass
Requestvia therequest_queueto the corresponding worker process - ProtocolServer is created in the worker process that takes request then send
Responseback to main process viareqponse_queue - DataPipeBehindQueues is the worker loop that holds a
ProtocolServerto maniputlateDataPipebased on theReqeust - QueueWrapper is the
DataPipethat holds aProtocolClientinstance to issueRequestand yield data fromresponse_queueto the subsequentDataPipegraph.
We can talk about more detail offline if you want
There was a problem hiding this comment.
Let me try to draw it with mermaid
graph TD;
Worker_1_1-->ProtocolServer_1_1-->ProtocolClient_1
Worker_1_2-->ProtocolServer_1_2-->ProtocolClient_1
Worker_1_3-->ProtocolServer_1_3-->ProtocolClient_1
ProtocolClient_1-->GPU1
ProtocolClient_2-->GPU2
ProtocolClient_3-->GPU3
|
@ejguan has imported this pull request. If you are a Meta employee, you can view this diff on Phabricator. |
|
@ejguan has imported this pull request. If you are a Meta employee, you can view this diff on Phabricator. |
1 similar comment
|
@ejguan has imported this pull request. If you are a Meta employee, you can view this diff on Phabricator. |
NivekT
left a comment
There was a problem hiding this comment.
I am going to need more time to look through this. What is the use case of non-shardable DataPipe?
NivekT
left a comment
There was a problem hiding this comment.
I think we should provide a definition of non-shardable DataPipe and why that may occur (an abstract example would be helpful as well). In particularly, as you mentioned, note that we do not want to duplicate such DataPipe into multiple workers and sharding filter should not be applied to it. Instead it should be read round-robin (or something else) by downstream DataPipes that are sharded?
Make sense. I will add doc regarding non-shardbable DataPipe/shardable DataPipe to the documents for dataloader2.
Actually, this is not True. Edit: Updated the summary with a few topics need to be covered. Let me know if there is any other concern on the documentation. |
| if not isinstance(response, communication.messages.ResetIteratorResponse): | ||
| raise Exception("Invalid response received") | ||
|
|
||
| def get_response_reset_epoch(self, block=False): |
There was a problem hiding this comment.
Non-blocking but - hmmmm.... did this cause any bug or unhandled messages? Do you happen to know why?
I know you were looking into unusually messages and responses
| """ | ||
| # Reset non-sharding process first | ||
| graph = traverse_dps(datapipe) | ||
| non_sharding_process_dps = find_dps(graph, communication.iter._IterateQueueDataPipes) |
There was a problem hiding this comment.
Question: Is it always the case where _IterateQueueDataPipes is the only non-sharding process?
There was a problem hiding this comment.
Yes because we will find the lowest common ancestor of all non-shardable DataPipes in the main process and replace it by this _IterateQueueDataPipes in the worker process.
So, it's guaranteed that there is only a single non-sharding process.
|
@NivekT
I propose to change to find lowest common ancestor of non-replicable Data Source and sent them to non-sharding process. graph TD;
DP1(non-replicable DP1)-->DP2;
DP2-->DP5;
DP3(non-replicable DP3)-->DP4;
DP4-->DP5;
DP5-->DP6;
DP6-->fullsync;
fullsync-->output;
The lowest common ancestor of all non-shardable Data Source is |
|
That makes sense. How do you plan to implement that? Will we still have users calling |
We should allow either users calling |
|
Replying to #919 :
An alternative approach for distributed sharding would be distribute the workload based on filename or some compression/encoding unit in file (in Parquet it's called "Page": https://parquet.apache.org/docs/concepts/) and in ORC I think it's called "Stripe". So it avoid reading the original data multiple times? |
| if not isinstance(response, communication.messages.ResetIteratorResponse): | ||
| raise Exception("Invalid response received") | ||
|
|
||
| def get_response_reset_epoch(self, block=False): |
There was a problem hiding this comment.
noob question: any docs / entry pointer to understand how these dataloader2/communication design / work? ~
|
Non-shardable is an extremely bad name as it's actually being sharded by round-robin dispatching. The actual meaning here is to prevent copy of DataPipe to multiple processes. I might rename it to non-replicable DataPipe/dispatching process. |
I agree with renaming "non-shardable" to "non-replicated" DataPipe. I suppose sometimes it is replicable but the users won't want to? |
|
@wenleix @NivekT This PR has been updated. And updated document can be found in https://ejguan.github.io/dataloader2.html#dynamic-sharding |
| # Lazily import to prevent circular import | ||
| from torchdata.dataloader2 import communication | ||
|
|
There was a problem hiding this comment.
This is my temporary fix for the circular import problem.
cc: @NivekT
ejguan
left a comment
There was a problem hiding this comment.
The following review steps are added to make it easier to review the main logic in the PR. Let me know if there is anything fuzzy to you.
| @functional_datapipe("sharding_round_robin_dispatch") | ||
| class ShardingRoundRobinDispatcherIterDataPipe(IterDataPipe): |
There was a problem hiding this comment.
Review Step 1: Add ShardingRoundRobinDispatcher is introduced to indicate where the pipeline should be non-replicable.
I am open to any suggestion on the name/functional name
| def __iter__(self) -> Iterator[T_co]: | ||
| yield from self.source_datapipe |
There was a problem hiding this comment.
Review Step 1.1: Keep __iter__ as a noop here rather than raising Error to support single-process use case.
There was a problem hiding this comment.
by "single-process use case". does it mean "eager mode"?
There was a problem hiding this comment.
Yeah, pure eager currently. In the future, we might provide a by-default SingleProcessReadingService for users.
| res_queue: Queue | ||
|
|
||
|
|
||
| def find_lca_non_replicable_dp(graph: DataPipeGraph) -> Optional[DataPipe]: |
There was a problem hiding this comment.
Review Step 2: Add this graph function to find the lowest common ancestor of the non-replicable DataPipes (ShardingRoundRobinDispatcher)
| graph = traverse_dps(end_dp) | ||
| return single_br_dp, multi_br_dp, ch1, ch2, fork_zip_dp, cir_br_dp, cir_map_dp, end_dp, graph | ||
|
|
||
| def test_single_non_replicable_dp(self): |
There was a problem hiding this comment.
Review Step 3.1: Tests for single non-replicable DataPipe
| graph, cir_map_dp = make_dp_non_replicable(graph, cir_map_dp) | ||
| self.assertEqual(find_lca_non_replicable_dp(graph), cir_map_dp) | ||
|
|
||
| def test_multi_non_replicable_dps(self): |
There was a problem hiding this comment.
Review Step 3.2: Tests for multiple non-replicable DataPipes
| return process, req_queue, res_queue, new_datapipe | ||
|
|
||
|
|
||
| def CreateProcessForMultipleDataPipelines(multiprocessing_ctx, datapipes): |
There was a problem hiding this comment.
Review Step 7.1: Create num_workers pairs of req_queue and res_queue.
And launch MultipleDataPipesToQueuesLoop to iterate over the non-replicable DataPipe
| ] | ||
|
|
||
|
|
||
| def MultipleDataPipesToQueuesLoop(source_datapipes, req_queues, res_queues, call_on_process_init=None): |
There was a problem hiding this comment.
Review 7.2: Launch a non-blocking DataPipeBehindQueues while-loop per child DataPipe from round_robin_demux.
Using zip_longest to mimic round robin calling next over each child DataPipe.
| # Dispatching process for non-replicable DataPipes exists | ||
| if self._dispatch_process is not None: | ||
| # Use the placehold to pass request/response queue to each worker process | ||
| dummy_dp.req_queue = self._dispatch_process[1][worker_id] | ||
| dummy_dp.res_queue = self._dispatch_process[2][worker_id] |
There was a problem hiding this comment.
Review 6.2: We only have one _DummyIterDataPipe in the main process but have num_workers pairs of req_queue and res_queue. To connect a pair to the corresponding worker process, inject the attributes from _DummyIterDataPipe before sending it to the worker process.
| # Find if there is non-replicable DataPipe | ||
| graph = traverse_dps(datapipe) | ||
| non_replicable_dp = find_dps(graph, _DummyIterDataPipe) # type: ignore | ||
|
|
||
| # There are two cases for DataPipe graph in terms of mp sharding: | ||
| # 1) All DataPipes are replicable, apply mp sharding to the whole graph | ||
| if len(non_replicable_dp) == 0: | ||
| torch.utils.data.graph_settings.apply_sharding( | ||
| datapipe, worker_info.num_workers, worker_info.worker_id, SHARDING_PRIORITIES.MULTIPROCESSING | ||
| ) | ||
| # 2) There is non-replicable DataPipe. Since we have replaced the lowest common | ||
| # ancestor by a `_DummyIterDataPipe`, we would only apply mp sharding | ||
| # to replicable branches that don't have `_DummyIterDataPipe`. | ||
| else: | ||
| assert len(non_replicable_dp) == 1 | ||
| replicable_branches = find_replicable_branches(graph) | ||
| for dp in replicable_branches: | ||
| torch.utils.data.graph_settings.apply_sharding( | ||
| dp, worker_info.num_workers, worker_info.worker_id, SHARDING_PRIORITIES.MULTIPROCESSING | ||
| ) | ||
|
|
||
| req_queue = non_replicable_dp[0].req_queue | ||
| res_queue = non_replicable_dp[0].res_queue | ||
|
|
||
| queue_wrapper = communication.iter.QueueWrapper( | ||
| communication.protocol.IterDataPipeQueueProtocolClient(req_queue, res_queue) | ||
| ) | ||
| dispatch_process_dp = communication.iter._IterateQueueDataPipes([queue_wrapper]) | ||
| graph = replace_dp(graph, non_replicable_dp[0], dispatch_process_dp) | ||
| datapipe = list(graph.values())[0][0] |
There was a problem hiding this comment.
Review Step 8: In the worker process, find if there is _DummyIterDataPipe.
If not, it means the whole pipeline is replicable and do the sharding by filter
If there is, we would do sharding only over the replicable branches.
QueueWrapper and _IterateQueueDataPipes is used to wrap res_queue and req_queue as a DataPipe that can handle Request and Response based on the protocol.
| def dispatch_process_reset_fn( | ||
| datapipe: DataPipe, | ||
| worker_info: WorkerInfo, | ||
| dist_info: _DistInfo, | ||
| ) -> DataPipe: | ||
| r""" | ||
| Based on the distributed shared random seed, this function is used to set the random state | ||
| of the non-repliable ``DataPipe`` graph and the global random states for the dispatch process. | ||
| This function would guarantee that all distributed non-sharding process share the | ||
| same random states to ensure the same shuffle order. | ||
| """ | ||
| worker_seed_generator = torch.Generator() | ||
| worker_seed_generator.manual_seed(dist_info.shared_seed) | ||
| torch.utils.data.graph_settings.apply_random_seed( | ||
| datapipe, | ||
| worker_seed_generator, | ||
| ) | ||
|
|
||
| # Set global random states | ||
| _set_global_random_state(worker_seed_generator) | ||
|
|
||
| return datapipe |
There was a problem hiding this comment.
Review Step 9: When new epoch starts, we want to control the random seed based on distributed information.
We need to guarantee all distributed dispatching process share the same random seed.
wenleix
left a comment
There was a problem hiding this comment.
"Review Step 1: Add ShardingRoundRobinDispatcher is introduced to indicate where the pipeline should be non-replicable."
LGTM. Name (sharding_round_robin_dispatch) is a bit long but let's keep it for now...
…ed non-sharding process
| res_queues | ||
| ), "``MultipleDataPipesToQueuesLoop`` requires the same number of datapipes, request queues and response queues" | ||
|
|
||
| torch.set_num_threads(1) |
There was a problem hiding this comment.
noob question: what's this for?
There was a problem hiding this comment.
IIRC, this was introduced to disable OpenMP in dataloader workers. This is because OpenMP would create number of threads that equals to the number of CPU cores by default. And, with multiprocessing enabled, num_workers x num_threads_per_worker threads will be created. This won't provide any further benefit.
Besides, OpenMP features should not be enabled if any OpenMP features are utilized in the main process and before subprocesses are forked.
Any suggestion?
wenleix
left a comment
There was a problem hiding this comment.
"Review Step 8: Worker process handling (graph rewrite and receiving demux result from dispatch process)"
LGTM.
| queue_wrapper = communication.iter.QueueWrapper( | ||
| communication.protocol.IterDataPipeQueueProtocolClient(req_queue, res_queue) | ||
| ) | ||
| dispatch_process_dp = communication.iter._IterateQueueDataPipes([queue_wrapper]) |
There was a problem hiding this comment.
IIUC:
IterDataPipeQueueProtocolClient will be wrapped into a QueueWrapper (but still not a IterDataPipe), and further wrapped into a _IterateQueueDataPipes which is a IterDataPIpe?
There was a problem hiding this comment.
Nope. Both QueueWrapper and _IterateQueueDataPipes are IterDataPipe, this is one of the thing that we can optimize later.
wenleix
left a comment
There was a problem hiding this comment.
Review Step 9: LGTM % minor question...
|
@ejguan has imported this pull request. If you are a Meta employee, you can view this diff on Phabricator. |
| yield response.value | ||
|
|
||
| def reset(self): | ||
| # NonBlocking DataPipes do not reset automatically, have to do it manually |
There was a problem hiding this comment.
@ejguan Just noticed that the reset method has changed after the move. It used to have this:
# Collect all existing requests results to clear queues
for dp in self.datapipes:
if dp.protocol.waiting_for_response():
dp.protocol.get_response_next(block=True)Is this no longer necessary?
There was a problem hiding this comment.
I don't think it's necessary because we always want to do reset_epoch as well for NonBlocking, which will discard all existing requests. So, when at the point of reset, we should expect no request within the worker process queues.

This PR is created on top of #555. And, this PR extends
PrototypeMultiprocessingReadingServiceto accept non-replicable DataPipe.And, this PR depends on pytorch/pytorch#90769
Main Changes
ShardingRoundRobinDispatcher(functional namesharding_round_robin_dispatch) to indicate non-replicable DataPipeMultipleDataPipesToQueuesLoopto connect non-sharding process to request/response queuesfind_lca_non_replicable_dpas a graph function to determine the lowest common ancestor of all non-replicabble DataPipes. This would guarantee that all non-replicable DataPipes will be running in a single dispatching processfind_replicable_branchesto apply mp sharding to those replicable branches, because all non-replicable branches have been properly sharded by routing data round-robinly to worker processes.ResetEpochResponsefrom protocol viaget_response_reset_epochPlease check the link for doc: https://ejguan.github.io/dataloader2.html#dynamic-sharding
nit Changes
SpawntoCreateas the process has not been started