Skip to content

Commit 648a664

Browse files
yf225pytorchmergebot
authored andcommitted
[Helion + torch.compile] Fix MultiOutput write deps to eliminate fusion workarounds (#177062)
MultiOutput.get_read_writes() now produces proper MemoryDep writes from FixedLayout instead of inheriting StarDep from InputsKernel. This lets the scheduler match template-output writes with downstream epilogue reads without the manual StarDep→MemoryDep rewrite that was in FusedSchedulerNode.fuse(). Also fixes score_fusion_memory to use name-based matching for templates (a view/reshape between template output and epilogue can produce different index expressions) and fixes the buggy duplicate isinstance check. Pull Request resolved: #177062 Approved by: https://github.com/jansel
1 parent a7e6b9d commit 648a664

2 files changed

Lines changed: 39 additions & 29 deletions

File tree

torch/_inductor/ir.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8834,6 +8834,30 @@ def get_inputs_that_alias_output(self) -> Sequence[str]:
88348834
and len(inp.get_inputs_that_alias_output()) > 0
88358835
]
88368836

8837+
def get_read_writes(self) -> dependencies.ReadWrites:
8838+
# Reads: StarDep on parent (we don't know which elements of the
8839+
# packed output we index into — conservative is correct).
8840+
reads: OrderedSet[dependencies.Dep] = OrderedSet()
8841+
for inp in self.inputs:
8842+
if isinstance(inp, IRNode):
8843+
reads.add(dependencies.StarDep(inp.get_name()))
8844+
8845+
# Writes: build proper MemoryDep from our FixedLayout so the
8846+
# scheduler can match our write with downstream epilogue reads.
8847+
name = self.get_name()
8848+
indexer = self.get_layout().make_indexer()
8849+
8850+
def dummy(index: Sequence[Any], rindex: Sequence[Any]) -> Any:
8851+
assert len(rindex) == 0
8852+
return ops.store(name, indexer(index), "fake")
8853+
8854+
write_rw = dependencies.extract_read_writes(dummy, self.get_size(), ())
8855+
return dependencies.ReadWrites(
8856+
reads=reads,
8857+
writes=write_rw.writes,
8858+
index_exprs=OrderedSet(),
8859+
)
8860+
88378861

88388862
class AllocatingMultiOutput(MultiOutput):
88398863
"""MultiOutput with Inductor-controlled allocation for .out() variant ops.

torch/_inductor/scheduler.py

Lines changed: 15 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -1913,28 +1913,7 @@ def fuse(
19131913
assert node1.scheduler is node2.scheduler
19141914
assert isinstance(node1, (SchedulerNode, FusedSchedulerNode))
19151915
if node1.is_template() and isinstance(node2, ExternKernelSchedulerNode):
1916-
# Fuse multi outputs template and its outputs
1917-
# * Node1 has memorydep of MultiOutput in reads
1918-
# * Node2 has StarDep of MultiOutput in writes
1919-
# Rewrite the Node2' StarDep to MemoryDep, because calculate score_fusion_memory
1920-
# of the template node and its epilogue requires the same type of dependencies
1921-
assert isinstance(node2.node, MultiOutput)
1922-
assert len(node2.read_writes.writes) == 1
1923-
assert isinstance(next(iter(node2.read_writes.writes)), StarDep)
1924-
name = next(iter(node2.read_writes.writes)).name
1925-
template_nodes = [node for node in node1.get_nodes() if node.is_template()]
1926-
assert len(template_nodes) == 1
1927-
template_node = template_nodes[0]
1928-
assert len(template_node.read_writes.writes) == 1
1929-
write = next(iter(template_node.read_writes.writes))
1930-
assert isinstance(write, MemoryDep)
1931-
node2.read_writes.writes = OrderedSet(
1932-
[
1933-
MemoryDep(
1934-
name, write.index, write.var_names, write.size, write.mode
1935-
),
1936-
]
1937-
)
1916+
assert isinstance(node2.node, ir.MultiOutput)
19381917
else:
19391918
assert isinstance(node2, (SchedulerNode, FusedSchedulerNode))
19401919
nodes = list(itertools.chain(node1.get_nodes(), node2.get_nodes()))
@@ -6148,21 +6127,28 @@ def _construct_return_value(
61486127
score = MixOrderReduction.get_fusion_score(node1, node2)
61496128
return _construct_return_value(score, 0, True)
61506129

6151-
# for evaluating fusion memory scores of UserDefinedTritonKernel,
6152-
# we use a slightly different logic which allows matching StarDep with MemoryDep in certain scenarios.
6153-
# (See the checks we make in `can_fuse_epilogue()` that makes this possible)
6130+
# For UserDefinedTritonKernel, the write deps are StarDep that won't
6131+
# match the epilogue's MemoryDep via set intersection. For templates,
6132+
# a view/reshape between the template output and epilogue can produce
6133+
# different index expressions that don't match via set intersection.
6134+
# Fall back to name-based matching so that the fusion score reflects
6135+
# the actual shared buffers.
61546136
if (
6155-
isinstance(node1.node, ir.UserDefinedTritonKernel)
6156-
and node1.node.can_fuse_epilogue()
6137+
(
6138+
isinstance(node1.node, ir.UserDefinedTritonKernel)
6139+
and node1.node.can_fuse_epilogue()
6140+
)
6141+
or node1.is_template()
6142+
or node2.is_template()
61576143
):
61586144
node1_deps = node1.read_writes.reads | node1.read_writes.writes
61596145
node2_deps = node2.read_writes.reads | node2.read_writes.writes
61606146

61616147
def _match(dep1: Dep, dep2: Dep):
61626148
if dep1 == dep2:
61636149
return True
6164-
if (isinstance(dep1, StarDep) and isinstance(dep2, MemoryDep)) or (
6165-
isinstance(dep1, StarDep) and isinstance(dep2, MemoryDep)
6150+
if isinstance(dep1, (StarDep, MemoryDep)) and isinstance(
6151+
dep2, (StarDep, MemoryDep)
61666152
):
61676153
return dep1.name == dep2.name
61686154
return False

0 commit comments

Comments
 (0)