Skip to content

[vLLM IR][RMSNorm] Port RMSNormGated to vLLM IR Ops#38798

Open
wxsIcey wants to merge 5 commits into
vllm-project:mainfrom
wxsIcey:wxs/vllm-ir-rms-norm-gated
Open

[vLLM IR][RMSNorm] Port RMSNormGated to vLLM IR Ops#38798
wxsIcey wants to merge 5 commits into
vllm-project:mainfrom
wxsIcey:wxs/vllm-ir-rms-norm-gated

Conversation

@wxsIcey

@wxsIcey wxsIcey commented Apr 2, 2026

Copy link
Copy Markdown
Contributor

Purpose

[vLLM IR][RMSNorm] Port RMSNormGated to vLLM IR Ops

Test Plan

Qwen3-Next functional testing has been conducted on the A100.

Test Result


Essential Elements of an Effective PR Description Checklist
  • The purpose of the PR, such as "Fix some issue (link existing issues this PR will resolve)".
  • The test plan, such as providing test command.
  • The test results, such as pasting the results comparison before and after, or e2e results
  • (Optional) The necessary documentation update, such as updating supported_models.md and examples for a new model.
  • (Optional) Release notes update. If your change is user facing, please update the release notes draft in the Google Doc.

Signed-off-by: Icey <1790571317@qq.com>
@mergify mergify Bot added nvidia rocm Related to AMD ROCm labels Apr 2, 2026
@github-project-automation github-project-automation Bot moved this to Todo in AMD Apr 2, 2026

@gemini-code-assist gemini-code-assist Bot left a comment

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces the rms_norm_gated operation to the vLLM IR, including both native and Triton implementations, and updates the configuration to prioritize these implementations. The review identified that the native implementation of rms_norm_gated fails to utilize the bias parameter and ignores the activation parameter, and also noted that the bias parameter requires a more accurate type hint in both the native and Triton implementations.

Comment thread vllm/ir/ops/layernorm.py
Comment thread vllm/kernels/triton/ops/layernorm.py Outdated
Comment thread vllm/kernels/triton/ops/layernorm.py Outdated
return torch.empty_like(x)


direct_register_custom_op(

@wxsIcey wxsIcey Apr 2, 2026

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I tried using wrap_triton instead of the custom op, but I'm having some issues.

wrap_triton is a torch.compile adapter layer, which requires all tensor parameters to be real Tensor objects so that Dynamo can correctly track their shape, stride, and dtype. However, the bias, z, and mean in layer_norm_fwd_kernel might be None in some cases.

When I tried assigning empty tensors to these parameters instead of None, I found that the output was always !

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

cc @ProExpertProg Do you know the root cause of this problem?

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you create a smaller repro so we can ask the pytorch team? Would definitely be good to avoid wrapping into another torch op if possible!

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

example:

"""Minimal layernorm-style repro for wrap_triton.

Compares three paths:
1. direct Triton launch
2. eager wrap_triton launch
3. make_fx-traced wrap_triton launch
"""

import torch
from torch.fx.experimental.proxy_tensor import make_fx
from torch.library import wrap_triton

import triton
import triton.language as tl


@triton.jit
def kernel(
    X,
    Y,
    W,
    B,
    Z,
    Mean,
    Rstd,
    stride_x_row,
    stride_y_row,
    stride_z_row,
    M,
    N: tl.constexpr,
    eps,
    BLOCK_N: tl.constexpr,
    ROWS_PER_BLOCK: tl.constexpr,
    HAS_BIAS: tl.constexpr,
    HAS_Z: tl.constexpr,
    STORE_MEAN: tl.constexpr,
    NORM_BEFORE_GATE: tl.constexpr,
):
    row_start = tl.program_id(0) * ROWS_PER_BLOCK
    rows = row_start + tl.arange(0, ROWS_PER_BLOCK)
    cols = tl.arange(0, BLOCK_N)

    row_mask = rows[:, None] < M
    col_mask = cols[None, :] < N
    mask = row_mask & col_mask

    x = tl.load(
        X + rows[:, None] * stride_x_row + cols[None, :],
        mask=mask,
        other=0.0,
    ).to(tl.float32)

    if HAS_Z and not NORM_BEFORE_GATE:
        z = tl.load(
            Z + rows[:, None] * stride_z_row + cols[None, :],
            mask=mask,
            other=0.0,
        ).to(tl.float32)
        x = x * (z * tl.sigmoid(z))

    mean = tl.sum(x, axis=1) / N
    centered = tl.where(mask, x - mean[:, None], 0.0)
    var = tl.sum(centered * centered, axis=1) / N
    rstd = tl.rsqrt(var + eps)

    if STORE_MEAN:
        tl.store(Mean + rows, mean, mask=rows < M)
    tl.store(Rstd + rows, rstd, mask=rows < M)

    w = tl.load(W + cols, mask=cols < N, other=0.0).to(tl.float32)
    y = centered * rstd[:, None] * w[None, :]

    if HAS_BIAS:
        b = tl.load(B + cols, mask=cols < N, other=0.0).to(tl.float32)
        y = y + b[None, :]

    if HAS_Z and NORM_BEFORE_GATE:
        z = tl.load(
            Z + rows[:, None] * stride_z_row + cols[None, :],
            mask=mask,
            other=0.0,
        ).to(tl.float32)
        y = y * (z * tl.sigmoid(z))

    tl.store(Y + rows[:, None] * stride_y_row + cols[None, :], y, mask=mask)


def run_kernel(
    x: torch.Tensor,
    w: torch.Tensor,
    bias: torch.Tensor | None,
    z: torch.Tensor | None,
    *,
    use_wrap: bool,
    rows_per_block: int = 2,
    eps: float = 1e-5,
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
    m, n = x.shape
    out = torch.empty_like(x)
    mean = torch.empty((m,), dtype=torch.float32, device=x.device)
    rstd = torch.empty((m,), dtype=torch.float32, device=x.device)
    has_bias = bias is not None
    has_z = z is not None

    if bias is None:
        bias = torch.empty((n,), dtype=w.dtype, device=x.device)
    if z is None:
        z = torch.empty_like(x)

    launcher = wrap_triton(kernel) if use_wrap else kernel
    launcher[(triton.cdiv(m, rows_per_block),)](
        x,
        out,
        w,
        bias,
        z,
        mean,
        rstd,
        x.stride(0),
        out.stride(0),
        z.stride(0),
        m,
        n,
        eps,
        BLOCK_N=triton.next_power_of_2(n),
        ROWS_PER_BLOCK=rows_per_block,
        HAS_BIAS=has_bias,
        HAS_Z=has_z,
        STORE_MEAN=True,
        NORM_BEFORE_GATE=True,
    )
    return out, mean, rstd


def wrapped_fn(
    x: torch.Tensor,
    w: torch.Tensor,
    bias: torch.Tensor | None,
    z: torch.Tensor | None,
    eps: float,
    rows_per_block: int,
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
    return run_kernel(
        x,
        w,
        bias,
        z,
        use_wrap=True,
        rows_per_block=rows_per_block,
        eps=eps,
    )


def run_case(has_bias: bool, has_z: bool) -> None:
    torch.manual_seed(0)
    x = torch.randn((8, 128), device="cuda", dtype=torch.float16)
    w = torch.randn((128,), device="cuda", dtype=torch.float16)
    bias = torch.randn((128,), device="cuda", dtype=torch.float16) if has_bias else None
    z = torch.randn((8, 128), device="cuda", dtype=torch.float16) if has_z else None
    eps = 1e-5
    rows_per_block = 2

    direct_out, direct_mean, direct_rstd = run_kernel(
        x, w, bias, z, use_wrap=False, rows_per_block=rows_per_block, eps=eps
    )
    eager_out, eager_mean, eager_rstd = run_kernel(
        x, w, bias, z, use_wrap=True, rows_per_block=rows_per_block, eps=eps
    )
    traced = make_fx(wrapped_fn)(x, w, bias, z, eps, rows_per_block)
    traced_out, traced_mean, traced_rstd = traced(
        x, w, bias, z, eps, rows_per_block
    )

    print(f"case has_bias={has_bias} has_z={has_z}")
    print("max|direct - eager out| =", (direct_out - eager_out).abs().max().item())
    print("max|direct - traced out| =", (direct_out - traced_out).abs().max().item())
    print("max|direct - eager mean| =", (direct_mean - eager_mean).abs().max().item())
    print(
        "max|direct - traced mean| =",
        (direct_mean - traced_mean).abs().max().item(),
    )
    print("max|direct - eager rstd| =", (direct_rstd - eager_rstd).abs().max().item())
    print(
        "max|direct - traced rstd| =",
        (direct_rstd - traced_rstd).abs().max().item(),
    )
    print()


def main() -> None:
    if not torch.cuda.is_available():
        raise RuntimeError("CUDA is required for this repro.")

    print("torch:", torch.__version__)
    print("triton:", triton.__version__)
    print("device:", torch.cuda.get_device_name(0))
    print()

    run_case(has_bias=False, has_z=False)
    run_case(has_bias=True, has_z=True)


if __name__ == "__main__":
    main()

outputs:

(vllm-env) root@coder-a100-01:/shared/wxs/vllm# python ./examples/repro_wrap_triton_layernorm_flags.py
torch: 2.10.0+cu128
triton: 3.6.0
device: NVIDIA A100-SXM4-80GB

case has_bias=False has_z=False
max|direct - eager out| = 0.0
max|direct - traced out| = 0.0
max|direct - eager mean| = 0.0
max|direct - traced mean| = 0.0
max|direct - eager rstd| = 0.0
max|direct - traced rstd| = 0.0

case has_bias=True has_z=True
max|direct - eager out| = 0.0
max|direct - traced out| = 0.0
max|direct - eager mean| = 0.0
max|direct - traced mean| = 0.0
max|direct - eager rstd| = 0.0
max|direct - traced rstd| = 0.0

Using the reduced layernorm-style Triton kernel, I could verify that direct Triton launch, eager wrap_triton, and make_fx wrap_triton all produce identical outputs.

Therefore, the problem cannot be attributed to wrap_triton or replacing optional tensors with real tensors as placeholders. I have already modified the pr to use wrap_trion, but I still don't understand why the output is !.

Signed-off-by: Icey <1790571317@qq.com>
@ProExpertProg

ProExpertProg commented Apr 3, 2026

Copy link
Copy Markdown
Collaborator

I think we can reduce the complexity here by porting rms_norm_gated after we allow calling IR ops from other IR ops (should be done soon). That way we can do something like this:

@ir.register_op
def rms_norm_silu_mul(x, weight, epsilon)
    x = ir.ops.rms_norm(x, weight, epsilon)
    return ir.ops.silu_mul(x)

@ProExpertProg ProExpertProg added the vllm-ir vLLM IR: intermediate representation and kernel registration label Apr 3, 2026
@github-project-automation github-project-automation Bot moved this to Todo in vLLM IR Apr 3, 2026
@ProExpertProg ProExpertProg moved this from Todo to In Progress in vLLM IR Apr 3, 2026
Signed-off-by: Icey <1790571317@qq.com>
@wxsIcey

wxsIcey commented Apr 7, 2026

Copy link
Copy Markdown
Contributor Author

I think we can reduce the complexity here by porting rms_norm_gated after we allow calling IR ops from other IR ops (should be done soon). That way we can do something like this:

@ir.register_op
def rms_norm_silu_mul(x, weight, epsilon)
    x = ir.ops.rms_norm(x, weight, epsilon)
    return ir.ops.silu_mul(x)

Thanks. Good idea.

@mergify

mergify Bot commented Apr 7, 2026

Copy link
Copy Markdown
Contributor

This pull request has merge conflicts that must be resolved before it can be
merged. Please rebase the PR, @wxsIcey.

https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/syncing-a-fork

@mergify mergify Bot added the needs-rebase label Apr 7, 2026
if z is None:
z = torch.empty_like(x)

wrap_triton(layer_norm_fwd_kernel)[grid](

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it's cleaner to apply this to the kernel directly as a decorator, like this:

@torch.library.wrap_triton # needed for make_fx lowering to work

@wxsIcey wxsIcey changed the title [vLLM IR] rms_norm_gated [vLLM IR] Port RMSNormGated to vLLM IR Ops Apr 8, 2026
@wxsIcey wxsIcey changed the title [vLLM IR] Port RMSNormGated to vLLM IR Ops [vLLM IR][RMSNorm] Port RMSNormGated to vLLM IR Ops Apr 8, 2026
Signed-off-by: chaojun-zhang <chaojun.zhang@intel.com>
@chaojun-zhang

Copy link
Copy Markdown
Contributor

@wxsIcey I added XPU support for this op . wxsIcey#15, please help review, thanks

@wxsIcey

wxsIcey commented Apr 24, 2026

Copy link
Copy Markdown
Contributor Author

@wxsIcey I added XPU support for this op . wxsIcey#15, please help review, thanks

Thanks for providing it, merged.

@mergify mergify Bot added the intel-gpu Related to Intel GPU label Apr 24, 2026

@ProExpertProg ProExpertProg left a comment

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

not sure if you're still working on this

ROWS_PER_BLOCK: tl.constexpr,
HAS_BIAS: tl.constexpr,
HAS_Z: tl.constexpr,
STORE_MEAN: tl.constexpr,

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I assume this is an outdated diff?

Comment thread vllm/platforms/cuda.py
rms_norm = ["oink"] + default

return IrOpPriorityConfig.with_default(default, rms_norm=rms_norm)
rms_norm_gated = ["triton", "native"]

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should be native first if we're compiling, no?

norm_before_gate: bool = False,
activation: str = "swish",
) -> Tensor:
return _rms_norm_gated_triton_impl(

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why the extra function? Can we just inline?

args = (x, weight, bias, z, epsilon, group_size, norm_before_gate, activation)

out = rms_norm_gated_native(*args)
ref = rms_norm_gated_ref(*args)

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No real point in having a copy of the native implementation. It makes more sense to check expected properties of the output (see rms_norm and fused_add_rms_norm), and compare to the non-gated rms-norm, if anything

return x, weight


def rms_norm_gated_inputs(

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Use input generator?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

intel-gpu Related to Intel GPU needs-rebase nvidia rocm Related to AMD ROCm vllm-ir vLLM IR: intermediate representation and kernel registration

Projects

Status: Todo
Status: No status
Status: In Progress

Development

Successfully merging this pull request may close these issues.

3 participants