Skip to content

Commit 1fca0ec

Browse files
committed
[Helion + torch.compile] Handle multi-output templates in prologue fusion dtype heuristic
TemplateBuffer subclasses with MultiOutputLayout (e.g. Helion kernels) don't have a single dtype. Add an explicit error in TemplateBuffer.dtype for this case, and guard the scheduler's low-precision heuristic with is_multi_outputs_template() so it skips the check rather than crashing. [ghstack-poisoned]
1 parent 6c88a20 commit 1fca0ec

2 files changed

Lines changed: 11 additions & 1 deletion

File tree

torch/_inductor/ir.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5257,6 +5257,14 @@ def __init__(
52575257
allowed_prologue_inps or OrderedSet()
52585258
)
52595259

5260+
@property
5261+
def dtype(self) -> torch.dtype:
5262+
if isinstance(self.layout, MultiOutputLayout):
5263+
raise NotImplementedError(
5264+
"Multi-output templates do not have a single dtype"
5265+
)
5266+
return self.get_layout().dtype
5267+
52605268
def get_read_writes(self) -> dependencies.ReadWrites:
52615269
return self.extract_read_writes(normalize=True)
52625270

torch/_inductor/scheduler.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5593,8 +5593,10 @@ def check_prologue_fusion_heuristics_fusable(
55935593
def low_prec_fp(dtype: torch.dtype) -> bool:
55945594
return dtype.itemsize <= 2 and dtype.is_floating_point
55955595

5596+
template_buf = template_node.get_template_node_or_throw()
55965597
if (
5597-
low_prec_fp(template_node.get_template_node_or_throw().dtype)
5598+
not template_buf.is_multi_outputs_template()
5599+
and low_prec_fp(template_buf.dtype)
55985600
and not prologue_node.can_codegen_in_low_precision()
55995601
):
56005602
why(

0 commit comments

Comments
 (0)