Description
I was implementing Gemma4 inference from scratch (in Rust) and the Per-Layer Embeddings (PLE) system was by far the hardest part to get right. The config fields are misleading, the embedding type is non-obvious, and the full pipeline involves several undocumented steps. Sharing this in case it helps others and in case you want to improve the docs.
Problem 1: hidden_size_per_layer_input is ambiguous
The config says hidden_size_per_layer_input: 256, which sounds like it's the embedding dimension. But embed_tokens_per_layer.weight has shape [262144, 8960] where 8960 = 35 layers * 256. The actual embedding dimension is num_hidden_layers * hidden_size_per_layer_input, not hidden_size_per_layer_input alone.
This confused me because the __init__ in Gemma4TextModel seems like it should create nn.Embedding(vocab, 256) but then loading the pretrained weight of shape [vocab, 8960] would fail. (It doesn't fail because from_pretrained handles the resize, but it's not obvious from reading the code.)
Problem 2: embed_tokens_per_layer is secretly a Gemma4TextScaledWordEmbedding
The PLE embedding isn't a plain nn.Embedding. It's a Gemma4TextScaledWordEmbedding that multiplies the lookup result by sqrt(hidden_size_per_layer_input) = sqrt(256) = 16.0.
This isn't mentioned anywhere in the config, the docstrings, or the model card. I only found it by inspecting type(lm.embed_tokens_per_layer).__name__ after my outputs were 16x too small.
Problem 3: The full PLE pipeline has undocumented steps
The actual PLE computation involves:
- Token-identity:
embed_tokens_per_layer(input_ids) (scaled by sqrt(256)) -> reshape to [B, S, num_layers, ple_dim]
- Context-aware projection:
per_layer_model_projection(inputs_embeds) (a Linear) -> scale by 1/sqrt(hidden_size) -> reshape to [B, S, num_layers, ple_dim] -> RMSNorm (per_layer_projection_norm)
- Combine:
(context_projection + token_identity) * (1/sqrt(2))
- Each layer
i gets per_layer_inputs[:, :, i, :]
This involves weights that aren't mentioned in the config at all:
per_layer_model_projection (Linear, hidden_size -> num_layers * ple_dim)
per_layer_projection_norm (RMSNorm, dim=ple_dim)
- Two hardcoded scale factors:
1/sqrt(hidden_size) and 1/sqrt(2)
The get_per_layer_inputs() and project_per_layer_inputs() methods implement this, but there are no docstrings explaining the overall pipeline or the scale factors.
Suggestion
Adding a docstring to Gemma4TextModel (or the config class) explaining:
- That
hidden_size_per_layer_input is the per-layer dimension, and the total embedding dim is num_hidden_layers * hidden_size_per_layer_input
- That the PLE embedding is scaled by
sqrt(hidden_size_per_layer_input)
- A brief description of the full PLE pipeline (token lookup + context projection + norm + combine with scale factors)
This would save a lot of pain for anyone implementing Gemma4 outside of HuggingFace transformers (e.g. llama.cpp, candle, mlx, etc.).
Environment
- transformers 5.5.0
- Model: google/gemma-4-E2B-it
Description
I was implementing Gemma4 inference from scratch (in Rust) and the Per-Layer Embeddings (PLE) system was by far the hardest part to get right. The config fields are misleading, the embedding type is non-obvious, and the full pipeline involves several undocumented steps. Sharing this in case it helps others and in case you want to improve the docs.
Problem 1:
hidden_size_per_layer_inputis ambiguousThe config says
hidden_size_per_layer_input: 256, which sounds like it's the embedding dimension. Butembed_tokens_per_layer.weighthas shape[262144, 8960]where8960 = 35 layers * 256. The actual embedding dimension isnum_hidden_layers * hidden_size_per_layer_input, nothidden_size_per_layer_inputalone.This confused me because the
__init__inGemma4TextModelseems like it should createnn.Embedding(vocab, 256)but then loading the pretrained weight of shape[vocab, 8960]would fail. (It doesn't fail becausefrom_pretrainedhandles the resize, but it's not obvious from reading the code.)Problem 2:
embed_tokens_per_layeris secretly aGemma4TextScaledWordEmbeddingThe PLE embedding isn't a plain
nn.Embedding. It's aGemma4TextScaledWordEmbeddingthat multiplies the lookup result bysqrt(hidden_size_per_layer_input) = sqrt(256) = 16.0.This isn't mentioned anywhere in the config, the docstrings, or the model card. I only found it by inspecting
type(lm.embed_tokens_per_layer).__name__after my outputs were 16x too small.Problem 3: The full PLE pipeline has undocumented steps
The actual PLE computation involves:
embed_tokens_per_layer(input_ids)(scaled by sqrt(256)) -> reshape to[B, S, num_layers, ple_dim]per_layer_model_projection(inputs_embeds)(a Linear) -> scale by1/sqrt(hidden_size)-> reshape to[B, S, num_layers, ple_dim]-> RMSNorm (per_layer_projection_norm)(context_projection + token_identity) * (1/sqrt(2))igetsper_layer_inputs[:, :, i, :]This involves weights that aren't mentioned in the config at all:
per_layer_model_projection(Linear, hidden_size -> num_layers * ple_dim)per_layer_projection_norm(RMSNorm, dim=ple_dim)1/sqrt(hidden_size)and1/sqrt(2)The
get_per_layer_inputs()andproject_per_layer_inputs()methods implement this, but there are no docstrings explaining the overall pipeline or the scale factors.Suggestion
Adding a docstring to
Gemma4TextModel(or the config class) explaining:hidden_size_per_layer_inputis the per-layer dimension, and the total embedding dim isnum_hidden_layers * hidden_size_per_layer_inputsqrt(hidden_size_per_layer_input)This would save a lot of pain for anyone implementing Gemma4 outside of HuggingFace transformers (e.g. llama.cpp, candle, mlx, etc.).
Environment