Skip to content

[DataPipe] Simple graph snapshotting#79479

Closed
NivekT wants to merge 21 commits intogh/nivekt/49/basefrom
gh/nivekt/49/head
Closed

[DataPipe] Simple graph snapshotting#79479
NivekT wants to merge 21 commits intogh/nivekt/49/basefrom
gh/nivekt/49/head

Conversation

@NivekT
Copy link
Contributor

@NivekT NivekT commented Jun 13, 2022

Stack from ghstack (oldest at bottom):

This mostly completes the "poor man's snapshotting" implementation (named "simple snapshotting"). This is the most basic version of snapshotting but it should work for all DataPipes. I will be adding more efficient implementation for different types of DataPipes in future PRs.

Implementation

The general idea of the simple snapshot is that we will:

  1. Create a new iterator
  2. Move that iterator forward by n_iterations
  3. Save that as the _fast_forward_iterator of the DataPipe
  4. The next time iter is called on the DataPipe, use the _fast_forward_iterator

Usage

As of this implementation, the usage will something like:

rng = torch.Generator()
initial_rng_state = rng.get_state()
datapipe: IterDataPipe = ...
# Some usage of the DataPipe, here maybe yielding the first 5 values
n_iter = 5
it = iter(datapipe)
for _ in range(n_iter):
    next(it)
serialized_graph = pickle.dumps(datapipe)

# The serialized object has most of the sufficient information for simple snapshot (except for initial RNG state)
# It can be deserialized at a later point in time or by a different process
deserialized_graph = pickle.loads(serialized_graph)
# I think `DataLoader2` or `ReadingService` should store `initial_rng_state` that can be saved by the API that we later use
rng_for_deserialized = torch.Generator()
rng_for_deserialized.set_state(initial_rng_state)
n_iterations = deserialized_graph._number_of_samples_yielded

_simple_snapshot_graph(deserialized_graph, n_iterations, rng=rng_for_deserialized)
# The while DataPipe graph should have the same state as before serialization, such that:
self.assertEqual(list(it), list(deserialized_graph))  # True

Next Steps

If this looks acceptable, the next step is I will modify DataLoader2's prototype ReadingService (the one with queues) to remember things like initial_rng_state and to have methods save_snapshot that will return the (serialized graph, initial_rng) and restore_snapshot. This should work for single worker data loading.

Note that, in the long term, initial_rng_state may not be necessary if we are able to directly save/restore the buffer and RNG state of Shuffler (that is work in progress). However, initial_rng_state and simple snapshot is still a good fall-back option for some edge cases where the buffer can't be stored.

Differential Revision: D37943406

@facebook-github-bot
Copy link
Contributor

facebook-github-bot commented Jun 13, 2022

🔗 Helpful links

✅ No Failures (0 Pending)

As of commit 7604dfe (more details on the Dr. CI page):

Expand to see more

💚 💚 Looks good so far! There are no failures yet. 💚 💚


This comment was automatically generated by Dr. CI (expand for details).

Please report bugs/suggestions to the (internal) Dr. CI Users group.

Click here to manually regenerate this comment.

@NivekT NivekT marked this pull request as draft June 13, 2022 23:58
@NivekT NivekT added module: data torch.utils.data release notes: dataloader release notes category topic: new features topic category labels Jun 13, 2022
@NivekT NivekT changed the title [DataPipe] Fast-forwarding feature [DataPipe] Simple graph fast-forwarding Jun 16, 2022
@NivekT NivekT requested review from VitalyFedyunin and ejguan June 24, 2022 21:57
@NivekT NivekT marked this pull request as ready for review June 24, 2022 21:57
Copy link
Contributor

@ejguan ejguan left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is the question I asked before. Do we need to make DataPipe itself supporting fast forward.
If we directly use DataLoader2 to save/restore the state, we can simply fast forward the iterator of DataPipe graph directly in DataLoader2.

@NivekT
Copy link
Contributor Author

NivekT commented Jun 27, 2022

This is the question I asked before. Do we need to make DataPipe itself supporting fast forward. If we directly use DataLoader2 to save/restore the state, we can simply fast forward the iterator of DataPipe graph directly in DataLoader2.

I discuss part of the design here.

  1. For the most basic implementation like the one here, DataLoader2 can directly call the fast forward function in this PR and and get the iterator to the desired state. However, this requires re-doing all the computations.
  2. For more efficient implementations (in a future PR), DataLoader2 will still call a fast forward API, but individual DataPipes will need custom fast-forward methods that allow efficient restoration state.
    • Notably, 1) stateless DataPipes do not need to do anything, 2) stateful ones may be able to restore from a serialized buffer, 3) source DataPipes (e.g. IterableWrapper) may still do a simple fast forward

@ejguan
Copy link
Contributor

ejguan commented Jun 27, 2022

I hesitate to expose fast-forwarding as the API at DataPipe for two reasons:

  • It makes us to maintain two different APIs on DataPipe (one is for snapshotting, the another one is for fast forwarding). Besides, based on my understand, the last DataPipe of the graph is used to determine the number of iteration to be fast-forwarded. It's kind the same approach to use DataLoader2 to determine the number of iterations. I don't see the benefit to do so as we will remove fast forwarding ultimately.
  • It's the DataLoader2 to control the determinism aka creating RNG per process and save initial_random_state. Given your current design, it requires an RNG saved by DataLoader2. But, exposing fast-forwarding at DataPipe would allow users to do fast-fowrading directly.

@NivekT
Copy link
Contributor Author

NivekT commented Jun 27, 2022

I hesitate to expose fast-forwarding as the API at DataPipe for two reasons:

That is fine with me. We can definitely make this private and only have DataLoader2 call it.

maintain two different APIs on DataPipe (one is for snapshotting, the another one is for fast forwarding)

They serve the same purpose. I can rename fast-forward as "poor man's snapshotting" if that is better.

I also agree that the API should be controlled by DataLoader2 (especially to get the randomness related states), unless there is an user demand for this API to be called outside of it.

This mostly completes the poor man's snapshotting implementation (named simple snapshot). This is the most basic version of snapshotting but it should work for all DataPipes. I will be adding more efficient implementation for different types of DataPipes in future PRs.

### Implementation

The general idea of the simple snapshot is that we will:
1. Create a new iterator
2. Move that iterator forward by `n_iterations`
3. Save that as the `_fast_forward_iterator` pf the DataPipe
4. The next time `iter` is called on the DataPipe, use the `_fast_forward_iterator`

### Usage
As of this implementation, the usage will something like:
```python
rng = torch.Generator()
initial_rng_state = rng.get_state()
datapipe: IterDataPipe = ...
# Some usage of the DataPipe, here maybe yielding the first 5 values
n_iter = 5
it = iter(datapipe)
for _ in range(n_iter):
    next(it)
serialized_graph = pickle.dumps(datapipe)

# The serialized object has most of the sufficient information for simple snapshot (except for initial RNG state)
# It can be deserialized at a later point in time or by a different process
deserialized_graph = pickle.loads(serialized_graph)
# I think `DataLoader2` or `ReadingService` should store `initial_rng_state` that can be saved by the API that we later use
rng_for_deserialized = torch.Generator()
rng_for_deserialized.set_state(initial_rng_state)
n_iterations = deserialized_graph._number_of_samples_yielded

_simple_snapshot_graph(deserialized_graph, n_iterations, rng=rng_for_deserialized)
# The while DataPipe graph should have the same state as before serialization, such that:
self.assertEqual(list(it), list(deserialized_graph))  # True
```

### Next Steps
If this looks acceptable, the next step is I will modify `DataLoader2`'s prototype ReadingService (the one with queues) to remember things like `initial_rng_state` and to have methods `save_snapshot` that will return the `(serialized graph, initial_rng)` and `restore_snapshot`. This should work for single worker data loading.


Note that, in the long term, `initial_rng_state` may not be necessary if we are able to directly save/restore the buffer and RNG state of `Shuffler` (that is work in progress). However, `initial_rng_state` and simple snapshot is still a good fall-back option for some edge cases where the buffer can't be stored.

[ghstack-poisoned]
This mostly completes the "poor man's snapshotting" implementation (named "simple snapshotting"). This is the most basic version of snapshotting but it should work for all DataPipes. I will be adding more efficient implementation for different types of DataPipes in future PRs.

### Implementation

The general idea of the simple snapshot is that we will:
1. Create a new iterator
2. Move that iterator forward by `n_iterations`
3. Save that as the `_fast_forward_iterator` of the DataPipe
4. The next time `iter` is called on the DataPipe, use the `_fast_forward_iterator`

### Usage
As of this implementation, the usage will something like:
```python
rng = torch.Generator()
initial_rng_state = rng.get_state()
datapipe: IterDataPipe = ...
# Some usage of the DataPipe, here maybe yielding the first 5 values
n_iter = 5
it = iter(datapipe)
for _ in range(n_iter):
    next(it)
serialized_graph = pickle.dumps(datapipe)

# The serialized object has most of the sufficient information for simple snapshot (except for initial RNG state)
# It can be deserialized at a later point in time or by a different process
deserialized_graph = pickle.loads(serialized_graph)
# I think `DataLoader2` or `ReadingService` should store `initial_rng_state` that can be saved by the API that we later use
rng_for_deserialized = torch.Generator()
rng_for_deserialized.set_state(initial_rng_state)
n_iterations = deserialized_graph._number_of_samples_yielded

_simple_snapshot_graph(deserialized_graph, n_iterations, rng=rng_for_deserialized)
# The while DataPipe graph should have the same state as before serialization, such that:
self.assertEqual(list(it), list(deserialized_graph))  # True
```

### Next Steps
If this looks acceptable, the next step is I will modify `DataLoader2`'s prototype ReadingService (the one with queues) to remember things like `initial_rng_state` and to have methods `save_snapshot` that will return the `(serialized graph, initial_rng)` and `restore_snapshot`. This should work for single worker data loading.


Note that, in the long term, `initial_rng_state` may not be necessary if we are able to directly save/restore the buffer and RNG state of `Shuffler` (that is work in progress). However, `initial_rng_state` and simple snapshot is still a good fall-back option for some edge cases where the buffer can't be stored.

[ghstack-poisoned]
@NivekT
Copy link
Contributor Author

NivekT commented Jul 18, 2022

@NivekT has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.

This mostly completes the "poor man's snapshotting" implementation (named "simple snapshotting"). This is the most basic version of snapshotting but it should work for all DataPipes. I will be adding more efficient implementation for different types of DataPipes in future PRs.

### Implementation

The general idea of the simple snapshot is that we will:
1. Create a new iterator
2. Move that iterator forward by `n_iterations`
3. Save that as the `_fast_forward_iterator` of the DataPipe
4. The next time `iter` is called on the DataPipe, use the `_fast_forward_iterator`

### Usage
As of this implementation, the usage will something like:
```python
rng = torch.Generator()
initial_rng_state = rng.get_state()
datapipe: IterDataPipe = ...
# Some usage of the DataPipe, here maybe yielding the first 5 values
n_iter = 5
it = iter(datapipe)
for _ in range(n_iter):
    next(it)
serialized_graph = pickle.dumps(datapipe)

# The serialized object has most of the sufficient information for simple snapshot (except for initial RNG state)
# It can be deserialized at a later point in time or by a different process
deserialized_graph = pickle.loads(serialized_graph)
# I think `DataLoader2` or `ReadingService` should store `initial_rng_state` that can be saved by the API that we later use
rng_for_deserialized = torch.Generator()
rng_for_deserialized.set_state(initial_rng_state)
n_iterations = deserialized_graph._number_of_samples_yielded

_simple_snapshot_graph(deserialized_graph, n_iterations, rng=rng_for_deserialized)
# The while DataPipe graph should have the same state as before serialization, such that:
self.assertEqual(list(it), list(deserialized_graph))  # True
```

### Next Steps
If this looks acceptable, the next step is I will modify `DataLoader2`'s prototype ReadingService (the one with queues) to remember things like `initial_rng_state` and to have methods `save_snapshot` that will return the `(serialized graph, initial_rng)` and `restore_snapshot`. This should work for single worker data loading.


Note that, in the long term, `initial_rng_state` may not be necessary if we are able to directly save/restore the buffer and RNG state of `Shuffler` (that is work in progress). However, `initial_rng_state` and simple snapshot is still a good fall-back option for some edge cases where the buffer can't be stored.

Differential Revision: [D37943406](https://our.internmc.facebook.com/intern/diff/D37943406)

[ghstack-poisoned]
@NivekT
Copy link
Contributor Author

NivekT commented Jul 19, 2022

@NivekT has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.

Copy link
Contributor Author

@NivekT NivekT left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Updated this PR. I am expecting this to be final.

This mostly completes the "poor man's snapshotting" implementation (named "simple snapshotting"). This is the most basic version of snapshotting but it should work for all DataPipes. I will be adding more efficient implementation for different types of DataPipes in future PRs.

### Implementation

The general idea of the simple snapshot is that we will:
1. Create a new iterator
2. Move that iterator forward by `n_iterations`
3. Save that as the `_fast_forward_iterator` of the DataPipe
4. The next time `iter` is called on the DataPipe, use the `_fast_forward_iterator`

### Usage
As of this implementation, the usage will something like:
```python
rng = torch.Generator()
initial_rng_state = rng.get_state()
datapipe: IterDataPipe = ...
# Some usage of the DataPipe, here maybe yielding the first 5 values
n_iter = 5
it = iter(datapipe)
for _ in range(n_iter):
    next(it)
serialized_graph = pickle.dumps(datapipe)

# The serialized object has most of the sufficient information for simple snapshot (except for initial RNG state)
# It can be deserialized at a later point in time or by a different process
deserialized_graph = pickle.loads(serialized_graph)
# I think `DataLoader2` or `ReadingService` should store `initial_rng_state` that can be saved by the API that we later use
rng_for_deserialized = torch.Generator()
rng_for_deserialized.set_state(initial_rng_state)
n_iterations = deserialized_graph._number_of_samples_yielded

_simple_snapshot_graph(deserialized_graph, n_iterations, rng=rng_for_deserialized)
# The while DataPipe graph should have the same state as before serialization, such that:
self.assertEqual(list(it), list(deserialized_graph))  # True
```

### Next Steps
If this looks acceptable, the next step is I will modify `DataLoader2`'s prototype ReadingService (the one with queues) to remember things like `initial_rng_state` and to have methods `save_snapshot` that will return the `(serialized graph, initial_rng)` and `restore_snapshot`. This should work for single worker data loading.


Note that, in the long term, `initial_rng_state` may not be necessary if we are able to directly save/restore the buffer and RNG state of `Shuffler` (that is work in progress). However, `initial_rng_state` and simple snapshot is still a good fall-back option for some edge cases where the buffer can't be stored.

Differential Revision: [D37943406](https://our.internmc.facebook.com/intern/diff/D37943406)

[ghstack-poisoned]
@NivekT
Copy link
Contributor Author

NivekT commented Jul 20, 2022

@NivekT has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.

@NivekT NivekT requested review from VitalyFedyunin and ejguan July 22, 2022 19:04
Copy link
Contributor

@ejguan ejguan left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM based on the test cases

@NivekT
Copy link
Contributor Author

NivekT commented Jul 22, 2022

@pytorchbot rebase

@pytorchmergebot
Copy link
Collaborator

@pytorchbot successfully started a rebase job. Check the current status here

This mostly completes the "poor man's snapshotting" implementation (named "simple snapshotting"). This is the most basic version of snapshotting but it should work for all DataPipes. I will be adding more efficient implementation for different types of DataPipes in future PRs.

### Implementation

The general idea of the simple snapshot is that we will:
1. Create a new iterator
2. Move that iterator forward by `n_iterations`
3. Save that as the `_fast_forward_iterator` of the DataPipe
4. The next time `iter` is called on the DataPipe, use the `_fast_forward_iterator`

### Usage
As of this implementation, the usage will something like:
```python
rng = torch.Generator()
initial_rng_state = rng.get_state()
datapipe: IterDataPipe = ...
# Some usage of the DataPipe, here maybe yielding the first 5 values
n_iter = 5
it = iter(datapipe)
for _ in range(n_iter):
    next(it)
serialized_graph = pickle.dumps(datapipe)

# The serialized object has most of the sufficient information for simple snapshot (except for initial RNG state)
# It can be deserialized at a later point in time or by a different process
deserialized_graph = pickle.loads(serialized_graph)
# I think `DataLoader2` or `ReadingService` should store `initial_rng_state` that can be saved by the API that we later use
rng_for_deserialized = torch.Generator()
rng_for_deserialized.set_state(initial_rng_state)
n_iterations = deserialized_graph._number_of_samples_yielded

_simple_snapshot_graph(deserialized_graph, n_iterations, rng=rng_for_deserialized)
# The while DataPipe graph should have the same state as before serialization, such that:
self.assertEqual(list(it), list(deserialized_graph))  # True
```

### Next Steps
If this looks acceptable, the next step is I will modify `DataLoader2`'s prototype ReadingService (the one with queues) to remember things like `initial_rng_state` and to have methods `save_snapshot` that will return the `(serialized graph, initial_rng)` and `restore_snapshot`. This should work for single worker data loading.


Note that, in the long term, `initial_rng_state` may not be necessary if we are able to directly save/restore the buffer and RNG state of `Shuffler` (that is work in progress). However, `initial_rng_state` and simple snapshot is still a good fall-back option for some edge cases where the buffer can't be stored.

Differential Revision: [D37943406](https://our.internmc.facebook.com/intern/diff/D37943406)

[ghstack-poisoned]
@pytorchmergebot
Copy link
Collaborator

Successfully rebased gh/nivekt/49/orig onto refs/remotes/origin/master, please pull locally before adding more changes (for example, via ghstack checkout https://github.com/pytorch/pytorch/pull/79479)

pytorchmergebot pushed a commit that referenced this pull request Jul 22, 2022
ghstack-source-id: 6d3120b
Pull Request resolved: #79479

[DataPipe] Snapshotting with simple fast-forwarding

ghstack-source-id: 6d3120b
Pull Request resolved: #80250
@NivekT
Copy link
Contributor Author

NivekT commented Jul 22, 2022

@pytorchbot merge

@pytorchmergebot
Copy link
Collaborator

@pytorchbot successfully started a merge job. Check the current status here

@pytorchmergebot
Copy link
Collaborator

Merge failed due to This PR has internal changes and must be landed via Phabricator
Raised by https://github.com/pytorch/pytorch/actions/runs/2721018437

@NivekT
Copy link
Contributor Author

NivekT commented Jul 23, 2022

@NivekT has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.

@NivekT
Copy link
Contributor Author

NivekT commented Jul 23, 2022

@pytorchbot merge

@pytorchmergebot
Copy link
Collaborator

@pytorchbot successfully started a merge job. Check the current status here

facebook-github-bot pushed a commit that referenced this pull request Jul 26, 2022
Summary:
This mostly completes the "poor man's snapshotting" implementation (named "simple snapshotting"). This is the most basic version of snapshotting but it should work for all DataPipes. I will be adding more efficient implementation for different types of DataPipes in future PRs.

### Implementation

The general idea of the simple snapshot is that we will:
1. Create a new iterator
2. Move that iterator forward by `n_iterations`
3. Save that as the `_fast_forward_iterator` of the DataPipe
4. The next time `iter` is called on the DataPipe, use the `_fast_forward_iterator`

### Usage
As of this implementation, the usage will something like:
```python
rng = torch.Generator()
initial_rng_state = rng.get_state()
datapipe: IterDataPipe = ...
# Some usage of the DataPipe, here maybe yielding the first 5 values
n_iter = 5
it = iter(datapipe)
for _ in range(n_iter):
    next(it)
serialized_graph = pickle.dumps(datapipe)

# The serialized object has most of the sufficient information for simple snapshot (except for initial RNG state)
# It can be deserialized at a later point in time or by a different process
deserialized_graph = pickle.loads(serialized_graph)
# I think `DataLoader2` or `ReadingService` should store `initial_rng_state` that can be saved by the API that we later use
rng_for_deserialized = torch.Generator()
rng_for_deserialized.set_state(initial_rng_state)
n_iterations = deserialized_graph._number_of_samples_yielded

_simple_snapshot_graph(deserialized_graph, n_iterations, rng=rng_for_deserialized)
# The while DataPipe graph should have the same state as before serialization, such that:
self.assertEqual(list(it), list(deserialized_graph))  # True
```

### Next Steps
If this looks acceptable, the next step is I will modify `DataLoader2`'s prototype ReadingService (the one with queues) to remember things like `initial_rng_state` and to have methods `save_snapshot` that will return the `(serialized graph, initial_rng)` and `restore_snapshot`. This should work for single worker data loading.

Note that, in the long term, `initial_rng_state` may not be necessary if we are able to directly save/restore the buffer and RNG state of `Shuffler` (that is work in progress). However, `initial_rng_state` and simple snapshot is still a good fall-back option for some edge cases where the buffer can't be stored.

Pull Request resolved: #79479
Approved by: https://github.com/ejguan

Test Plan:
contbuild & OSS CI, see https://hud.pytorch.org/commit/pytorch/pytorch/35d97e21c8e2e28fae4a4744ce42f1544972ba1f

Original Phabricator Test Plan:
Imported from OSS

Reviewed By: osalpekar, ejguan

Differential Revision: D37943406

Pulled By: NivekT

fbshipit-source-id: 7c50fd277a3d855a1b25d4c0f15face941d47a9d
@facebook-github-bot facebook-github-bot deleted the gh/nivekt/49/head branch July 26, 2022 14:20
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants