Skip to content

Commit ddfde89

Browse files
committed
Update on "[Helion + torch.compile] Refactor TemplateBuffer as extensible base class"
Move common fields and methods up from TritonTemplateBuffer to TemplateBuffer so that all template subclasses (Triton, CuteDSL, external backends) share them: - Add mutated_inputs, allowed_prologue_inps to TemplateBuffer.__init__ - Move mutation_outputs setup from TritonTemplateBuffer to base class - Move get_outputs(), get_allowed_prologue_inps() up - Extract _read_deps_from_inputs() helper from extract_read_writes() - Remove can_fuse_multi_output_epilogue() (unused) - Simplify TritonTemplateBuffer to delegate to super().__init__() - Remove redundant self.outputs from CppTemplateBuffer cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx ipiszy kadeng muchulee8 amjames chauhang aakhundov coconutruben jataylo [ghstack-poisoned]
2 parents 865a3f6 + 51f28e8 commit ddfde89

1 file changed

Lines changed: 4 additions & 6 deletions

File tree

torch/_inductor/ir.py

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -5223,9 +5223,7 @@ def __init__(
52235223
# Annotations dict for storing metadata (e.g., KernelTemplateChoice)
52245224
self.annotations: dict[str, Any] = {}
52255225

5226-
# Inputs that the kernel mutates in-place (parallel to
5227-
# ExternKernel.mutation_outputs — kept separate from self.outputs
5228-
# so subclasses can freely use self.outputs for other purposes).
5226+
# Inputs that the kernel mutates in-place
52295227
self.mutated_inputs = mutated_inputs
52305228
self.mutation_outputs: list[MutationOutput] = []
52315229
if mutated_inputs is not None:
@@ -5335,9 +5333,6 @@ def __init__(
53355333
self.subgraph_inps: list[IRNode | sympy.Expr | None] | None = None
53365334
self.subgraph_outs: list[IRNode | None] | None = None
53375335

5338-
def get_outputs(self) -> list[Buffer]:
5339-
return [self, *self.mutation_outputs]
5340-
53415336
@cache_on_self_and_args("TritonTemplateBuffer")
53425337
def get_free_symbol_uses(
53435338
self, unbacked_only: bool = False
@@ -5362,6 +5357,9 @@ def get_free_symbol_uses(
53625357

53635358
return res
53645359

5360+
def get_outputs(self) -> list[Buffer]:
5361+
return [self, *self.mutation_outputs]
5362+
53655363
def __str__(self) -> str:
53665364
out = f"TritonTemplateBuffer(layout={self.layout})"
53675365
return out

0 commit comments

Comments
 (0)