Skip to content

Add worker_init_fn and worker_reset_fn to ProtoMPRS#907

Closed
ejguan wants to merge 4 commits intometa-pytorch:mainfrom
ejguan:reset_fn
Closed

Add worker_init_fn and worker_reset_fn to ProtoMPRS#907
ejguan wants to merge 4 commits intometa-pytorch:mainfrom
ejguan:reset_fn

Conversation

@ejguan
Copy link
Contributor

@ejguan ejguan commented Nov 23, 2022

Changes

  • Add worker_init_fn and worker_reset_fn to PrototypeMultiprocessingReadingService. This should help SequentialRS to extend worker_reset_fn per epoch (still need to hack by assigning worker_reset_fn manually)
  • Move _process_init_fn and _process_reset_fn and move random number/tensor generation to torchdata.dataloader2.utils
  • Instead of using explicit arguments for init_fn or reset_fn, add DistInfo and WorkerInfo to pass relevant information to worker. This should help extensibility to add extra information to those Info class

@facebook-github-bot facebook-github-bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label Nov 23, 2022
@facebook-github-bot
Copy link
Contributor

@ejguan has imported this pull request. If you are a Meta employee, you can view this diff on Phabricator.

1 similar comment
@facebook-github-bot
Copy link
Contributor

@ejguan has imported this pull request. If you are a Meta employee, you can view this diff on Phabricator.

Copy link
Contributor

@wenleix wenleix left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks!

) -> DataPipe:
global_worker_id = worker_info.worker_id * dist_info.world_size + dist_info.rank
total_num_workers = worker_info.num_workers * dist_info.world_size
torch.utils.data.graph_settings.apply_sharding(datapipe, total_num_workers, global_worker_id)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

curious what does this "apply_sharding" do? ~

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

apply_sharding will set sharding_filter DataPipe to yield every n-th element of the original DataPipe. n is calcuated based on instance id across distributed workers.

protocol_type = communication.protocol.MapDataPipeQueueProtocolServer # type: ignore[assignment]
else:
raise Exception("Only supports IterDataPipe or MapDataPipe, got", source_datapipe)
if call_on_process_init is not None:
Copy link
Contributor

@wenleix wenleix Nov 24, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

curious: why make call_on_process_init call later? ~

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I moved it to make sure the Error is raised earlier before calling call_on_process_init.

Otherwise, method should be 'fork', 'spawn'.
prefetch_worker: (int, 10 by default): Number of data will be prefetched at
the end of each worker process.
prefetch_mainloop: (int, 10 by default): Number of data will be prefetched
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

curious: what's the relationship between prefetch_mainloop and prefetch_worker? is it expected that prefetch_mainloop == prefetch_worker * num_workers?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

prefetch_mainloop is another prefetch from main process to prefetch data from mp.Queue. So, overall the pipeline is prefetching total prefetch_mainloop + prefetch_worker * num_workers.

if np_seed < 0:
np_seed = 2 ** 32 + np_seed
numpy.random.seed(np_seed)
self._dist_info = DistInfo(1, 0)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

do we plan to support world_size > 1 in the future? (understand it's a prototype now~)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's supporting world_size > 1. And, in some cases like lightning, they initialize distributed environment after DataLoader is constructed. So, I have to make sure distributed call moved to __iter__.

self,
num_workers: int = 0,
multiprocessing_context=None,
multiprocessing_context: Optional[str] = None,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is this multiprocessing_context expected to be one of the 'fork', 'spawn' and 'forkserver' ?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not always. For windows, only spawn can be supported. That's the reason I checked the validity of mp_context using mp.get_all_start_methods.
See: https://docs.python.org/3/library/multiprocessing.html#multiprocessing.get_all_start_methods

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Right, i mean it cannot be out of these three options~ (although not all of them are supported~)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You are right

worker_seed_generator,
)
# Set different seeds across distributed workers
global_worker_id = worker_info.worker_id * dist_info.world_size + dist_info.rank
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If I understand correctly, Say we have 2 distributed training process (world_size=2) and 3 data loading subprocess per training process (num_workers=3). the global_worker_id assignment is like following:

rank        0   1
------------------
worker0     0   1     
worker1     2   3
worker2     4   5

Essentially the global worker id with rank 0 are (0, 2, 4) while the global worker id with rank 1 are (1, 3, 5). Wondering why not make global worker id with the same rank continuous? ~

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Essentially the global worker id with rank 0 are (0, 2, 4) while the global worker id with rank 1 are (1, 3, 5). Wondering why not make global worker id with the same rank continuous? ~

Aha, you found a tricky problem we found earlier. We can't make the same rank continuous because we want to make sure data are sharded as evenly as possible. Using the same example with world_size=2, num_workers=3, if the total number of data is 8 and we do continuous global worker id per rank, the length of data in rank 0 will be 5 but rank 1 will be 3. And, if I make the global_worker_id as the current implementation, both ranks will be able to fetch 4 data

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@ejguan I see. Because the sharding mechanism is something like tuple_id % total_worker == global_worker_id? So when it's not divisible, lower "global_worker_id" will get the "extra one row". Result in more rows in low rank worker?~

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am not 100% sure that I understand.
We will try to fill in low workers for all nodes from low rank to high rank.

Comment on lines +243 to +245
_world_size = dist.get_world_size()
_rank = dist.get_rank()
self._dist_info = DistInfo(_world_size, _rank)
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

cc: @wenleix Here is the place fetching distributed information lazily.

@ejguan ejguan requested a review from NivekT November 28, 2022 15:23
shared_seed: int


def process_init_fn(
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We have to keep process_init_fn separately because apply_sharding can be called once per sharding-level cc: @NivekT

@facebook-github-bot
Copy link
Contributor

@ejguan has imported this pull request. If you are a Meta employee, you can view this diff on Phabricator.

Comment on lines +213 to +214
worker_prefetch_cnt: int = 10,
main_prefetch_cnt: int = 10,
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let me know if you have different opinion on those names

Copy link
Contributor

@NivekT NivekT left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I was wondering if we will need something similar for DistributedRS but it seems like it may not be necessary. And we can do that when needed.

@wenleix
Copy link
Contributor

wenleix commented Dec 1, 2022

@NivekT

I was wondering if we will need something similar for DistributedRS but it seems like it may not be necessary. And we can do that when needed.

Agree. Looks like if we just want to have "co-located data reading and training" (a.k.a "onbox"), this multi-process reading service should be good enough~

By DistributedRS, do you mean "disaggregated data reading and training" (or at least hybrid). One scenario I can think this is useful is when we cannot saturate trainer with just local CPUs.

@ejguan
Copy link
Contributor Author

ejguan commented Dec 1, 2022

By DistributedRS, do you mean "disaggregated data reading and training" (or at least hybrid). One scenario I can think this is useful is when we cannot saturate trainer with just local CPUs.

@wenleix
DistributedRS mean this one. It's just on-trainer distributed not disaggregated. Disaggregated requires more integration with customers' backend system that we have zero control (I am not sure if we can work with TorchX to enable this by TorchData directly)
We currently move all distributed code into ProtoRS as well to support users' request (vision). In the long term, we will rely on SequentialRS to combine both ProtoMPRS and DistributedRS together for users to choose either to enable MP + Dist or just MP or just Dist.

@ejguan ejguan added the topic: new feature topic category label Dec 13, 2022
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. topic: new feature topic category

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants