❓ Questions and Help
When we collect state_dict of FSDP, we summon_full_params and do post-processing.
However, I found that when we do post processing, we have
state_dict[key] = state_dict[key].clone()
does it mean we double the memory for the model states? For large model, it would cause CUDA OOM.
What is the suggested workaround?
cc @min-xu-ai , @sshleifer