Skip to content

[dcp] add new checkpoint staging to preserve storage sharing and support mutable state_dicts#155192

Closed
teja-rao wants to merge 1 commit intopytorch:mainfrom
teja-rao:export-D75993324
Closed

[dcp] add new checkpoint staging to preserve storage sharing and support mutable state_dicts#155192
teja-rao wants to merge 1 commit intopytorch:mainfrom
teja-rao:export-D75993324

Conversation

@teja-rao
Copy link
Contributor

@teja-rao teja-rao commented Jun 5, 2025

Summary:
This implements staging in way that doesnt mess up checkpointing semantics. We want to be close to torch.save/load semantics and when async checkpointing is used it messes up shared storages, doesnt handle custom objects or tensors well. EG: users passes a state_dict with a cuda tensor in datatype. this is deepcloned causing the staging tensor to be created on GPU. This can cause ooms is hard to debug.

This diffs hooks into deepcopy of storages to move them to cpu using the cached storages created for async checkpoint staging. This allows reusing storages created for staging to avoid recreating them on each checkpoint while also being flexible enough to handle any changes - clean up old storages or create new ones as needed.

Lifetime of staging storages is tied to the original storage object. when the original storage object is gc-ed, we delete the corresponding staging storage from cache possibly causing it to gc-ed is there are no other references. I am using data_ptr of the storage to keep track of this. Please share thoughts on this.
The alternative is to use fqn's instead of storage_id and verify the underlying storage object has same shape/size,etc to make the caching logic work. Current implementation is much simpler and cleaner.

The API:

# construct a stager once per job in checkpointing. 
stager = StateDictStager(pin_memory=pin_memory, share_memory=share_memory)

# do this on every checkpoint:            
 with staging_context(stager):
     cpu_state_dict = copy.deepcopy(state_dict)

Also, adds support for pinned-memory.

One problem this implementation does not address is that we lose the original device.

The only alternatives here are - pickle synchronously like torch.save but with special handling for storages. It is valuable to keep state_dict throughout the checkpointing process. so users can manipulate and debug as needed. so we need to unpickle in the background process. I think this is flexible, not performant and not very different to current solution but needs more code. One idea if we really want to address is this to stick the original device in a some variable on storage and then use it recover on load side. I think we do not need this for now and can be explicit about losing device type for async checkpointing.

Update:
Note: Due to reservations on hooking into deepcopy to customize it, the PR is now updated to use deepcopy like logic to clone the state_dict. There are some caveats to this solution:

  1. Duplicated deepcopy code to hook into for tensors. There is a risk of this code getting outdated with python version changes. This is needed to handle several different types like NamedTuples, frozen dataclasses, nested dataclasses. deepcopy logic is relying on reduce_ex to get a function with which these can be constructed.
  2. Since we are bypassing deepcopy and adding custom logic to clone a tensor, we are missing some of the functionality that exists in deepcopy for torch.Tensor like _clear_non_serializable_cached_data(), or other logic. Would like thoughts on which logic or if everything should be copied?
  3. If any object implemented deepcopy , we will not be able to handle any tensors in the attrs with this logic because they likely just call copy.deepcopy on the attrs instead of this deepcopy logic. We are taking care of subclasses of torch.Tensor to workaround this.

The new API:

# construct a stager once per job in checkpointing. 
stager = StateDictStager(pin_memory=pin_memory, share_memory=share_memory)

# do this on every checkpoint:            
cpu_state_dict = copy.stage(state_dict)

Test Plan:
unit tests

Differential Revision: D75993324

cc @H-Huang @awgu @wanchaol @fegin @fduwjj @wz337 @wconstab @d4l3k

@pytorch-bot
Copy link

pytorch-bot bot commented Jun 5, 2025

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/155192

Note: Links to docs will display an error until the docs builds have been completed.

✅ No Failures

As of commit 50b4627 with merge base 728cf67 (image):
💚 Looks good so far! There are no failures yet. 💚

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@pytorch-bot pytorch-bot bot added oncall: distributed Add this issue/PR to distributed oncall triage queue release notes: distributed (checkpoint) labels Jun 5, 2025
@facebook-github-bot
Copy link
Contributor

This pull request was exported from Phabricator. Differential Revision: D75993324

@facebook-github-bot
Copy link
Contributor

This pull request was exported from Phabricator. Differential Revision: D75993324

@teja-rao teja-rao force-pushed the export-D75993324 branch from 20f8b37 to 715b2ee Compare June 5, 2025 05:22
@facebook-github-bot
Copy link
Contributor

This pull request was exported from Phabricator. Differential Revision: D75993324

@teja-rao teja-rao force-pushed the export-D75993324 branch from 715b2ee to f20d67f Compare June 5, 2025 05:26
Copy link
Contributor

@fegin fegin left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

While I like the the idea of using deepcopy, which is very simple, do we know the implication of deepcopy a DTensor/DeviceMesh that contains a process_group? I have to admit that I have never tried it before and we should verify this (e.g., whether there will be more PGs created). Not saying this can be an issue, but we should verify it.

Okay, DeviceMesh doesn't contain PGs. This shouldn't be a concern.

Comment on lines 39 to 42
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This doesn't match the current implementation. And are you going to implement the non_blocking feature?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's also good to tag the core member to review the code change as it touches the core part. cc., @albanD

Copy link
Contributor Author

@teja-rao teja-rao Jun 5, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I will update the PR to pass non-blocking to copy_ . i intend to support zero-copy.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

FYI: We've had previous issues with deepcopy and had to rollback save plan caching feature because of deepcopy performance. More context: #149320

Copy link
Contributor Author

@teja-rao teja-rao Jun 18, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

FYI: We've had previous issues with deepcopy and had to rollback save plan caching feature because of deepcopy performance. More context: #149320

@MeetVadakkanchery There is no way to avoid deepcopy if you want data to be modified while uploading a checkpoint. The best way to mitigate costs of deepcopy is to have a simpler state_dict object.

@fegin
Copy link
Contributor

fegin commented Jun 5, 2025

Another concern of deepcopy is that the state_dict may contain tensor subclass other than DTensor, which I'm not sure if that will work well. I'm not an expert in this area. cc., @danielvegamyhre

@teja-rao
Copy link
Contributor Author

teja-rao commented Jun 5, 2025

While I like the the idea of using deepcopy, which is very simple, do we know the implication of deepcopy a DTensor/DeviceMesh that contains a process_group? I have to admit that I have never tried it before and we should verify this (e.g., whether there will be more PGs created). Not saying this can be an issue, but we should verify it.

Okay, DeviceMesh doesn't contain PGs. This shouldn't be a concern.

Yes, this is a good point. If the tensors contain non-serializable objects this approach may cause issues but the user also has an option to define deep_copy and customize the behavior for staging, just like we are doing in tensor/storage.py to define exactly how they want the staging to work.

@teja-rao teja-rao requested a review from albanD June 5, 2025 16:47
@teja-rao teja-rao force-pushed the export-D75993324 branch from f20d67f to 65b4c96 Compare June 5, 2025 17:24
@facebook-github-bot
Copy link
Contributor

This pull request was exported from Phabricator. Differential Revision: D75993324

@teja-rao teja-rao force-pushed the export-D75993324 branch from 65b4c96 to 96ceddf Compare June 5, 2025 17:28
@facebook-github-bot
Copy link
Contributor

This pull request was exported from Phabricator. Differential Revision: D75993324

@teja-rao teja-rao force-pushed the export-D75993324 branch from 96ceddf to 0b742d9 Compare June 5, 2025 17:33
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

n00b question: how is storage.storage_deepcopy() invoked?

Copy link
Contributor

@meetv18 meetv18 Jun 5, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If a state_dict has many non-tensor items, state_dict.deep_copy() can run into perf issues.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This may worth to verify. Each DTensor will have a DeviceMesh, which is a non-tensor item.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this is a bit out-dated now that we changed the approach to mimic deepcopy instead of hooking into it.

@facebook-github-bot
Copy link
Contributor

This pull request was exported from Phabricator. Differential Revision: D75993324

@facebook-github-bot
Copy link
Contributor

This pull request was exported from Phabricator. Differential Revision: D75993324

@facebook-github-bot
Copy link
Contributor

This pull request was exported from Phabricator. Differential Revision: D75993324

@facebook-github-bot
Copy link
Contributor

This pull request was exported from Phabricator. Differential Revision: D75993324

8 similar comments
@facebook-github-bot
Copy link
Contributor

This pull request was exported from Phabricator. Differential Revision: D75993324

@facebook-github-bot
Copy link
Contributor

This pull request was exported from Phabricator. Differential Revision: D75993324

@facebook-github-bot
Copy link
Contributor

This pull request was exported from Phabricator. Differential Revision: D75993324

@facebook-github-bot
Copy link
Contributor

This pull request was exported from Phabricator. Differential Revision: D75993324

@facebook-github-bot
Copy link
Contributor

This pull request was exported from Phabricator. Differential Revision: D75993324

@facebook-github-bot
Copy link
Contributor

This pull request was exported from Phabricator. Differential Revision: D75993324

@facebook-github-bot
Copy link
Contributor

This pull request was exported from Phabricator. Differential Revision: D75993324

@facebook-github-bot
Copy link
Contributor

This pull request was exported from Phabricator. Differential Revision: D75993324

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why not do this inside the pin_memory call directly since it is expected all callers will do that?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

weakref needs the storage object to be passed in and that limits the API applicability to storage objects alone. With current APIs, the user can just allocate a raw memory chunk and pin it.

Copy link
Contributor

@mikaylagawarecki mikaylagawarecki left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

only nits from me

Comment on lines 63 to 86
Copy link
Contributor

@mikaylagawarecki mikaylagawarecki Jun 18, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: (only if helpful) I think you can get most of these checks (with the exception of the storage_offset check from assertEqual)

Copy link
Contributor

@mikaylagawarecki mikaylagawarecki Jun 18, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

here I meant assertEqual on the TestClass not torch.equal, but I now realize you define this function outside the testclass so you can ignore my previous comment, sorry about that

def assertEqual(
self,
x,
y,
msg: Optional[Union[str, Callable[[str], str]]] = None,
*,
atol: Optional[float] = None,
rtol: Optional[float] = None,
equal_nan=True,
exact_dtype=True,
# TODO: default this to True
exact_device=False,
exact_layout=False,
exact_stride=False,
exact_is_coalesced=False
):

Comment on lines 825 to 826
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: remove print statements

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: Should we add a test for non_blocking=True?

(I know this arg is just forwarding non_blocking to copy_ but since users will have to explicitly synchronize the CPU in that case, a test might be helpful to demonstrate how this should be done)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

good call on the test/example, will add it in a follow up PR.

The reason for making it Flase by default, is that non-blocking needs stream setup to be done properly and is non-trivial.

@facebook-github-bot
Copy link
Contributor

This pull request was exported from Phabricator. Differential Revision: D75993324

1 similar comment
@facebook-github-bot
Copy link
Contributor

This pull request was exported from Phabricator. Differential Revision: D75993324

…ort state_dict changes across checkpoints (pytorch#155192)

Summary:
Pull Request resolved: pytorch#155192

Update:
Updated the diff to avoid hooking into deepcopy by rolling out handwritten deepcopy like scaffolding. There are some caveats:

1. Duplicated deepcopy code to hook into for tensors. There is a risk of this code getting outdated with python version changes. This is needed to handle several different types like NamedTuples, frozen dataclasses, nested dataclasses. deepcopy logic is relying on __reduce_ex__ to get a function with which these can be constructed.

2. Since we are bypassing deepcopy and adding custom logic to clone a tensor, we are missing some of the functionality that exists in __deepcopy__ for torch.Tensor like _clear_non_serializable_cached_data(), or other logic. Would like thoughts on which logic or if everything should be copied?

3. If any object implemented __deepcopy__ , we will not be able to handle any tensors in the attrs with this logic because they likely just call copy.deepcopy on the attrs instead of this deepcopy logic. We are taking care of subclasses of torch.Tensor to workaround this.

This implements staging in way that doesnt mess up checkpointing semantics. We want to be close to torch.save/load semantics and when async checkpointing is used it messes up shared storages, doesnt handle custom objects or tensors well. EG: users passes a state_dict with a cuda tensor in datatype.  this is deepcloned causing the staging tensor to be created on GPU. This can cause ooms is hard to debug.

~~This diffs hooks into deepcopy of storages to move them to cpu using the cached storages created for async checkpoint staging.  This allows reusing storages created for staging to avoid recreating them on each checkpoint while also being flexible enough to handle any changes - clean up old storages or create new ones as needed.~~This diffs replicates deepcopy logic to clone the state_dict and add special handling for tensors to move them to cpu.

Lifetime of staging storages is tied to the original storage object. when the original storage object is gc-ed, we delete the corresponding staging storage from cache possibly causing it to gc-ed is there are no other references.  I am using data_ptr of the storage to keep track of this. Please share thoughts on this.
The alternative is to use fqn's instead of storage_id and verify the underlying storage object has same shape/size,etc to make the caching logic work. Current implementation is much simpler and cleaner.

The API:
```
# construct a stager once per job in checkpointing.
stager = StateDictStager(pin_memory=pin_memory, share_memory=share_memory)

# do this on every checkpoint:
 with staging_context(stager):
     cpu_state_dict = copy.deepcopy(state_dict)
```

Also, adds support for pinned-memory.

One problem this implementation does not address is that we lose the original device.

The only alternatives here are - pickle synchronously like torch.save but with special handling for storages. It is valuable to keep state_dict throughout the checkpointing process. so users can manipulate and debug as needed. so we need to unpickle in the background process. I think this is flexible, not performant and not very different to current solution but needs more code. One idea if we really want to address is this to stick the original device in a some variable on storage and then use it recover on load side. I think we do not need this for now and can be explicit about losing device type for async checkpointing.

Test Plan:
unit tests

TBD: once I get initial feedback , i am planning to hook this in to checkpointing and test in some training jobs to validate across a variety of models

Rollback Plan:

Reviewed By: mikaylagawarecki

Differential Revision: D75993324
@facebook-github-bot
Copy link
Contributor

This pull request was exported from Phabricator. Differential Revision: D75993324

@facebook-github-bot
Copy link
Contributor

@pytorchbot merge

(Initiating merge automatically since Phabricator Diff has merged)

@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged once all checks pass (ETA 0-4 Hours).

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

new_storage = type(storage)(storage.size(), device="cpu")

if self.pin_memory and new_storage.nbytes() > 0:
pin_memory_utils.pin_memory(new_storage.data_ptr(), new_storage.nbytes())
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@teja-rao Should this be

        if not new_storage.is_pinned():
            pin_memory_utils.pin_memory(new_storage.data_ptr(), new_storage.nbytes())

Otherwise, the following gets raised on the second checkpointing call in my training script

      self.async_wait()
    File "/home/ubuntu/carlos/titan/.venv/lib/python3.10/site-packages/torch/utils/_contextlib.py", line 120, in decorate_context
      return func(*args, **kwargs)
    File "/home/ubuntu/carlos/titan/poolside/titan/checkpoint/fs_checkpoint.py", line 398, in async_wait
      self.save_future.result()
    File "/home/ubuntu/.local/share/uv/python/cpython-3.10.12-linux-x86_64-gnu/lib/python3.10/concurrent/futures/_base.py", line 451, in result
      return self.__get_result()
    File "/home/ubuntu/.local/share/uv/python/cpython-3.10.12-linux-x86_64-gnu/lib/python3.10/concurrent/futures/_base.py", line 403, in __get_result
      raise self._exception
    File "/home/ubuntu/.local/share/uv/python/cpython-3.10.12-linux-x86_64-gnu/lib/python3.10/concurrent/futures/thread.py", line 58, in run
      result = self.fn(*self.args, **self.kwargs)
    File "/home/ubuntu/carlos/titan/.venv/lib/python3.10/site-packages/torch/distributed/checkpoint/_async_process_executor.py", line 283, in _execute_save_impl
      staging_future_or_state_dict.result()
    File "/home/ubuntu/.local/share/uv/python/cpython-3.10.12-linux-x86_64-gnu/lib/python3.10/concurrent/futures/_base.py", line 451, in result
      return self.__get_result()
    File "/home/ubuntu/.local/share/uv/python/cpython-3.10.12-linux-x86_64-gnu/lib/python3.10/concurrent/futures/_base.py", line 403, in __get_result
      raise self._exception
    File "/home/ubuntu/.local/share/uv/python/cpython-3.10.12-linux-x86_64-gnu/lib/python3.10/concurrent/futures/thread.py", line 58, in run
      result = self.fn(*self.args, **self.kwargs)
    File "/home/ubuntu/carlos/titan/.venv/lib/python3.10/site-packages/torch/distributed/checkpoint/staging.py", line 228, in _stage
      state_dict = self._state_dict_stager.stage(
    File "/home/ubuntu/carlos/titan/.venv/lib/python3.10/site-packages/torch/utils/_contextlib.py", line 120, in decorate_context
      return func(*args, **kwargs)
    File "/home/ubuntu/carlos/titan/.venv/lib/python3.10/site-packages/torch/distributed/checkpoint/_state_dict_stager.py", line 168, in stage
      return self.deepcopy_with_tensor_offload(state_dict, non_blocking=non_blocking)
    File "/home/ubuntu/carlos/titan/.venv/lib/python3.10/site-packages/torch/utils/_contextlib.py", line 120, in decorate_context
      return func(*args, **kwargs)
    File "/home/ubuntu/carlos/titan/.venv/lib/python3.10/site-packages/torch/distributed/checkpoint/_state_dict_stager.py", line 271, in deepcopy_with_tensor_offload
      y = copier(x, memo)
    File "/home/ubuntu/carlos/titan/.venv/lib/python3.10/site-packages/torch/distributed/checkpoint/_state_dict_stager.py", line 81, in _deepcopy_dict
      self.deepcopy_with_tensor_offload(value, memo)
    File "/home/ubuntu/carlos/titan/.venv/lib/python3.10/site-packages/torch/utils/_contextlib.py", line 120, in decorate_context
      return func(*args, **kwargs)
    File "/home/ubuntu/carlos/titan/.venv/lib/python3.10/site-packages/torch/distributed/checkpoint/_state_dict_stager.py", line 271, in deepcopy_with_tensor_offload
      y = copier(x, memo)
    File "/home/ubuntu/carlos/titan/.venv/lib/python3.10/site-packages/torch/distributed/checkpoint/_state_dict_stager.py", line 81, in _deepcopy_dict
      self.deepcopy_with_tensor_offload(value, memo)
    File "/home/ubuntu/carlos/titan/.venv/lib/python3.10/site-packages/torch/utils/_contextlib.py", line 120, in decorate_context
      return func(*args, **kwargs)
    File "/home/ubuntu/carlos/titan/.venv/lib/python3.10/site-packages/torch/distributed/checkpoint/_state_dict_stager.py", line 266, in deepcopy_with_tensor_offload
      y = self._offload_tensor(x, memo, non_blocking=non_blocking)
    File "/home/ubuntu/carlos/titan/.venv/lib/python3.10/site-packages/torch/distributed/checkpoint/_state_dict_stager.py", line 196, in _offload_tensor
      copied_storage = self._stage_untyped_storage(
    File "/home/ubuntu/carlos/titan/.venv/lib/python3.10/site-packages/torch/distributed/checkpoint/_state_dict_stager.py", line 147, in _stage_untyped_storage
      pin_memory_utils.pin_memory(new_storage.data_ptr(), new_storage.nbytes())
    File "/home/ubuntu/carlos/titan/.venv/lib/python3.10/site-packages/torch/cuda/_pin_memory_utils.py", line 15, in pin_memory
      raise RuntimeError(
  RuntimeError: Registering memory failed with cudaError: 712. It's possible that this is an asynchronous error raised from a previous cuda operation. Consider launching with CUDA_LAUNCH_BLOCKING=1 to debug.

I don't know why new_storage would be pinned on creation (some caching somewhere?), also, I'm not sure if the following finalize should be done in that case

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ciflow/trunk Trigger trunk jobs on your pull request fb-exported Merged oncall: distributed Add this issue/PR to distributed oncall triage queue release notes: distributed (checkpoint)

Projects

None yet

Development

Successfully merging this pull request may close these issues.

9 participants