FIX: weight tying for LoRA embeddings and lm_head#3711
Conversation
for more information, see https://pre-commit.ci
Summary of ChangesHello @oKatanaaa, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed! This pull request resolves an issue where weight tying between token embeddings and the language model head was being disrupted when using LoRA with Unsloth's offloading mechanism, particularly for models like Qwen3-4B. The changes introduce a new configuration option and explicit re-tying logic to preserve the shared storage of these weights, preventing their divergence during training and ensuring the integrity of models that rely on this architectural constraint. Highlights
Using Gemini Code AssistThe full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips. Invoking Gemini You can request assistance from Gemini at any point by creating a comment using either
Customization To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a Limitations & Feedback Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here. You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension. Footnotes
|
There was a problem hiding this comment.
Code Review
This pull request introduces a fix for a weight tying issue with LoRA embeddings and the language model head, which occurs when ensure_weight_tying=True. The changes correctly propagate this flag through get_peft_model and add logic to manually re-tie the weights after the PEFT model is created, addressing the problem of weights diverging due to Unsloth's offloading mechanism. The implementation is sound and effectively resolves the described issue. I have one minor suggestion to enhance logging within an exception handler, which will improve debuggability should any unexpected errors arise.
There was a problem hiding this comment.
💡 Codex Review
Here are some automated review suggestions for this pull request.
ℹ️ About Codex in GitHub
Your team has set up Codex to review pull requests in this repo. Reviews are triggered when you
- Open a pull request for review
- Mark a draft as ready
- Comment "@codex review".
If Codex has suggestions, it will comment; otherwise it will react with 👍.
Codex can also answer questions or update the PR. Try commenting "@codex address that feedback".
| loftq_config = loftq_config, | ||
| use_rslora = use_rslora, | ||
| modules_to_save = modules_to_save, | ||
| ensure_weight_tying = ensure_weight_tying, |
There was a problem hiding this comment.
Guard ensure_weight_tying for older PEFT versions
The new arguments dict always injects ensure_weight_tying and then passes it straight into LoraConfig(**arguments) later in get_peft_model. Our dependency range in pyproject.toml still allows PEFT 0.7.x, whose LoraConfig signature does not accept this keyword; with such a supported version, this call will raise TypeError: __init__() got an unexpected keyword argument 'ensure_weight_tying' even if callers leave the default False. Either filter this argument based on inspect.signature(LoraConfig) (as vision.py already does) or bump the minimum PEFT version to one that supports the parameter.
Useful? React with 👍 / 👎.
There was a problem hiding this comment.
I think bumping the version is the way here to avoid complicating logic further
|
|
||
| model = FastLlamaModel.patch_peft_model(model, use_gradient_checkpointing) | ||
|
|
||
| if ensure_weight_tying: |
There was a problem hiding this comment.
Can you please explain the need for such an argument
Can't we infer that from the config itself?
There was a problem hiding this comment.
Given the Unsloth's spirit, having it inferred automatically would be nice. But I'm not sure if all of the model's have the tie_word_embeddings parameter in their config. Looking at Gemma3 models, their config does not declare this parameter, but ALL Gemma3 models actually have tied word embeddings.
EDIT: Gemma3 has tie_word_embeddings in its config, it's just not exposed on huggingface model pages. I personally would still have it as an explicit parameter (or by checking if it is present in kwargs) since tie_word_embeddings only tells the base model ships tied weights, not that the user actually wants PEFT to re-alias the trainable modules_to_save copies
So I think it is safer to have an explicit argument.
I guess we could check weights' pointers to see if they are equal and then raise a warning to the user (not enforce tying though, as there might be cases when you want to untie weights).
| target_module._parameters.pop("weight") | ||
| if hasattr(target_module, "weight"): | ||
| try: | ||
| delattr(target_module, "weight") |
There was a problem hiding this comment.
Can you elaborate on when each of these cases happen?
target_module.weight vs target_module._paramters.weight
There was a problem hiding this comment.
The if "weight" in getattr(target_module, "_parameters", {}): happens on PEFT’s trainable copies (modules_to_save.default for embed_tokens/lm_head) and non-offloaded originals. They still have a proper nn.Parameter registered under "weight", so we have to pop it before re-registering the shared one to avoid the name collision error.
The if hasattr(target_module, "weight") happens on Unsloth’s offload/rebuild of lm_head. In offload_output_embeddings, Unsloth deletes the registered parameter (del new_output_embeddings.weight) and then assigns a plain tensor back (new_output_embeddings.weight = offloaded_W). That leaves a weight attribute that is not registered, so you must delete the attribute first or it shadows the new registration.
Since we retie both the saved modules ("default" in the wrappers) and the original ones (for merge consistency and to avoid any other potential issues), both guards are needed.
|
Here is a simple script to reproduce the issue btw: import torch
from unsloth import FastLanguageModel
# Assumes unsloth without the post-PEFT retie fix. ensure_weight_tying is passed via kwargs.
model, tokenizer = FastLanguageModel.from_pretrained(
model_name="unsloth/Qwen3-4B-Base-unsloth-bnb-4bit",
max_seq_length=256,
dtype=None,
load_in_4bit=True,
)
def ptr(mod):
return None if mod is None else mod.weight.data_ptr()
def check(stage):
in_emb = model.get_input_embeddings()
out_emb = model.get_output_embeddings()
in_def = getattr(getattr(in_emb, "modules_to_save", {}), "default", None)
out_def = getattr(getattr(out_emb, "modules_to_save", {}), "default", None)
print(f"{stage}: base tied? {ptr(in_emb)==ptr(out_emb)}, "
f"default tied? {ptr(in_def)==ptr(out_def)}")
check("before PEFT")
model = FastLanguageModel.get_peft_model(
model,
r=8,
target_modules=["q_proj","k_proj","v_proj","o_proj"],
modules_to_save=["embed_tokens","lm_head"],
ensure_weight_tying=True, # flag is set, but no fix present
use_gradient_checkpointing="unsloth", # triggers offload
max_seq_length=256,
)
check("after PEFT (expect untied without fix)") |
|
hey @oKatanaaa can you please tell me the motivation to tie the embedding and lm_head weights for a model that didn't have it to begin with? |
I'm not sure. This PR is not enabling this functionality though. Even if Some advanced users could enforce weight tying to intentionally reduce model size (usually word embeddings take ~13% of total model size, tying would reduce it by ~6% which could be significant). |
🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
|
Thanks for the PR and Happy New Year! I tested this on a B200 with the following configurations: Test Results:
Weight tying verification:
The fix correctly re-ties the trainable copies after PEFT wrapping, which preserves weight sharing during Unsloth's offload/rebuild path. I also added a commit with a TODO comment for vision.py since the parameter is added but not yet implemented for vision models. Sidenote: This PR was reviewed automatically by the Unsloth Code Review Bot. |
|
Thanks @oKatanaaa again :) Was trying out our new auto review system as well! |
FIX: weight tying for LoRA embeddings and lm_head
Issue
Training of models with tied weights (such as Qwen3-4B) is broken when training token embeddings.
When using LoRA with modules_to_save=["embed_tokens","lm_head"] and ensure_weight_tying=True (param from PEFT config), Unsloth’s offload/duplication of embeddings and lm_head breaks the shared storage, so the trainable copies diverge and PEFT’s tying isn’t applied. This leaves embed_tokens and lm_head untied during training/merging.
Summary