Skip to content

Gemma4: PLE (Per-Layer Embeddings) implementation is underdocumented and config is misleading #45206

@w4nderlust

Description

@w4nderlust

Description

I was implementing Gemma4 inference from scratch (in Rust) and the Per-Layer Embeddings (PLE) system was by far the hardest part to get right. The config fields are misleading, the embedding type is non-obvious, and the full pipeline involves several undocumented steps. Sharing this in case it helps others and in case you want to improve the docs.

Problem 1: hidden_size_per_layer_input is ambiguous

The config says hidden_size_per_layer_input: 256, which sounds like it's the embedding dimension. But embed_tokens_per_layer.weight has shape [262144, 8960] where 8960 = 35 layers * 256. The actual embedding dimension is num_hidden_layers * hidden_size_per_layer_input, not hidden_size_per_layer_input alone.

This confused me because the __init__ in Gemma4TextModel seems like it should create nn.Embedding(vocab, 256) but then loading the pretrained weight of shape [vocab, 8960] would fail. (It doesn't fail because from_pretrained handles the resize, but it's not obvious from reading the code.)

Problem 2: embed_tokens_per_layer is secretly a Gemma4TextScaledWordEmbedding

The PLE embedding isn't a plain nn.Embedding. It's a Gemma4TextScaledWordEmbedding that multiplies the lookup result by sqrt(hidden_size_per_layer_input) = sqrt(256) = 16.0.

This isn't mentioned anywhere in the config, the docstrings, or the model card. I only found it by inspecting type(lm.embed_tokens_per_layer).__name__ after my outputs were 16x too small.

Problem 3: The full PLE pipeline has undocumented steps

The actual PLE computation involves:

  1. Token-identity: embed_tokens_per_layer(input_ids) (scaled by sqrt(256)) -> reshape to [B, S, num_layers, ple_dim]
  2. Context-aware projection: per_layer_model_projection(inputs_embeds) (a Linear) -> scale by 1/sqrt(hidden_size) -> reshape to [B, S, num_layers, ple_dim] -> RMSNorm (per_layer_projection_norm)
  3. Combine: (context_projection + token_identity) * (1/sqrt(2))
  4. Each layer i gets per_layer_inputs[:, :, i, :]

This involves weights that aren't mentioned in the config at all:

  • per_layer_model_projection (Linear, hidden_size -> num_layers * ple_dim)
  • per_layer_projection_norm (RMSNorm, dim=ple_dim)
  • Two hardcoded scale factors: 1/sqrt(hidden_size) and 1/sqrt(2)

The get_per_layer_inputs() and project_per_layer_inputs() methods implement this, but there are no docstrings explaining the overall pipeline or the scale factors.

Suggestion

Adding a docstring to Gemma4TextModel (or the config class) explaining:

  1. That hidden_size_per_layer_input is the per-layer dimension, and the total embedding dim is num_hidden_layers * hidden_size_per_layer_input
  2. That the PLE embedding is scaled by sqrt(hidden_size_per_layer_input)
  3. A brief description of the full PLE pipeline (token lookup + context projection + norm + combine with scale factors)

This would save a lot of pain for anyone implementing Gemma4 outside of HuggingFace transformers (e.g. llama.cpp, candle, mlx, etc.).

Environment

  • transformers 5.5.0
  • Model: google/gemma-4-E2B-it

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type
    No fields configured for issues without a type.

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions