-
Notifications
You must be signed in to change notification settings - Fork 27.7k
[inductor] Backward pass 9% slower in 2.11 vs 2.9.1 due to over-fusion of rms_norm_backward #179423
Copy link
Copy link
Closed
Labels
bot-triagedThis is a label only to be used by the auto triage botThis is a label only to be used by the auto triage bothigh prioritymodule: inductormodule: performanceIssues related to performance, either of kernel code or framework glueIssues related to performance, either of kernel code or framework gluemodule: regressionIt used to work, and now it doesn'tIt used to work, and now it doesn'tneeds reproductionEnsure you have actionable steps to reproduce the issue. Someone else needs to confirm the repro.Ensure you have actionable steps to reproduce the issue. Someone else needs to confirm the repro.oncall: pt2triagedThis issue has been looked at a team member, and triaged and prioritized into an appropriate moduleThis issue has been looked at a team member, and triaged and prioritized into an appropriate module
Metadata
Metadata
Assignees
Labels
bot-triagedThis is a label only to be used by the auto triage botThis is a label only to be used by the auto triage bothigh prioritymodule: inductormodule: performanceIssues related to performance, either of kernel code or framework glueIssues related to performance, either of kernel code or framework gluemodule: regressionIt used to work, and now it doesn'tIt used to work, and now it doesn'tneeds reproductionEnsure you have actionable steps to reproduce the issue. Someone else needs to confirm the repro.Ensure you have actionable steps to reproduce the issue. Someone else needs to confirm the repro.oncall: pt2triagedThis issue has been looked at a team member, and triaged and prioritized into an appropriate moduleThis issue has been looked at a team member, and triaged and prioritized into an appropriate module
Summary
The same transformer model compiled with
torch.compile(model, dynamic=False, fullgraph=True)has a backward pass that is ~9% slower on PyTorch 2.11 compared to 2.9.1, despite identical forward pass performance.Profiler Data
Root Cause: Inductor Over-Fusion
Inductor generates fewer but larger fused Triton kernels in 2.11:
Key difference: 2.11 fuses
_fused_rms_norm_backwardinto adjacent kernels. 2.9.1 keeps them separate. The larger fused kernels run slower.Isolation
Environment
max_fusion_size=64andaggressive_fusion=False(defaults) - over-fusion happens within these limitsImpact
For time-budgeted training (600s), costs ~57 training steps (~1% of total).
cc @ezyang @gchanan @kadeng @msaroufim @jerryzh168 @chauhang @penguinwu @voznesenskym @EikanWang @jgong5 @Guobing-Chen @XiaobingSuper @zhuhaozhe @blzheng @wenzhe-nrv @jiayisunx @ipiszy @muchulee8 @amjames @aakhundov @coconutruben @jataylo