Skip to content

large model, low memory: need torch.load that loads one submodule at a time #75242

@stas00

Description

@stas00

🚀 The feature, motivation and pitch

Here is a puzzle for you:

Say, you have 24GB of GPU RAM and 32GB of CPU RAM and a pretrained fp32 pytorch model checkpoint that is 40GB big.

Say, you want to inference or finetune in fp16 or bf16 and you have enough CPU and GPU memory to handle a 20GB model.

It should be possible, but you can't load the model in half precision because torch.load is inflexible and requires that all 40GB be loaded first.

The main issue here is that the user has enough memory to start training in half-precision, but once they saved the checkpoint, they can't resume since they won't have enough memory to allocate the model and load the checkpoint.

Are there any plans to make torch.load more flexible and not load the whole thing at once but do it one sub-module (or even param) at a time and bonus for converting to the target torch.dtype on the fly.

In other words the hardware requirements should be close to the final model size and not 2x or 3x. Currently it's 3x when the original is in the higher dtype. Here is the breakdown:

1x allocate the model in fp16
2x load the fp32 model
-------------------------------
3x Total peak memory

With flexible torch.load:

1x allocate the model in fp16
0.05x allocate the largest layer in fp32
-------------------------------
1.05x Total peak memory

(for demonstration 0.01 is an example of a model with 20 equal layers)

At HF Transformers we already use the hack of allocating the model on the meta device and then materializing the model from the state_dict, but this is still not enough as the full fp32 model model has to be loaded first.

Thank you!

cc @mruberry

Metadata

Metadata

Assignees

No one assigned

    Labels

    module: serializationIssues related to serialization (e.g., via pickle, or otherwise) of PyTorch objectstriagedThis issue has been looked at a team member, and triaged and prioritized into an appropriate module

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions