Skip to content

Add support for tie_word_embeddings when loading weights + support for SmolLM#1508

Merged
merrymercy merged 8 commits intosgl-project:mainfrom
TianyiQ:main
Sep 25, 2024
Merged

Add support for tie_word_embeddings when loading weights + support for SmolLM#1508
merrymercy merged 8 commits intosgl-project:mainfrom
TianyiQ:main

Conversation

@TianyiQ
Copy link
Copy Markdown
Contributor

@TianyiQ TianyiQ commented Sep 25, 2024

Motivation

When trying to add support for SmolLM, I noticed that its config dict has "tie_word_embeddings": true which isn't supported by SGLang, and as a result, lm_head.weight is not loaded. This problem likely exists for some other models too (not just SmolLM).

Modifications

Within LlamaForCausalLM, at the end of the weight loading stage, I added the following code:

if hasattr(self.config, "tie_word_embeddings") and self.config.tie_word_embeddings:
    # Tie output embedding layer to input embedding layer, to solve issues where lm_head.weight is missing
    param = self.lm_head.weight
    weight_loader = getattr(param, "weight_loader", default_weight_loader)
    weight_loader(param, self.model.embed_tokens.weight)

This can probably be copy-pasted to other LM classes other than LlamaForCausalLM, but I'm less familiar with those, and will defer to maintainers to decide whether to do the copying.

The effect can be tested by running python3 -m sglang.bench_latency --correct --model HuggingFaceTB/SmolLM-135M-Instruct. This would output gibberish before the change, but output coherent answers after it.

Checklist

  • Format your code according to the Contributor Guide.
  • Add unit tests as outlined in the Contributor Guide.
  • Update documentation as needed, including docstrings or example tutorials.

Comment thread python/sglang/srt/models/llama.py Outdated
Comment thread python/sglang/srt/models/llama.py
Comment thread python/sglang/srt/models/llama.py Outdated
@TianyiQ
Copy link
Copy Markdown
Contributor Author

TianyiQ commented Sep 25, 2024

Oops sorry about the failed test; I thought that was due to my local env (it's a bit messed up atm), but looks like it's not.

@merrymercy merrymercy merged commit 3c93187 into sgl-project:main Sep 25, 2024
@merrymercy
Copy link
Copy Markdown
Contributor

@TianyiQ Thanks. It is merged and it passes all tests

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants