[DataPipe] Simple graph snapshotting#79479
[DataPipe] Simple graph snapshotting#79479NivekT wants to merge 21 commits intogh/nivekt/49/basefrom
Conversation
[ghstack-poisoned]
🔗 Helpful links
✅ No Failures (0 Pending)As of commit 7604dfe (more details on the Dr. CI page): Expand to see more💚 💚 Looks good so far! There are no failures yet. 💚 💚 This comment was automatically generated by Dr. CI (expand for details).Please report bugs/suggestions to the (internal) Dr. CI Users group. |
[ghstack-poisoned]
[ghstack-poisoned]
[ghstack-poisoned]
[ghstack-poisoned]
[ghstack-poisoned]
[ghstack-poisoned]
[ghstack-poisoned]
[ghstack-poisoned]
ejguan
left a comment
There was a problem hiding this comment.
This is the question I asked before. Do we need to make DataPipe itself supporting fast forward.
If we directly use DataLoader2 to save/restore the state, we can simply fast forward the iterator of DataPipe graph directly in DataLoader2.
I discuss part of the design here.
|
|
I hesitate to expose fast-forwarding as the API at DataPipe for two reasons:
|
That is fine with me. We can definitely make this private and only have
They serve the same purpose. I can rename fast-forward as "poor man's snapshotting" if that is better. I also agree that the API should be controlled by |
This mostly completes the poor man's snapshotting implementation (named simple snapshot). This is the most basic version of snapshotting but it should work for all DataPipes. I will be adding more efficient implementation for different types of DataPipes in future PRs.
### Implementation
The general idea of the simple snapshot is that we will:
1. Create a new iterator
2. Move that iterator forward by `n_iterations`
3. Save that as the `_fast_forward_iterator` pf the DataPipe
4. The next time `iter` is called on the DataPipe, use the `_fast_forward_iterator`
### Usage
As of this implementation, the usage will something like:
```python
rng = torch.Generator()
initial_rng_state = rng.get_state()
datapipe: IterDataPipe = ...
# Some usage of the DataPipe, here maybe yielding the first 5 values
n_iter = 5
it = iter(datapipe)
for _ in range(n_iter):
next(it)
serialized_graph = pickle.dumps(datapipe)
# The serialized object has most of the sufficient information for simple snapshot (except for initial RNG state)
# It can be deserialized at a later point in time or by a different process
deserialized_graph = pickle.loads(serialized_graph)
# I think `DataLoader2` or `ReadingService` should store `initial_rng_state` that can be saved by the API that we later use
rng_for_deserialized = torch.Generator()
rng_for_deserialized.set_state(initial_rng_state)
n_iterations = deserialized_graph._number_of_samples_yielded
_simple_snapshot_graph(deserialized_graph, n_iterations, rng=rng_for_deserialized)
# The while DataPipe graph should have the same state as before serialization, such that:
self.assertEqual(list(it), list(deserialized_graph)) # True
```
### Next Steps
If this looks acceptable, the next step is I will modify `DataLoader2`'s prototype ReadingService (the one with queues) to remember things like `initial_rng_state` and to have methods `save_snapshot` that will return the `(serialized graph, initial_rng)` and `restore_snapshot`. This should work for single worker data loading.
Note that, in the long term, `initial_rng_state` may not be necessary if we are able to directly save/restore the buffer and RNG state of `Shuffler` (that is work in progress). However, `initial_rng_state` and simple snapshot is still a good fall-back option for some edge cases where the buffer can't be stored.
[ghstack-poisoned]
This mostly completes the "poor man's snapshotting" implementation (named "simple snapshotting"). This is the most basic version of snapshotting but it should work for all DataPipes. I will be adding more efficient implementation for different types of DataPipes in future PRs.
### Implementation
The general idea of the simple snapshot is that we will:
1. Create a new iterator
2. Move that iterator forward by `n_iterations`
3. Save that as the `_fast_forward_iterator` of the DataPipe
4. The next time `iter` is called on the DataPipe, use the `_fast_forward_iterator`
### Usage
As of this implementation, the usage will something like:
```python
rng = torch.Generator()
initial_rng_state = rng.get_state()
datapipe: IterDataPipe = ...
# Some usage of the DataPipe, here maybe yielding the first 5 values
n_iter = 5
it = iter(datapipe)
for _ in range(n_iter):
next(it)
serialized_graph = pickle.dumps(datapipe)
# The serialized object has most of the sufficient information for simple snapshot (except for initial RNG state)
# It can be deserialized at a later point in time or by a different process
deserialized_graph = pickle.loads(serialized_graph)
# I think `DataLoader2` or `ReadingService` should store `initial_rng_state` that can be saved by the API that we later use
rng_for_deserialized = torch.Generator()
rng_for_deserialized.set_state(initial_rng_state)
n_iterations = deserialized_graph._number_of_samples_yielded
_simple_snapshot_graph(deserialized_graph, n_iterations, rng=rng_for_deserialized)
# The while DataPipe graph should have the same state as before serialization, such that:
self.assertEqual(list(it), list(deserialized_graph)) # True
```
### Next Steps
If this looks acceptable, the next step is I will modify `DataLoader2`'s prototype ReadingService (the one with queues) to remember things like `initial_rng_state` and to have methods `save_snapshot` that will return the `(serialized graph, initial_rng)` and `restore_snapshot`. This should work for single worker data loading.
Note that, in the long term, `initial_rng_state` may not be necessary if we are able to directly save/restore the buffer and RNG state of `Shuffler` (that is work in progress). However, `initial_rng_state` and simple snapshot is still a good fall-back option for some edge cases where the buffer can't be stored.
[ghstack-poisoned]
|
@NivekT has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator. |
This mostly completes the "poor man's snapshotting" implementation (named "simple snapshotting"). This is the most basic version of snapshotting but it should work for all DataPipes. I will be adding more efficient implementation for different types of DataPipes in future PRs.
### Implementation
The general idea of the simple snapshot is that we will:
1. Create a new iterator
2. Move that iterator forward by `n_iterations`
3. Save that as the `_fast_forward_iterator` of the DataPipe
4. The next time `iter` is called on the DataPipe, use the `_fast_forward_iterator`
### Usage
As of this implementation, the usage will something like:
```python
rng = torch.Generator()
initial_rng_state = rng.get_state()
datapipe: IterDataPipe = ...
# Some usage of the DataPipe, here maybe yielding the first 5 values
n_iter = 5
it = iter(datapipe)
for _ in range(n_iter):
next(it)
serialized_graph = pickle.dumps(datapipe)
# The serialized object has most of the sufficient information for simple snapshot (except for initial RNG state)
# It can be deserialized at a later point in time or by a different process
deserialized_graph = pickle.loads(serialized_graph)
# I think `DataLoader2` or `ReadingService` should store `initial_rng_state` that can be saved by the API that we later use
rng_for_deserialized = torch.Generator()
rng_for_deserialized.set_state(initial_rng_state)
n_iterations = deserialized_graph._number_of_samples_yielded
_simple_snapshot_graph(deserialized_graph, n_iterations, rng=rng_for_deserialized)
# The while DataPipe graph should have the same state as before serialization, such that:
self.assertEqual(list(it), list(deserialized_graph)) # True
```
### Next Steps
If this looks acceptable, the next step is I will modify `DataLoader2`'s prototype ReadingService (the one with queues) to remember things like `initial_rng_state` and to have methods `save_snapshot` that will return the `(serialized graph, initial_rng)` and `restore_snapshot`. This should work for single worker data loading.
Note that, in the long term, `initial_rng_state` may not be necessary if we are able to directly save/restore the buffer and RNG state of `Shuffler` (that is work in progress). However, `initial_rng_state` and simple snapshot is still a good fall-back option for some edge cases where the buffer can't be stored.
Differential Revision: [D37943406](https://our.internmc.facebook.com/intern/diff/D37943406)
[ghstack-poisoned]
|
@NivekT has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator. |
NivekT
left a comment
There was a problem hiding this comment.
Updated this PR. I am expecting this to be final.
This mostly completes the "poor man's snapshotting" implementation (named "simple snapshotting"). This is the most basic version of snapshotting but it should work for all DataPipes. I will be adding more efficient implementation for different types of DataPipes in future PRs.
### Implementation
The general idea of the simple snapshot is that we will:
1. Create a new iterator
2. Move that iterator forward by `n_iterations`
3. Save that as the `_fast_forward_iterator` of the DataPipe
4. The next time `iter` is called on the DataPipe, use the `_fast_forward_iterator`
### Usage
As of this implementation, the usage will something like:
```python
rng = torch.Generator()
initial_rng_state = rng.get_state()
datapipe: IterDataPipe = ...
# Some usage of the DataPipe, here maybe yielding the first 5 values
n_iter = 5
it = iter(datapipe)
for _ in range(n_iter):
next(it)
serialized_graph = pickle.dumps(datapipe)
# The serialized object has most of the sufficient information for simple snapshot (except for initial RNG state)
# It can be deserialized at a later point in time or by a different process
deserialized_graph = pickle.loads(serialized_graph)
# I think `DataLoader2` or `ReadingService` should store `initial_rng_state` that can be saved by the API that we later use
rng_for_deserialized = torch.Generator()
rng_for_deserialized.set_state(initial_rng_state)
n_iterations = deserialized_graph._number_of_samples_yielded
_simple_snapshot_graph(deserialized_graph, n_iterations, rng=rng_for_deserialized)
# The while DataPipe graph should have the same state as before serialization, such that:
self.assertEqual(list(it), list(deserialized_graph)) # True
```
### Next Steps
If this looks acceptable, the next step is I will modify `DataLoader2`'s prototype ReadingService (the one with queues) to remember things like `initial_rng_state` and to have methods `save_snapshot` that will return the `(serialized graph, initial_rng)` and `restore_snapshot`. This should work for single worker data loading.
Note that, in the long term, `initial_rng_state` may not be necessary if we are able to directly save/restore the buffer and RNG state of `Shuffler` (that is work in progress). However, `initial_rng_state` and simple snapshot is still a good fall-back option for some edge cases where the buffer can't be stored.
Differential Revision: [D37943406](https://our.internmc.facebook.com/intern/diff/D37943406)
[ghstack-poisoned]
|
@NivekT has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator. |
ejguan
left a comment
There was a problem hiding this comment.
LGTM based on the test cases
|
@pytorchbot rebase |
|
@pytorchbot successfully started a rebase job. Check the current status here |
This mostly completes the "poor man's snapshotting" implementation (named "simple snapshotting"). This is the most basic version of snapshotting but it should work for all DataPipes. I will be adding more efficient implementation for different types of DataPipes in future PRs.
### Implementation
The general idea of the simple snapshot is that we will:
1. Create a new iterator
2. Move that iterator forward by `n_iterations`
3. Save that as the `_fast_forward_iterator` of the DataPipe
4. The next time `iter` is called on the DataPipe, use the `_fast_forward_iterator`
### Usage
As of this implementation, the usage will something like:
```python
rng = torch.Generator()
initial_rng_state = rng.get_state()
datapipe: IterDataPipe = ...
# Some usage of the DataPipe, here maybe yielding the first 5 values
n_iter = 5
it = iter(datapipe)
for _ in range(n_iter):
next(it)
serialized_graph = pickle.dumps(datapipe)
# The serialized object has most of the sufficient information for simple snapshot (except for initial RNG state)
# It can be deserialized at a later point in time or by a different process
deserialized_graph = pickle.loads(serialized_graph)
# I think `DataLoader2` or `ReadingService` should store `initial_rng_state` that can be saved by the API that we later use
rng_for_deserialized = torch.Generator()
rng_for_deserialized.set_state(initial_rng_state)
n_iterations = deserialized_graph._number_of_samples_yielded
_simple_snapshot_graph(deserialized_graph, n_iterations, rng=rng_for_deserialized)
# The while DataPipe graph should have the same state as before serialization, such that:
self.assertEqual(list(it), list(deserialized_graph)) # True
```
### Next Steps
If this looks acceptable, the next step is I will modify `DataLoader2`'s prototype ReadingService (the one with queues) to remember things like `initial_rng_state` and to have methods `save_snapshot` that will return the `(serialized graph, initial_rng)` and `restore_snapshot`. This should work for single worker data loading.
Note that, in the long term, `initial_rng_state` may not be necessary if we are able to directly save/restore the buffer and RNG state of `Shuffler` (that is work in progress). However, `initial_rng_state` and simple snapshot is still a good fall-back option for some edge cases where the buffer can't be stored.
Differential Revision: [D37943406](https://our.internmc.facebook.com/intern/diff/D37943406)
[ghstack-poisoned]
|
Successfully rebased |
|
@pytorchbot merge |
|
@pytorchbot successfully started a merge job. Check the current status here |
|
Merge failed due to This PR has internal changes and must be landed via Phabricator |
|
@NivekT has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator. |
|
@pytorchbot merge |
|
@pytorchbot successfully started a merge job. Check the current status here |
Summary:
This mostly completes the "poor man's snapshotting" implementation (named "simple snapshotting"). This is the most basic version of snapshotting but it should work for all DataPipes. I will be adding more efficient implementation for different types of DataPipes in future PRs.
### Implementation
The general idea of the simple snapshot is that we will:
1. Create a new iterator
2. Move that iterator forward by `n_iterations`
3. Save that as the `_fast_forward_iterator` of the DataPipe
4. The next time `iter` is called on the DataPipe, use the `_fast_forward_iterator`
### Usage
As of this implementation, the usage will something like:
```python
rng = torch.Generator()
initial_rng_state = rng.get_state()
datapipe: IterDataPipe = ...
# Some usage of the DataPipe, here maybe yielding the first 5 values
n_iter = 5
it = iter(datapipe)
for _ in range(n_iter):
next(it)
serialized_graph = pickle.dumps(datapipe)
# The serialized object has most of the sufficient information for simple snapshot (except for initial RNG state)
# It can be deserialized at a later point in time or by a different process
deserialized_graph = pickle.loads(serialized_graph)
# I think `DataLoader2` or `ReadingService` should store `initial_rng_state` that can be saved by the API that we later use
rng_for_deserialized = torch.Generator()
rng_for_deserialized.set_state(initial_rng_state)
n_iterations = deserialized_graph._number_of_samples_yielded
_simple_snapshot_graph(deserialized_graph, n_iterations, rng=rng_for_deserialized)
# The while DataPipe graph should have the same state as before serialization, such that:
self.assertEqual(list(it), list(deserialized_graph)) # True
```
### Next Steps
If this looks acceptable, the next step is I will modify `DataLoader2`'s prototype ReadingService (the one with queues) to remember things like `initial_rng_state` and to have methods `save_snapshot` that will return the `(serialized graph, initial_rng)` and `restore_snapshot`. This should work for single worker data loading.
Note that, in the long term, `initial_rng_state` may not be necessary if we are able to directly save/restore the buffer and RNG state of `Shuffler` (that is work in progress). However, `initial_rng_state` and simple snapshot is still a good fall-back option for some edge cases where the buffer can't be stored.
Pull Request resolved: #79479
Approved by: https://github.com/ejguan
Test Plan:
contbuild & OSS CI, see https://hud.pytorch.org/commit/pytorch/pytorch/35d97e21c8e2e28fae4a4744ce42f1544972ba1f
Original Phabricator Test Plan:
Imported from OSS
Reviewed By: osalpekar, ejguan
Differential Revision: D37943406
Pulled By: NivekT
fbshipit-source-id: 7c50fd277a3d855a1b25d4c0f15face941d47a9d
Stack from ghstack (oldest at bottom):
This mostly completes the "poor man's snapshotting" implementation (named "simple snapshotting"). This is the most basic version of snapshotting but it should work for all DataPipes. I will be adding more efficient implementation for different types of DataPipes in future PRs.
Implementation
The general idea of the simple snapshot is that we will:
n_iterations_fast_forward_iteratorof the DataPipeiteris called on the DataPipe, use the_fast_forward_iteratorUsage
As of this implementation, the usage will something like:
Next Steps
If this looks acceptable, the next step is I will modify
DataLoader2's prototype ReadingService (the one with queues) to remember things likeinitial_rng_stateand to have methodssave_snapshotthat will return the(serialized graph, initial_rng)andrestore_snapshot. This should work for single worker data loading.Note that, in the long term,
initial_rng_statemay not be necessary if we are able to directly save/restore the buffer and RNG state ofShuffler(that is work in progress). However,initial_rng_stateand simple snapshot is still a good fall-back option for some edge cases where the buffer can't be stored.Differential Revision: D37943406