Skip to content

[Inductor] User kernel unary epilogue fusion#173662

Closed
AmesingFlank wants to merge 10 commits intopytorch:mainfrom
AmesingFlank:frank_dev
Closed

[Inductor] User kernel unary epilogue fusion#173662
AmesingFlank wants to merge 10 commits intopytorch:mainfrom
AmesingFlank:frank_dev

Conversation

@AmesingFlank
Copy link
Copy Markdown
Contributor

@AmesingFlank AmesingFlank commented Jan 28, 2026

This PR enables fusing user-defined triton kernels with unary epilogues such as relu(). This roughly involves the following changes

  • Extended TTIR analysis from tracking only writes to tracking both reads and writes (analyze_kernel_access).
  • We deem a user triton kernel eligible for fusion iff it writes to, but does not read from, a tensor initialized with UB values (see UserDefinedTritonKernel::can_fuse_epilogues for detailed rationale)
  • Created FusedExternTritonKernelSchedulerNode and updated fusion logic to allow extern kernel nodes (triton kernels) to epilogue-fuse with pointwise scheduler nodes.
  • To modify the triton kernel source code, we parse the original src into a python AST, then identify the expr containing the original value written via tl.store(). We generate an expr for the value after the epilogue and replace that into the tl.store

@eellison

cc @voznesenskym @penguinwu @EikanWang @jgong5 @Guobing-Chen @XiaobingSuper @zhuhaozhe @blzheng @wenzhe-nrv @jiayisunx @ipiszy @kadeng @muchulee8 @amjames @chauhang @aakhundov @coconutruben @jataylo @mlazos

@pytorch-bot
Copy link
Copy Markdown

pytorch-bot Bot commented Jan 28, 2026

This PR needs a release notes: label

If your changes are user facing and intended to be a part of release notes, please use a label starting with release notes:.

If not, please add the topic: not user facing label.

To add a label, you can comment to pytorchbot, for example
@pytorchbot label "topic: not user facing"

For more information, see
https://github.com/pytorch/pytorch/wiki/PyTorch-AutoLabel-Bot#why-categorize-for-release-notes-and-how-does-it-work.

@pytorch-bot
Copy link
Copy Markdown

pytorch-bot Bot commented Jan 28, 2026

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/173662

Note: Links to docs will display an error until the docs builds have been completed.

✅ You can merge normally! (3 Unrelated Failures)

As of commit c6605fa with merge base f26ec24 (image):

FLAKY - The following job failed but was likely due to flakiness present on trunk:

BROKEN TRUNK - The following job failed but was present on the merge base:

👉 Rebase onto the `viable/strict` branch to avoid these failures

UNSTABLE - The following job is marked as unstable, possibly due to flakiness on trunk:

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@linux-foundation-easycla
Copy link
Copy Markdown

linux-foundation-easycla Bot commented Jan 28, 2026

CLA Signed

The committers listed above are authorized under a signed CLA.

@AmesingFlank AmesingFlank changed the title [WIP][Indictor] User kernel unary epilogue fusion [WIP][Inductor] User kernel unary epilogue fusion Jan 28, 2026
@AmesingFlank AmesingFlank force-pushed the frank_dev branch 11 times, most recently from 4a3cf5f to 69413d7 Compare January 29, 2026 15:41
@AmesingFlank AmesingFlank marked this pull request as ready for review January 29, 2026 15:43
@AmesingFlank AmesingFlank changed the title [WIP][Inductor] User kernel unary epilogue fusion [Inductor] User kernel unary epilogue fusion Jan 29, 2026
@AmesingFlank AmesingFlank force-pushed the frank_dev branch 3 times, most recently from 5e60dd8 to 335b8c5 Compare January 29, 2026 17:10
Copy link
Copy Markdown
Contributor

@eellison eellison left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

very good start ! left some comments

Comment thread torch/_inductor/codegen/wrapper.py Outdated
Comment on lines +320 to +326
for op in reversed(epilogue.data.origins): # `origins` contain the ops in reverse
if op.name == "relu":
store_value_expr = f"triton_helpers.maximum(0, {store_value_expr})"
elif op.name == "sigmoid":
store_value_expr = f"tl.sigmoid({store_value_expr})"
else:
raise AssertionError("unsupported epilogue op: ", op.name)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ultimately, we'll need to actually codegen the unary fn.

Comment thread torch/_higher_order_ops/triton_kernel_wrap.py Outdated
"tt.atomic_cas": [0],
"tt.atomic_rmw": [0],
"tt.experimental_descriptor_store": [0],
"tt.experimental_tensormap_create": [0],
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

tt.experimental_tensormap_create is not a write op.. it just creates a descriptor which will then be used for either loading or storing..

This is unfortunately a part of the code base that has no owner currently. maybe it will be you :)

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see. Given that this is pre-existing, OK if we look into this in a follow-up?

Comment thread torch/_higher_order_ops/triton_kernel_wrap.py Outdated
MUTATION_OPS = {
WRITE_OPS = {
"tt.store": [0],
"tt.atomic_cas": [0],
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we're missing a bunch of atomic ops.... we should also record which writes are atomic and update the write Dep mode to reflect that. https://github.com/pytorch/pytorch/blob/main/torch/_inductor/dependencies.py#L83

we should also update the typing hint to be be Optional[Literal[...]] so it's more clear what the modes can be

Comment thread torch/_higher_order_ops/triton_kernel_wrap.py Outdated
@functools.cache
def identify_triton_stores(source_code: str):
"""
Parse Python source code of triton kernel and find all tl.store calls.
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The existing parser uses ttgir, and this uses the ast. Are there any concerns there ?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The two paths are for different purposes. The TTIR-based analysis are for extracting read/write info out of the code, whereas the AST parsing is so that we can modify certain parts of the code to achieve fusion. I do think these scenarios call for different approaches and its the best to use IR for program analysis and use AST for source code manipulation

Comment thread torch/_inductor/codegen/wrapper.py Outdated
Comment thread torch/_inductor/ir.py Outdated
Comment on lines +7132 to +7133
if next(iter(self.mutable_args[0].origins)).name != "empty":
return False
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

similarly, we should not rely on the origins. you can check that this is a Nop input:

# explicitly set ranges to zeros in order to make a NopKernelSchedulerNode
buffer.data = dataclasses.replace(buffer.data, ranges=[0] * len(size))

Comment thread torch/_inductor/ir.py Outdated
Comment thread torch/_inductor/ir.py Outdated
We do this by pruning the `out` tensor allocation and directly writing the relu-output.
"""
# only do epilogue fusion if the kernel has a single output tensor
if len(self.args_read_writes.writes) != 1:
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

let's also check that the input & output size/ stride/dtype are the same before we have logic to fix.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good call. I updated the PR to include that check in Scheduler::can_fuse.

Comment thread torch/_inductor/scheduler.py Outdated
Comment on lines +1485 to +1490
if isinstance(node, ir.UserDefinedTritonKernel) and node.can_fuse_epilogues():
numel = math.prod(node.mutable_args[0].shape)
rnumel = 1
device = node.get_device_or_error()
# pyrefly: ignore [bad-assignment]
self.group = (device, (numel, rnumel))
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

so, this will only support fusion for one of the outputs. i think supporting only a single epilogue is fine initially. potentially we could see which output has a potential epilogue fusion while we have this limitation.

Comment thread torch/_inductor/scheduler.py Outdated
Comment on lines +2296 to +2297
def fusable_pointwise_ops(cls):
return OrderedSet(["relu", "sigmoid"])
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I guess this is because of our current codegen limitations ?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Correct. I've updated the PR to properly call node.codegen() so this is now removed

@AmesingFlank AmesingFlank force-pushed the frank_dev branch 4 times, most recently from df2508a to 1ec9830 Compare January 31, 2026 03:53
@pytorchmergebot
Copy link
Copy Markdown
Collaborator

Reverting PR 173662 failed

Reason: Command git -C /home/runner/work/pytorch/pytorch revert --no-edit 615c79fa101e6d79144bf47ce8334d20c9787b2d returned non-zero exit code 1

Auto-merging test/inductor/test_triton_kernels.py
Auto-merging torch/_higher_order_ops/triton_kernel_wrap.py
CONFLICT (content): Merge conflict in torch/_higher_order_ops/triton_kernel_wrap.py
Auto-merging torch/_inductor/codegen/wrapper.py
Auto-merging torch/_inductor/config.py
Auto-merging torch/_inductor/ir.py
Auto-merging torch/_inductor/utils.py
error: could not revert 615c79fa101... [Inductor] User kernel unary epilogue fusion (#173662)
hint: After resolving the conflicts, mark them with
hint: "git add/rm <pathspec>", then run
hint: "git revert --continue".
hint: You can instead skip this commit with "git revert --skip".
hint: To abort and get back to the state before "git revert",
hint: run "git revert --abort".
hint: Disable this message with "git config set advice.mergeConflict false"
Details for Dev Infra team Raised by workflow job

@zou3519
Copy link
Copy Markdown
Contributor

zou3519 commented Mar 6, 2026

I'm going to disable the test until we fix it

@desertfire
Copy link
Copy Markdown
Contributor

If it's only cpp_wrapper failures, I can take care of that. I recently enabled more cpp-wrapper tests on CI, so this was a landing race.

desertfire added a commit that referenced this pull request Mar 6, 2026
Summary: #173662 added more tests to test/inductor/test_triton_kernels.py, and #175416 enable cpp-wrapper test on test/inductor/test_triton_kernels.py. So there was a land race and #173662 didn't have the failing CI signal at the landing time.

Forward fix by updating the code checking target for cpp-wrapper.

[ghstack-poisoned]
desertfire added a commit that referenced this pull request Mar 6, 2026
Summary: #173662 added more tests to test/inductor/test_triton_kernels.py, and #175416 enable cpp-wrapper test on test/inductor/test_triton_kernels.py. So there was a land race and #173662 didn't have the failing CI signal at the landing time.

Forward fix by updating the code checking target for cpp-wrapper.

ghstack-source-id: 5244ba6
Pull Request resolved: #176745
@laithsakka
Copy link
Copy Markdown
Contributor

this PR increase instruction count on mm_loop_inductor_gpu by 10%

@laithsakka
Copy link
Copy Markdown
Contributor

Screenshot 2026-03-06 at 1 54 11 PM

desertfire added a commit that referenced this pull request Mar 6, 2026
Summary:
1. #173662 added more tests to test/inductor/test_triton_kernels.py, and #175416 enable cpp-wrapper test on test/inductor/test_triton_kernels.py. So there was a land race and #173662 didn't have the failing CI signal at the landing time.

Forward fix by updating the code checking target for cpp-wrapper.

2. #176353 also had land race. Skip now and the fix is coming later.

[ghstack-poisoned]
desertfire added a commit that referenced this pull request Mar 6, 2026
Summary:
1. #173662 added more tests to test/inductor/test_triton_kernels.py, and #175416 enable cpp-wrapper test on test/inductor/test_triton_kernels.py. So there was a land race and #173662 didn't have the failing CI signal at the landing time.

Forward fix by updating the code checking target for cpp-wrapper.

2. #176353 also had land race. Skip now and the fix is coming later.

ghstack-source-id: c856a94
Pull Request resolved: #176745
AmesingFlank added a commit that referenced this pull request Mar 7, 2026
AmesingFlank added a commit that referenced this pull request Mar 7, 2026
AmesingFlank added a commit that referenced this pull request Mar 7, 2026
AmesingFlank added a commit that referenced this pull request Mar 7, 2026
pytorchmergebot pushed a commit that referenced this pull request Mar 7, 2026
Summary:
1. #173662 added more tests to test/inductor/test_triton_kernels.py, and #175416 enable cpp-wrapper test on test/inductor/test_triton_kernels.py. So there was a land race and #173662 didn't have the failing CI signal at the landing time.

Forward fix by updating the code checking target for cpp-wrapper.

2. #176353 also had land race. Skip now and the fix is coming later.
Pull Request resolved: #176745
Approved by: https://github.com/AmesingFlank, https://github.com/zou3519
AmesingFlank added a commit that referenced this pull request Mar 7, 2026
AmesingFlank added a commit that referenced this pull request Mar 7, 2026
AmesingFlank added a commit that referenced this pull request Mar 7, 2026
AmesingFlank added a commit that referenced this pull request Mar 8, 2026
AmesingFlank added a commit that referenced this pull request Mar 8, 2026
AmesingFlank added a commit that referenced this pull request Mar 8, 2026
AmesingFlank added a commit that referenced this pull request Mar 8, 2026
AmesingFlank added a commit that referenced this pull request Mar 8, 2026
AmesingFlank added a commit that referenced this pull request Mar 8, 2026
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

6 participants