[dcp] add new checkpoint staging to preserve storage sharing and support mutable state_dicts#155192
[dcp] add new checkpoint staging to preserve storage sharing and support mutable state_dicts#155192teja-rao wants to merge 1 commit intopytorch:mainfrom
Conversation
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/155192
Note: Links to docs will display an error until the docs builds have been completed. ✅ No FailuresAs of commit 50b4627 with merge base 728cf67 ( This comment was automatically generated by Dr. CI and updates every 15 minutes. |
|
This pull request was exported from Phabricator. Differential Revision: D75993324 |
fecda50 to
62e250b
Compare
62e250b to
20f8b37
Compare
|
This pull request was exported from Phabricator. Differential Revision: D75993324 |
20f8b37 to
715b2ee
Compare
|
This pull request was exported from Phabricator. Differential Revision: D75993324 |
715b2ee to
f20d67f
Compare
There was a problem hiding this comment.
While I like the the idea of using deepcopy, which is very simple, do we know the implication of deepcopy a DTensor/DeviceMesh that contains a process_group? I have to admit that I have never tried it before and we should verify this (e.g., whether there will be more PGs created). Not saying this can be an issue, but we should verify it.
Okay, DeviceMesh doesn't contain PGs. This shouldn't be a concern.
There was a problem hiding this comment.
This doesn't match the current implementation. And are you going to implement the non_blocking feature?
There was a problem hiding this comment.
It's also good to tag the core member to review the code change as it touches the core part. cc., @albanD
There was a problem hiding this comment.
I will update the PR to pass non-blocking to copy_ . i intend to support zero-copy.
There was a problem hiding this comment.
FYI: We've had previous issues with deepcopy and had to rollback save plan caching feature because of deepcopy performance. More context: #149320
There was a problem hiding this comment.
FYI: We've had previous issues with deepcopy and had to rollback save plan caching feature because of deepcopy performance. More context: #149320
@MeetVadakkanchery There is no way to avoid deepcopy if you want data to be modified while uploading a checkpoint. The best way to mitigate costs of deepcopy is to have a simpler state_dict object.
|
Another concern of deepcopy is that the |
Yes, this is a good point. If the tensors contain non-serializable objects this approach may cause issues but the user also has an option to define deep_copy and customize the behavior for staging, just like we are doing in tensor/storage.py to define exactly how they want the staging to work. |
f20d67f to
65b4c96
Compare
|
This pull request was exported from Phabricator. Differential Revision: D75993324 |
65b4c96 to
96ceddf
Compare
|
This pull request was exported from Phabricator. Differential Revision: D75993324 |
96ceddf to
0b742d9
Compare
There was a problem hiding this comment.
n00b question: how is storage.storage_deepcopy() invoked?
There was a problem hiding this comment.
If a state_dict has many non-tensor items, state_dict.deep_copy() can run into perf issues.
There was a problem hiding this comment.
This may worth to verify. Each DTensor will have a DeviceMesh, which is a non-tensor item.
There was a problem hiding this comment.
this is a bit out-dated now that we changed the approach to mimic deepcopy instead of hooking into it.
0b742d9 to
94ca04b
Compare
fad25ec to
b1d734f
Compare
|
This pull request was exported from Phabricator. Differential Revision: D75993324 |
b1d734f to
c10470e
Compare
|
This pull request was exported from Phabricator. Differential Revision: D75993324 |
c10470e to
f106da5
Compare
|
This pull request was exported from Phabricator. Differential Revision: D75993324 |
f106da5 to
bc2b773
Compare
|
This pull request was exported from Phabricator. Differential Revision: D75993324 |
8 similar comments
|
This pull request was exported from Phabricator. Differential Revision: D75993324 |
|
This pull request was exported from Phabricator. Differential Revision: D75993324 |
|
This pull request was exported from Phabricator. Differential Revision: D75993324 |
|
This pull request was exported from Phabricator. Differential Revision: D75993324 |
|
This pull request was exported from Phabricator. Differential Revision: D75993324 |
|
This pull request was exported from Phabricator. Differential Revision: D75993324 |
|
This pull request was exported from Phabricator. Differential Revision: D75993324 |
|
This pull request was exported from Phabricator. Differential Revision: D75993324 |
There was a problem hiding this comment.
Why not do this inside the pin_memory call directly since it is expected all callers will do that?
There was a problem hiding this comment.
weakref needs the storage object to be passed in and that limits the API applicability to storage objects alone. With current APIs, the user can just allocate a raw memory chunk and pin it.
mikaylagawarecki
left a comment
There was a problem hiding this comment.
only nits from me
There was a problem hiding this comment.
nit: (only if helpful) I think you can get most of these checks (with the exception of the storage_offset check from assertEqual)
There was a problem hiding this comment.
here I meant assertEqual on the TestClass not torch.equal, but I now realize you define this function outside the testclass so you can ignore my previous comment, sorry about that
pytorch/torch/testing/_internal/common_utils.py
Lines 4056 to 4071 in c74fd35
There was a problem hiding this comment.
nit: remove print statements
There was a problem hiding this comment.
nit: Should we add a test for non_blocking=True?
(I know this arg is just forwarding non_blocking to copy_ but since users will have to explicitly synchronize the CPU in that case, a test might be helpful to demonstrate how this should be done)
There was a problem hiding this comment.
good call on the test/example, will add it in a follow up PR.
The reason for making it Flase by default, is that non-blocking needs stream setup to be done properly and is non-trivial.
|
This pull request was exported from Phabricator. Differential Revision: D75993324 |
1 similar comment
|
This pull request was exported from Phabricator. Differential Revision: D75993324 |
…ort state_dict changes across checkpoints (pytorch#155192) Summary: Pull Request resolved: pytorch#155192 Update: Updated the diff to avoid hooking into deepcopy by rolling out handwritten deepcopy like scaffolding. There are some caveats: 1. Duplicated deepcopy code to hook into for tensors. There is a risk of this code getting outdated with python version changes. This is needed to handle several different types like NamedTuples, frozen dataclasses, nested dataclasses. deepcopy logic is relying on __reduce_ex__ to get a function with which these can be constructed. 2. Since we are bypassing deepcopy and adding custom logic to clone a tensor, we are missing some of the functionality that exists in __deepcopy__ for torch.Tensor like _clear_non_serializable_cached_data(), or other logic. Would like thoughts on which logic or if everything should be copied? 3. If any object implemented __deepcopy__ , we will not be able to handle any tensors in the attrs with this logic because they likely just call copy.deepcopy on the attrs instead of this deepcopy logic. We are taking care of subclasses of torch.Tensor to workaround this. This implements staging in way that doesnt mess up checkpointing semantics. We want to be close to torch.save/load semantics and when async checkpointing is used it messes up shared storages, doesnt handle custom objects or tensors well. EG: users passes a state_dict with a cuda tensor in datatype. this is deepcloned causing the staging tensor to be created on GPU. This can cause ooms is hard to debug. ~~This diffs hooks into deepcopy of storages to move them to cpu using the cached storages created for async checkpoint staging. This allows reusing storages created for staging to avoid recreating them on each checkpoint while also being flexible enough to handle any changes - clean up old storages or create new ones as needed.~~This diffs replicates deepcopy logic to clone the state_dict and add special handling for tensors to move them to cpu. Lifetime of staging storages is tied to the original storage object. when the original storage object is gc-ed, we delete the corresponding staging storage from cache possibly causing it to gc-ed is there are no other references. I am using data_ptr of the storage to keep track of this. Please share thoughts on this. The alternative is to use fqn's instead of storage_id and verify the underlying storage object has same shape/size,etc to make the caching logic work. Current implementation is much simpler and cleaner. The API: ``` # construct a stager once per job in checkpointing. stager = StateDictStager(pin_memory=pin_memory, share_memory=share_memory) # do this on every checkpoint: with staging_context(stager): cpu_state_dict = copy.deepcopy(state_dict) ``` Also, adds support for pinned-memory. One problem this implementation does not address is that we lose the original device. The only alternatives here are - pickle synchronously like torch.save but with special handling for storages. It is valuable to keep state_dict throughout the checkpointing process. so users can manipulate and debug as needed. so we need to unpickle in the background process. I think this is flexible, not performant and not very different to current solution but needs more code. One idea if we really want to address is this to stick the original device in a some variable on storage and then use it recover on load side. I think we do not need this for now and can be explicit about losing device type for async checkpointing. Test Plan: unit tests TBD: once I get initial feedback , i am planning to hook this in to checkpointing and test in some training jobs to validate across a variety of models Rollback Plan: Reviewed By: mikaylagawarecki Differential Revision: D75993324
|
This pull request was exported from Phabricator. Differential Revision: D75993324 |
|
@pytorchbot merge (Initiating merge automatically since Phabricator Diff has merged) |
Merge startedYour change will be merged once all checks pass (ETA 0-4 Hours). Learn more about merging in the wiki. Questions? Feedback? Please reach out to the PyTorch DevX Team |
| new_storage = type(storage)(storage.size(), device="cpu") | ||
|
|
||
| if self.pin_memory and new_storage.nbytes() > 0: | ||
| pin_memory_utils.pin_memory(new_storage.data_ptr(), new_storage.nbytes()) |
There was a problem hiding this comment.
@teja-rao Should this be
if not new_storage.is_pinned():
pin_memory_utils.pin_memory(new_storage.data_ptr(), new_storage.nbytes())Otherwise, the following gets raised on the second checkpointing call in my training script
self.async_wait()
File "/home/ubuntu/carlos/titan/.venv/lib/python3.10/site-packages/torch/utils/_contextlib.py", line 120, in decorate_context
return func(*args, **kwargs)
File "/home/ubuntu/carlos/titan/poolside/titan/checkpoint/fs_checkpoint.py", line 398, in async_wait
self.save_future.result()
File "/home/ubuntu/.local/share/uv/python/cpython-3.10.12-linux-x86_64-gnu/lib/python3.10/concurrent/futures/_base.py", line 451, in result
return self.__get_result()
File "/home/ubuntu/.local/share/uv/python/cpython-3.10.12-linux-x86_64-gnu/lib/python3.10/concurrent/futures/_base.py", line 403, in __get_result
raise self._exception
File "/home/ubuntu/.local/share/uv/python/cpython-3.10.12-linux-x86_64-gnu/lib/python3.10/concurrent/futures/thread.py", line 58, in run
result = self.fn(*self.args, **self.kwargs)
File "/home/ubuntu/carlos/titan/.venv/lib/python3.10/site-packages/torch/distributed/checkpoint/_async_process_executor.py", line 283, in _execute_save_impl
staging_future_or_state_dict.result()
File "/home/ubuntu/.local/share/uv/python/cpython-3.10.12-linux-x86_64-gnu/lib/python3.10/concurrent/futures/_base.py", line 451, in result
return self.__get_result()
File "/home/ubuntu/.local/share/uv/python/cpython-3.10.12-linux-x86_64-gnu/lib/python3.10/concurrent/futures/_base.py", line 403, in __get_result
raise self._exception
File "/home/ubuntu/.local/share/uv/python/cpython-3.10.12-linux-x86_64-gnu/lib/python3.10/concurrent/futures/thread.py", line 58, in run
result = self.fn(*self.args, **self.kwargs)
File "/home/ubuntu/carlos/titan/.venv/lib/python3.10/site-packages/torch/distributed/checkpoint/staging.py", line 228, in _stage
state_dict = self._state_dict_stager.stage(
File "/home/ubuntu/carlos/titan/.venv/lib/python3.10/site-packages/torch/utils/_contextlib.py", line 120, in decorate_context
return func(*args, **kwargs)
File "/home/ubuntu/carlos/titan/.venv/lib/python3.10/site-packages/torch/distributed/checkpoint/_state_dict_stager.py", line 168, in stage
return self.deepcopy_with_tensor_offload(state_dict, non_blocking=non_blocking)
File "/home/ubuntu/carlos/titan/.venv/lib/python3.10/site-packages/torch/utils/_contextlib.py", line 120, in decorate_context
return func(*args, **kwargs)
File "/home/ubuntu/carlos/titan/.venv/lib/python3.10/site-packages/torch/distributed/checkpoint/_state_dict_stager.py", line 271, in deepcopy_with_tensor_offload
y = copier(x, memo)
File "/home/ubuntu/carlos/titan/.venv/lib/python3.10/site-packages/torch/distributed/checkpoint/_state_dict_stager.py", line 81, in _deepcopy_dict
self.deepcopy_with_tensor_offload(value, memo)
File "/home/ubuntu/carlos/titan/.venv/lib/python3.10/site-packages/torch/utils/_contextlib.py", line 120, in decorate_context
return func(*args, **kwargs)
File "/home/ubuntu/carlos/titan/.venv/lib/python3.10/site-packages/torch/distributed/checkpoint/_state_dict_stager.py", line 271, in deepcopy_with_tensor_offload
y = copier(x, memo)
File "/home/ubuntu/carlos/titan/.venv/lib/python3.10/site-packages/torch/distributed/checkpoint/_state_dict_stager.py", line 81, in _deepcopy_dict
self.deepcopy_with_tensor_offload(value, memo)
File "/home/ubuntu/carlos/titan/.venv/lib/python3.10/site-packages/torch/utils/_contextlib.py", line 120, in decorate_context
return func(*args, **kwargs)
File "/home/ubuntu/carlos/titan/.venv/lib/python3.10/site-packages/torch/distributed/checkpoint/_state_dict_stager.py", line 266, in deepcopy_with_tensor_offload
y = self._offload_tensor(x, memo, non_blocking=non_blocking)
File "/home/ubuntu/carlos/titan/.venv/lib/python3.10/site-packages/torch/distributed/checkpoint/_state_dict_stager.py", line 196, in _offload_tensor
copied_storage = self._stage_untyped_storage(
File "/home/ubuntu/carlos/titan/.venv/lib/python3.10/site-packages/torch/distributed/checkpoint/_state_dict_stager.py", line 147, in _stage_untyped_storage
pin_memory_utils.pin_memory(new_storage.data_ptr(), new_storage.nbytes())
File "/home/ubuntu/carlos/titan/.venv/lib/python3.10/site-packages/torch/cuda/_pin_memory_utils.py", line 15, in pin_memory
raise RuntimeError(
RuntimeError: Registering memory failed with cudaError: 712. It's possible that this is an asynchronous error raised from a previous cuda operation. Consider launching with CUDA_LAUNCH_BLOCKING=1 to debug.I don't know why new_storage would be pinned on creation (some caching somewhere?), also, I'm not sure if the following finalize should be done in that case
Summary:
This implements staging in way that doesnt mess up checkpointing semantics. We want to be close to torch.save/load semantics and when async checkpointing is used it messes up shared storages, doesnt handle custom objects or tensors well. EG: users passes a state_dict with a cuda tensor in datatype. this is deepcloned causing the staging tensor to be created on GPU. This can cause ooms is hard to debug.
This diffs hooks into deepcopy of storages to move them to cpu using the cached storages created for async checkpoint staging. This allows reusing storages created for staging to avoid recreating them on each checkpoint while also being flexible enough to handle any changes - clean up old storages or create new ones as needed.
Lifetime of staging storages is tied to the original storage object. when the original storage object is gc-ed, we delete the corresponding staging storage from cache possibly causing it to gc-ed is there are no other references. I am using data_ptr of the storage to keep track of this. Please share thoughts on this.
The alternative is to use fqn's instead of storage_id and verify the underlying storage object has same shape/size,etc to make the caching logic work. Current implementation is much simpler and cleaner.
The API:
Also, adds support for pinned-memory.
One problem this implementation does not address is that we lose the original device.
The only alternatives here are - pickle synchronously like torch.save but with special handling for storages. It is valuable to keep state_dict throughout the checkpointing process. so users can manipulate and debug as needed. so we need to unpickle in the background process. I think this is flexible, not performant and not very different to current solution but needs more code. One idea if we really want to address is this to stick the original device in a some variable on storage and then use it recover on load side. I think we do not need this for now and can be explicit about losing device type for async checkpointing.
Update:
Note: Due to reservations on hooking into deepcopy to customize it, the PR is now updated to use deepcopy like logic to clone the state_dict. There are some caveats to this solution:
The new API:
Test Plan:
unit tests
Differential Revision: D75993324
cc @H-Huang @awgu @wanchaol @fegin @fduwjj @wz337 @wconstab @d4l3k