Skip to content

[FSDP] Question about getting model state dict #658

@shuyingsunshine21

Description

@shuyingsunshine21

❓ Questions and Help

When we collect state_dict of FSDP, we summon_full_params and do post-processing.

However, I found that when we do post processing, we have

state_dict[key] = state_dict[key].clone()

does it mean we double the memory for the model states? For large model, it would cause CUDA OOM.

What is the suggested workaround?

cc @min-xu-ai , @sshleifer

Metadata

Metadata

Assignees

No one assigned

    Labels

    FSDPFullyShardedDataParallel (zero-3)

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions