Skip to content

[inductor] split reduction even if all reads are broadcasted#167894

Open
shunting314 wants to merge 4 commits intogh/shunting314/263/basefrom
gh/shunting314/263/head
Open

[inductor] split reduction even if all reads are broadcasted#167894
shunting314 wants to merge 4 commits intogh/shunting314/263/basefrom
gh/shunting314/263/head

Conversation

@shunting314
Copy link
Contributor

@shunting314 shunting314 commented Nov 15, 2025

Stack from ghstack (oldest at bottom):

With split reduction we can speedup the following (extreme) kernel by 48x

# 56ms -> 1.163ms

import torch
from triton.testing import do_bench

def f(x):
    return x.sum(dim=(0, 1))

x = torch.randn(100000000, 1, 2, device="cuda").expand(-1, 2, -1)
opt_f = torch.compile(f)
ref = f(x)
act = opt_f(x)

torch.testing.assert_close(ref, act, atol=1e-3, rtol=1e-3)
ms = do_bench(lambda: opt_f(x))
print(f"ms={ms:.3f}")

Not confident if this change will break things. Let's wait for CI

cc @voznesenskym @penguinwu @EikanWang @jgong5 @Guobing-Chen @XiaobingSuper @zhuhaozhe @blzheng @wenzhe-nrv @jiayisunx @ipiszy @kadeng @muchulee8 @amjames @chauhang @aakhundov @coconutruben @jataylo @chenyang78

@pytorch-bot
Copy link

pytorch-bot bot commented Nov 15, 2025

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/167894

Note: Links to docs will display an error until the docs builds have been completed.

❌ 13 New Failures, 4 Cancelled Jobs, 4 Unrelated Failures

As of commit 0b28245 with merge base 4b9418a (image):

NEW FAILURES - The following jobs have failed:

CANCELLED JOBS - The following jobs were cancelled. Please retry:

FLAKY - The following jobs failed but were likely due to flakiness present on trunk:

This comment was automatically generated by Dr. CI and updates every 15 minutes.

shunting314 added a commit that referenced this pull request Nov 15, 2025
With split reduction we can speedup the following (extreme) kernel by 48x

```
# 56ms -> 1.163ms

import torch
from triton.testing import do_bench

def f(x):
    return x.sum(dim=(0, 1))

x = torch.randn(100000000, 1, 2, device="cuda").expand(-1, 2, -1)
opt_f = torch.compile(f)
ref = f(x)
act = opt_f(x)

torch.testing.assert_close(ref, act, atol=1e-3, rtol=1e-3)
ms = do_bench(lambda: opt_f(x))
print(f"ms={ms:.3f}")
```

Not confident if this change will break things. Let's wait for CI



cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx ipiszy chenyang78 kadeng muchulee8 amjames chauhang aakhundov coconutruben

[ghstack-poisoned]
shunting314 added a commit that referenced this pull request Nov 15, 2025
@shunting314 shunting314 added the topic: not user facing topic category label Nov 18, 2025
@shunting314
Copy link
Contributor Author

@pytorchbot merge -i

@pytorch-bot pytorch-bot bot added the ciflow/trunk Trigger trunk jobs on your pull request label Nov 18, 2025
@pytorchmergebot
Copy link
Collaborator

@pytorchmergebot
Copy link
Collaborator

Merge failed

Reason: 2 jobs have failed, first few of them are: trunk / macos-py3-arm64 / test (mps, 1, 1, macos-m2-15), trunk / macos-py3-arm64 / test (mps, 1, 1, macos-m1-14)

Details for Dev Infra team Raised by workflow job

With split reduction we can speedup the following (extreme) kernel by 48x

```
# 56ms -> 1.163ms

import torch
from triton.testing import do_bench

def f(x):
    return x.sum(dim=(0, 1))

x = torch.randn(100000000, 1, 2, device="cuda").expand(-1, 2, -1)
opt_f = torch.compile(f)
ref = f(x)
act = opt_f(x)

torch.testing.assert_close(ref, act, atol=1e-3, rtol=1e-3)
ms = do_bench(lambda: opt_f(x))
print(f"ms={ms:.3f}")
```

Not confident if this change will break things. Let's wait for CI



cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx ipiszy chenyang78 kadeng muchulee8 amjames chauhang aakhundov coconutruben

[ghstack-poisoned]
shunting314 added a commit that referenced this pull request Dec 6, 2025
With split reduction we can speedup the following (extreme) kernel by 48x

```
# 56ms -> 1.163ms

import torch
from triton.testing import do_bench

def f(x):
    return x.sum(dim=(0, 1))

x = torch.randn(100000000, 1, 2, device="cuda").expand(-1, 2, -1)
opt_f = torch.compile(f)
ref = f(x)
act = opt_f(x)

torch.testing.assert_close(ref, act, atol=1e-3, rtol=1e-3)
ms = do_bench(lambda: opt_f(x))
print(f"ms={ms:.3f}")
```

Not confident if this change will break things. Let's wait for CI



cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx ipiszy chenyang78 kadeng muchulee8 amjames chauhang aakhundov coconutruben

[ghstack-poisoned]
shunting314 added a commit that referenced this pull request Dec 6, 2025
@github-actions
Copy link
Contributor

github-actions bot commented Feb 4, 2026

Looks like this PR hasn't been updated in a while so we're going to go ahead and mark this as Stale.
Feel free to remove the Stale label if you feel this was a mistake.
If you are unable to remove the Stale label please contact a maintainer in order to do so.
If you want the bot to never mark this PR stale again, add the no-stale label.
Stale pull requests will automatically be closed after 30 days of inactivity.

@github-actions github-actions bot added the Stale label Feb 4, 2026
@github-actions github-actions bot closed this Mar 6, 2026
@shunting314 shunting314 reopened this Mar 6, 2026
@shunting314 shunting314 removed the Stale label Mar 6, 2026
@shunting314
Copy link
Contributor Author

Didn't land because of accuracy failure for 'sam'. May find some time to further debug.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants