-
Notifications
You must be signed in to change notification settings - Fork 27.4k
FSDP state_dict transformations of modules with persistent buffers fail with mixed precision enabled #93391
Description
🐛 Describe the bug
Firstly, thank you to the PyTorch Distributed team for your invaluable contributions to the PyTorch ecosystem, your work is immensely impressive and inspiring!
In preparing the downstream package I maintain (finetuning-scheduler) to support PyTorch 2.0's version of FSDP, I noticed modules that include multiple persistent buffers were not having their state properly transformed during saving of state_dicts.
The issue was that the post-state_dict hook codepath shared by the FULL_STATE_DICT and SHARDED_STATE_DICT _state_dict_types (_common_unshard_post_state_dict_hook) was inadvertently referencing a local variable (buffer) that was used in a prior transformation, instead of the buffers variable that should have been referenced in the iteration context:
pytorch/torch/distributed/fsdp/_state_dict_utils.py
Lines 251 to 253 in 332d55d
| for buffers, clean_fqn in zip(buffers, buffer_clean_fqns): | |
| fqn = f"{prefix}{clean_fqn}" | |
| state_dict[fqn] = buffer.clone() |
In this case, modules with a single persistent buffer or without mixed precision enabled would be unaffected. With multiple buffers and mixed precision enabled however, the issue may appear stochastically in proportion to the ratio of persistent buffers that have compatible dimensions (since the value of the last buffer visited in the buffer_names Set is copied to all buffers and the Set iteration order will of course vary)
File ".../pytorch/torch/nn/modules/module.py", line 2028, in load_state_dict
raise RuntimeError('Error(s) in loading state_dict for {}:\n\t{}'.format(
RuntimeError: Error(s) in loading state_dict for FullyShardedDataParallel:
size mismatch for _fsdp_wrapped_module.1._fsdp_wrapped_module.running_mean: copying a param with shape torch.Size([]) from checkpoint, the shape in current model is torch.Size([10]).To both address this issue and enhance coverage to avoid similar issues, I'll be opening a PR momentarily that fixes this typo and adds an additional set of basic tests that validates state_dict saving and loading for modules with persistent buffers in various contexts.
Thanks again to the PyTorch Distributed team and PyTorch community more broadly for all your work!
Versions
Collecting environment information...
PyTorch version: 2.0.0.dev20230131
Is debug build: False
CUDA used to build PyTorch: 11.8
ROCM used to build PyTorch: N/A
OS: Ubuntu 20.04.5 LTS (x86_64)
GCC version: (Ubuntu 9.4.0-1ubuntu1~20.04.1) 9.4.0
Clang version: Could not collect
CMake version: version 3.22.1
Libc version: glibc-2.31
Python version: 3.10.9 (main, Jan 11 2023, 15:21:40) [GCC 11.2.0] (64-bit runtime)
Python platform: Linux-5.4.0-137-generic-x86_64-with-glibc2.31
Is CUDA available: True
CUDA runtime version: Could not collect
CUDA_MODULE_LOADING set to: LAZY
GPU models and configuration:
GPU 0: NVIDIA GeForce RTX 2070 SUPER
GPU 1: NVIDIA GeForce RTX 2070
Nvidia driver version: 525.60.13
cuDNN version: Probably one of the following:
/usr/lib/x86_64-linux-gnu/libcudnn.so.8.7.0
/usr/lib/x86_64-linux-gnu/libcudnn_adv_infer.so.8.7.0
/usr/lib/x86_64-linux-gnu/libcudnn_adv_train.so.8.7.0
/usr/lib/x86_64-linux-gnu/libcudnn_cnn_infer.so.8.7.0
/usr/lib/x86_64-linux-gnu/libcudnn_cnn_train.so.8.7.0
/usr/lib/x86_64-linux-gnu/libcudnn_ops_infer.so.8.7.0
/usr/lib/x86_64-linux-gnu/libcudnn_ops_train.so.8.7.0
/usr/local/cuda-10.2/targets/x86_64-linux/lib/libcudnn.so
HIP runtime version: N/A
MIOpen runtime version: N/A
Is XNNPACK available: True
Versions of relevant libraries:
[pip3] mypy==0.981
[pip3] mypy-extensions==0.4.3
[pip3] numpy==1.23.5
[pip3] torch==1.13.0a0+git95112ca
[conda] blas 1.0 mkl
[conda] mkl 2021.4.0 h06a4308_640
[conda] mkl-service 2.4.0 py310h7f8727e_0
[conda] mkl_fft 1.3.1 py310hd6ae3a3_0
[conda] mkl_random 1.2.2 py310h00e6091_0
[conda] numpy 1.23.5 py310hd5efca6_0
[conda] numpy-base 1.23.5 py310h8e6c178_0
[conda] pytorch-cuda 11.8 h8dd9ede_2 pytorch-nightly
[conda] pytorch-mutex 1.0 cuda pytorch-nightly
[conda] torchtriton 2.0.0+0d7e753227 py310 pytorch-nightly
cc @ezyang @gchanan @zou3519 @mrshenli @pritamdamania87 @zhaojuanmao @satgera @rohan-varma @gqchen @aazzolini @osalpekar @jiayisuse @H-Huang @kwen2501 @awgu