Skip to content

[REQUEST] zero.Init and silent skipping of custom _init_weights functions #2650

@stas00

Description

@stas00

Is your feature request related to a problem? Please describe.

So recently a user reported this to me and 2 days ago I run into this issue myself.

The problem is that if init weights is run outside the sub-module constructor, zero.Init would have already sharded the weights placing a size zero placeholder which the custom init weights is run on, but the actual weights remain untouched.

The _init_weights function is a standard override that sub-modules of HF transformers eco-system do.

Consider:

    def _init_weights(self, module):
        if isinstance(module, nn.Linear):
            module.weight.data.normal_(mean=0.0, std=self.config.init_std)
            if module.bias is not None:
                module.bias.data.zero_()

this silently does nothing.

and of course requires:

    def _init_weights(self, module):
        deepspeed.zero.GatheredParameters(module, modifier_rank=0):
            if isinstance(module, nn.Linear):
                module.weight.data.normal_(mean=0.0, std=self.config.init_std)
                if module.bias is not None:
                    module.bias.data.zero_()

Describe the solution you'd like

I don't have any ideas yet on how to approach this. What we want is to flag to the user (ideally assert that their init isn't working under zero.Init w/o the use of GatheredParameters.

How can one diagnose this problem:

  • Often the symptom of missing _init_weights is a much higher loss in the first few steps.
  • So if one has enough cpu RAM one can test w/ and w/o zero.Init and if the loss isn't the same then they know some init didn't run.

For some reason I thought that if the codebase uses module.weight.data instead of module.weight then it'll assert on a zero-length tensor, but this appears to be wrong. and it's silent on either approach.

Now I remember seeing a torch warning about doing some function on a zero-length tensor. For some reason I can't find it now. Will continue looking. but I think it only happens in some init functions, so not reliable.


ok, this bites in other areas too: here is a snippet of discussion from elsewhere:

oddly I'm seeing a small discrepancy between z2 and z3 reports. Any ideas why this might be happening?

z3: Number of trainable parameters = 60492288
z2: Number of trainable parameters = 60506624

and I traced it down to this code:

    embedding_size = model.get_input_embeddings().weight.shape[0]
    if len(tokenizer) > embedding_size:
        model.resize_token_embeddings(len(tokenizer))

it silently reshaped the embedding under zero3, but not under zero2 or no deepspeed.

I'm not sure how to approach it but I think when zero3 hides the tensor it should flag to the user when normal accessors that don't return the truth should assert.
e.g. p.requires_grad is fine, p.shape or p.numel should assert so that the user will write code to gather the param

Metadata

Metadata

Assignees

No one assigned

    Labels

    bugSomething isn't workingtraining

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions