Stop decomposing native_layer_norm to fix bf16 precision divergence#177163
Stop decomposing native_layer_norm to fix bf16 precision divergence#177163
Conversation
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/177163
Note: Links to docs will display an error until the docs builds have been completed. ❌ 14 New Failures, 4 Unrelated FailuresAs of commit c6e26ca with merge base 4bc9d7f ( NEW FAILURES - The following jobs have failed:
FLAKY - The following jobs failed but were likely due to flakiness present on trunk:
BROKEN TRUNK - The following job failed but were present on the merge base:👉 Rebase onto the `viable/strict` branch to avoid these failures
UNSTABLE - The following job is marked as unstable, possibly due to flakiness on trunk:
This comment was automatically generated by Dr. CI and updates every 15 minutes. |
This PR needs a
|
|
Btw, is |
Human Note
The decomposition of aten.native_layer_norm computes mean/variance using torch.var_mean, which uses a different reduction algorithm than the C++ kernel's RowwiseMoments (Welford + cascade summation). This creates fp32 differences of ~7e-7 that cause bf16 rounding divergences at ~0.001% of elements, amplified by subsequent matmuls to errors of 0.018+ in realistic models.
The fix removes the native_layer_norm decomposition from the inductor decomposition table and makes the make_fallback unconditional (not MTIA-only). This ensures the C++ kernel is used directly, matching eager behavior exactly.
The decomposition formula (x-mean)rstdweight+bias is bit-exact with the C++ kernel when given the same statistics. The entire error comes from the statistics computation. A follow-up PR could add a direct inductor lowering using var_mean_welford_ to re-enable Triton kernel generation while maintaining numerical accuracy.
Agent Report
Issue #168126: Large numeric divergence for torch compile vs eager in bf16
Summary
The decomposition of
aten.native_layer_normused bytorch.compilecomputesmean/variance statistics using a different algorithm (
torch.var_mean) than theC++ eager kernel (
RowwiseMoments— Welford's algorithm with cascade summation).This creates small fp32 differences (up to 7e-7) that cause bf16 rounding divergences,
which are then amplified by subsequent operations (linear layers, matmuls).
Root Cause Analysis
The error chain
Statistics algorithm mismatch — The C++
native_layer_normkernel computesmean/variance using
RowwiseMoments(Welford + cascade summation inaten/src/ATen/native/cpu/moments_utils.h). The decomposition intorch/_refs/__init__.py::_normalizeusestorch.var_mean, which uses adifferent reduction algorithm. Mean differs by up to 4.66e-08 across 76% of
rows; variance by up to 5.5e-07.
Exact formula, wrong stats — When given the same mean/rstd, the
decomposition's formula
(x - mean) * rstd * weight + biasis bit-exactwith the C++ kernel (diff = 0.0). The entire error comes from the statistics
computation.
bf16 rounding divergence — The ~7e-7 fp32 error in the norm output causes
~0.001% of elements to round differently when cast to bf16. Each mismatched
element differs by ~1 bf16 ULP (~0.004 near unit-scale values).
Matmul amplification — A bf16 linear layer sums 128 products. A 1-ULP
input difference in ~7 elements yields max output error of ~0.125. Through
multiple layers, the error grows further.
Key evidence (all from runtime, not speculative)
aot_eagervs eageraot_eager_decomp_partitionvs eageraot_eager_decomp_partitionWITHOUT layer_norm decomp vs eagervar_meanstats vs C++ kernelCUDA has the same root cause
The CUDA
native_layer_normkernel uses a custom Welford implementation withvectorized loads and warp shuffle reduction, while
torch.var_meanon CUDA usesa different Welford through the generic
gpu_reduce_kernel. Same algorithmmismatch, same fix applies.
What the Fix Does
Stop decomposing
native_layer_normin the inductor.Two changes:
torch/_inductor/decomposition.py—_native_layer_normnow always returnsNotImplementedinstead of delegating to the decomposition. Previously onlyMTIA was excluded.
torch/_inductor/lowering.py—make_fallback(aten.native_layer_norm)is now unconditional (previously gated behind
torch.mtia._is_compiled()).The C++ kernel is already well-optimized (vectorized Welford, SIMD on CPU, warp
shuffle reduction on CUDA) and numerically stable. Removing the decomposition
preserves exact eager-compile agreement for LayerNorm.
Regression test added
test_layer_norm_bf16_numericsintest/inductor/test_torchinductor.py— verifiesthat
LayerNorm + Linearunder bf16 autocast produces identical results betweeneager and compiled modes.
Test updated
The existing GPU test for buffer reuse with LayerNorm was updated. Since
native_layer_normis no longer decomposed, there is no fused Triton kernel tocheck
in_out_ptragainst. The test now verifies correctness only.Test Results
Before fix: test fails
After fix: test passes
In-tree test passes
Repro Script (minimal, CPU)
Before fix output:
After fix output: (no error)
Remaining Gaps
Inductor fusion precision (separate issue)
The original full repro (TriangleMultiplicationOutgoing from the Boltz model) still
fails with the inductor backend even after this fix. Investigation revealed an
additional, separate source of divergence: inductor's kernel fusion computes
bf16 pointwise operations (sigmoid, multiply) in fp32 without intermediate bf16
truncation.
For example,
p_in(x) * g_in(x).sigmoid()where both linears return bf16:This skips one bf16 truncation, causing ~26% of elements to differ by 1 ULP.
The
TORCHINDUCTOR_EMULATE_PRECISION_CASTS=1flag does not fix this case.After our fix:
aot_eager_decomp_partitionperfectly matches eager (0mismatched elements), confirming the decomposition-level fix is complete. The
remaining divergence is purely from inductor's code generation and is a separate,
pre-existing behavior.
Evidence from the full model:
aot_eageraot_eager_decomp_partitioninductorThis inductor fusion precision issue should be tracked separately.
Performance
Removing the decomposition means inductor cannot fuse LayerNorm into surrounding
kernels. For a dedicated lowering (using inductor's Welford reduction + pointwise
affine), a follow-up would be needed. The C++ fallback is already well-optimized
for the standalone operation, so the impact is likely small for most workloads.
Other norms
native_group_normuses the same_normalize→var_meandecomposition pathand may have the same divergence issue. This should be investigated separately.
Fixes #168126
Repro Script
Agent Worklog
Worklog: Issue #168126 - LayerNorm bf16 divergence
Run 1: Diagnosis and Plan
Build setup
Reproduction
Key diagnostic results
aot_eager_decomp_partition without layer_norm decomp: PASS — removing the
decomposition of native_layer_norm fixes the issue entirely
Decomp formula with EAGER stats: diff = 0.0 — the decomposition's
(x - mean) * rstd * weight + biasformula is BIT-EXACT with the C++ kernelwhen given the same mean/rstd
Decomp formula with DECOMP stats: diff = 7.15e-7 — the ENTIRE error comes
from the mean/variance computation being different
torch.var_meanvsRowwiseMoments: mean differs by up to 4.66e-08,variance by up to 5.5e-7 across 76% of rows. The C++ native_layer_norm kernel
uses Welford's algorithm with cascade summation (RowwiseMoments), while the
decomposition uses torch.var_mean which uses a different reduction algorithm.
rsqrt vs 1/sqrt: ZERO difference (not a factor)
bf16 rounding amplification: 7 mismatched bf16 elements after casting norm
output → after matmul amplification: max diff 0.125, 23 mismatched elements
Root cause
The decomposition of
aten.native_layer_normintorch/_refs/__init__.py::_normalizecomputes mean/variance using
torch.var_mean, while the C++ kernel usesRowwiseMoments(Welford + cascade sum). These give different fp32 results(≤7e-7), which cause bf16 rounding divergences at ~0.001% of elements, then
amplified by subsequent ops (linears, matmuls).
Run 2
CUDA kernel analysis
Examined
aten/src/ATen/native/cuda/layer_norm_kernel.cu. The CUDAnative_layer_normkernel has two paths:
Vectorized path (line 1114, for float/half/bf16 with aligned data, N < 2^24,
N % 4 == 0): Uses
compute_stats()→cuWelfordOnlineSum()+cuWelfordCombine()— a custom Welford implementation with vectorized loads, warp shuffle reduction,
and shared-memory inter-warp reduction.
Non-vectorized path (line 1117): Uses
RowwiseMomentsCUDAKernel→WelfordOpswithcuda_utils::BlockReduce— Welford's algorithm through thestandard CUDA block reduction utility.
Both paths use Welford's online algorithm, but with CUDA-specific parallelism
(warp shuffles, shared memory).
CUDA
torch.var_meananalysisExamined
aten/src/ATen/native/cuda/ReduceMomentKernel.cu. The CUDAvar_meankernel (
std_var_kernel_impl) also usesWelfordOps, but goes throughgpu_reduce_kernel<scalar_t, out_t, 2>— the generic GPU reduction frameworkwith unrolling factor 2.
Same problem, different Welford implementations
Both the
native_layer_normCUDA kernel andtorch.var_meanon CUDA useWelford's algorithm, but through different code paths:
native_layer_norm: Per-row Welford with vectorized loads (4 elements at a time),intra-warp shuffle reduction, inter-warp shared-memory tree reduction. Thread
layout:
blockDim.x * blockDim.ythreads per row, one row per block.torch.var_mean(viagpu_reduce_kernel): Generic multi-dimensional reductionframework with unrolling factor 2, different thread/block decomposition and
accumulation order.
The different accumulation orders produce different fp32 rounding, just like the
CPU case. The errors may be slightly different in magnitude due to different
parallelism granularity, but the fundamental issue is identical.
Decomposition guard analysis
In
torch/_inductor/decomposition.py(line 206-217):Only MTIA is excluded from decomposition. CUDA tensors ARE decomposed, so CUDA
hits the same
_normalize()→torch.var_mean()path that causes the divergence.Conclusion
Yes, the CUDA path has the same problem and the same solution would fix it.
The fix (making
_native_layer_normreturnNotImplementedunconditionally, orat least for CUDA too) prevents the decomposition and keeps the native CUDA kernel,
which eliminates the statistics mismatch. This is device-agnostic — the root cause
is the decomposition replacing a specialized per-row Welford with a generic
reduction Welford, regardless of whether the device is CPU or CUDA.
Run 3
Your job:
Rules:
Output:
Run 4
Run 5
Build
Verification: test fails before fix, passes after
max_abs_diff=0.000305, max_rel_diff=0.033)
In-tree regression test
test_layer_norm_bf16_numerics_cpupasses (3.9s)Full model repro: discovered additional inductor fusion issue
aot_eager_decomp_partitionnow matches eager perfectly (0/262144)sigmoid + multiplyinto one kernel, computingintermediates in fp32 without bf16 truncation between ops
p_in(x) * g_in(x).sigmoid()with bf16: 138080/524288 (26%) mismatchedTORCHINDUCTOR_EMULATE_PRECISION_CASTS=1does not fix itLinting
Artifacts generated
fix.diff— 3 files changed: decomposition.py, lowering.py, test_torchinductor.pyreport.md— updated with full analysis including the inductor fusion gapThis PR was generated by ptq with human review.
cc @voznesenskym @penguinwu @EikanWang @jgong5 @Guobing-Chen @XiaobingSuper @zhuhaozhe @blzheng @wenzhe-nrv @jiayisunx @ipiszy @kadeng @muchulee8 @amjames @chauhang @aakhundov @coconutruben @jataylo