Skip to content

Commit 26807dc

Browse files
Revert "[PT2][fusion] ban fusions with large accumulated reads (#157563)"
This reverts commit c062550. Reverted #157563 on behalf of https://github.com/clee2000 due to broke test_linear_and_cel on main https://hud.pytorch.org/pytorch/pytorch/commit/c062550a3598d27c2d6572db7c0f4ff90a84cc84, caused OOM? Also broken on PR, Dr. CI classification is wrong (claims the test is disabled by an issue but the issue is for a different test). Also I'm pretty sure the expected results json is supposed to have a ton of empty lines, its to prevent merge conflicts, I will add it to the linter ([comment](#157563 (comment)))
1 parent 4f36743 commit 26807dc

9 files changed

Lines changed: 106 additions & 118 deletions

File tree

Lines changed: 76 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,23 +1,89 @@
1-
add_loop_eager,compile_time_instruction_count,2996000000,0.015
1+
add_loop_eager,compile_time_instruction_count,3017000000,0.015
2+
3+
4+
25
add_loop_eager_dynamic,compile_time_instruction_count,4352000000,0.025
3-
add_loop_inductor,compile_time_instruction_count,33090000000,0.015
4-
add_loop_inductor_dynamic_gpu,compile_time_instruction_count,42660000000,0.025
5-
add_loop_inductor_gpu,compile_time_instruction_count,29690000000,0.015
6+
7+
8+
9+
add_loop_inductor,compile_time_instruction_count,29490000000,0.015
10+
11+
12+
13+
add_loop_inductor_dynamic_gpu,compile_time_instruction_count,38760000000,0.025
14+
15+
16+
17+
add_loop_inductor_gpu,compile_time_instruction_count,26000000000,0.015
18+
19+
20+
621
basic_modules_ListOfLinears_eager,compile_time_instruction_count,947600000,0.015
7-
basic_modules_ListOfLinears_inductor,compile_time_instruction_count,18830000000,0.015
8-
basic_modules_ListOfLinears_inductor_gpu_force_shape_pad,compile_time_instruction_count,17460000000,0.015
9-
basic_modules_ListOfLinears_inductor_gpu,compile_time_instruction_count,11020000000,0.2
22+
23+
24+
25+
basic_modules_ListOfLinears_inductor,compile_time_instruction_count,18490000000,0.015
26+
27+
28+
29+
basic_modules_ListOfLinears_inductor_gpu_force_shape_pad,compile_time_instruction_count,17020000000,0.015
30+
31+
32+
33+
basic_modules_ListOfLinears_inductor_gpu,compile_time_instruction_count,10297683772,0.2
34+
35+
36+
1037
update_hint_regression,compile_time_instruction_count,1673000000,0.02
38+
39+
40+
1141
sum_floordiv_regression,compile_time_instruction_count,986800000,0.015
12-
symint_sum,compile_time_instruction_count,3184000000,0.015
42+
43+
44+
45+
symint_sum,compile_time_instruction_count,3166000000,0.015
46+
47+
48+
1349
symint_sum_loop,compile_time_instruction_count,4202000000,0.015
50+
51+
52+
1453
aotdispatcher_inference_nosubclass_cpu,compile_time_instruction_count,2103000000,0.015
54+
55+
56+
1557
aotdispatcher_inference_subclass_cpu,compile_time_instruction_count,6004000000,0.015
58+
59+
60+
1661
aotdispatcher_partitioner_cpu,compile_time_instruction_count,8783000000,0.015
62+
63+
64+
1765
aotdispatcher_partitioner_cpu2,compile_time_instruction_count,1940000000,0.015
66+
67+
68+
1869
aotdispatcher_training_nosubclass_cpu,compile_time_instruction_count,3885000000,0.015
70+
71+
72+
1973
aotdispatcher_training_subclass_cpu,compile_time_instruction_count,10470000000,0.015
20-
mm_loop_inductor_gpu,compile_time_instruction_count,4365000000,0.015
21-
mm_loop_inductor_dynamic_gpu,compile_time_instruction_count,8184000000,0.015
74+
75+
76+
77+
mm_loop_inductor_gpu,compile_time_instruction_count,4324000000,0.015
78+
79+
80+
81+
mm_loop_inductor_dynamic_gpu,compile_time_instruction_count,8116000000,0.015
82+
83+
84+
2285
basic_NestedModule_eager,compile_time_instruction_count,8152524390,0.015
86+
87+
88+
2389
basic_InlineMod_eager,compile_time_instruction_count,7255000000,0.015

test/inductor/test_memory.py

Lines changed: 0 additions & 51 deletions
Original file line numberDiff line numberDiff line change
@@ -306,57 +306,6 @@ def f(a, b, c):
306306
expected_bound = a.size(0) * c.size(1) * a.dtype.itemsize * 2
307307
self.assertLess(peak_mem, expected_bound)
308308

309-
def test_fusion_acc_large_reads(self):
310-
def f(x, y, z):
311-
res = torch.zeros_like(x[0])
312-
for i in range(4):
313-
temp = torch.matmul(x, y) + z
314-
res = res + temp
315-
return res
316-
317-
N = 128
318-
x = torch.rand(N, N, dtype=torch.float32, device=GPU_TYPE)
319-
y = torch.rand(N, N, dtype=torch.float32, device=GPU_TYPE)
320-
z = torch.rand(N, N, dtype=torch.float32, device=GPU_TYPE)
321-
322-
# CASE 1: no restriction on the amount of accumulation
323-
with config.patch({"realize_acc_reads_size_threshold": float("inf")}):
324-
f_compiled = torch.compile(f)
325-
code = run_and_get_triton_code(f_compiled, x, y, z)
326-
(
327-
FileCheck()
328-
.check("triton_poi_fused_add_0.run(buf4, arg2_1, buf1, buf2, buf3")
329-
.run(code)
330-
)
331-
332-
# CASE 2: for tensors with the same size as x (which is 4 * N**2 bytes)
333-
# at most 12 / 4 = 3 reads can be accumulated during fusion
334-
with config.patch({"realize_acc_reads_size_threshold": 12 * N**2}):
335-
f_compiled = torch.compile(f)
336-
code = run_and_get_triton_code(f_compiled, x, y, z)
337-
(
338-
FileCheck()
339-
.check("triton_poi_fused_add_0.run(buf3, arg2_1, buf1, buf2,")
340-
.check("triton_poi_fused_add_1.run(buf5, buf4, arg2_1,")
341-
.run(code)
342-
)
343-
344-
# CASE 3: no such fusion allowed
345-
with config.patch({"realize_acc_reads_size_threshold": N**2}):
346-
f_compiled = torch.compile(f)
347-
code = run_and_get_triton_code(f_compiled, x, y, z)
348-
(
349-
FileCheck()
350-
.check("triton_poi_fused_add_0.run(buf1, arg2_1,")
351-
.check("triton_poi_fused_add_0.run(buf3, arg2_1,")
352-
.check("triton_poi_fused_add_0.run(buf4, buf3,")
353-
.check("triton_poi_fused_add_0.run(buf6, arg2_1,")
354-
.check("triton_poi_fused_add_0.run(buf7, buf6,")
355-
.check("triton_poi_fused_add_0.run(buf9, arg2_1,")
356-
.check("triton_poi_fused_add_0.run(buf10, buf9,")
357-
.run(code)
358-
)
359-
360309

361310
if __name__ == "__main__":
362311
from torch._inductor.test_case import run_tests

test/inductor/test_online_softmax.py

Lines changed: 1 addition & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,6 @@
1313
instantiate_parametrized_tests,
1414
IS_LINUX,
1515
parametrize,
16-
serialTest,
1716
)
1817
from torch.testing._internal.inductor_utils import GPU_TYPE, HAS_CUDA
1918

@@ -78,17 +77,12 @@ def f(x):
7877
out, source_codes = run_and_get_code(f, x)
7978
return source_codes[0]
8079

81-
@serialTest()
8280
def test_codegen_3pass_softmax_due_to_disable(self):
83-
with inductor_config.patch(
84-
online_softmax=False,
85-
realize_acc_reads_size_threshold=float("inf"),
86-
):
81+
with inductor_config.patch(online_softmax=False):
8782
wrapper_code = self.get_softmax_wrapper()
8883

8984
self.assertEqual(wrapper_code.count("for r0_offset in"), 3)
9085

91-
@serialTest()
9286
@parametrize("V", [2048, 50304])
9387
@parametrize("use_log_softmax", [False, True])
9488
def test_codegen_online_softmax(self, use_log_softmax, V):

torch/_inductor/choices.py

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -365,10 +365,6 @@ def can_fuse(
365365
WhyNoFuse(node1, node2)("Fusion will increase peak memory")
366366
return False
367367

368-
if scheduler.fusion_accumulate_large_reads(node1, node2):
369-
WhyNoFuse(node1, node2)("Fusion accumulate large amount of reads")
370-
return False
371-
372368
return True
373369

374370
@staticmethod

torch/_inductor/config.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -574,7 +574,6 @@ def use_autoheuristic(name: str) -> bool:
574574

575575
# Threshold to prevent excessive accumulation of ops in one buffer during lowering
576576
realize_acc_reads_threshold = 8
577-
realize_acc_reads_size_threshold = 3 * (1024**3)
578577

579578
# fallback to eager for random/dropout, this is slow but useful for debugging
580579
fallback_random = False

torch/_inductor/graph.py

Lines changed: 0 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -123,7 +123,6 @@
123123
from torch.fx.graph import Graph
124124

125125
from .codegen.wrapper import PythonWrapperCodegen
126-
from .dependencies import Dep
127126
from .scheduler import BaseSchedulerNode
128127

129128
CompiledModule = Union[ModuleType, FileBackedGraphModule]
@@ -486,9 +485,6 @@ def __init__(
486485

487486
self.bw_donated_idxs = get_donated_idxs()
488487

489-
# Cache for dep size hints to avoid expensive recomputation
490-
self.dep_size_hint_cache: dict[Dep, int] = {}
491-
492488
def freeze_runtime_asserts(self) -> None:
493489
self._shape_env.freeze_runtime_asserts()
494490

@@ -574,23 +570,6 @@ def has_feature(
574570
assert isinstance(feature, BackendFeature), feature
575571
return feature in self.get_backend_features(get_device_type(device))
576572

577-
def get_dep_size_hint(self, dep: Dep) -> int:
578-
"""
579-
Get the size hint for a dependency with caching to avoid expensive recomputation.
580-
"""
581-
if dep not in self.dep_size_hint_cache:
582-
res = 0
583-
try:
584-
if not dep.has_unbacked_symbols():
585-
res = dep.numbytes_hint()
586-
except KeyError:
587-
# In at least one test (test/inductor/test_torchbind.py) we
588-
# create a StarDep that doesn't exist in the graph and calling
589-
# `has_unbacked_symbols()` throws an error.
590-
pass
591-
self.dep_size_hint_cache[dep] = res
592-
return self.dep_size_hint_cache[dep]
593-
594573
def get_current_device_or_throw(self) -> torch.device:
595574
if device := self.current_device:
596575
return device

torch/_inductor/ir.py

Lines changed: 0 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -7829,10 +7829,6 @@ def create(data: IRNode) -> Union[TensorBox, ShapeAsConstantBuffer]:
78297829

78307830

78317831
class StorageBox(MutableBox):
7832-
"""
7833-
StorageBox allow in-place mutation of Tensors
7834-
"""
7835-
78367832
def is_input_buffer(self) -> bool:
78377833
if isinstance(self.data, (InputBuffer, ReinterpretView)):
78387834
return self.data.get_name() in V.graph.graph_inputs
@@ -7882,17 +7878,10 @@ def realize_hint(self) -> None:
78827878
):
78837879
self.realize()
78847880

7885-
def has_accumulated_enough_reads_by_size(self) -> bool:
7886-
return (
7887-
sum(V.graph.get_dep_size_hint(dep) for dep in self.get_reads())
7888-
> config.realize_acc_reads_size_threshold
7889-
)
7890-
78917881
def has_exceeded_max_reads(self) -> bool:
78927882
return isinstance(self.data, Pointwise) and (
78937883
self.num_reads() > config.realize_acc_reads_threshold
78947884
or self.has_large_inner_fn()
7895-
or self.has_accumulated_enough_reads_by_size()
78967885
)
78977886

78987887
def should_realize_on_reuse(self, users: int) -> bool:

torch/_inductor/memory.py

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -78,8 +78,19 @@ def get_freeable_input_buf(
7878
A dictionary containing all freeble input buffers, keyed by their names.
7979
"""
8080

81+
# this function is copied from torch/_inductor/scheduler.py
82+
# TODO: would be nice to remove the try/except block for both places
8183
def _dep_size_hint(dep: Dep) -> int:
82-
return V.graph.get_dep_size_hint(dep)
84+
res = 0
85+
try:
86+
if not dep.has_unbacked_symbols():
87+
res = dep.numbytes_hint()
88+
except KeyError:
89+
# In at least one test (test/inductor/test_torchbind.py) we
90+
# create a StarDep that doesn't exist in the graph and calling
91+
# `has_unbacked_symbols()` throws an error.
92+
pass
93+
return res
8394

8495
# get freeable input buffers' successor nodes and their sizes
8596
# note that different deps can have the same name, so we use name as keys

torch/_inductor/scheduler.py

Lines changed: 17 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -2051,12 +2051,15 @@ class Scheduler:
20512051
optimizations such as fusion, reorder, and graph partition.
20522052
"""
20532053

2054+
__dep_size_hint_cache: dict[Dep, int]
2055+
20542056
def __init__(self, nodes: list[ir.Operation]) -> None:
20552057
with dynamo_timed("Scheduler.__init__"):
20562058
self._init(nodes)
20572059

20582060
def _init(self, nodes: list[ir.Operation]) -> None:
20592061
super().__init__()
2062+
self.__dep_size_hint_cache = {}
20602063
V.graph.scheduler = self
20612064
self.backends: dict[torch.device, BaseScheduling] = {}
20622065
self.post_grad_graph_id = next(_post_grad_graph_counter)
@@ -3502,17 +3505,6 @@ def _find_single_user_inputs(
35023505
return True
35033506
return False
35043507

3505-
def fusion_accumulate_large_reads(
3506-
self, node1: BaseSchedulerNode, node2: BaseSchedulerNode
3507-
) -> bool:
3508-
all_reads = (node1.read_writes.reads | node2.read_writes.reads) - (
3509-
node1.read_writes.writes | node2.read_writes.writes
3510-
)
3511-
return (
3512-
sum(self.dep_size_hint(dep) for dep in all_reads)
3513-
> config.realize_acc_reads_size_threshold
3514-
)
3515-
35163508
def are_long_distant_nodes(
35173509
self, node1: BaseSchedulerNode, node2: BaseSchedulerNode
35183510
) -> bool:
@@ -4018,7 +4010,20 @@ def fusable_read_and_write(self, read: Dep, write: MemoryDep) -> bool:
40184010
return False
40194011

40204012
def dep_size_hint(self, dep: Dep) -> int:
4021-
return V.graph.get_dep_size_hint(dep)
4013+
res = 0
4014+
if dep not in self.__dep_size_hint_cache:
4015+
try:
4016+
if not dep.has_unbacked_symbols():
4017+
res = dep.numbytes_hint()
4018+
except KeyError:
4019+
# In at least one test (test/inductor/test_torchbind.py) we
4020+
# create a StarDep that doesn't exist in the graph and calling
4021+
# `has_unbacked_symbols()` throws an error.
4022+
pass
4023+
self.__dep_size_hint_cache[dep] = res
4024+
else:
4025+
res = self.__dep_size_hint_cache[dep]
4026+
return res
40224027

40234028
def score_fusion_memory(
40244029
self, node1: BaseSchedulerNode, node2: BaseSchedulerNode

0 commit comments

Comments
 (0)