Skip to content

[Helion + torch.compile] prologue / epilogue fusion with pointwise ops#1520

Closed
yf225 wants to merge 7 commits intomainfrom
helion_inductor_fusion_v3_fusion_pr1
Closed

[Helion + torch.compile] prologue / epilogue fusion with pointwise ops#1520
yf225 wants to merge 7 commits intomainfrom
helion_inductor_fusion_v3_fusion_pr1

Conversation

@yf225
Copy link
Copy Markdown
Contributor

@yf225 yf225 commented Feb 19, 2026

When a Helion kernel runs inside torch.compile, Inductor's scheduler may identify pointwise ops adjacent to the kernel (e.g. relu, add, dtype casts) that can be fused directly into the kernel's loads or stores, avoiding extra read/write round-trips to global memory.

This commit implements both epilogue fusion (fusing ops that consume a kernel output, into the hl.store site) and prologue fusion (fusing ops that produce a kernel input, into the hl.load site).

codegen_template_override (entry point called by Inductor scheduler)
     └→ _codegen_with_fusion
          Phase 1 — trace fusion expressions:
          ├→ _build_prologue_specs
          │    └→ _extract_fusion_expr (per node) → self._prologue_specs
          ├→ _build_epilogue_specs
          │    └→ _extract_fusion_expr (per node) → self._epilogue_specs
          │
          Phase 2 — regenerate kernel (specs are now populated):
          └→ _generate_triton_ast
               └→ generate_ast → Helion kernel codegen
                    ├→ hl.store (memory_ops.py) → codegen_epilogue_fusion
                    └→ hl.load  (memory_ops.py) → codegen_prologue_fusion

Dependent PRs:

Design doc: #1346

@meta-cla meta-cla bot added the CLA Signed This label is managed by the Meta Open Source bot. label Feb 19, 2026
@yf225 yf225 force-pushed the helion_inductor_fusion_v3_fusion_pr1 branch from d792eb4 to 81c2739 Compare February 19, 2026 21:14
@meta-codesync
Copy link
Copy Markdown

meta-codesync bot commented Feb 19, 2026

@yf225 has imported this pull request. If you are a Meta employee, you can view this in D93779623.

@yf225 yf225 force-pushed the helion_inductor_fusion_v3_fusion_pr1 branch 20 times, most recently from a46be33 to ea83788 Compare February 20, 2026 21:15
@yf225 yf225 marked this pull request as ready for review February 20, 2026 21:16
yf225 added 7 commits March 17, 2026 11:17
…eBuffer base class

Switch from TritonTemplateBuffer to TemplateBuffer base class to enable
prologue/epilogue fusion support. This refactors the template buffer to
use make_kernel_render closure, build_multi_outputs, and the standard
render() path instead of the custom codegen_template_override/
emit_kernel_override/call_kernel overrides.

Also adds get_needed_import_lines() to output_header.py, refactoring
get_needed_imports() to use it.

stack-info: PR: #1723, branch: yf225/stack/70
…rameter plumbing

Add infrastructure for prologue/epilogue fusion in generated Triton code:
- store_transform/load_transform hooks in memory_ops for hl.store/hl.load
- Extra params, removed args, and protected arg names in codegen pipeline
- dim_index_exprs on SubscriptIndexing for epilogue fusion offsets
- _ensure_inductor_fusion_config() to enable fusion via config_patches

stack-info: PR: #1724, branch: yf225/stack/71

# Conflicts:
#	helion/_compiler/generate_ast.py
…onTemplateBuffer

Add fusion code generation to template_buffer.py:
- _render_with_hooks: setup fusion hooks, read fusion metadata, build
  extra_params, pass store/load transforms conditionally
- _build_call_args: add prologue_source_buffers and extra_params support
- _generate_triton_ast: pass store/load transform and extra_params
- create(): add on_tensor_leaf/on_non_tensor_leaf callback params
- _codegen_epilogue_fusion: emit per-epilogue index definitions and
  STORE_OUTPUT placeholders
- _codegen_prologue_fusion: emit prologue variables and LOAD_INPUT
  placeholders
- _flatten_return_ast: helper for build_multi_outputs traversal
- lower_helion_kernel: compute epilogue_fusable_outputs from stored
  proxy ids and output fusion metadata

stack-info: PR: #1725, branch: yf225/stack/72
- Replace unconditional `self.skipTest` with `requires_torch_version("2.11")`
  guard, now that the required PyTorch-side changes have landed
- Update `expected_num_kernels` values across ~60 tests to reflect
  reduced kernel counts from prologue/epilogue fusion
- Un-skip 3 previously disabled tests:
  `test_autotune_no_fusion_final_has_fusion`,
  `test_inductor_output_code_has_helion_generated_triton_kernel`,
  `test_symint_return_from_tensor_shape`

stack-info: PR: #1727, branch: yf225/stack/74
The previous code from PR #1722 treated ALL unbacked symbols as
data-dependent, but compound expressions like `flag * 2` (= 2*u0)
have unbacked symbols that come from known kernel parameters.

Fix: check whether unbacked symbols are a subset of sym_remap keys.
If yes, substitute concrete values and evaluate. If any symbol is
unknown, raise DataDependentOutputShapeNotSupported.

Un-skip 3 tests that were broken by the overly broad check:
test_kernel_returns_none_in_tuple,
test_kernel_returns_none_first_in_tuple,
test_kernel_returns_tuple_of_scalars.
Add ref kernel equivalents and baseline kernel count verification to
test_torch_compile.py. Each test's `f` function now accepts a `_kernels`
kwarg, allowing `functools.partial` to swap helion kernels for eager
refs. The helper `_compile_and_count_kernels` is extracted from
`_run_compile_test`, which gains `ref_kernels` and
`expected_num_kernels_ref` parameters to assert that the ref
baseline produces the expected number of Triton kernels via
torch.compile/inductor.
…se 2)

When a Helion kernel's output feeds into a single-axis reduction
(e.g., k_rms_norm(x, w).sum(dim=1)), Inductor previously produced
2 kernels because the scheduler unconditionally blocked reduction
epilogues on templates. This change fuses supported reductions into
the template kernel, producing 1 kernel.

Phase 1 (persistent): When the tile covers the full reduction dim
(e.g., full slice ':' or block_size >= dim_size), the template emits
tl.sum/tl.max/tl.min inline at the store site.

Phase 2 (loop): When the tile is smaller than the reduction dim
(block_size < dim_size), the grid dim is converted to a for-loop
with an accumulator that sums partial reductions across iterations.

Only single-axis reductions with sum/max/min are supported.
Full reductions (sum()), argmax/argmin, welford, and multi-axis
reductions fall through to the 2-kernel path.
@yf225
Copy link
Copy Markdown
Contributor Author

yf225 commented Mar 18, 2026

Closing in favor of the PR diff starting at #1723

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

CLA Signed This label is managed by the Meta Open Source bot.

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants