-
Notifications
You must be signed in to change notification settings - Fork 32.5k
Description
Opening up this issue to discuss how to initialise and load large models with Flax.
Goals:
- Be able to initialise and pjit really large models that don’t fit on a single device.
- Avoid memory fragmentation when loading pre-trained weights in
from_pretrained.
Issues:
- With the current API it’s not possible to initialise a model which does not fit on a single device, since the model
weights are always initialised on the available device when we init the model class.
model = FlaxPretrainedModel(conifg) # weights are always initialised in this call. OOMs when the model can't fit on the device. It’s also not easy to use with pjit or init on CPU as described by @borisdayma in this colab
- The
from_pretrainedmethod creates memory fragmentation and also fails to load large models. Because infrom_pretrainedthe model is first initialised randomly to get names and shapes of parameters and then the pre-trained weights are loaded. Here both the random weights and pre-trained weights are on the device so we have two copies of weights in device memory.
Possible solutions:
Posting the two options proposed by @borisdayma
Option 1: params returned separately
model, params = FlaxPretrainedModel(…)
predictions = model(params, input)Cons:
- Would be a big breaking change and doesn’t really fit well with the rest of the Transformers API.
- Still doesn’t solve the problem of initialising really large models.
Option 2: params created separately
model = FlaxPretrainedModel(…)
params = model.init_weights(…)- here we can pjit the
init_weightsmethod or jit it with cpu backend to initialise the weights on CPU.
How would we implement this?
- introduce a new flag
do_init(need a better name) which will default toTrue, and the API will stay the same as it is now. - if it’s
Falsethen don’t init the params. In this casemodel.paramswill always beNoneand the users will always have to pass
params to the forward call. We useeval_shapeto get the params tree with shapes and save it onmodel.params_shape_tree(better name ?)
It would look something like this:
# existing API
model = FlaxPretrainedModel(…)
model.params # this will store the model params
# with do_init=False
model = FlaxPretrainedModel(do_init=False)
model.params # this will be None
model(inputs) # this will raise error since params are not initialised
params = model.init_weights()
# always pass params to forward
model(inputs, params=params)How to handle from_pretrained in this case (do_init=False)?
This is also related to the second goal. To avoid fragmentation we could use jax.eval_shape here to get the params shape info that is required when loading pre-trained weights.
But there are a few issues with this:
1. How to load a model with a head with pre-trained weights from the base model?
In this case, we need to randomly initialise only the head weights and load rest from pre-trained. But there seems to be no way of only initialising some parameters and not others.
One possible solution is to add a method called init_head to every module with a head and when head weights are missing we call that method. See #15584 for what that would look like.
But this adds a lot of boilerplate code since we’ll need to do this for every head module in every flax model.
It’s also not possible to do this when some other intermediate weights are missing and need to be initialised randomly.
@marcvanzee @jheek @avital is there any way in Flax where we could just initialise certain parameters without initialising the whole model?
2. What if the pre-trained weights don’t fit on a single device?
In this case, we’ll need to load the weights on the CPU and then let the user shard those across devices.
Should we introduce an argument called device or backend which will specify where to load the weights? (cc @patrickvonplaten )
3. Should we return the loaded weights or assign them to model.params ?
When do_init is False we always set model.params to None and keep the params external as described above. But then the from_pretrained API will look something like below.
model, params = FlaxPretrainedModel.from_pretrained(do_init=False)which again is not well aligned with the rest of the API. What do we think about this? Looking forward to your suggestions @patrickvonplaten @borisdayma @LysandreJik
@jheek @marcvanzee @avital It would be awesome if you could take a look at this and share how large model loading is handled in Flax.