[Fix] lm_head lora save#515
Conversation
Summary of ChangesHello @Datta0, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed! This pull request resolves an issue where LoRA adapters for Highlights
🧠 New Feature in Public Preview: You can now enable Memory to help Gemini Code Assist learn from your team's feedback. This makes future code reviews more consistent and personalized to your project's style. Click here to enable Memory in your admin console. Changelog
Activity
Using Gemini Code AssistThe full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips. Invoking Gemini You can request assistance from Gemini at any point by creating a comment using either
Customization To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a Limitations & Feedback Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here. You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension. Footnotes
|
There was a problem hiding this comment.
Code Review
This pull request introduces a fix for saving LoRA adapters for lm_head when using tied embeddings. A new tie_word_embeddings flag is added to _merge_and_overwrite_lora and is used to correctly apply lm_head's LoRA adapters to embed_tokens during the merge process. The implementation is sound, but I've suggested a small improvement to make the logic for finding the corresponding lm_head LoRA stats more robust.
| if lora_stats is None and lm_head_key.startswith("model."): | ||
| lora_stats = converted_lora_weights.get(lm_head_key[len("model."):], None) |
There was a problem hiding this comment.
The current logic for finding the lm_head LoRA stats only handles stripping the model. prefix. It doesn't handle the case where lm_head_key lacks the prefix but the key in converted_lora_weights has it. This can be made more robust by checking for both possibilities.
| if lora_stats is None and lm_head_key.startswith("model."): | |
| lora_stats = converted_lora_weights.get(lm_head_key[len("model."):], None) | |
| if lora_stats is None: | |
| # Also check with/without "model." prefix for robustness. | |
| if lm_head_key.startswith("model."): | |
| lora_stats = converted_lora_weights.get(lm_head_key[len("model."):], None) | |
| else: | |
| lora_stats = converted_lora_weights.get("model." + lm_head_key, None) |
Handle several edge cases in the LoRA merge path when lm_head is a LoRA target: 1. Vocab resize in _merge_lora: when tokens are added, the LoRA B matrix has more rows than the base safetensors weight. Zero-fill expand the base weight before applying addmm_. 2. Tied-weight exclusion in assert_same_keys: on models with tie_word_embeddings=True, lm_head.weight and embed_tokens.weight may be present/absent inconsistently between the original safetensors and the built state_dict. Exclude both from the key-match check to prevent false positives. 3. Reverse prefix check for lm_head key lookup: if the key does not start with "model.", try prepending it (mirrors the existing strip-"model." path). 4. Shard rewrite for resized tensors: when in-place mmap write fails due to size mismatch, buffer the resized tensor and rewrite the entire shard file after the mmap is closed.
There was a problem hiding this comment.
💡 Codex Review
Here are some automated review suggestions for this pull request.
Reviewed commit: ba1d56586c
ℹ️ About Codex in GitHub
Your team has set up Codex to review pull requests in this repo. Reviews are triggered when you
- Open a pull request for review
- Mark a draft as ready
- Comment "@codex review".
If Codex has suggestions, it will comment; otherwise it will react with 👍.
Codex can also answer questions or update the PR. Try commenting "@codex address that feedback".
| and tie_word_embeddings | ||
| and lora_key.endswith("embed_tokens") | ||
| ): | ||
| lm_head_key = lora_key[:-len("embed_tokens")] + "lm_head" | ||
| lora_stats = converted_lora_weights.get(lm_head_key, None) |
There was a problem hiding this comment.
Skip lm_head fallback when lm_head tensor already exists
This tied-embedding fallback remaps embed_tokens to lm_head adapters unconditionally when embed_tokens has no direct LoRA, but it does not check whether lm_head.weight is also present in the shard. In that case the same logical adapter is merged into both tensors and count is incremented twice, which can trip the final effective_loras != n_saved_modules sanity check and fail the save path for tied models that serialize both keys; the fallback should be gated on lm_head.weight being absent (or the counting should be deduplicated).
Useful? React with 👍 / 👎.
Fixes : #4098
lm_head should ideally be a module_to_save and not target_module
This is confirmed to work by @marcandrelarochelle the OP of the issue
Needs: unslothai/unsloth#4106