-
Notifications
You must be signed in to change notification settings - Fork 27.2k
Description
🚀 The feature, motivation and pitch
Here is a puzzle for you:
Say, you have 24GB of GPU RAM and 32GB of CPU RAM and a pretrained fp32 pytorch model checkpoint that is 40GB big.
Say, you want to inference or finetune in fp16 or bf16 and you have enough CPU and GPU memory to handle a 20GB model.
It should be possible, but you can't load the model in half precision because torch.load is inflexible and requires that all 40GB be loaded first.
The main issue here is that the user has enough memory to start training in half-precision, but once they saved the checkpoint, they can't resume since they won't have enough memory to allocate the model and load the checkpoint.
Are there any plans to make torch.load more flexible and not load the whole thing at once but do it one sub-module (or even param) at a time and bonus for converting to the target torch.dtype on the fly.
In other words the hardware requirements should be close to the final model size and not 2x or 3x. Currently it's 3x when the original is in the higher dtype. Here is the breakdown:
1x allocate the model in fp16
2x load the fp32 model
-------------------------------
3x Total peak memory
With flexible torch.load:
1x allocate the model in fp16
0.05x allocate the largest layer in fp32
-------------------------------
1.05x Total peak memory
(for demonstration 0.01 is an example of a model with 20 equal layers)
At HF Transformers we already use the hack of allocating the model on the meta device and then materializing the model from the state_dict, but this is still not enough as the full fp32 model model has to be loaded first.
Thank you!
cc @mruberry