Skip to content

[DataPipe] Snapshotting with simple fast-forwarding#80250

Closed
NivekT wants to merge 2 commits intogh/nivekt/54/basefrom
gh/nivekt/54/head
Closed

[DataPipe] Snapshotting with simple fast-forwarding#80250
NivekT wants to merge 2 commits intogh/nivekt/54/basefrom
gh/nivekt/54/head

Conversation

@NivekT
Copy link
Contributor

@NivekT NivekT commented Jun 24, 2022

Stack from ghstack:

This mostly completes the poor man's snapshotting implementation (named simple fast forward). This is the most basic version of snapshotting but it should work for all DataPipes. I will be adding more efficient implementation for different types of DataPipes in future PRs.

As of this implementation, the usage will something like:

rng = torch.Generator()
initial_rng_state = rng.get_state()
datapipe: IterDataPipe = ...
# Some usage of the DataPipe, here maybe yielding the first 5 values
n_iter = 5
it = iter(datapipe)
for _ in range(n_iter):
    next(it)serialized_graph = pickle.dumps(datapipe)

# The serialized object has most of the sufficient information for simple fast-forward (except for initial RNG state)
# It can be deserialized at a later point in time or by a different process
deserialized_graph = pickle.loads(serialized_graph)
# I think `DataLoader` should store `initial_rng_state` that can be saved by the API that we later use
rng_for_deserialized = torch.Generator()
rng_for_deserialized.set_state(initial_rng_state)
n_fastforward = deserialized_graph._number_of_samples_yielded
simple_fast_forward_graph(deserialized_graph, n_fastforward, rng=rng_for_deserialized)
# The while DataPipe graph should have the same state as before serialization, such that:
self.assertEqual(list(it), list(deserialized_graph))  # True

If this looks acceptable, I can modify DataLoader2 to remember things like initial_rng_state and to have methods save_snapshot that will return the (serialized graph, initial_rng) and restore_snapshot. This should work for single worker data loading.

In the long term, initial_rng_state may not be necessary if we are able to directly save/restore the buffer and RNG state of Shuffler (that is work in progress). However, initial_rng_state and simple fast-forward is still a good fall-back option for some edge cases where the buffer can't be stored.

@facebook-github-bot
Copy link
Contributor

facebook-github-bot commented Jun 24, 2022

🔗 Helpful links

✅ No Failures (0 Pending)

As of commit 7cc63df (more details on the Dr. CI page):

Expand to see more

💚 💚 Looks good so far! There are no failures yet. 💚 💚


This comment was automatically generated by Dr. CI (expand for details).

Please report bugs/suggestions to the (internal) Dr. CI Users group.

Click here to manually regenerate this comment.

@NivekT NivekT added module: data torch.utils.data release notes: dataloader release notes category topic: new features topic category labels Jun 24, 2022
@NivekT NivekT requested review from VitalyFedyunin and ejguan June 24, 2022 21:57
@NivekT
Copy link
Contributor Author

NivekT commented Jun 27, 2022

Squashed with the other PR.

@NivekT NivekT closed this Jun 27, 2022
pytorchmergebot pushed a commit that referenced this pull request Jul 22, 2022
ghstack-source-id: 6d3120b
Pull Request resolved: #79479

[DataPipe] Snapshotting with simple fast-forwarding

ghstack-source-id: 6d3120b
Pull Request resolved: #80250
@facebook-github-bot facebook-github-bot deleted the gh/nivekt/54/head branch July 28, 2022 14:18
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

cla signed module: data torch.utils.data release notes: dataloader release notes category topic: new features topic category

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants