[vLLM IR][RMSNorm] Port RMSNormGated to vLLM IR Ops#38798
Conversation
Signed-off-by: Icey <1790571317@qq.com>
There was a problem hiding this comment.
Code Review
This pull request introduces the rms_norm_gated operation to the vLLM IR, including both native and Triton implementations, and updates the configuration to prioritize these implementations. The review identified that the native implementation of rms_norm_gated fails to utilize the bias parameter and ignores the activation parameter, and also noted that the bias parameter requires a more accurate type hint in both the native and Triton implementations.
| return torch.empty_like(x) | ||
|
|
||
|
|
||
| direct_register_custom_op( |
There was a problem hiding this comment.
I tried using wrap_triton instead of the custom op, but I'm having some issues.
wrap_triton is a torch.compile adapter layer, which requires all tensor parameters to be real Tensor objects so that Dynamo can correctly track their shape, stride, and dtype. However, the bias, z, and mean in layer_norm_fwd_kernel might be None in some cases.
When I tried assigning empty tensors to these parameters instead of None, I found that the output was always !
There was a problem hiding this comment.
cc @ProExpertProg Do you know the root cause of this problem?
There was a problem hiding this comment.
Can you create a smaller repro so we can ask the pytorch team? Would definitely be good to avoid wrapping into another torch op if possible!
There was a problem hiding this comment.
example:
"""Minimal layernorm-style repro for wrap_triton.
Compares three paths:
1. direct Triton launch
2. eager wrap_triton launch
3. make_fx-traced wrap_triton launch
"""
import torch
from torch.fx.experimental.proxy_tensor import make_fx
from torch.library import wrap_triton
import triton
import triton.language as tl
@triton.jit
def kernel(
X,
Y,
W,
B,
Z,
Mean,
Rstd,
stride_x_row,
stride_y_row,
stride_z_row,
M,
N: tl.constexpr,
eps,
BLOCK_N: tl.constexpr,
ROWS_PER_BLOCK: tl.constexpr,
HAS_BIAS: tl.constexpr,
HAS_Z: tl.constexpr,
STORE_MEAN: tl.constexpr,
NORM_BEFORE_GATE: tl.constexpr,
):
row_start = tl.program_id(0) * ROWS_PER_BLOCK
rows = row_start + tl.arange(0, ROWS_PER_BLOCK)
cols = tl.arange(0, BLOCK_N)
row_mask = rows[:, None] < M
col_mask = cols[None, :] < N
mask = row_mask & col_mask
x = tl.load(
X + rows[:, None] * stride_x_row + cols[None, :],
mask=mask,
other=0.0,
).to(tl.float32)
if HAS_Z and not NORM_BEFORE_GATE:
z = tl.load(
Z + rows[:, None] * stride_z_row + cols[None, :],
mask=mask,
other=0.0,
).to(tl.float32)
x = x * (z * tl.sigmoid(z))
mean = tl.sum(x, axis=1) / N
centered = tl.where(mask, x - mean[:, None], 0.0)
var = tl.sum(centered * centered, axis=1) / N
rstd = tl.rsqrt(var + eps)
if STORE_MEAN:
tl.store(Mean + rows, mean, mask=rows < M)
tl.store(Rstd + rows, rstd, mask=rows < M)
w = tl.load(W + cols, mask=cols < N, other=0.0).to(tl.float32)
y = centered * rstd[:, None] * w[None, :]
if HAS_BIAS:
b = tl.load(B + cols, mask=cols < N, other=0.0).to(tl.float32)
y = y + b[None, :]
if HAS_Z and NORM_BEFORE_GATE:
z = tl.load(
Z + rows[:, None] * stride_z_row + cols[None, :],
mask=mask,
other=0.0,
).to(tl.float32)
y = y * (z * tl.sigmoid(z))
tl.store(Y + rows[:, None] * stride_y_row + cols[None, :], y, mask=mask)
def run_kernel(
x: torch.Tensor,
w: torch.Tensor,
bias: torch.Tensor | None,
z: torch.Tensor | None,
*,
use_wrap: bool,
rows_per_block: int = 2,
eps: float = 1e-5,
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
m, n = x.shape
out = torch.empty_like(x)
mean = torch.empty((m,), dtype=torch.float32, device=x.device)
rstd = torch.empty((m,), dtype=torch.float32, device=x.device)
has_bias = bias is not None
has_z = z is not None
if bias is None:
bias = torch.empty((n,), dtype=w.dtype, device=x.device)
if z is None:
z = torch.empty_like(x)
launcher = wrap_triton(kernel) if use_wrap else kernel
launcher[(triton.cdiv(m, rows_per_block),)](
x,
out,
w,
bias,
z,
mean,
rstd,
x.stride(0),
out.stride(0),
z.stride(0),
m,
n,
eps,
BLOCK_N=triton.next_power_of_2(n),
ROWS_PER_BLOCK=rows_per_block,
HAS_BIAS=has_bias,
HAS_Z=has_z,
STORE_MEAN=True,
NORM_BEFORE_GATE=True,
)
return out, mean, rstd
def wrapped_fn(
x: torch.Tensor,
w: torch.Tensor,
bias: torch.Tensor | None,
z: torch.Tensor | None,
eps: float,
rows_per_block: int,
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
return run_kernel(
x,
w,
bias,
z,
use_wrap=True,
rows_per_block=rows_per_block,
eps=eps,
)
def run_case(has_bias: bool, has_z: bool) -> None:
torch.manual_seed(0)
x = torch.randn((8, 128), device="cuda", dtype=torch.float16)
w = torch.randn((128,), device="cuda", dtype=torch.float16)
bias = torch.randn((128,), device="cuda", dtype=torch.float16) if has_bias else None
z = torch.randn((8, 128), device="cuda", dtype=torch.float16) if has_z else None
eps = 1e-5
rows_per_block = 2
direct_out, direct_mean, direct_rstd = run_kernel(
x, w, bias, z, use_wrap=False, rows_per_block=rows_per_block, eps=eps
)
eager_out, eager_mean, eager_rstd = run_kernel(
x, w, bias, z, use_wrap=True, rows_per_block=rows_per_block, eps=eps
)
traced = make_fx(wrapped_fn)(x, w, bias, z, eps, rows_per_block)
traced_out, traced_mean, traced_rstd = traced(
x, w, bias, z, eps, rows_per_block
)
print(f"case has_bias={has_bias} has_z={has_z}")
print("max|direct - eager out| =", (direct_out - eager_out).abs().max().item())
print("max|direct - traced out| =", (direct_out - traced_out).abs().max().item())
print("max|direct - eager mean| =", (direct_mean - eager_mean).abs().max().item())
print(
"max|direct - traced mean| =",
(direct_mean - traced_mean).abs().max().item(),
)
print("max|direct - eager rstd| =", (direct_rstd - eager_rstd).abs().max().item())
print(
"max|direct - traced rstd| =",
(direct_rstd - traced_rstd).abs().max().item(),
)
print()
def main() -> None:
if not torch.cuda.is_available():
raise RuntimeError("CUDA is required for this repro.")
print("torch:", torch.__version__)
print("triton:", triton.__version__)
print("device:", torch.cuda.get_device_name(0))
print()
run_case(has_bias=False, has_z=False)
run_case(has_bias=True, has_z=True)
if __name__ == "__main__":
main()
outputs:
(vllm-env) root@coder-a100-01:/shared/wxs/vllm# python ./examples/repro_wrap_triton_layernorm_flags.py
torch: 2.10.0+cu128
triton: 3.6.0
device: NVIDIA A100-SXM4-80GB
case has_bias=False has_z=False
max|direct - eager out| = 0.0
max|direct - traced out| = 0.0
max|direct - eager mean| = 0.0
max|direct - traced mean| = 0.0
max|direct - eager rstd| = 0.0
max|direct - traced rstd| = 0.0
case has_bias=True has_z=True
max|direct - eager out| = 0.0
max|direct - traced out| = 0.0
max|direct - eager mean| = 0.0
max|direct - traced mean| = 0.0
max|direct - eager rstd| = 0.0
max|direct - traced rstd| = 0.0
Using the reduced layernorm-style Triton kernel, I could verify that direct Triton launch, eager wrap_triton, and make_fx wrap_triton all produce identical outputs.
Therefore, the problem cannot be attributed to wrap_triton or replacing optional tensors with real tensors as placeholders. I have already modified the pr to use wrap_trion, but I still don't understand why the output is !.
Signed-off-by: Icey <1790571317@qq.com>
|
I think we can reduce the complexity here by porting rms_norm_gated after we allow calling IR ops from other IR ops (should be done soon). That way we can do something like this: |
Signed-off-by: Icey <1790571317@qq.com>
Thanks. Good idea. |
|
This pull request has merge conflicts that must be resolved before it can be |
| if z is None: | ||
| z = torch.empty_like(x) | ||
|
|
||
| wrap_triton(layer_norm_fwd_kernel)[grid]( |
There was a problem hiding this comment.
I think it's cleaner to apply this to the kernel directly as a decorator, like this:
Signed-off-by: chaojun-zhang <chaojun.zhang@intel.com>
|
@wxsIcey I added XPU support for this op . wxsIcey#15, please help review, thanks |
Thanks for providing it, merged. |
ProExpertProg
left a comment
There was a problem hiding this comment.
not sure if you're still working on this
| ROWS_PER_BLOCK: tl.constexpr, | ||
| HAS_BIAS: tl.constexpr, | ||
| HAS_Z: tl.constexpr, | ||
| STORE_MEAN: tl.constexpr, |
There was a problem hiding this comment.
I assume this is an outdated diff?
| rms_norm = ["oink"] + default | ||
|
|
||
| return IrOpPriorityConfig.with_default(default, rms_norm=rms_norm) | ||
| rms_norm_gated = ["triton", "native"] |
There was a problem hiding this comment.
Should be native first if we're compiling, no?
| norm_before_gate: bool = False, | ||
| activation: str = "swish", | ||
| ) -> Tensor: | ||
| return _rms_norm_gated_triton_impl( |
There was a problem hiding this comment.
Why the extra function? Can we just inline?
| args = (x, weight, bias, z, epsilon, group_size, norm_before_gate, activation) | ||
|
|
||
| out = rms_norm_gated_native(*args) | ||
| ref = rms_norm_gated_ref(*args) |
There was a problem hiding this comment.
No real point in having a copy of the native implementation. It makes more sense to check expected properties of the output (see rms_norm and fused_add_rms_norm), and compare to the non-gated rms-norm, if anything
| return x, weight | ||
|
|
||
|
|
||
| def rms_norm_gated_inputs( |
There was a problem hiding this comment.
Use input generator?
Purpose
[vLLM IR][RMSNorm] Port RMSNormGated to vLLM IR Ops
Test Plan
Qwen3-Next functional testing has been conducted on the A100.
Test Result
Essential Elements of an Effective PR Description Checklist
supported_models.mdandexamplesfor a new model.