[Inductor] User kernel unary epilogue fusion#173662
[Inductor] User kernel unary epilogue fusion#173662AmesingFlank wants to merge 10 commits intopytorch:mainfrom
Conversation
This PR needs a
|
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/173662
Note: Links to docs will display an error until the docs builds have been completed. ✅ You can merge normally! (3 Unrelated Failures)As of commit c6605fa with merge base f26ec24 ( FLAKY - The following job failed but was likely due to flakiness present on trunk:
BROKEN TRUNK - The following job failed but was present on the merge base:👉 Rebase onto the `viable/strict` branch to avoid these failures
UNSTABLE - The following job is marked as unstable, possibly due to flakiness on trunk:
This comment was automatically generated by Dr. CI and updates every 15 minutes. |
4a3cf5f to
69413d7
Compare
5e60dd8 to
335b8c5
Compare
eellison
left a comment
There was a problem hiding this comment.
very good start ! left some comments
| for op in reversed(epilogue.data.origins): # `origins` contain the ops in reverse | ||
| if op.name == "relu": | ||
| store_value_expr = f"triton_helpers.maximum(0, {store_value_expr})" | ||
| elif op.name == "sigmoid": | ||
| store_value_expr = f"tl.sigmoid({store_value_expr})" | ||
| else: | ||
| raise AssertionError("unsupported epilogue op: ", op.name) |
There was a problem hiding this comment.
Ultimately, we'll need to actually codegen the unary fn.
| "tt.atomic_cas": [0], | ||
| "tt.atomic_rmw": [0], | ||
| "tt.experimental_descriptor_store": [0], | ||
| "tt.experimental_tensormap_create": [0], |
There was a problem hiding this comment.
tt.experimental_tensormap_create is not a write op.. it just creates a descriptor which will then be used for either loading or storing..
This is unfortunately a part of the code base that has no owner currently. maybe it will be you :)
There was a problem hiding this comment.
I see. Given that this is pre-existing, OK if we look into this in a follow-up?
| MUTATION_OPS = { | ||
| WRITE_OPS = { | ||
| "tt.store": [0], | ||
| "tt.atomic_cas": [0], |
There was a problem hiding this comment.
we're missing a bunch of atomic ops.... we should also record which writes are atomic and update the write Dep mode to reflect that. https://github.com/pytorch/pytorch/blob/main/torch/_inductor/dependencies.py#L83
we should also update the typing hint to be be Optional[Literal[...]] so it's more clear what the modes can be
| @functools.cache | ||
| def identify_triton_stores(source_code: str): | ||
| """ | ||
| Parse Python source code of triton kernel and find all tl.store calls. |
There was a problem hiding this comment.
The existing parser uses ttgir, and this uses the ast. Are there any concerns there ?
There was a problem hiding this comment.
The two paths are for different purposes. The TTIR-based analysis are for extracting read/write info out of the code, whereas the AST parsing is so that we can modify certain parts of the code to achieve fusion. I do think these scenarios call for different approaches and its the best to use IR for program analysis and use AST for source code manipulation
| if next(iter(self.mutable_args[0].origins)).name != "empty": | ||
| return False |
There was a problem hiding this comment.
similarly, we should not rely on the origins. you can check that this is a Nop input:
pytorch/torch/_inductor/lowering.py
Lines 3675 to 3676 in a182b08
335b8c5 to
9a51e40
Compare
| We do this by pruning the `out` tensor allocation and directly writing the relu-output. | ||
| """ | ||
| # only do epilogue fusion if the kernel has a single output tensor | ||
| if len(self.args_read_writes.writes) != 1: |
There was a problem hiding this comment.
let's also check that the input & output size/ stride/dtype are the same before we have logic to fix.
There was a problem hiding this comment.
Good call. I updated the PR to include that check in Scheduler::can_fuse.
| if isinstance(node, ir.UserDefinedTritonKernel) and node.can_fuse_epilogues(): | ||
| numel = math.prod(node.mutable_args[0].shape) | ||
| rnumel = 1 | ||
| device = node.get_device_or_error() | ||
| # pyrefly: ignore [bad-assignment] | ||
| self.group = (device, (numel, rnumel)) |
There was a problem hiding this comment.
so, this will only support fusion for one of the outputs. i think supporting only a single epilogue is fine initially. potentially we could see which output has a potential epilogue fusion while we have this limitation.
| def fusable_pointwise_ops(cls): | ||
| return OrderedSet(["relu", "sigmoid"]) |
There was a problem hiding this comment.
I guess this is because of our current codegen limitations ?
There was a problem hiding this comment.
Correct. I've updated the PR to properly call node.codegen() so this is now removed
df2508a to
1ec9830
Compare
Reverting PR 173662 failedReason: Command Details for Dev Infra teamRaised by workflow job |
|
I'm going to disable the test until we fix it |
|
If it's only cpp_wrapper failures, I can take care of that. I recently enabled more cpp-wrapper tests on CI, so this was a landing race. |
Summary: #173662 added more tests to test/inductor/test_triton_kernels.py, and #175416 enable cpp-wrapper test on test/inductor/test_triton_kernels.py. So there was a land race and #173662 didn't have the failing CI signal at the landing time. Forward fix by updating the code checking target for cpp-wrapper. [ghstack-poisoned]
Summary: #173662 added more tests to test/inductor/test_triton_kernels.py, and #175416 enable cpp-wrapper test on test/inductor/test_triton_kernels.py. So there was a land race and #173662 didn't have the failing CI signal at the landing time. Forward fix by updating the code checking target for cpp-wrapper. ghstack-source-id: 5244ba6 Pull Request resolved: #176745
|
this PR increase instruction count on mm_loop_inductor_gpu by 10% |
Summary: 1. #173662 added more tests to test/inductor/test_triton_kernels.py, and #175416 enable cpp-wrapper test on test/inductor/test_triton_kernels.py. So there was a land race and #173662 didn't have the failing CI signal at the landing time. Forward fix by updating the code checking target for cpp-wrapper. 2. #176353 also had land race. Skip now and the fix is coming later. [ghstack-poisoned]
Summary: 1. #173662 added more tests to test/inductor/test_triton_kernels.py, and #175416 enable cpp-wrapper test on test/inductor/test_triton_kernels.py. So there was a land race and #173662 didn't have the failing CI signal at the landing time. Forward fix by updating the code checking target for cpp-wrapper. 2. #176353 also had land race. Skip now and the fix is coming later. ghstack-source-id: c856a94 Pull Request resolved: #176745
Summary: 1. #173662 added more tests to test/inductor/test_triton_kernels.py, and #175416 enable cpp-wrapper test on test/inductor/test_triton_kernels.py. So there was a land race and #173662 didn't have the failing CI signal at the landing time. Forward fix by updating the code checking target for cpp-wrapper. 2. #176353 also had land race. Skip now and the fix is coming later. Pull Request resolved: #176745 Approved by: https://github.com/AmesingFlank, https://github.com/zou3519

This PR enables fusing user-defined triton kernels with unary epilogues such as
relu(). This roughly involves the following changesUserDefinedTritonKernel::can_fuse_epiloguesfor detailed rationale)FusedExternTritonKernelSchedulerNodeand updated fusion logic to allow extern kernel nodes (triton kernels) to epilogue-fuse with pointwise scheduler nodes.tl.store(). We generate an expr for the value after the epilogue and replace that into thetl.store@eellison
cc @voznesenskym @penguinwu @EikanWang @jgong5 @Guobing-Chen @XiaobingSuper @zhuhaozhe @blzheng @wenzhe-nrv @jiayisunx @ipiszy @kadeng @muchulee8 @amjames @chauhang @aakhundov @coconutruben @jataylo @mlazos