Skip to content

Stop decomposing native_layer_norm to fix bf16 precision divergence#177163

Draft
aorenste wants to merge 1 commit intomainfrom
ptq/168126
Draft

Stop decomposing native_layer_norm to fix bf16 precision divergence#177163
aorenste wants to merge 1 commit intomainfrom
ptq/168126

Conversation

@aorenste
Copy link
Copy Markdown
Contributor

@aorenste aorenste commented Mar 11, 2026

Human Note

The decomposition of aten.native_layer_norm computes mean/variance using torch.var_mean, which uses a different reduction algorithm than the C++ kernel's RowwiseMoments (Welford + cascade summation). This creates fp32 differences of ~7e-7 that cause bf16 rounding divergences at ~0.001% of elements, amplified by subsequent matmuls to errors of 0.018+ in realistic models.

The fix removes the native_layer_norm decomposition from the inductor decomposition table and makes the make_fallback unconditional (not MTIA-only). This ensures the C++ kernel is used directly, matching eager behavior exactly.

The decomposition formula (x-mean)rstdweight+bias is bit-exact with the C++ kernel when given the same statistics. The entire error comes from the statistics computation. A follow-up PR could add a direct inductor lowering using var_mean_welford_ to re-enable Triton kernel generation while maintaining numerical accuracy.

Agent Report

Issue #168126: Large numeric divergence for torch compile vs eager in bf16

Summary

The decomposition of aten.native_layer_norm used by torch.compile computes
mean/variance statistics using a different algorithm (torch.var_mean) than the
C++ eager kernel (RowwiseMoments — Welford's algorithm with cascade summation).
This creates small fp32 differences (up to 7e-7) that cause bf16 rounding divergences,
which are then amplified by subsequent operations (linear layers, matmuls).

Root Cause Analysis

The error chain

  1. Statistics algorithm mismatch — The C++ native_layer_norm kernel computes
    mean/variance using RowwiseMoments (Welford + cascade summation in
    aten/src/ATen/native/cpu/moments_utils.h). The decomposition in
    torch/_refs/__init__.py::_normalize uses torch.var_mean, which uses a
    different reduction algorithm. Mean differs by up to 4.66e-08 across 76% of
    rows; variance by up to 5.5e-07.

  2. Exact formula, wrong stats — When given the same mean/rstd, the
    decomposition's formula (x - mean) * rstd * weight + bias is bit-exact
    with the C++ kernel (diff = 0.0). The entire error comes from the statistics
    computation.

  3. bf16 rounding divergence — The ~7e-7 fp32 error in the norm output causes
    ~0.001% of elements to round differently when cast to bf16. Each mismatched
    element differs by ~1 bf16 ULP (~0.004 near unit-scale values).

  4. Matmul amplification — A bf16 linear layer sums 128 products. A 1-ULP
    input difference in ~7 elements yields max output error of ~0.125. Through
    multiple layers, the error grows further.

Key evidence (all from runtime, not speculative)

Test Result Significance
aot_eager vs eager PASS No decomposition → no divergence
aot_eager_decomp_partition vs eager FAIL (2/524288) Decomposition causes divergence
aot_eager_decomp_partition WITHOUT layer_norm decomp vs eager PASS Removing decomposition fixes it
decomp formula with eager stats vs C++ kernel diff = 0.0 Formula is bit-exact
decomp formula with var_mean stats vs C++ kernel diff = 7.15e-7 All error from stats
fp32 end-to-end PASS Error within fp32 tolerance
bf16 end-to-end (with or without autocast) FAIL bf16 rounding amplifies fp32 error

CUDA has the same root cause

The CUDA native_layer_norm kernel uses a custom Welford implementation with
vectorized loads and warp shuffle reduction, while torch.var_mean on CUDA uses
a different Welford through the generic gpu_reduce_kernel. Same algorithm
mismatch, same fix applies.

What the Fix Does

Stop decomposing native_layer_norm in the inductor.

Two changes:

  1. torch/_inductor/decomposition.py_native_layer_norm now always returns
    NotImplemented instead of delegating to the decomposition. Previously only
    MTIA was excluded.
  2. torch/_inductor/lowering.pymake_fallback(aten.native_layer_norm)
    is now unconditional (previously gated behind torch.mtia._is_compiled()).

The C++ kernel is already well-optimized (vectorized Welford, SIMD on CPU, warp
shuffle reduction on CUDA) and numerically stable. Removing the decomposition
preserves exact eager-compile agreement for LayerNorm.

Regression test added

test_layer_norm_bf16_numerics in test/inductor/test_torchinductor.py — verifies
that LayerNorm + Linear under bf16 autocast produces identical results between
eager and compiled modes.

Test updated

The existing GPU test for buffer reuse with LayerNorm was updated. Since
native_layer_norm is no longer decomposed, there is no fused Triton kernel to
check in_out_ptr against. The test now verifies correctness only.

Test Results

Before fix: test fails

EXPECTED FAIL (before fix): Tensor-likes are not close!

Mismatched elements: 1 / 524288 (0.0%)
Greatest absolute difference: 0.00030517578125 at index (2, 15, 24, 117) (up to 1e-05 allowed)
Greatest relative difference: 0.032958984375 at index (2, 15, 24, 117) (up to 0.016 allowed)

After fix: test passes

PASS: compiled output matches eager output

In-tree test passes

test_layer_norm_bf16_numerics_cpu ... ok
Ran 1 test in 3.892s
OK
Repro Script (minimal, CPU)
import torch
from torch import nn

norm = nn.LayerNorm(128, eps=1e-5)
linear = nn.Linear(128, 128, bias=False)

def f(x):
    return linear(norm(x))

compiled_f = torch.compile(f, fullgraph=True)

with torch.autocast(device_type="cpu", dtype=torch.bfloat16):
    torch.manual_seed(42)
    x = torch.randn(4, 32, 32, 128)
    out_eager = f(x)
    out_compiled = compiled_f(x)
    torch.testing.assert_close(out_eager, out_compiled)

Before fix output:

AssertionError: Tensor-likes are not close!
Mismatched elements: 1 / 524288 (0.0%)
Greatest absolute difference: 0.00030517578125
Greatest relative difference: 0.032958984375

After fix output: (no error)

Remaining Gaps

Inductor fusion precision (separate issue)

The original full repro (TriangleMultiplicationOutgoing from the Boltz model) still
fails with the inductor backend even after this fix. Investigation revealed an
additional, separate source of divergence: inductor's kernel fusion computes
bf16 pointwise operations (sigmoid, multiply) in fp32 without intermediate bf16
truncation
.

For example, p_in(x) * g_in(x).sigmoid() where both linears return bf16:

  • Eager: sigmoid produces bf16 output → multiply loads two bf16 values
  • Inductor: sigmoid result stays fp32 → multiply in fp32 → truncate to bf16

This skips one bf16 truncation, causing ~26% of elements to differ by 1 ULP.
The TORCHINDUCTOR_EMULATE_PRECISION_CASTS=1 flag does not fix this case.

After our fix: aot_eager_decomp_partition perfectly matches eager (0
mismatched elements), confirming the decomposition-level fix is complete. The
remaining divergence is purely from inductor's code generation and is a separate,
pre-existing behavior.

Evidence from the full model:

Backend max_abs_diff Mismatched elements
aot_eager 0.0 0/262144 (0%)
aot_eager_decomp_partition 0.0 0/262144 (0%)
inductor 0.0125 260734/262144 (99.5%)

This inductor fusion precision issue should be tracked separately.

Performance

Removing the decomposition means inductor cannot fuse LayerNorm into surrounding
kernels. For a dedicated lowering (using inductor's Welford reduction + pointwise
affine), a follow-up would be needed. The C++ fallback is already well-optimized
for the standalone operation, so the impact is likely small for most workloads.

Other norms

native_group_norm uses the same _normalizevar_mean decomposition path
and may have the same divergence issue. This should be investigated separately.

Fixes #168126

Repro Script
import torch
from torch import Tensor, nn


class TriangleMultiplicationOutgoing(nn.Module):
    def __init__(self, dim: int = 128) -> None:
        super().__init__()

        self.norm_in = nn.LayerNorm(dim, eps=1e-5)
        self.p_in = nn.Linear(dim, 2 * dim, bias=False)
        self.g_in = nn.Linear(dim, 2 * dim, bias=False)

        self.norm_out = nn.LayerNorm(dim)
        self.p_out = nn.Linear(dim, dim, bias=False)
        self.g_out = nn.Linear(dim, dim, bias=False)

    def forward(self, x: Tensor, mask: Tensor) -> Tensor:
        x = self.norm_in(x)
        x_in = x
        x = self.p_in(x) * self.g_in(x).sigmoid()
        x = x * mask.unsqueeze(-1)
        a, b = torch.chunk(x.float(), 2, dim=-1)
        x = torch.einsum("bikd,bjkd->bijd", a, b)
        x = self.p_out(self.norm_out(x)) * self.g_out(x_in).sigmoid()
        x = x + x_in
        return x


if __name__ == "__main__":
    for dtype in [torch.float32, torch.bfloat16]:
        print(f"Testing with dtype: {dtype}")
        with torch.autocast(device_type="cuda", dtype=dtype):
            torch.manual_seed(42)
            x = torch.randn(16, 128, 128, 128, device="cuda")
            mask = torch.randint(0, 2, (16, 128, 128), device="cuda")

            eager_layer = TriangleMultiplicationOutgoing().cuda()
            compiled_layer = torch.compile(TriangleMultiplicationOutgoing().cuda(), fullgraph=True)

            # Copy weights from reference to optimized to ensure identical parameters
            with torch.no_grad():
                for param, ref_param in zip(compiled_layer.parameters(), eager_layer.parameters()):
                    param.data.copy_(ref_param.data)

            out_eager = eager_layer(x, mask)
            out_compiled = compiled_layer(x, mask)
            torch.testing.assert_close(out_eager, out_compiled)
        print(f"Passed with dtype: {dtype}")
Agent Worklog

Worklog: Issue #168126 - LayerNorm bf16 divergence

Run 1: Diagnosis and Plan

Build setup

  • Job venv uses Python 3.12; built pytorch CPU-only (USE_CUDA=0 USE_MKL=0 BLAS=Eigen)
  • Required PYTHONPATH override due to stale /home/aorenste/local/pytorch in sys.path

Reproduction

  • Confirmed minimal repro (CPU, no CUDA) fails: 2 mismatched elements out of 524288
  • Max abs diff: 0.000275, max rel diff: 0.123 (exceeds bf16 tolerances)
  • Confirmed: aot_eager PASS, aot_eager_decomp_partition FAIL, inductor FAIL
  • Confirmed: fp32 PASS, bf16 (with or without autocast) FAIL

Key diagnostic results

  1. aot_eager_decomp_partition without layer_norm decomp: PASS — removing the
    decomposition of native_layer_norm fixes the issue entirely

  2. Decomp formula with EAGER stats: diff = 0.0 — the decomposition's
    (x - mean) * rstd * weight + bias formula is BIT-EXACT with the C++ kernel
    when given the same mean/rstd

  3. Decomp formula with DECOMP stats: diff = 7.15e-7 — the ENTIRE error comes
    from the mean/variance computation being different

  4. torch.var_mean vs RowwiseMoments: mean differs by up to 4.66e-08,
    variance by up to 5.5e-7 across 76% of rows. The C++ native_layer_norm kernel
    uses Welford's algorithm with cascade summation (RowwiseMoments), while the
    decomposition uses torch.var_mean which uses a different reduction algorithm.

  5. rsqrt vs 1/sqrt: ZERO difference (not a factor)

  6. bf16 rounding amplification: 7 mismatched bf16 elements after casting norm
    output → after matmul amplification: max diff 0.125, 23 mismatched elements

Root cause

The decomposition of aten.native_layer_norm in torch/_refs/__init__.py::_normalize
computes mean/variance using torch.var_mean, while the C++ kernel uses
RowwiseMoments (Welford + cascade sum). These give different fp32 results
(≤7e-7), which cause bf16 rounding divergences at ~0.001% of elements, then
amplified by subsequent ops (linears, matmuls).

Run 2

User: Verify if the cuda implementation has the same problem and would have the same solution

CUDA kernel analysis

Examined aten/src/ATen/native/cuda/layer_norm_kernel.cu. The CUDA native_layer_norm
kernel has two paths:

  1. Vectorized path (line 1114, for float/half/bf16 with aligned data, N < 2^24,
    N % 4 == 0): Uses compute_stats()cuWelfordOnlineSum() + cuWelfordCombine()
    — a custom Welford implementation with vectorized loads, warp shuffle reduction,
    and shared-memory inter-warp reduction.

  2. Non-vectorized path (line 1117): Uses RowwiseMomentsCUDAKernel
    WelfordOps with cuda_utils::BlockReduce — Welford's algorithm through the
    standard CUDA block reduction utility.

Both paths use Welford's online algorithm, but with CUDA-specific parallelism
(warp shuffles, shared memory).

CUDA torch.var_mean analysis

Examined aten/src/ATen/native/cuda/ReduceMomentKernel.cu. The CUDA var_mean
kernel (std_var_kernel_impl) also uses WelfordOps, but goes through
gpu_reduce_kernel<scalar_t, out_t, 2> — the generic GPU reduction framework
with unrolling factor 2.

Same problem, different Welford implementations

Both the native_layer_norm CUDA kernel and torch.var_mean on CUDA use
Welford's algorithm, but through different code paths:

  • native_layer_norm: Per-row Welford with vectorized loads (4 elements at a time),
    intra-warp shuffle reduction, inter-warp shared-memory tree reduction. Thread
    layout: blockDim.x * blockDim.y threads per row, one row per block.

  • torch.var_mean (via gpu_reduce_kernel): Generic multi-dimensional reduction
    framework with unrolling factor 2, different thread/block decomposition and
    accumulation order.

The different accumulation orders produce different fp32 rounding, just like the
CPU case. The errors may be slightly different in magnitude due to different
parallelism granularity, but the fundamental issue is identical.

Decomposition guard analysis

In torch/_inductor/decomposition.py (line 206-217):

@register_decomposition(aten.native_layer_norm)
def _native_layer_norm(input, normalized_shape, weight, bias, eps):
    if input.is_mtia:
        return NotImplemented
    return decomp_native_layer_norm(input, normalized_shape, weight, bias, eps)

Only MTIA is excluded from decomposition. CUDA tensors ARE decomposed, so CUDA
hits the same _normalize()torch.var_mean() path that causes the divergence.

Conclusion

Yes, the CUDA path has the same problem and the same solution would fix it.
The fix (making _native_layer_norm return NotImplemented unconditionally, or
at least for CUDA too) prevents the decomposition and keeps the native CUDA kernel,
which eliminates the statistics mismatch. This is device-agnostic — the root cause
is the decomposition replacing a specialized per-row Welford with a generic
reduction Welford, regardless of whether the device is CPU or CUDA.

Run 3

User: I agree with the reproduction and the proposed solution. Implement and verify it.

Your job:

  • write the smallest fix that explains the reproduced failure
  • add or update a focused regression test
  • demonstrate that the test would fail before the fix
  • demonstrate that the same test passes with the fix
  • keep scope tight to the approved diagnosis and plan

Rules:

  • no speculative cleanup or unrelated refactors
  • if the test does not fail before the fix, stop and report that the evidence is insufficient
  • if the fix requires a broader change than planned, stop and ask before expanding scope
  • prefer the narrowest test that captures the reported failure mode

Output:

  • patch summary
  • regression test added or updated
  • evidence that it fails before the fix
  • evidence that it passes after the fix
  • any remaining risks or gaps

Run 4

User: Continue with the build.

Run 5

User: Rebuild and then continue.

Build

  • Rebuilt PyTorch successfully (CPU-only, incremental build)

Verification: test fails before fix, passes after

  • Monkey-patched old decomposition back: test correctly fails (1/524288 mismatched,
    max_abs_diff=0.000305, max_rel_diff=0.033)
  • With fix applied: test passes (0 mismatched)

In-tree regression test

  • test_layer_norm_bf16_numerics_cpu passes (3.9s)

Full model repro: discovered additional inductor fusion issue

  • The full TriangleMultiplicationOutgoing repro still fails with inductor
  • Investigated: aot_eager_decomp_partition now matches eager perfectly (0/262144)
  • The remaining divergence is purely from inductor code generation
  • Root cause: inductor fuses sigmoid + multiply into one kernel, computing
    intermediates in fp32 without bf16 truncation between ops
  • p_in(x) * g_in(x).sigmoid() with bf16: 138080/524288 (26%) mismatched
  • This is a separate, pre-existing inductor behavior (not LayerNorm-related)
  • TORCHINDUCTOR_EMULATE_PRECISION_CASTS=1 does not fix it

Linting

  • All changed files pass linting (RUFF, PYFMT)

Artifacts generated

  • fix.diff — 3 files changed: decomposition.py, lowering.py, test_torchinductor.py
  • report.md — updated with full analysis including the inductor fusion gap

This PR was generated by ptq with human review.

cc @voznesenskym @penguinwu @EikanWang @jgong5 @Guobing-Chen @XiaobingSuper @zhuhaozhe @blzheng @wenzhe-nrv @jiayisunx @ipiszy @kadeng @muchulee8 @amjames @chauhang @aakhundov @coconutruben @jataylo

@pytorch-bot
Copy link
Copy Markdown

pytorch-bot bot commented Mar 11, 2026

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/177163

Note: Links to docs will display an error until the docs builds have been completed.

❌ 14 New Failures, 4 Unrelated Failures

As of commit c6e26ca with merge base 4bc9d7f (image):

NEW FAILURES - The following jobs have failed:

FLAKY - The following jobs failed but were likely due to flakiness present on trunk:

BROKEN TRUNK - The following job failed but were present on the merge base:

👉 Rebase onto the `viable/strict` branch to avoid these failures

UNSTABLE - The following job is marked as unstable, possibly due to flakiness on trunk:

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@pytorch-bot
Copy link
Copy Markdown

pytorch-bot bot commented Mar 11, 2026

This PR needs a release notes: label

If your changes are user facing and intended to be a part of release notes, please use a label starting with release notes:.

If not, please add the topic: not user facing label.

To add a label, you can comment to pytorchbot, for example
@pytorchbot label "topic: not user facing"

For more information, see
https://github.com/pytorch/pytorch/wiki/PyTorch-AutoLabel-Bot#why-categorize-for-release-notes-and-how-does-it-work.

@vadimkantorov
Copy link
Copy Markdown
Contributor

Btw, is torch.var_mean's reduction less performant/accurate? If so - maybe it also should be updated to use the same reduction algo as in native_layer_norm?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

Projects

None yet

Development

Successfully merging this pull request may close these issues.

Large numeric divergence for torch compile vs eager in bf16

2 participants