Fix tied weight embeddings fails to load state dict #128076
Fix tied weight embeddings fails to load state dict #128076j316chuck wants to merge 3 commits intopytorch:mainfrom
Conversation
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/128076
Note: Links to docs will display an error until the docs builds have been completed. ❌ 2 New Failures, 3 Unrelated FailuresAs of commit 996783b with merge base a7c5968 ( NEW FAILURES - The following jobs have failed:
FLAKY - The following job failed but was likely due to flakiness present on trunk:
UNSTABLE - The following jobs failed but were likely due to flakiness present on trunk and has been marked as unstable:
This comment was automatically generated by Dr. CI and updates every 15 minutes. |
|
|
Please seek CI approval before scheduling CIFlow labels |
|
I'm still trying to understand why the test was broken. But I don't think this is an appropriate fix we should land. We should fix |
|
@fegin you can reproduce this test failure on a 2-gpu instance: Make sure to replace what we have here in |
|
Concretely I think the difference is: Does not contain tied weight embeddings. However, your function Does contain tied weight embeddings. This yields the error in the test: |
|
@j316chuck It would be helpful if you can point out how this tied embedding weight is initialized. I'm not sure what does "tied" mean for the embedding weight. As I mentioned in the issue, |
|
@j316chuck Another way to rephrase my question, you will need an unittest in PyTorch, not in the composer, to verify this PR to get this PR landed. |
|
@fegin can you help me add this unit test, I ran into a lot of lint / pr issues when trying to commit directly into pytorch. I believe I linked you the composer unit test that you can pattern match off of fwiw? Here is how we initialize the model in composer from HF. A tied weight embedding layer is when the input weight embeddings is tied into the output layer of the LLM as well. This layer is double counted in Here's an example of tied weight embeddings: |
|
We definitely want to cherry pick this into 2.4rc as it's a pretty nasty regression (TensorEngine stops working). |
|
@j316chuck I can reproduce the issue with the model definition you provided. However, #125336 is correct. It just surfaces an issue of shared parameter not properly handled in distributed state_dict. This PR will hide the issue again, which may become an issue in the future. I'll submit a PR to fix the issue. |
|
@j316chuck Please check if #128685 fixes the issue. |
…or optimizer state_dict" * Fixes #128011 See the discussion in #128076 Current implementation of `set_optimizer_state_dict()` assumes that all the fqns returned by `_get_fqns()` must exist in the optimizer state_dict. This is not true if the model has shared parameters. In such a case, only one fqn of the shared parameters will appear in the optimizer state_dict. This PR addresses the issue. Differential Revision: [D58573487](https://our.internmc.facebook.com/intern/diff/D58573487/) cc mrshenli pritamdamania87 zhaojuanmao satgera gqchen aazzolini osalpekar jiayisuse H-Huang kwen2501 awgu penguinwu XilunWu wanchaol fduwjj wz337 tianyu-l wconstab yf225 chauhang d4l3k LucasLLC MeetVadakkanchery mhorowitz [ghstack-poisoned]
…te_dict" * Fixes #128011 See the discussion in #128076 Current implementation of `set_optimizer_state_dict()` assumes that all the fqns returned by `_get_fqns()` must exist in the optimizer state_dict. This is not true if the model has shared parameters. In such a case, only one fqn of the shared parameters will appear in the optimizer state_dict. This PR addresses the issue. Differential Revision: [D58573487](https://our.internmc.facebook.com/intern/diff/D58573487/) cc mrshenli pritamdamania87 zhaojuanmao satgera gqchen aazzolini osalpekar jiayisuse H-Huang kwen2501 awgu penguinwu XilunWu wanchaol fduwjj wz337 tianyu-l wconstab yf225 chauhang d4l3k LucasLLC MeetVadakkanchery mhorowitz [ghstack-poisoned]
…or optimizer state_dict" * Fixes #128011 See the discussion in #128076 Current implementation of `set_optimizer_state_dict()` assumes that all the fqns returned by `_get_fqns()` must exist in the optimizer state_dict. This is not true if the model has shared parameters. In such a case, only one fqn of the shared parameters will appear in the optimizer state_dict. This PR addresses the issue. Differential Revision: [D58573487](https://our.internmc.facebook.com/intern/diff/D58573487/) cc mrshenli pritamdamania87 zhaojuanmao satgera gqchen aazzolini osalpekar jiayisuse H-Huang kwen2501 awgu penguinwu XilunWu wanchaol fduwjj wz337 tianyu-l wconstab yf225 chauhang d4l3k LucasLLC MeetVadakkanchery mhorowitz [ghstack-poisoned]
…te_dict" * Fixes #128011 See the discussion in #128076 Current implementation of `set_optimizer_state_dict()` assumes that all the fqns returned by `_get_fqns()` must exist in the optimizer state_dict. This is not true if the model has shared parameters. In such a case, only one fqn of the shared parameters will appear in the optimizer state_dict. This PR addresses the issue. Differential Revision: [D58573487](https://our.internmc.facebook.com/intern/diff/D58573487/) cc mrshenli pritamdamania87 zhaojuanmao satgera gqchen aazzolini osalpekar jiayisuse H-Huang kwen2501 awgu penguinwu XilunWu wanchaol fduwjj wz337 tianyu-l wconstab yf225 chauhang d4l3k LucasLLC MeetVadakkanchery mhorowitz [ghstack-poisoned]
…28685) * Fixes #128011 See the discussion in #128076 Current implementation of `set_optimizer_state_dict()` assumes that all the fqns returned by `_get_fqns()` must exist in the optimizer state_dict. This is not true if the model has shared parameters. In such a case, only one fqn of the shared parameters will appear in the optimizer state_dict. This PR addresses the issue. Differential Revision: [D58573487](https://our.internmc.facebook.com/intern/diff/D58573487/) Pull Request resolved: #128685 Approved by: https://github.com/LucasLLC
|
Close in favor of #128685 |
…28685) * Fixes #128011 See the discussion in #128076 Current implementation of `set_optimizer_state_dict()` assumes that all the fqns returned by `_get_fqns()` must exist in the optimizer state_dict. This is not true if the model has shared parameters. In such a case, only one fqn of the shared parameters will appear in the optimizer state_dict. This PR addresses the issue. Differential Revision: [D58573487](https://our.internmc.facebook.com/intern/diff/D58573487/) Pull Request resolved: #128685 Approved by: https://github.com/LucasLLC (cherry picked from commit 1a52791)
#129252) [DSD] Correctly handle shared parameters for optimizer state_dict (#128685) * Fixes #128011 See the discussion in #128076 Current implementation of `set_optimizer_state_dict()` assumes that all the fqns returned by `_get_fqns()` must exist in the optimizer state_dict. This is not true if the model has shared parameters. In such a case, only one fqn of the shared parameters will appear in the optimizer state_dict. This PR addresses the issue. Differential Revision: [D58573487](https://our.internmc.facebook.com/intern/diff/D58573487/) Pull Request resolved: #128685 Approved by: https://github.com/LucasLLC (cherry picked from commit 1a52791)
|
Cherry pick #128685 picked hence demilestoning this |
Fixes #128011
cc @mrshenli @pritamdamania87 @zhaojuanmao @satgera @gqchen @aazzolini @osalpekar @jiayisuse @H-Huang @kwen2501 @awgu @penguinwu @fegin @XilunWu @wanchaol @fduwjj @wz337 @tianyu-l @wconstab @yf225 @chauhang @d4l3k @LucasLLC