Warn if using tied target module with tie_word_embeddings#2025
Conversation
BenjaminBossan
left a comment
There was a problem hiding this comment.
Thanks for creating this PR. I think we need to rethink the approach here, as the current one will not work in all situations.
get_peft_modelis a very generic function and is also used for prompt tuning methods, for instance. Therefore, we cannot assume thatpeft_config.target_modulesexist.- Not all methods allow to merge the weights, thus we should not warn in those cases (false warnings should be avoided as much as possible).
- Even if
peft_config.target_modulesdoes exist, it could be a string, so looping over it will not always be correct. - As we already observed, it will not work for custom models with tied weights, but let's consider this out of scope for now.
So how can we correctly identify when a warning is needed? My proposal is that this needs to be solved on a different level:
The check if there is a tied target layer needs to live on the corresponding method's model level (e.g. LoraModel), as only there can we really know which layers are targeted. Thankfully, the models that support merging all inherit from BaseTuner. There, we have the inject_adapter method. If you look at this line, you can see that all modules that are actually targeted are stored in self.targeted_module_names. Therefore, after exiting the loop, we can add a new method that takes this list and checks if any of the keys are tied weights using the logic you proposed.
This new check should be implemented as a new method on the BaseTuner class, so that subclasses such as LoraModel may choose to override the method if there ever is a need.
Additionally, I wonder if there should be a warning when the user attempts to merge. One could argue that this is too late, but even at this point, there are workarounds: If the user clones the tied weights, they can merge without affecting the other weight (at the cost of extra memory).
This additional warning could be added to the _check_merge_allowed method and it could re-use the same method as mentioned above to perform the check. However, the warning message should be a bit different.
I know this is all a bit more complicated that initially thought and not necessarily what you "signed up for". So let me know if you still want to work on this or not, in which case I'll put this on my backlog.
Not at all thanks, sounds really good, I'll have a go! |
Thanks a lot. |
c236129 to
44a02de
Compare
|
@BenjaminBossan I made a version addressing your suggestions. Also, I refactored getting the model config in the code base.
I feel like the new message can be the same. Let me know. (I can't run the whole test suite as I do not have a cuda-compatible gpu.) |
BenjaminBossan
left a comment
There was a problem hiding this comment.
Thanks for updating the PR.
I feel like the new message can be the same. Let me know.
I think the error message is good as is for when the model is being initialized. When merging, I think we could show a different warning, where we mention that if the weight is cloned beforehand, merging should work, at the cost of higher memory usage.
To implement this, I would change the _warn_if_tied_embeddings_in_target_modules method from warning to just performing the check and returning a bool (renaming the method accordingly). Then during injection, if the check returns True, the current warning is given, and during merging, if the check returns True, the adapted warning is given. WDYT?
(I can't run the whole test suite as I do not have a cuda-compatible gpu.)
This is fine.
BenjaminBossan
left a comment
There was a problem hiding this comment.
Thanks for the changes. I have a few small suggestions for improvements, please check them out?
It would also be great to add unit tests for this, but probably this will be a bit more complicated. I leave it up to you if you want to give this a try, otherwise I'll work on it in a subsequent PR.
516fc3c to
3a51e67
Compare
|
Sure very happy to write tests! I'll put them in Just one question: to mock models with tied embeddings, should I use the test model model = AutoModelForCausalLM.from_pretrained(model_id, tie_word_embeddings=True) |
cd3e830 to
cf4bf3e
Compare
|
@BenjaminBossan I added the test here |
BenjaminBossan
left a comment
There was a problem hiding this comment.
Thanks so much for making the updates, using DUMMY_MODEL_CONFIG consistently and extending the tests. This looks quite good already, but I have some suggestions for improvements, please chekc.
I just did this as it's a bit unclear if in this case the model_config needs to default to None or if it can be the DUMMY one, let me know!
The change you made looks good as is.
Just one question: to mock models with tied embeddings, should I use the test model "HuggingFaceH4/tiny-random-LlamaForCausalLM" but loaded with:
I didn't know that this was an option. Yes, looks like the right choice.
| warnings.warn( | ||
| f"Model with `tie_word_embeddings=True` and the {tied_target_modules=} are part of the adapter. " | ||
| "This can lead to complications when merging the adapter. " | ||
| "You can opt to merge the adapter after cloning the weights (to untie the embeddings), " |
There was a problem hiding this comment.
Honestly, I didn't know about the option to pass tie_word_embeddings=False. Is there even a need to clone the weights in that case?
There was a problem hiding this comment.
Looks like it works, I added in the warning code to create the untied model.
| config = BaseTuner.get_model_config(ModelWithNoConfig()) | ||
| assert config == DUMMY_MODEL_CONFIG | ||
|
|
||
| def test_warn_for_tied_embeddings_inject_and_merge(self): |
There was a problem hiding this comment.
Thanks a lot for adding these tests. They are already looking quite good. I think, however, that this last test can be simplified a bit.
As you correctly observed, there are 6 scenarios to test:
- Warning for
get_peft_modeland warning for merging. - Valid warning vs no tied embeddings vs tied embeddings but not targeted.
Instead of cramming those into a single test, let's make this 6 separate tests. It should also be fine to make it 3 tests, where get_peft_model and merging are checked together. Hopefully, this should make the assert_warning_triggered function unnecessary.
You probably also had a bit of an issue that unrelated warnings could be recorded. Maybe this can be made simpler by using the recwarn fixture. Then you can just check that any warning has been recorded with the corresponding message, something like:
assert any(str(warning.message).startswith(msg) for warning in recwarn.list)
| pass | ||
|
|
||
|
|
||
| class TestBaseTunerMethods(unittest.TestCase): |
There was a problem hiding this comment.
Let's split this test class into 2: One for get_model_config and one for the tied embeddings.
a2f7354 to
7926888
Compare
|
@BenjaminBossan I think I've address the comments 👍 |
|
Thanks for the latest updates. I only have one more question, namely when it comes to how to untie the weights. In the script you provide, you clone the weights but is that even necessary if >>> from transformers import AutoModelForCausalLM
>>> model = AutoModelForCausalLM.from_pretrained("google/gemma-2b", tie_word_embeddings=False)
>>> model.model.embed_tokens.weight.data_ptr()
126062054867008
>>> model.lm_head.weight.data_ptr() # <= different data ptr
126051845931072
>>> model.model.embed_tokens.weight.sum()
tensor(952564.6250, grad_fn=<SumBackward0>)
>>> model.lm_head.weight.sum()
tensor(255.3427, grad_fn=<SumBackward0>)
>>> from peft import LoraConfig, get_peft_model
>>> config = LoraConfig(init_lora_weights=False, target_modules=["embed_tokens"])
>>> model = get_peft_model(model, config)
>>> unloaded = model.merge_and_unload()
>>> unloaded.model.embed_tokens.weight.sum() # <= embed weights changed
tensor(985655.8125)
>>> unloaded.lm_head.weight.sum() # <= lm head stayed the same
tensor(255.3427) |
Yes I agree with your script but the user wants to fix This cloning also seems to allow to save it correctly. If you do not clone (beside actaully re-tieing the embeddings), then when you load the saved-untied model the last assertion below will fail, otherwise, if you clone, it will pass: model = AutoModelForCausalLM.from_pretrained("google/gemma-2-2b-it", tie_word_embeddings=False)
# Set the randomly initialized lm_head to the previously tied embeddings
model.lm_head.weight.data = model.model.embed_tokens.weight.data
assert torch.equal(model.lm_head.weight.data, model.model.embed_tokens.weight.data)
# Save the untied model
untied_model_dir = "tmp_model"
model.save_pretrained(untied_model_dir)
model.config.save_pretrained(untied_model_dir)
# Now use the original model but in untied format
model = AutoModelForCausalLM.from_pretrained(untied_model_dir)
assert model.model.embed_tokens.weight.data.data_ptr() != model.lm_head.weight.data.data_ptr()
assert torch.equal(model.lm_head.weight.data, model.model.embed_tokens.weight.data) |
Oh wow, I did not know that the LM head will be randomly initialized, that's quite surprising IMO. I would have expected to get the same parameter values, just not tied. Thanks for making me aware of that. |
|
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update. |
|
Not sure how to reproduce the error in the git actions ruff check src tests examples docs scripts docker
All checks passed!
ruff format --check src tests examples docs scripts docker
189 files already formatted
doc-builder style src/peft tests docs/source --max_len 119 --check_only
Traceback (most recent call last):
File "/opt/hostedtoolcache/Python/3.8.18/x64/bin/doc-builder", line 8, in <module>
sys.exit(main())
File "/opt/hostedtoolcache/Python/3.8.[18](https://github.com/huggingface/peft/actions/runs/10578337704/job/29356715757?pr=2025#step:5:19)/x64/lib/python3.8/site-packages/doc_builder/commands/doc_builder_cli.py", line 47, in main
args.func(args)
File "/opt/hostedtoolcache/Python/3.8.18/x64/lib/python3.8/site-packages/doc_builder/commands/style.py", line 28, in style_command
raise ValueError(f"{len(changed)} files should be restyled!")
ValueError: 1 files should be restyled!
make: *** [Makefile:11: quality] Error 1
Error: Process completed with exit code 2. |
|
@ltoniazzi could you please run |
|
I ran @@ -530,8 +530,8 @@ model = AutoModelForCausalLM.from_pretrained(untied_model_dir)
@staticmethod
def get_model_config(model: nn.Module) -> dict:
"""
- This method gets the config from a model in dictionary form.
- If model has not attribute config, then this method returns a default config.
+ This method gets the config from a model in dictionary form. If model has not attribute config, then this
+ method returns a default config. |
Done! |
tie_word_embeddings
BenjaminBossan
left a comment
There was a problem hiding this comment.
Thanks, very nicely done PR. I just have two tiny comments for cosmetic reasons, otherwise this can be merged.
| ) | ||
| return model | ||
|
|
||
| def _is_warn_triggered(self, rrecwarn, endswith): |
There was a problem hiding this comment.
Did you call it rrecwarn to avoid naming conflicts? If yes, how about just passing the recwarn.list, which is all we need, and call it warning_list or so.
| # Now use the original model but in untied format | ||
| model = AutoModelForCausalLM.from_pretrained(untied_model_dir) | ||
| ``` | ||
| """ |
There was a problem hiding this comment.
I see why you left-aligned the code snippet so that it is nicely printed. But this is really an eye-sore to read in code. Here is a trick to that let's us use the correct indentation but still get a nice warning message by using textwrap.dedent:
example_code = textwrap.dedent(
"""
```python
from transformers import AutoModelForCausalLM
# Load original tied model
model = AutoModelForCausalLM.from_pretrained("google/gemma-2-2b-it", tie_word_embeddings=False)
# Set the randomly initialized lm_head to the previously tied embeddings
model.lm_head.weight.data = model.model.embed_tokens.weight.data.clone()
# Save the untied model
untied_model_dir = "dir/for/untied/model"
model.save_pretrained(untied_model_dir)
model.config.save_pretrained(untied_model_dir)
# Now use the original model but in untied format
model = AutoModelForCausalLM.from_pretrained(untied_model_dir)
```
"""
)
warnings.warn(
f"Model with `tie_word_embeddings=True` and the {tied_target_modules=} are part of the adapter. "
"This can lead to complications. "
"You can opt to merge the adapter after cloning the weights (to untie the embeddings). "
"You can untie the embeddings by loading the model with `tie_word_embeddings=False`. For example:"
+ example_code
)The textwrap module is from the standardlib and needs to be imported.
There was a problem hiding this comment.
Addressed both thanks!
BenjaminBossan
left a comment
There was a problem hiding this comment.
Thanks so much, great work, hopefully this will help users in the future to avoid this potential pitfall.
|
@BenjaminBossan Thanks so much for your help! ❤️ Btw, a test on main failed, do you think it's related to this PR? |
Don't worry, this is a known issue with X-LoRA that came about with a recent change in transformers. |
When users are targetting tied weights (e.g. embedding and LM head), merging the adapter will lead to errors. Now users are warned about the possibility when they create such a PEFT model and also when they try to merge.
Context
Solving issue #2018.
target_modulewhen the embeddings are tied, because this could lead to errors, for example when merging the adapter.Todo
Try if load withtie_word_embeddings=Falseis an actual option. Load Gemma2 with finetuned differentlm_weightsand check that the lm_head is not replaced with the embedding (even if cloned). If it works, try to merge an adapter to lm_weight and then load it to check if embed and lm_head are kept separate. (the main concern is that the loading model's architecture might ignore anylm_headweight present in safetensors, as it happens in llama.cpp for example).