Skip to content

[inductor] Fix mix_order_reduction over-fusion via load count check#179494

Closed
abaybektursun wants to merge 1 commit intopytorch:mainfrom
abaybektursun:fix/disable-mix-order-reduction-default
Closed

[inductor] Fix mix_order_reduction over-fusion via load count check#179494
abaybektursun wants to merge 1 commit intopytorch:mainfrom
abaybektursun:fix/disable-mix-order-reduction-default

Conversation

@abaybektursun
Copy link
Copy Markdown
Contributor

@abaybektursun abaybektursun commented Apr 6, 2026

[inductor] Fix mix_order_reduction over-fusion via load count check

Fixes #179423

Problem

FusedMixOrderReductions.sub_node_can_fuse() absorbs additional nodes into mixed-order reduction kernels without checking the resulting load count. This creates Triton kernels with excessive tl.load() calls in the RSPLIT loop, causing register spills and a +6.3ms/step regression on H100.

Model

The regression was found training a small transformer for the Parameter Golf competition. The exact model:

class RMSNorm(nn.Module):
    def forward(self, x):
        return F.rms_norm(x, (x.size(-1),))

class MLP(nn.Module):
    def __init__(self, dim, mult=4):
        super().__init__()
        hidden = int(dim * mult)
        self.fc = nn.Linear(dim, hidden, bias=False)
        self.proj = nn.Linear(hidden, dim, bias=False)
    def forward(self, x):
        return self.proj(F.leaky_relu(self.fc(x), negative_slope=0.5).square())

class Attention(nn.Module):
    def __init__(self, dim, num_heads=8, num_kv_heads=4):
        super().__init__()
        self.num_heads = num_heads
        self.num_kv_heads = num_kv_heads
        self.head_dim = dim // num_heads
        self.c_q = nn.Linear(dim, dim, bias=False)
        self.c_k = nn.Linear(dim, num_kv_heads * self.head_dim, bias=False)
        self.c_v = nn.Linear(dim, num_kv_heads * self.head_dim, bias=False)
        self.proj = nn.Linear(dim, dim, bias=False)
    def forward(self, x):
        B, T, D = x.shape
        q = self.c_q(x).reshape(B, T, self.num_heads, self.head_dim)
        k = self.c_k(x).reshape(B, T, self.num_kv_heads, self.head_dim)
        v = self.c_v(x).reshape(B, T, self.num_kv_heads, self.head_dim)
        q = F.rms_norm(q, (q.size(-1),))
        k = F.rms_norm(k, (k.size(-1),))
        q = q.transpose(1, 2)
        k = k.transpose(1, 2).repeat_interleave(self.num_heads // self.num_kv_heads, dim=1)
        v = v.transpose(1, 2).repeat_interleave(self.num_heads // self.num_kv_heads, dim=1)
        y = F.scaled_dot_product_attention(q, k, v, is_causal=True)
        return self.proj(y.transpose(1, 2).reshape(B, T, D))

class Block(nn.Module):
    def __init__(self, dim):
        super().__init__()
        self.attn_norm = RMSNorm()
        self.mlp_norm = RMSNorm()
        self.attn = Attention(dim)
        self.mlp = MLP(dim)
    def forward(self, x):
        x = x + self.attn(self.attn_norm(x))
        x = x + self.mlp(self.mlp_norm(x))
        return x

class Model(nn.Module):
    def __init__(self, vocab_size=4096, dim=512, num_layers=11):
        super().__init__()
        self.tok_emb = nn.Embedding(vocab_size, dim)
        self.blocks = nn.ModuleList([Block(dim) for _ in range(num_layers)])
        self.norm = RMSNorm()
        self.head = nn.Linear(dim, vocab_size, bias=False)
    def forward(self, x, y):
        h = self.tok_emb(x)
        h = F.rms_norm(h, (h.size(-1),))
        for block in self.blocks:
            h = block(h)
        h = self.norm(h)
        logits = self.head(h)
        return F.cross_entropy(logits.view(-1, logits.size(-1)), y.view(-1))

model = Model().cuda().bfloat16()
compiled = torch.compile(model, dynamic=False, fullgraph=True)
x = torch.randint(0, 4096, (32, 2048), device='cuda')
y = torch.randint(0, 4096, (32, 2048), device='cuda')
with torch.autocast(device_type='cuda', dtype=torch.bfloat16):
    compiled(x, y).backward()

Key properties: dim=512, 11 transformer blocks, GQA attention with QK-norm, squared leaky-relu MLP, bf16 autocast, fullgraph=True.

Root Cause

During the backward pass, each block produces:

  • rms_norm backward: inner reduction over ncol=512, xnumel=98304 (batch×seq = 48×2048)
  • weight gradient sums: outer reduction over xnumel=98304, keeping ncol=512

mix_order_reduction fuses these two reductions (different iteration orders, same data) into a single kernel. Then sub_node_can_fuse absorbs surrounding pointwise ops (residual connections, dtype casts, scaling) without any check on the resulting read count.

The fused kernel uses persistent reduction with R0_BLOCK = ncol = 512 threads per block and an RSPLIT loop that iterates over chunks of the x-dimension. On H100 (65536 regs/SM), 512 threads/block gives 128 regs/thread. That is the register budget.

Each external read buffer becomes a tl.load() inside the RSPLIT loop. Every additional load adds register pressure. The unfused kernel (7 reads) barely fits in 128 regs. The over-fused kernel (11+ reads, plus persistent accumulator arrays) overflows and spills to local memory.

The spill penalty (~100 cycles per access vs 0 for register) is paid every RSPLIT loop iteration (64 iterations per block, 1536 blocks total), producing the 6.3ms regression.

Kernel comparison

2.9.1 — unfused rms_norm backward (kernel_8, Grid1D, 7 loads, 1 reduction):

# No loop, no accumulators, no workspace — one thread block per row
def triton_per_fused__fused_rms_norm_backward_8(
    in_out_ptr0, in_ptr0, in_ptr1, in_ptr2, in_ptr3, in_ptr4, in_ptr5,
    out_ptr1, out_ptr2, xnumel, r0_numel, XBLOCK):
    # xnumel=98304, r0_numel=512, persistent R0_BLOCK=512
    # 7 loads → ~120 regs/thread, fits in 128 budget
    tmp0 = tl.load(in_ptr0 + ...)   # [98304, 512] rsqrt * Hessian
    tmp1 = tl.load(in_out_ptr0 + ...)  # [98304, 512] upstream grad
    tmp9 = tl.load(in_ptr1 + ...)   # [98304, 512] residual
    tmp10 = tl.load(in_ptr2 + ...)  # scalar: mix coefficient
    tmp20 = tl.load(in_ptr3 + ...)  # [98304, 1] rsqrt
    tmp25 = tl.load(in_ptr4 + ...)  # [512] norm weight 1
    tmp28 = tl.load(in_ptr5 + ...)  # [512] norm weight 2
    # ... 1 inner reduction (sum over 512), 3 stores

2.11 — over-fused rms_norm backward + weight grad sums (kernel_3, MixOrderReductionGrid, 11 loads, 3 reductions):

# RSPLIT loop with persistent accumulators and workspace memory
def triton_per_fused__fused_rms_norm_backward__to_copy_..._3(
    in_out_ptr0, in_ptr0, in_ptr1, in_ptr2, in_ptr3, in_ptr4,
    in_ptr5, in_ptr6, in_ptr7, in_ptr8, in_ptr9,
    out_ptr0, out_ptr1, out_ptr3, out_ptr4, ws_ptr,
    xnumel, r0_numel, XBLOCK, RSPLIT_SIZE, NUM_STAGES):
    # 11 loads + 2 accumulators → ~200 regs/thread, EXCEEDS 128 budget
    accum0 = tl.full([R0_BLOCK], 0, tl.float32)  # [512] persists across loop
    accum1 = tl.full([R0_BLOCK], 0, tl.float32)  # [512] persists across loop
    for _ in tl.range(0, split_size, XBLOCK):
        tmp0 = tl.load(in_ptr0 + ...)     # [98304, 512]
        tmp1 = tl.load(in_ptr1 + ...)     # [98304, 512]
        tmp7 = tl.load(in_ptr2 + ...)     # [98304, 512]
        tmp13 = tl.load(in_ptr3 + ...)    # [98304, 512]
        tmp14 = tl.load(in_out_ptr0 + ...)# [98304, 512]
        tmp22 = tl.load(in_ptr4 + ...)    # scalar
        tmp32 = tl.load(in_ptr5 + ...)    # [98304, 1]
        tmp37 = tl.load(in_ptr6 + ...)    # [512]
        tmp40 = tl.load(in_ptr7 + ...)    # [512]
        tmp43 = tl.load(in_ptr8 + ...)    # [98304, 512]
        tmp46 = tl.load(in_ptr9 + ...)    # [98304, 512]
        # 3 inner reductions + 5 stores + 2 accumulator updates per iter
        # Spilled regs hit local memory EVERY iteration
    tl.store(ws_ptr + ..., accum0, ...)  # workspace for inter-block reduction
    tl.store(ws_ptr + ..., accum1, ...)

This same over-fusion pattern repeats across the backward pass, producing 9 MixOrderReductionGrid kernels with 6-19 loads each. The worst (kernel_34) has 19 loads.

Profiler data

H100 80GB SXM, torch.profiler:

Config Triton kernels Self CUDA Time Delta
2.11, mix_order=1 (default) 65 105.764ms +6.3ms
2.11, mix_order=0 71 99.471ms baseline
2.9.1 (no mix_order) 71 99.5ms baseline

Fix

Count unique read buffers across all subnodes in FusedMixOrderReductions.can_fuse_with. If the count exceeds mix_order_reduction_max_reads (default 10), reject the fusion:

all_reads = {dep.name for all subnodes' reads if MemoryDep}
if len(all_reads) > max_reads:
    return False

Uses all_reads rather than all_reads - all_writes because mutated buffers (in_out_ptr) are both read and written — they are still tl.load() calls. Each unique read maps 1:1 to a tl.load() in the generated RSPLIT loop. The check runs at scheduling time with zero compilation cost — it just counts buffer names from the existing read_writes dependency data.

Test Plan

  • OverFusionTest.test_max_reads_limits_fusion — 3-block transformer backward, verifies correctness and that mix_order_reduction still fires (not fully disabled)
  • Existing MixOrderReductionTest and NoMixOrderReductionTest suites unaffected
  • Verified on H100: step time with this fix matches 2.9.1 / mix_order=0 baseline

cc @voznesenskym @penguinwu @EikanWang @jgong5 @Guobing-Chen @XiaobingSuper @zhuhaozhe @blzheng @wenzhe-nrv @jiayisunx @ipiszy @kadeng @muchulee8 @amjames @chauhang @aakhundov @coconutruben @jataylo

@pytorch-bot
Copy link
Copy Markdown

pytorch-bot Bot commented Apr 6, 2026

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/179494

Note: Links to docs will display an error until the docs builds have been completed.

✅ No Failures

As of commit 50a6517 with merge base a11cc39 (image):
💚 Looks good so far! There are no failures yet. 💚

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@pytorch-bot
Copy link
Copy Markdown

pytorch-bot Bot commented Apr 6, 2026

This PR needs a release notes: label

If your changes are user facing and intended to be a part of release notes, please use a label starting with release notes:.

If not, please add the topic: not user facing label.

To add a label, you can comment to pytorchbot, for example
@pytorchbot label "topic: not user facing"

For more information, see
https://github.com/pytorch/pytorch/wiki/PyTorch-AutoLabel-Bot#why-categorize-for-release-notes-and-how-does-it-work.

@abaybektursun
Copy link
Copy Markdown
Contributor Author

@pytorchbot label "release notes: composability"

@shunting314
Copy link
Copy Markdown
Contributor

We should not disable it. The feature is helpful in general. We should debug the outlier and fix it.

@abaybektursun abaybektursun force-pushed the fix/disable-mix-order-reduction-default branch from 2c7fa3e to bde32d4 Compare April 6, 2026 19:39
@abaybektursun abaybektursun changed the title [inductor] Disable mix_order_reduction by default (9% backward regression) [inductor] Limit mix_order_reduction fusion size to prevent over-fusion regression Apr 6, 2026
@abaybektursun
Copy link
Copy Markdown
Contributor Author

@shunting314 Absolutely. On it, I will provide the model details isolate the kernels

@abaybektursun abaybektursun force-pushed the fix/disable-mix-order-reduction-default branch from bde32d4 to e88cfab Compare April 6, 2026 20:30
@abaybektursun abaybektursun changed the title [inductor] Limit mix_order_reduction fusion size to prevent over-fusion regression [inductor] Fix mix_order_reduction over-fusion via external read count limit Apr 6, 2026
@linux-foundation-easycla
Copy link
Copy Markdown

linux-foundation-easycla Bot commented Apr 6, 2026

CLA Signed

The committers listed above are authorized under a signed CLA.

  • ✅ login: abaybektursun / name: Abay Bektursun (50a6517)

@abaybektursun abaybektursun force-pushed the fix/disable-mix-order-reduction-default branch 5 times, most recently from 8638e8c to ba35c56 Compare April 6, 2026 21:05
@abaybektursun abaybektursun changed the title [inductor] Fix mix_order_reduction over-fusion via external read count limit [inductor] Fix mix_order_reduction over-fusion via post-compilation spill check Apr 6, 2026
@abaybektursun abaybektursun force-pushed the fix/disable-mix-order-reduction-default branch 2 times, most recently from 6d072c3 to fb79e34 Compare April 6, 2026 22:12
@shunting314
Copy link
Copy Markdown
Contributor

A few comments

  1. how much compilation time increase does this PR introduce? Now every time we codegen a mix order reduction, we would need to compile the triton kernel which can be slow
  2. rather than cancel mix-order reduction after we made the fusion decision, can we don't decide to fuse in the first place
  3. does it work by checking the number of loads in the kernel. If it exceeds a threshold, we don't fuse?

Comment on lines +1137 to +1142
# mix_order_reduction should still be used (not fully disabled)
self.assertGreater(
metrics.codegen_mix_order_reduction,
0,
"Mix order reduction should still be triggered with spill check",
)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

comment/message say opposite things as the code

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

simplified to self.assertGreater(metrics.codegen_mix_order_reduction, 0) with a comment that says what the code checks: max_reads should limit over-fusion, not disable mix_order entirely.

@abaybektursun abaybektursun force-pushed the fix/disable-mix-order-reduction-default branch from fb79e34 to efcd9ad Compare April 6, 2026 23:23
@abaybektursun abaybektursun changed the title [inductor] Fix mix_order_reduction over-fusion via post-compilation spill check [inductor] Fix mix_order_reduction over-fusion via load count check Apr 6, 2026
@abaybektursun
Copy link
Copy Markdown
Contributor Author

@shunting314
▎ 1. how much compilation time increase does this PR introduce?
The post-compilation approach added ~200-600ms per mix_order instance (Triton JIT compilation to check n_spills), totaling ~2-5s extra per torch.compile for our 11-layer model. That's too expensive. I've reverted to a scheduler-level check that has zero compilation cost.

▎ 2. rather than cancel mix-order reduction after we made the fusion decision, can we don't decide to fuse in the first place
The updated PR moves the check to sub_node_can_fuse it prevents the over-fusion before any codegen happens. Is this alight?

▎ 3. does it work by checking the number of loads in the kernel. If it exceeds a threshold, we don't fuse?
Yes. Each external read buffer counted at scheduler time maps 1:1 to a tl.load() in the RSPLIT loop. Profiled on our model (dim=512, bf16, H100): 7 external reads = clean, 11 = +6.3ms/step regression from register spills. Threshold of 10 catches the over-fused kernels (11-19 reads) while preserving the moderate fusions (6-10 reads)

@liangel-02 liangel-02 added the triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module label Apr 7, 2026
@pytorch-bot
Copy link
Copy Markdown

pytorch-bot Bot commented Apr 14, 2026

The following ciflow label(s) have been added but CI has not been triggered yet because the workflows are awaiting approval:

  • ciflow/inductor
  • ciflow/torchtitan

Once a maintainer approves the workflows (scroll to the bottom of the PR page), the corresponding CI jobs will be triggered automatically. Please ping one of the reviewers if you do not have access to approve and run workflows.

@abaybektursun abaybektursun force-pushed the fix/disable-mix-order-reduction-default branch from 0c7e530 to da0dcb9 Compare April 17, 2026 01:17
@jansel
Copy link
Copy Markdown
Contributor

jansel commented Apr 17, 2026

@pytorchbot merge

@pytorch-bot pytorch-bot Bot added the ciflow/trunk Trigger trunk jobs on your pull request label Apr 17, 2026
@pytorchmergebot
Copy link
Copy Markdown
Collaborator

Merge started

Your change will be merged once all checks pass (ETA 0-4 Hours).

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

@pytorchmergebot
Copy link
Copy Markdown
Collaborator

Merge failed

Reason: 1 mandatory check(s) failed. The first few are:

Dig deeper by viewing the failures on hud

Details for Dev Infra team Raised by workflow job

Failing merge rule: Core Maintainers

@jansel
Copy link
Copy Markdown
Contributor

jansel commented Apr 20, 2026

@pytorchbot merge

@pytorchmergebot
Copy link
Copy Markdown
Collaborator

Merge started

Your change will be merged once all checks pass (ETA 0-4 Hours).

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

@shunting314
Copy link
Copy Markdown
Contributor

@abaybektursun
Copy link
Copy Markdown
Contributor Author

abaybektursun commented Apr 20, 2026

@shunting314

  1. ROCm tolerance: Our test OverFusionTest.test_max_reads_limits_fusion fails on ROCm MI355 because tol=1e-2 is too tight for bf16 gradients through a 3-block transformer backward on AMD GPUs. Raised to tol=5e-2. Pushed.

  2. inductor_huggingface: DebertaV2ForMaskedLM fails with ModuleNotFoundError: Could not import module 'DebertaV2ForMaskedLM' the eager model itself can't load. Not related to this PR.

@abaybektursun abaybektursun force-pushed the fix/disable-mix-order-reduction-default branch from a0fa9f7 to 50a6517 Compare April 21, 2026 13:49
@jansel
Copy link
Copy Markdown
Contributor

jansel commented Apr 21, 2026

@pytorchbot merge

@pytorchmergebot
Copy link
Copy Markdown
Collaborator

Merge started

Your change will be merged once all checks pass (ETA 0-4 Hours).

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

Stonepia added a commit to chuanqi129/pytorch that referenced this pull request Apr 30, 2026
Drop the inaccurate 'mirroring the ROCm tolerance bump (1e-2 -> 5e-2)'
justification for the XPU tolerance: no such ROCm-specific bump exists
in this test (the 5e-2 baseline has been the only value since the test
was introduced in pytorch#179494). Replace it with an accurate explanation of
why the XPU eager-vs-compiled bf16 backward drift exceeds the CUDA
baseline.

Comment-only change; no behavioral difference.

intel/torch-xpu-ops#3509

Co-authored-by: Claude <noreply@anthropic.com>
Stonepia added a commit to chuanqi129/pytorch that referenced this pull request Apr 30, 2026
Address @Stonepia's review of #5:

- Review 2 (ROCm tolerance reference): the previous comment claimed the
  XPU bump was 'mirroring the ROCm tolerance bump (1e-2 -> 5e-2) applied
  for the same reason'. There is no such ROCm-specific bump in this test
  -- the 5e-2 baseline has been the only value since pytorch#179494 introduced
  the test. The misleading reference is dropped.

- Review 1 (root cause is unverified): the reviewer's empirical run on
  Intel Data Center GPU Max 1550 (PVC) shows the test passes at the
  5e-2 baseline (rejected_mix_order_reduction_fusion = 15, far above 0),
  contradicting the original PR's claim that the rejection counter
  assertion is the failing one. The XPU CI disable issue (pytorch#3509) lacks
  a traceback, so the actual failing assertion remains unknown.

  The hard rules forbid adding @skipIfXpu, so the next-most-defensible
  change is kept: the XPU-only tolerance bump on the same(grad_ref,
  grad_act) check, which targets the most likely remaining culprit
  (different XPU SKU on linux.idc.xpu vs PVC producing larger bf16
  drift) without weakening regression coverage:

  * CUDA and ROCm tolerances are unchanged (no behavioral change off
    XPU).
  * Both metric assertions (codegen_mix_order_reduction > 0 and
    rejected_mix_order_reduction_fusion > 0) remain unchanged on every
    backend, so the pytorch#179423 over-fusion regression is still gated.
  * The synthetic >10-reads helper added in the original PR is already
    gone (removed in iteration 1) -- the transformer pattern alone
    drives the rejection counter, exactly as the reviewer noted.

  The comment is rewritten to honestly reflect what is and is not
  known: it documents that the failing assertion was never identified,
  records the PVC empirical result, and states why the bump is scoped
  to XPU only.

Comment-only behavioral change relative to iteration 2; no logic
change.

intel/torch-xpu-ops#3509

Co-authored-by: Claude <noreply@anthropic.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ciflow/inductor ciflow/torchtitan Run TorchTitan integration tests ciflow/trunk Trigger trunk jobs on your pull request Merged module: inductor open source release notes: composability release notes category triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module

Projects

None yet

Development

Successfully merging this pull request may close these issues.

[inductor] Backward pass 9% slower in 2.11 vs 2.9.1 due to over-fusion of rms_norm_backward

6 participants