-
Notifications
You must be signed in to change notification settings - Fork 4.7k
Description
Is your feature request related to a problem? Please describe.
So recently a user reported this to me and 2 days ago I run into this issue myself.
The problem is that if init weights is run outside the sub-module constructor, zero.Init would have already sharded the weights placing a size zero placeholder which the custom init weights is run on, but the actual weights remain untouched.
The _init_weights function is a standard override that sub-modules of HF transformers eco-system do.
Consider:
def _init_weights(self, module):
if isinstance(module, nn.Linear):
module.weight.data.normal_(mean=0.0, std=self.config.init_std)
if module.bias is not None:
module.bias.data.zero_()
this silently does nothing.
and of course requires:
def _init_weights(self, module):
deepspeed.zero.GatheredParameters(module, modifier_rank=0):
if isinstance(module, nn.Linear):
module.weight.data.normal_(mean=0.0, std=self.config.init_std)
if module.bias is not None:
module.bias.data.zero_()
Describe the solution you'd like
I don't have any ideas yet on how to approach this. What we want is to flag to the user (ideally assert that their init isn't working under zero.Init w/o the use of GatheredParameters.
How can one diagnose this problem:
- Often the symptom of missing
_init_weightsis a much higher loss in the first few steps. - So if one has enough cpu RAM one can test w/ and w/o
zero.Initand if the loss isn't the same then they know some init didn't run.
For some reason I thought that if the codebase uses module.weight.data instead of module.weight then it'll assert on a zero-length tensor, but this appears to be wrong. and it's silent on either approach.
Now I remember seeing a torch warning about doing some function on a zero-length tensor. For some reason I can't find it now. Will continue looking. but I think it only happens in some init functions, so not reliable.
ok, this bites in other areas too: here is a snippet of discussion from elsewhere:
oddly I'm seeing a small discrepancy between z2 and z3 reports. Any ideas why this might be happening?
z3: Number of trainable parameters = 60492288
z2: Number of trainable parameters = 60506624
and I traced it down to this code:
embedding_size = model.get_input_embeddings().weight.shape[0]
if len(tokenizer) > embedding_size:
model.resize_token_embeddings(len(tokenizer))
it silently reshaped the embedding under zero3, but not under zero2 or no deepspeed.
I'm not sure how to approach it but I think when zero3 hides the tensor it should flag to the user when normal accessors that don't return the truth should assert.
e.g. p.requires_grad is fine, p.shape or p.numel should assert so that the user will write code to gather the param