[Helion + torch.compile] prologue / epilogue fusion with pointwise ops#1520
Closed
[Helion + torch.compile] prologue / epilogue fusion with pointwise ops#1520
Conversation
d792eb4 to
81c2739
Compare
a46be33 to
ea83788
Compare
f287043 to
0fcc412
Compare
…eBuffer base class Switch from TritonTemplateBuffer to TemplateBuffer base class to enable prologue/epilogue fusion support. This refactors the template buffer to use make_kernel_render closure, build_multi_outputs, and the standard render() path instead of the custom codegen_template_override/ emit_kernel_override/call_kernel overrides. Also adds get_needed_import_lines() to output_header.py, refactoring get_needed_imports() to use it. stack-info: PR: #1723, branch: yf225/stack/70
…rameter plumbing Add infrastructure for prologue/epilogue fusion in generated Triton code: - store_transform/load_transform hooks in memory_ops for hl.store/hl.load - Extra params, removed args, and protected arg names in codegen pipeline - dim_index_exprs on SubscriptIndexing for epilogue fusion offsets - _ensure_inductor_fusion_config() to enable fusion via config_patches stack-info: PR: #1724, branch: yf225/stack/71 # Conflicts: # helion/_compiler/generate_ast.py
…onTemplateBuffer Add fusion code generation to template_buffer.py: - _render_with_hooks: setup fusion hooks, read fusion metadata, build extra_params, pass store/load transforms conditionally - _build_call_args: add prologue_source_buffers and extra_params support - _generate_triton_ast: pass store/load transform and extra_params - create(): add on_tensor_leaf/on_non_tensor_leaf callback params - _codegen_epilogue_fusion: emit per-epilogue index definitions and STORE_OUTPUT placeholders - _codegen_prologue_fusion: emit prologue variables and LOAD_INPUT placeholders - _flatten_return_ast: helper for build_multi_outputs traversal - lower_helion_kernel: compute epilogue_fusable_outputs from stored proxy ids and output fusion metadata stack-info: PR: #1725, branch: yf225/stack/72
- Replace unconditional `self.skipTest` with `requires_torch_version("2.11")`
guard, now that the required PyTorch-side changes have landed
- Update `expected_num_kernels` values across ~60 tests to reflect
reduced kernel counts from prologue/epilogue fusion
- Un-skip 3 previously disabled tests:
`test_autotune_no_fusion_final_has_fusion`,
`test_inductor_output_code_has_helion_generated_triton_kernel`,
`test_symint_return_from_tensor_shape`
stack-info: PR: #1727, branch: yf225/stack/74
The previous code from PR #1722 treated ALL unbacked symbols as data-dependent, but compound expressions like `flag * 2` (= 2*u0) have unbacked symbols that come from known kernel parameters. Fix: check whether unbacked symbols are a subset of sym_remap keys. If yes, substitute concrete values and evaluate. If any symbol is unknown, raise DataDependentOutputShapeNotSupported. Un-skip 3 tests that were broken by the overly broad check: test_kernel_returns_none_in_tuple, test_kernel_returns_none_first_in_tuple, test_kernel_returns_tuple_of_scalars.
Add ref kernel equivalents and baseline kernel count verification to test_torch_compile.py. Each test's `f` function now accepts a `_kernels` kwarg, allowing `functools.partial` to swap helion kernels for eager refs. The helper `_compile_and_count_kernels` is extracted from `_run_compile_test`, which gains `ref_kernels` and `expected_num_kernels_ref` parameters to assert that the ref baseline produces the expected number of Triton kernels via torch.compile/inductor.
…se 2) When a Helion kernel's output feeds into a single-axis reduction (e.g., k_rms_norm(x, w).sum(dim=1)), Inductor previously produced 2 kernels because the scheduler unconditionally blocked reduction epilogues on templates. This change fuses supported reductions into the template kernel, producing 1 kernel. Phase 1 (persistent): When the tile covers the full reduction dim (e.g., full slice ':' or block_size >= dim_size), the template emits tl.sum/tl.max/tl.min inline at the store site. Phase 2 (loop): When the tile is smaller than the reduction dim (block_size < dim_size), the grid dim is converted to a for-loop with an accumulator that sums partial reductions across iterations. Only single-axis reductions with sum/max/min are supported. Full reductions (sum()), argmax/argmin, welford, and multi-axis reductions fall through to the 2-kernel path.
Contributor
Author
|
Closing in favor of the PR diff starting at #1723 |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
When a Helion kernel runs inside
torch.compile, Inductor's scheduler may identify pointwise ops adjacent to the kernel (e.g.relu,add, dtype casts) that can be fused directly into the kernel's loads or stores, avoiding extra read/write round-trips to global memory.This commit implements both epilogue fusion (fusing ops that consume a kernel output, into the
hl.storesite) and prologue fusion (fusing ops that produce a kernel input, into thehl.loadsite).Dependent PRs:
ExternalTritonTemplateKernelfor external template prologue/epilogue fusion pytorch#176571 (in review)Design doc: #1346