Skip to content

[Fix] lm_head lora save#515

Merged
danielhanchen merged 2 commits into
unslothai:mainfrom
Datta0:lora_save
Mar 3, 2026
Merged

[Fix] lm_head lora save#515
danielhanchen merged 2 commits into
unslothai:mainfrom
Datta0:lora_save

Conversation

@Datta0

@Datta0 Datta0 commented Feb 25, 2026

Copy link
Copy Markdown
Collaborator

Fixes : #4098
lm_head should ideally be a module_to_save and not target_module
This is confirmed to work by @marcandrelarochelle the OP of the issue

Needs: unslothai/unsloth#4106

@gemini-code-assist

Copy link
Copy Markdown
Contributor

Summary of Changes

Hello @Datta0, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

This pull request resolves an issue where LoRA adapters for lm_head were not being correctly saved, particularly in scenarios involving tied word embeddings. By introducing a tie_word_embeddings parameter and implementing specific logic, the changes ensure that lm_head LoRA weights are properly identified and merged, even when they are implicitly linked to embed_tokens.weight, thereby improving the robustness of LoRA model saving.

Highlights

  • lm_head LoRA Saving Fix: Corrected the handling of lm_head LoRA adapters during the saving process, ensuring they are properly treated as module_to_save to resolve a known issue where they were not being saved correctly.
  • Tied Word Embeddings Support: Introduced specific logic to merge lm_head LoRA adapters onto embed_tokens.weight when word embeddings are tied, addressing scenarios where lm_head.weight might be omitted from safetensors due to sharing.

🧠 New Feature in Public Preview: You can now enable Memory to help Gemini Code Assist learn from your team's feedback. This makes future code reviews more consistent and personalized to your project's style. Click here to enable Memory in your admin console.

Changelog
  • unsloth_zoo/saving_utils.py
    • Added tie_word_embeddings parameter to the _merge_and_overwrite_lora function signature.
    • Implemented conditional logic within _merge_and_overwrite_lora to check for lm_head LoRA adapters on embed_tokens.weight when tie_word_embeddings is true.
    • Updated the upload_items function to pass the tie_word_embeddings status from the model configuration to the LoRA merging utility.
Activity
  • No review comments or activity have been recorded yet.
Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point by creating a comment using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands on the current page.

Feature Command Description
Code Review /gemini review Performs a code review for the current pull request in its current state.
Pull Request Summary /gemini summary Provides a summary of the current pull request in its current state.
Comment @gemini-code-assist Responds in comments when explicitly tagged, both in pull request comments and review comments.
Help /gemini help Displays a list of available commands.

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here.

You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution.

@gemini-code-assist gemini-code-assist Bot left a comment

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces a fix for saving LoRA adapters for lm_head when using tied embeddings. A new tie_word_embeddings flag is added to _merge_and_overwrite_lora and is used to correctly apply lm_head's LoRA adapters to embed_tokens during the merge process. The implementation is sound, but I've suggested a small improvement to make the logic for finding the corresponding lm_head LoRA stats more robust.

Comment on lines +702 to +703
if lora_stats is None and lm_head_key.startswith("model."):
lora_stats = converted_lora_weights.get(lm_head_key[len("model."):], None)

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The current logic for finding the lm_head LoRA stats only handles stripping the model. prefix. It doesn't handle the case where lm_head_key lacks the prefix but the key in converted_lora_weights has it. This can be made more robust by checking for both possibilities.

Suggested change
if lora_stats is None and lm_head_key.startswith("model."):
lora_stats = converted_lora_weights.get(lm_head_key[len("model."):], None)
if lora_stats is None:
# Also check with/without "model." prefix for robustness.
if lm_head_key.startswith("model."):
lora_stats = converted_lora_weights.get(lm_head_key[len("model."):], None)
else:
lora_stats = converted_lora_weights.get("model." + lm_head_key, None)

Datta0 and others added 2 commits March 3, 2026 13:58
Handle several edge cases in the LoRA merge path when lm_head is a
LoRA target:

1. Vocab resize in _merge_lora: when tokens are added, the LoRA B
   matrix has more rows than the base safetensors weight. Zero-fill
   expand the base weight before applying addmm_.

2. Tied-weight exclusion in assert_same_keys: on models with
   tie_word_embeddings=True, lm_head.weight and embed_tokens.weight
   may be present/absent inconsistently between the original
   safetensors and the built state_dict. Exclude both from the
   key-match check to prevent false positives.

3. Reverse prefix check for lm_head key lookup: if the key does not
   start with "model.", try prepending it (mirrors the existing
   strip-"model." path).

4. Shard rewrite for resized tensors: when in-place mmap write fails
   due to size mismatch, buffer the resized tensor and rewrite the
   entire shard file after the mmap is closed.

@chatgpt-codex-connector chatgpt-codex-connector Bot left a comment

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

💡 Codex Review

Here are some automated review suggestions for this pull request.

Reviewed commit: ba1d56586c

ℹ️ About Codex in GitHub

Your team has set up Codex to review pull requests in this repo. Reviews are triggered when you

  • Open a pull request for review
  • Mark a draft as ready
  • Comment "@codex review".

If Codex has suggestions, it will comment; otherwise it will react with 👍.

Codex can also answer questions or update the PR. Try commenting "@codex address that feedback".

Comment on lines +716 to +720
and tie_word_embeddings
and lora_key.endswith("embed_tokens")
):
lm_head_key = lora_key[:-len("embed_tokens")] + "lm_head"
lora_stats = converted_lora_weights.get(lm_head_key, None)

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P2 Badge Skip lm_head fallback when lm_head tensor already exists

This tied-embedding fallback remaps embed_tokens to lm_head adapters unconditionally when embed_tokens has no direct LoRA, but it does not check whether lm_head.weight is also present in the shard. In that case the same logical adapter is merged into both tensors and count is incremented twice, which can trip the final effective_loras != n_saved_modules sanity check and fail the save path for tied models that serialize both keys; the fallback should be gated on lm_head.weight being absent (or the counting should be deduplicated).

Useful? React with 👍 / 👎.

@danielhanchen danielhanchen merged commit e601f52 into unslothai:main Mar 3, 2026
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants