Skip to content

Commit edf1a92

Browse files
yf225pytorchmergebot
authored andcommitted
[Re-land] [Helion + torch.compile] Refactor TemplateBuffer as extensible base class (#177367)
This is a reland of #177063. Move common fields and methods from TritonTemplateBuffer up to TemplateBuffer so that external template backends (e.g. Helion) can reuse the same mutation-tracking and prologue-fusion infrastructure: - Add mutated_inputs, allowed_prologue_inps params to TemplateBuffer.__init__ - Build mutation_outputs list in base class (parallel to ExternKernel.mutation_outputs) - Move get_allowed_prologue_inps() to base class - Extract _read_deps_from_inputs() helper from extract_read_writes() - Remove can_fuse_multi_output_epilogue() (always returned False, unused) - Simplify TritonTemplateBuffer.__init__() to delegate to super() get_outputs() stays on TritonTemplateBuffer since it is the only subclass that currently passes mutated_inputs; other subclasses (CppTemplateBuffer, CuteDSLTemplateBuffer, etc.) manage their own output lists independently. Pull Request resolved: #177367 Approved by: https://github.com/shunting314 ghstack dependencies: #177302
1 parent c658c67 commit edf1a92

2 files changed

Lines changed: 49 additions & 47 deletions

File tree

torch/_inductor/ir.py

Lines changed: 49 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -5206,15 +5206,18 @@ def constant_to_device(self, device: torch.device) -> IRNode:
52065206

52075207
class TemplateBuffer(OperationBuffer):
52085208
"""
5209-
Represents a Triton (in the future other type) of template operator
5210-
that we can fuse an epilogue onto.
5209+
Base class for template operators that support epilogue and prologue fusion.
5210+
Subclasses: TritonTemplateBuffer (built-in Triton templates),
5211+
HelionTemplateBuffer (Helion kernels), etc.
52115212
"""
52125213

52135214
def __init__(
52145215
self,
52155216
layout: OutputSpec,
52165217
inputs: Sequence[IRNode],
52175218
make_kernel_render: Callable[..., Any] | None,
5219+
mutated_inputs: Iterable[IRNode] | None = None,
5220+
allowed_prologue_inps: OrderedSet[str] | None = None,
52185221
) -> None:
52195222
super().__init__(name=None, layout=layout)
52205223
self.inputs = InputsKernel.unwrap_storage(inputs)
@@ -5224,9 +5227,43 @@ def __init__(
52245227
# Annotations dict for storing metadata (e.g., KernelTemplateChoice)
52255228
self.annotations: dict[str, Any] = {}
52265229

5230+
# Inputs that the kernel mutates in-place
5231+
self.mutated_inputs = mutated_inputs
5232+
self.mutation_outputs: list[MutationOutput] = []
5233+
if mutated_inputs is not None:
5234+
first_input = self.inputs[0]
5235+
assert isinstance(first_input, IRNode), type(first_input)
5236+
device = first_input.get_device()
5237+
self.mutation_outputs = [
5238+
MutationOutput(NoneLayout(device=device), buf, self)
5239+
for buf in mutated_inputs
5240+
]
5241+
# Input buffer names eligible for prologue fusion.
5242+
self.allowed_prologue_inps: OrderedSet[str] = (
5243+
allowed_prologue_inps or OrderedSet()
5244+
)
5245+
52275246
def get_read_writes(self) -> dependencies.ReadWrites:
52285247
return self.extract_read_writes(normalize=True)
52295248

5249+
def _read_deps_from_inputs(self, normalize: bool) -> OrderedSet[dependencies.Dep]:
5250+
"""Build read dependencies from all inputs."""
5251+
reads: OrderedSet[dependencies.Dep] = OrderedSet()
5252+
for inp_raw in self.inputs:
5253+
assert isinstance(inp_raw, (ReinterpretView, Buffer)), type(inp_raw)
5254+
inp: ReinterpretView | Buffer = inp_raw
5255+
assert isinstance(inp.layout, Layout), type(inp.layout)
5256+
inp_indexer = inp.layout.make_indexer()
5257+
5258+
def dummy(index: Sequence[Any], rindex: Sequence[Any]) -> Any:
5259+
assert len(rindex) == 0
5260+
return ops.load(inp.get_name(), inp_indexer(index))
5261+
5262+
reads |= dependencies.extract_read_writes(
5263+
dummy, inp.get_size(), (), normalize=normalize
5264+
).reads
5265+
return reads
5266+
52305267
def extract_read_writes(self, normalize: bool = False) -> dependencies.ReadWrites:
52315268
name = self.get_name()
52325269
indexer = self.get_layout().make_indexer()
@@ -5238,22 +5275,7 @@ def dummy(index: Sequence[Any], rindex: Sequence[Any]) -> Any:
52385275
deps = dependencies.extract_read_writes(
52395276
dummy, self.get_size(), (), normalize=normalize
52405277
)
5241-
5242-
for inp in self.inputs:
5243-
assert isinstance(inp, (ReinterpretView, Buffer)), type(inp)
5244-
assert isinstance(inp.layout, Layout), type(inp.layout)
5245-
5246-
indexer = inp.layout.make_indexer()
5247-
5248-
def dummy(index: Sequence[Any], rindex: Sequence[Any]) -> Any:
5249-
assert len(rindex) == 0
5250-
# pyrefly: ignore [missing-attribute]
5251-
return ops.load(inp.get_name(), indexer(index))
5252-
5253-
deps.reads |= dependencies.extract_read_writes(
5254-
dummy, inp.get_size(), (), normalize=normalize
5255-
).reads
5256-
5278+
deps.reads |= self._read_deps_from_inputs(normalize)
52575279
return deps
52585280

52595281
def get_reduction_size(self) -> Sequence[Expr]:
@@ -5282,14 +5304,8 @@ def is_multi_outputs_template(self) -> bool:
52825304
"""Whether this template produces multiple outputs via MultiOutputLayout."""
52835305
return isinstance(self.layout, MultiOutputLayout)
52845306

5285-
def can_fuse_multi_output_epilogue(self, snode: object) -> bool:
5286-
"""Whether scheduler node can be fused as an epilogue of this multi-output template.
5287-
5288-
Returns ``False`` by default. Subclasses may override to support
5289-
additional fusion patterns (e.g. epilogue fusion with multi-output
5290-
extraction and pointwise operations).
5291-
"""
5292-
return False
5307+
def get_allowed_prologue_inps(self) -> OrderedSet[str]:
5308+
return self.allowed_prologue_inps
52935309

52945310

52955311
class TritonTemplateBuffer(TemplateBuffer):
@@ -5310,19 +5326,12 @@ def __init__(
53105326
We work around this by creating an extra input buffer during the lowering
53115327
and we mark them as mutated inputs.
53125328
"""
5313-
super().__init__(layout, inputs, make_kernel_render)
5314-
self.mutated_inputs = mutated_inputs
5315-
self.outputs: list[Buffer] = [self]
5316-
if mutated_inputs is not None:
5317-
assert isinstance(self.inputs[0], IRNode), type(self.inputs[0])
5318-
device = self.inputs[0].get_device()
5319-
self.outputs += [
5320-
MutationOutput(NoneLayout(device=device), buf, self)
5321-
for buf in mutated_inputs
5322-
]
5323-
5324-
self.allowed_prologue_inps = (
5325-
allowed_prologue_inps if allowed_prologue_inps else OrderedSet()
5329+
super().__init__(
5330+
layout,
5331+
inputs,
5332+
make_kernel_render,
5333+
mutated_inputs=mutated_inputs,
5334+
allowed_prologue_inps=allowed_prologue_inps,
53265335
)
53275336

53285337
self.subgraph_inps: list[IRNode | sympy.Expr | None] | None = None
@@ -5353,10 +5362,7 @@ def get_free_symbol_uses(
53535362
return res
53545363

53555364
def get_outputs(self) -> list[Buffer]:
5356-
return self.outputs
5357-
5358-
def get_allowed_prologue_inps(self) -> OrderedSet[str]:
5359-
return self.allowed_prologue_inps
5365+
return [self, *self.mutation_outputs]
53605366

53615367
def __str__(self) -> str:
53625368
out = f"TritonTemplateBuffer(layout={self.layout})"

torch/_inductor/scheduler.py

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -7751,16 +7751,12 @@ def can_fuse_multi_outputs_template(
77517751
and node2 corresponds to one of its outputs. If so, we further check if
77527752
backend supports this fusion.
77537753
7754-
Delegates to ``TemplateBuffer.can_fuse_multi_output_epilogue`` which
7755-
TemplateBuffer subclasses may override to allow fusion of additional node types.
77567754
"""
77577755
template_buf = node1.get_template_node()
77587756
if not isinstance(template_buf, ir.TemplateBuffer):
77597757
return False
77607758
if not template_buf.is_multi_outputs_template():
77617759
return False
7762-
if template_buf.can_fuse_multi_output_epilogue(node2):
7763-
return True
77647760
return False
77657761

77667762
def fuse(

0 commit comments

Comments
 (0)