Skip to content

[Discussion] Loading and initialising large models with Flax #15766

@patil-suraj

Description

@patil-suraj

Opening up this issue to discuss how to initialise and load large models with Flax.

Goals:

  1. Be able to initialise and pjit really large models that don’t fit on a single device.
  2. Avoid memory fragmentation when loading pre-trained weights in from_pretrained.

Issues:

  • With the current API it’s not possible to initialise a model which does not fit on a single device, since the model
    weights are always initialised on the available device when we init the model class.
model = FlaxPretrainedModel(conifg) # weights are always initialised in this call. OOMs when the model can't fit on the device. 

It’s also not easy to use with pjit or init on CPU as described by @borisdayma in this colab

  • The from_pretrained method creates memory fragmentation and also fails to load large models. Because in from_pretrained the model is first initialised randomly to get names and shapes of parameters and then the pre-trained weights are loaded. Here both the random weights and pre-trained weights are on the device so we have two copies of weights in device memory.

Possible solutions:

Posting the two options proposed by @borisdayma

Option 1: params returned separately

model, params = FlaxPretrainedModel(…)
predictions = model(params, input)

Cons:

  • Would be a big breaking change and doesn’t really fit well with the rest of the Transformers API.
  • Still doesn’t solve the problem of initialising really large models.

Option 2: params created separately

model = FlaxPretrainedModel(…)
params = model.init_weights(…)
  • here we can pjit the init_weights method or jit it with cpu backend to initialise the weights on CPU.

How would we implement this?

  • introduce a new flag do_init (need a better name) which will default to True, and the API will stay the same as it is now.
  • if it’s False then don’t init the params. In this case model.params will always be None and the users will always have to pass
    params to the forward call. We use eval_shape to get the params tree with shapes and save it on model.params_shape_tree (better name ?)

It would look something like this:

# existing API
model = FlaxPretrainedModel(…)
model.params # this will store the model params

# with do_init=False
model = FlaxPretrainedModel(do_init=False)
model.params # this will be None

model(inputs) # this will raise error since params are not initialised

params = model.init_weights()
# always pass params to forward
model(inputs, params=params)

How to handle from_pretrained in this case (do_init=False)?

This is also related to the second goal. To avoid fragmentation we could use jax.eval_shape here to get the params shape info that is required when loading pre-trained weights.

But there are a few issues with this:

1. How to load a model with a head with pre-trained weights from the base model?

In this case, we need to randomly initialise only the head weights and load rest from pre-trained. But there seems to be no way of only initialising some parameters and not others.
One possible solution is to add a method called init_head to every module with a head and when head weights are missing we call that method. See #15584 for what that would look like.
But this adds a lot of boilerplate code since we’ll need to do this for every head module in every flax model.

It’s also not possible to do this when some other intermediate weights are missing and need to be initialised randomly.

@marcvanzee @jheek @avital is there any way in Flax where we could just initialise certain parameters without initialising the whole model?

2. What if the pre-trained weights don’t fit on a single device?

In this case, we’ll need to load the weights on the CPU and then let the user shard those across devices.
Should we introduce an argument called device or backend which will specify where to load the weights? (cc @patrickvonplaten )

3. Should we return the loaded weights or assign them to model.params ?

When do_init is False we always set model.params to None and keep the params external as described above. But then the from_pretrained API will look something like below.

model, params = FlaxPretrainedModel.from_pretrained(do_init=False)

which again is not well aligned with the rest of the API. What do we think about this? Looking forward to your suggestions @patrickvonplaten @borisdayma @LysandreJik

@jheek @marcvanzee @avital It would be awesome if you could take a look at this and share how large model loading is handled in Flax.

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions