Skip to content

[inductor] Backward pass 9% slower in 2.11 vs 2.9.1 due to over-fusion of rms_norm_backward #179423

@abaybektursun

Description

@abaybektursun

Summary

The same transformer model compiled with torch.compile(model, dynamic=False, fullgraph=True) has a backward pass that is ~9% slower on PyTorch 2.11 compared to 2.9.1, despite identical forward pass performance.

Profiler Data

Component PyTorch 2.9.1 PyTorch 2.11.0 Delta
Backward compiled graph 67.28ms 73.21ms +5.93ms (+8.8%)
Forward compiled graph 34.40ms 34.47ms +0.07ms
aten::mm 33.17ms 33.13ms identical
FA3 backward 20.13ms 20.11ms identical

Root Cause: Inductor Over-Fusion

Inductor generates fewer but larger fused Triton kernels in 2.11:

PyTorch 2.9.1 PyTorch 2.11.0
Triton kernel functions 71 65
Largest backward kernel 11,292 lines 11,855 lines

Key difference: 2.11 fuses _fused_rms_norm_backward into adjacent kernels. 2.9.1 keeps them separate. The larger fused kernels run slower.

Isolation

  • Not Triton: Swapping Triton 3.5.1 into PyTorch 2.11 has no effect
  • Not autocast: Gap persists without autocast
  • Not cuDNN/cuBLAS: Forcing backends has no effect
  • Forward is identical: Only the backward compiled graph is slower

Environment

  • GPU: NVIDIA H100 80GB HBM3 SXM, Driver 570.148.08, CUDA 12.8
  • Model: 34.4M param transformer, 11 layers, d=512, RMSNorm, depth recurrence, parallel residuals
  • max_fusion_size=64 and aggressive_fusion=False (defaults) - over-fusion happens within these limits

Impact

For time-budgeted training (600s), costs ~57 training steps (~1% of total).

cc @ezyang @gchanan @kadeng @msaroufim @jerryzh168 @chauhang @penguinwu @voznesenskym @EikanWang @jgong5 @Guobing-Chen @XiaobingSuper @zhuhaozhe @blzheng @wenzhe-nrv @jiayisunx @ipiszy @muchulee8 @amjames @aakhundov @coconutruben @jataylo

Metadata

Metadata

Assignees

Labels

bot-triagedThis is a label only to be used by the auto triage bothigh prioritymodule: inductormodule: performanceIssues related to performance, either of kernel code or framework gluemodule: regressionIt used to work, and now it doesn'tneeds reproductionEnsure you have actionable steps to reproduce the issue. Someone else needs to confirm the repro.oncall: pt2triagedThis issue has been looked at a team member, and triaged and prioritized into an appropriate module

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions