Skip to content

[Feature] Support mrope_section with rope_type: "yarn" #13219

@JustinTong0323

Description

@JustinTong0323

Checklist

background

When configuring rope_scaling, if rope_type is set to "yarn", the parameters mrope_section and mrope_interleaved are ignored. The current implementation instantiates the YaRNScalingRotaryEmbedding class, which does not handle these MRoPE (Multimodal Rotary Positional Embedding) parameters. This prevents using YaRN scaling in combination with the multimodal rope configurations needed for certain models.

solution

I would like the ability to use mrope_section and mrope_interleaved in conjunction with "rope_type": "yarn". This would allow for simultaneous application of YaRN context extension and multimodal positional embeddings.

This could be achieved by:

  1. Modifying YaRNScalingRotaryEmbedding to incorporate the logic from MRotaryEmbedding.
  2. Creating a new rotary embedding class that combines the features of both.

additional context

I encountered this issue when trying to launch a server for a Qwen multimodal model. The goal was to apply both YaRN scaling for a large context window and the specific MRoPE configuration for the model.

Here is the command I attempted to use:

python -m sglang.launch_server \
  --model-path Qwen/Qwen3-VL-235B-A22B-Instruct \
  --tp 8 \
  --mm-attention-backend triton_attn \
  --json-model-override-args '{"rope_scaling":{"rope_type":"yarn","factor":3.0,"original_max_position_embeddings":262144,"mrope_section":[24,20,20],"mrope_interleaved":true}}' \
  --context-length 1000000

While the command runs, the mrope_section and mrope_interleaved settings are not applied, which is not the desired behavior. Enabling this combination would be highly beneficial.

The relevant code snippets from sglang/srt/layers/rotary_embedding.py illustrate the current behavior:

Current handling for rope_type: "yarn":

elif scaling_type == "yarn":
scaling_factor = rope_scaling["factor"]
original_max_position = rope_scaling["original_max_position_embeddings"]
extra_kwargs = {
k: v
for k, v in rope_scaling.items()
if k
in ("extrapolation_factor", "attn_factor", "beta_fast", "beta_slow")
}
rotary_emb = YaRNScalingRotaryEmbedding(
head_size,
rotary_dim,
original_max_position,
base,
is_neox_style,
scaling_factor,
dtype,
**extra_kwargs,
)

As shown, mrope_section and mrope_interleaved are not passed to YaRNScalingRotaryEmbedding.

Current handling for rope_type: "default" with mrope_section:

elif scaling_type == "default":
if "mrope_section" in rope_scaling:
rotary_emb = MRotaryEmbedding(
head_size,
rotary_dim,
max_position,
base,
is_neox_style,
dtype,
mrope_section=rope_scaling["mrope_section"],
mrope_interleaved=rope_scaling.get("mrope_interleaved", False),
)
else:
rotary_emb = RotaryEmbedding(
head_size,
rotary_dim,
max_position,
base,
is_neox_style,
dtype,
)

This block demonstrates that mrope_section and mrope_interleaved are specifically handled when rope_type is "default" by instantiating MRotaryEmbedding.

Metadata

Metadata

Assignees

No one assigned

    Labels

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions