[inductor] split reduction even if all reads are broadcasted#167894
[inductor] split reduction even if all reads are broadcasted#167894shunting314 wants to merge 4 commits intogh/shunting314/263/basefrom
Conversation
[ghstack-poisoned]
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/167894
Note: Links to docs will display an error until the docs builds have been completed. ❌ 13 New Failures, 4 Cancelled Jobs, 4 Unrelated FailuresAs of commit 0b28245 with merge base 4b9418a ( NEW FAILURES - The following jobs have failed:
CANCELLED JOBS - The following jobs were cancelled. Please retry:
FLAKY - The following jobs failed but were likely due to flakiness present on trunk:
This comment was automatically generated by Dr. CI and updates every 15 minutes. |
With split reduction we can speedup the following (extreme) kernel by 48x
```
# 56ms -> 1.163ms
import torch
from triton.testing import do_bench
def f(x):
return x.sum(dim=(0, 1))
x = torch.randn(100000000, 1, 2, device="cuda").expand(-1, 2, -1)
opt_f = torch.compile(f)
ref = f(x)
act = opt_f(x)
torch.testing.assert_close(ref, act, atol=1e-3, rtol=1e-3)
ms = do_bench(lambda: opt_f(x))
print(f"ms={ms:.3f}")
```
Not confident if this change will break things. Let's wait for CI
cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx ipiszy chenyang78 kadeng muchulee8 amjames chauhang aakhundov coconutruben
[ghstack-poisoned]
|
@pytorchbot merge -i |
Merge startedYour change will be merged while ignoring the following 4 checks: inductor / inductor-cpu-test / test (cpu_inductor_torchbench, 2, 2, linux.2xlarge.amx), inductor / inductor-cpu-test / test (dynamic_cpu_inductor_torchbench, 2, 2, linux.2xlarge.amx), inductor / unit-test / inductor-pallas-test / test (inductor-pallas, 1, 1, linux.g5.12xlarge.nvidia.gpu), inductor / inductor-test / test (inductor_torchbench, 2, 2, linux.g5.4xlarge.nvidia.gpu) Learn more about merging in the wiki. Questions? Feedback? Please reach out to the PyTorch DevX Team |
Merge failedReason: 2 jobs have failed, first few of them are: trunk / macos-py3-arm64 / test (mps, 1, 1, macos-m2-15), trunk / macos-py3-arm64 / test (mps, 1, 1, macos-m1-14) Details for Dev Infra teamRaised by workflow job |
With split reduction we can speedup the following (extreme) kernel by 48x
```
# 56ms -> 1.163ms
import torch
from triton.testing import do_bench
def f(x):
return x.sum(dim=(0, 1))
x = torch.randn(100000000, 1, 2, device="cuda").expand(-1, 2, -1)
opt_f = torch.compile(f)
ref = f(x)
act = opt_f(x)
torch.testing.assert_close(ref, act, atol=1e-3, rtol=1e-3)
ms = do_bench(lambda: opt_f(x))
print(f"ms={ms:.3f}")
```
Not confident if this change will break things. Let's wait for CI
cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx ipiszy chenyang78 kadeng muchulee8 amjames chauhang aakhundov coconutruben
[ghstack-poisoned]
With split reduction we can speedup the following (extreme) kernel by 48x
```
# 56ms -> 1.163ms
import torch
from triton.testing import do_bench
def f(x):
return x.sum(dim=(0, 1))
x = torch.randn(100000000, 1, 2, device="cuda").expand(-1, 2, -1)
opt_f = torch.compile(f)
ref = f(x)
act = opt_f(x)
torch.testing.assert_close(ref, act, atol=1e-3, rtol=1e-3)
ms = do_bench(lambda: opt_f(x))
print(f"ms={ms:.3f}")
```
Not confident if this change will break things. Let's wait for CI
cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx ipiszy chenyang78 kadeng muchulee8 amjames chauhang aakhundov coconutruben
[ghstack-poisoned]
|
Looks like this PR hasn't been updated in a while so we're going to go ahead and mark this as |
|
Didn't land because of accuracy failure for 'sam'. May find some time to further debug. |
Stack from ghstack (oldest at bottom):
With split reduction we can speedup the following (extreme) kernel by 48x
Not confident if this change will break things. Let's wait for CI
cc @voznesenskym @penguinwu @EikanWang @jgong5 @Guobing-Chen @XiaobingSuper @zhuhaozhe @blzheng @wenzhe-nrv @jiayisunx @ipiszy @kadeng @muchulee8 @amjames @chauhang @aakhundov @coconutruben @jataylo @chenyang78