Skip to content

Fix tied weight embeddings fails to load state dict #128076

Closed
j316chuck wants to merge 3 commits intopytorch:mainfrom
j316chuck:patch-1
Closed

Fix tied weight embeddings fails to load state dict #128076
j316chuck wants to merge 3 commits intopytorch:mainfrom
j316chuck:patch-1

Conversation

@j316chuck
Copy link

@j316chuck j316chuck commented Jun 5, 2024

@pytorch-bot pytorch-bot bot added module: distributed_checkpoint oncall: distributed Add this issue/PR to distributed oncall triage queue labels Jun 5, 2024
@pytorch-bot
Copy link

pytorch-bot bot commented Jun 5, 2024

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/128076

Note: Links to docs will display an error until the docs builds have been completed.

❌ 2 New Failures, 3 Unrelated Failures

As of commit 996783b with merge base a7c5968 (image):

NEW FAILURES - The following jobs have failed:

FLAKY - The following job failed but was likely due to flakiness present on trunk:

UNSTABLE - The following jobs failed but were likely due to flakiness present on trunk and has been marked as unstable:

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@linux-foundation-easycla
Copy link

linux-foundation-easycla bot commented Jun 5, 2024

CLA Not Signed

@j316chuck j316chuck changed the title Update state_dict.py Fix Tied Weight Embeddings Fail to Load on Torch 2.3 Jun 5, 2024
@j316chuck j316chuck changed the title Fix Tied Weight Embeddings Fail to Load on Torch 2.3 Fix Tied Weight Embeddings Fail to Load State Dict on Torch 2.3 Jun 5, 2024
@j316chuck j316chuck changed the title Fix Tied Weight Embeddings Fail to Load State Dict on Torch 2.3 Fix Tied Weight Embeddings Fail to Load State Dict Jun 5, 2024
@j316chuck j316chuck changed the title Fix Tied Weight Embeddings Fail to Load State Dict Fix tied weight embeddings fails to load state dict Jun 5, 2024
@awgu
Copy link
Collaborator

awgu commented Jun 6, 2024

@fegin @wz337 @LucasLLC could one of you guys help review? thanks!

@fegin fegin added the ciflow/periodic Trigger jobs ran periodically on master (periodic.yml) on the PR label Jun 6, 2024
@pytorch-bot
Copy link

pytorch-bot bot commented Jun 6, 2024

Please seek CI approval before scheduling CIFlow labels

@pytorch-bot pytorch-bot bot removed the ciflow/periodic Trigger jobs ran periodically on master (periodic.yml) on the PR label Jun 6, 2024
@fegin
Copy link
Contributor

fegin commented Jun 6, 2024

I'm still trying to understand why the test was broken. But I don't think this is an appropriate fix we should land. We should fix _iterate_valid_model_state if it does not correctly traverse to the module.

@janeyx99 janeyx99 requested a review from fegin June 6, 2024 17:15
@janeyx99 janeyx99 added the triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module label Jun 6, 2024
@j316chuck
Copy link
Author

j316chuck commented Jun 7, 2024

@fegin you can reproduce this test failure on a 2-gpu instance:

git clone https://github.com/mosaicml/composer/git 
cd composer 
pip install -e .[all]
LOCAL_WORLD_SIZE=1 python3   -m coverage run -m pytest -v --durations=20 -m 'not daily and not remote and gpu and (doctest or not doctest)' -o tmp_path_retention_policy=none --codeblocks -k 

Make sure to replace what we have here in _patch_pytorch.py with your original logic.

@j316chuck
Copy link
Author

j316chuck commented Jun 7, 2024

Concretely I think the difference is:

    for name, param in chain(model.named_parameters(), model.named_buffers()):
           print(name)

Does not contain tied weight embeddings.

However, your function _iterate_valid_model_state:

for name, _ in _iterate_valid_model_state(model):
      print(name)

Does contain tied weight embeddings.

This yields the error in the test: KeyError: 'model.cls.predictions.decoder.weight'

@fegin fegin added ciflow/trunk Trigger trunk jobs on your pull request ciflow/periodic Trigger jobs ran periodically on master (periodic.yml) on the PR labels Jun 7, 2024
@fegin
Copy link
Contributor

fegin commented Jun 7, 2024

@j316chuck It would be helpful if you can point out how this tied embedding weight is initialized. I'm not sure what does "tied" mean for the embedding weight. As I mentioned in the issue, _iterate_valid_model_state also uses named_buffers and named_parameters internally. The only difference is _iterate_valid_model_state uses the non-recursive version and performs the recursion by itself. It would be more helpful that I can understand what this tied embedding weight mean.

@fegin
Copy link
Contributor

fegin commented Jun 7, 2024

@j316chuck Another way to rephrase my question, you will need an unittest in PyTorch, not in the composer, to verify this PR to get this PR landed.

@j316chuck
Copy link
Author

j316chuck commented Jun 8, 2024

@fegin can you help me add this unit test, I ran into a lot of lint / pr issues when trying to commit directly into pytorch. I believe I linked you the composer unit test that you can pattern match off of fwiw? Here is how we initialize the model in composer from HF.

A tied weight embedding layer is when the input weight embeddings is tied into the output layer of the LLM as well. This layer is double counted in _iterate_valid_model_state but not chain(model.named_parameters(), model.named_buffers()). This creates the state dict loading issues we saw in our composer unit tests.

Here's an example of tied weight embeddings:

import torch
import torch.nn as nn

class TiedEmbeddingModel(nn.Module):
    def __init__(self, vocab_size, embedding_dim):
        super(TiedEmbeddingModel, self).__init__()
        self.embedding = nn.Embedding(vocab_size, embedding_dim)
        self.decoder = nn.Linear(embedding_dim, vocab_size)
        self.decoder.weight = self.embedding.weight  # Tying weights

    def forward(self, input):
        embedded = self.embedding(input)
        output = self.decoder(embedded)
        return output

# Example usage
vocab_size = 10000
embedding_dim = 300
model = TiedEmbeddingModel(vocab_size, embedding_dim)

# Save model state_dict
torch.save(model.state_dict(), 'tied_embedding_model.pth')

# Load model state_dict
loaded_model = TiedEmbeddingModel(vocab_size, embedding_dim)
loaded_model.load_state_dict(torch.load('tied_embedding_model.pth'))

@Skylion007
Copy link
Collaborator

We definitely want to cherry pick this into 2.4rc as it's a pretty nasty regression (TensorEngine stops working).

@Skylion007 Skylion007 added this to the 2.4.0 milestone Jun 11, 2024
@fegin
Copy link
Contributor

fegin commented Jun 13, 2024

@j316chuck I can reproduce the issue with the model definition you provided. However, #125336 is correct. It just surfaces an issue of shared parameter not properly handled in distributed state_dict. This PR will hide the issue again, which may become an issue in the future. I'll submit a PR to fix the issue.

@fegin
Copy link
Contributor

fegin commented Jun 14, 2024

@j316chuck Please check if #128685 fixes the issue.

fegin added a commit that referenced this pull request Jun 14, 2024
…or optimizer state_dict"

* 
Fixes #128011

See the discussion in #128076

Current implementation of `set_optimizer_state_dict()` assumes that all the fqns returned by `_get_fqns()` must exist in the optimizer state_dict. This is not true if the model has shared parameters. In such a case, only one fqn of the shared parameters will appear in the optimizer state_dict. This PR addresses the issue.

Differential Revision: [D58573487](https://our.internmc.facebook.com/intern/diff/D58573487/)

cc mrshenli pritamdamania87 zhaojuanmao satgera gqchen aazzolini osalpekar jiayisuse H-Huang kwen2501 awgu penguinwu XilunWu wanchaol fduwjj wz337 tianyu-l wconstab yf225 chauhang d4l3k LucasLLC MeetVadakkanchery mhorowitz

[ghstack-poisoned]
fegin added a commit that referenced this pull request Jun 14, 2024
…te_dict"

* 
Fixes #128011

See the discussion in #128076

Current implementation of `set_optimizer_state_dict()` assumes that all the fqns returned by `_get_fqns()` must exist in the optimizer state_dict. This is not true if the model has shared parameters. In such a case, only one fqn of the shared parameters will appear in the optimizer state_dict. This PR addresses the issue.

Differential Revision: [D58573487](https://our.internmc.facebook.com/intern/diff/D58573487/)

cc mrshenli pritamdamania87 zhaojuanmao satgera gqchen aazzolini osalpekar jiayisuse H-Huang kwen2501 awgu penguinwu XilunWu wanchaol fduwjj wz337 tianyu-l wconstab yf225 chauhang d4l3k LucasLLC MeetVadakkanchery mhorowitz

[ghstack-poisoned]
fegin added a commit that referenced this pull request Jun 17, 2024
…or optimizer state_dict"

* 
Fixes #128011

See the discussion in #128076

Current implementation of `set_optimizer_state_dict()` assumes that all the fqns returned by `_get_fqns()` must exist in the optimizer state_dict. This is not true if the model has shared parameters. In such a case, only one fqn of the shared parameters will appear in the optimizer state_dict. This PR addresses the issue.

Differential Revision: [D58573487](https://our.internmc.facebook.com/intern/diff/D58573487/)

cc mrshenli pritamdamania87 zhaojuanmao satgera gqchen aazzolini osalpekar jiayisuse H-Huang kwen2501 awgu penguinwu XilunWu wanchaol fduwjj wz337 tianyu-l wconstab yf225 chauhang d4l3k LucasLLC MeetVadakkanchery mhorowitz

[ghstack-poisoned]
fegin added a commit that referenced this pull request Jun 17, 2024
…te_dict"

* 
Fixes #128011

See the discussion in #128076

Current implementation of `set_optimizer_state_dict()` assumes that all the fqns returned by `_get_fqns()` must exist in the optimizer state_dict. This is not true if the model has shared parameters. In such a case, only one fqn of the shared parameters will appear in the optimizer state_dict. This PR addresses the issue.

Differential Revision: [D58573487](https://our.internmc.facebook.com/intern/diff/D58573487/)

cc mrshenli pritamdamania87 zhaojuanmao satgera gqchen aazzolini osalpekar jiayisuse H-Huang kwen2501 awgu penguinwu XilunWu wanchaol fduwjj wz337 tianyu-l wconstab yf225 chauhang d4l3k LucasLLC MeetVadakkanchery mhorowitz

[ghstack-poisoned]
pytorchmergebot pushed a commit that referenced this pull request Jun 18, 2024
…28685)

*
Fixes #128011

See the discussion in #128076

Current implementation of `set_optimizer_state_dict()` assumes that all the fqns returned by `_get_fqns()` must exist in the optimizer state_dict. This is not true if the model has shared parameters. In such a case, only one fqn of the shared parameters will appear in the optimizer state_dict. This PR addresses the issue.

Differential Revision: [D58573487](https://our.internmc.facebook.com/intern/diff/D58573487/)

Pull Request resolved: #128685
Approved by: https://github.com/LucasLLC
@fegin
Copy link
Contributor

fegin commented Jun 18, 2024

Close in favor of #128685

@fegin fegin closed this Jun 18, 2024
fegin added a commit that referenced this pull request Jun 21, 2024
…28685)

*
Fixes #128011

See the discussion in #128076

Current implementation of `set_optimizer_state_dict()` assumes that all the fqns returned by `_get_fqns()` must exist in the optimizer state_dict. This is not true if the model has shared parameters. In such a case, only one fqn of the shared parameters will appear in the optimizer state_dict. This PR addresses the issue.

Differential Revision: [D58573487](https://our.internmc.facebook.com/intern/diff/D58573487/)

Pull Request resolved: #128685
Approved by: https://github.com/LucasLLC

(cherry picked from commit 1a52791)
atalman pushed a commit that referenced this pull request Jun 26, 2024
#129252)

[DSD] Correctly handle shared parameters for optimizer state_dict (#128685)

*
Fixes #128011

See the discussion in #128076

Current implementation of `set_optimizer_state_dict()` assumes that all the fqns returned by `_get_fqns()` must exist in the optimizer state_dict. This is not true if the model has shared parameters. In such a case, only one fqn of the shared parameters will appear in the optimizer state_dict. This PR addresses the issue.

Differential Revision: [D58573487](https://our.internmc.facebook.com/intern/diff/D58573487/)

Pull Request resolved: #128685
Approved by: https://github.com/LucasLLC

(cherry picked from commit 1a52791)
@atalman atalman removed this from the 2.4.0 milestone Jun 27, 2024
@atalman
Copy link
Contributor

atalman commented Jun 27, 2024

Cherry pick #128685 picked hence demilestoning this

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ciflow/periodic Trigger jobs ran periodically on master (periodic.yml) on the PR ciflow/trunk Trigger trunk jobs on your pull request oncall: distributed Add this issue/PR to distributed oncall triage queue open source triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module

Projects

None yet

Development

Successfully merging this pull request may close these issues.

Tied Weight Embeddings Models Fail to Load on Torch 2.4 Nightly

9 participants