@@ -5206,15 +5206,18 @@ def constant_to_device(self, device: torch.device) -> IRNode:
52065206
52075207class TemplateBuffer (OperationBuffer ):
52085208 """
5209- Represents a Triton (in the future other type) of template operator
5210- that we can fuse an epilogue onto.
5209+ Base class for template operators that support epilogue and prologue fusion.
5210+ Subclasses: TritonTemplateBuffer (built-in Triton templates),
5211+ HelionTemplateBuffer (Helion kernels), etc.
52115212 """
52125213
52135214 def __init__ (
52145215 self ,
52155216 layout : OutputSpec ,
52165217 inputs : Sequence [IRNode ],
52175218 make_kernel_render : Callable [..., Any ] | None ,
5219+ mutated_inputs : Iterable [IRNode ] | None = None ,
5220+ allowed_prologue_inps : OrderedSet [str ] | None = None ,
52185221 ) -> None :
52195222 super ().__init__ (name = None , layout = layout )
52205223 self .inputs = InputsKernel .unwrap_storage (inputs )
@@ -5224,9 +5227,43 @@ def __init__(
52245227 # Annotations dict for storing metadata (e.g., KernelTemplateChoice)
52255228 self .annotations : dict [str , Any ] = {}
52265229
5230+ # Inputs that the kernel mutates in-place
5231+ self .mutated_inputs = mutated_inputs
5232+ self .mutation_outputs : list [MutationOutput ] = []
5233+ if mutated_inputs is not None :
5234+ first_input = self .inputs [0 ]
5235+ assert isinstance (first_input , IRNode ), type (first_input )
5236+ device = first_input .get_device ()
5237+ self .mutation_outputs = [
5238+ MutationOutput (NoneLayout (device = device ), buf , self )
5239+ for buf in mutated_inputs
5240+ ]
5241+ # Input buffer names eligible for prologue fusion.
5242+ self .allowed_prologue_inps : OrderedSet [str ] = (
5243+ allowed_prologue_inps or OrderedSet ()
5244+ )
5245+
52275246 def get_read_writes (self ) -> dependencies .ReadWrites :
52285247 return self .extract_read_writes (normalize = True )
52295248
5249+ def _read_deps_from_inputs (self , normalize : bool ) -> OrderedSet [dependencies .Dep ]:
5250+ """Build read dependencies from all inputs."""
5251+ reads : OrderedSet [dependencies .Dep ] = OrderedSet ()
5252+ for inp_raw in self .inputs :
5253+ assert isinstance (inp_raw , (ReinterpretView , Buffer )), type (inp_raw )
5254+ inp : ReinterpretView | Buffer = inp_raw
5255+ assert isinstance (inp .layout , Layout ), type (inp .layout )
5256+ inp_indexer = inp .layout .make_indexer ()
5257+
5258+ def dummy (index : Sequence [Any ], rindex : Sequence [Any ]) -> Any :
5259+ assert len (rindex ) == 0
5260+ return ops .load (inp .get_name (), inp_indexer (index ))
5261+
5262+ reads |= dependencies .extract_read_writes (
5263+ dummy , inp .get_size (), (), normalize = normalize
5264+ ).reads
5265+ return reads
5266+
52305267 def extract_read_writes (self , normalize : bool = False ) -> dependencies .ReadWrites :
52315268 name = self .get_name ()
52325269 indexer = self .get_layout ().make_indexer ()
@@ -5238,22 +5275,7 @@ def dummy(index: Sequence[Any], rindex: Sequence[Any]) -> Any:
52385275 deps = dependencies .extract_read_writes (
52395276 dummy , self .get_size (), (), normalize = normalize
52405277 )
5241-
5242- for inp in self .inputs :
5243- assert isinstance (inp , (ReinterpretView , Buffer )), type (inp )
5244- assert isinstance (inp .layout , Layout ), type (inp .layout )
5245-
5246- indexer = inp .layout .make_indexer ()
5247-
5248- def dummy (index : Sequence [Any ], rindex : Sequence [Any ]) -> Any :
5249- assert len (rindex ) == 0
5250- # pyrefly: ignore [missing-attribute]
5251- return ops .load (inp .get_name (), indexer (index ))
5252-
5253- deps .reads |= dependencies .extract_read_writes (
5254- dummy , inp .get_size (), (), normalize = normalize
5255- ).reads
5256-
5278+ deps .reads |= self ._read_deps_from_inputs (normalize )
52575279 return deps
52585280
52595281 def get_reduction_size (self ) -> Sequence [Expr ]:
@@ -5282,14 +5304,8 @@ def is_multi_outputs_template(self) -> bool:
52825304 """Whether this template produces multiple outputs via MultiOutputLayout."""
52835305 return isinstance (self .layout , MultiOutputLayout )
52845306
5285- def can_fuse_multi_output_epilogue (self , snode : object ) -> bool :
5286- """Whether scheduler node can be fused as an epilogue of this multi-output template.
5287-
5288- Returns ``False`` by default. Subclasses may override to support
5289- additional fusion patterns (e.g. epilogue fusion with multi-output
5290- extraction and pointwise operations).
5291- """
5292- return False
5307+ def get_allowed_prologue_inps (self ) -> OrderedSet [str ]:
5308+ return self .allowed_prologue_inps
52935309
52945310
52955311class TritonTemplateBuffer (TemplateBuffer ):
@@ -5310,19 +5326,12 @@ def __init__(
53105326 We work around this by creating an extra input buffer during the lowering
53115327 and we mark them as mutated inputs.
53125328 """
5313- super ().__init__ (layout , inputs , make_kernel_render )
5314- self .mutated_inputs = mutated_inputs
5315- self .outputs : list [Buffer ] = [self ]
5316- if mutated_inputs is not None :
5317- assert isinstance (self .inputs [0 ], IRNode ), type (self .inputs [0 ])
5318- device = self .inputs [0 ].get_device ()
5319- self .outputs += [
5320- MutationOutput (NoneLayout (device = device ), buf , self )
5321- for buf in mutated_inputs
5322- ]
5323-
5324- self .allowed_prologue_inps = (
5325- allowed_prologue_inps if allowed_prologue_inps else OrderedSet ()
5329+ super ().__init__ (
5330+ layout ,
5331+ inputs ,
5332+ make_kernel_render ,
5333+ mutated_inputs = mutated_inputs ,
5334+ allowed_prologue_inps = allowed_prologue_inps ,
53265335 )
53275336
53285337 self .subgraph_inps : list [IRNode | sympy .Expr | None ] | None = None
@@ -5353,10 +5362,7 @@ def get_free_symbol_uses(
53535362 return res
53545363
53555364 def get_outputs (self ) -> list [Buffer ]:
5356- return self .outputs
5357-
5358- def get_allowed_prologue_inps (self ) -> OrderedSet [str ]:
5359- return self .allowed_prologue_inps
5365+ return [self , * self .mutation_outputs ]
53605366
53615367 def __str__ (self ) -> str :
53625368 out = f"TritonTemplateBuffer(layout={ self .layout } )"
0 commit comments