[inductor] Support autotune restore_value for user-defined Triton kernels#139851
[inductor] Support autotune restore_value for user-defined Triton kernels#139851aakhundov wants to merge 5 commits intogh/aakhundov/16/basefrom
Conversation
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/139851
Note: Links to docs will display an error until the docs builds have been completed. ❌ 1 Cancelled JobAs of commit 8191aa2 with merge base 157c18a ( CANCELLED JOB - The following job was cancelled. Please retry:
This comment was automatically generated by Dr. CI and updates every 15 minutes. |
This PR adds support for the `restore_value` argument of the `triton.autotune` for the user-defined Triton kernels in PT2. The `kernel.restore_idx` are extracted in the `ir.UserDefinedTritonKernel` and the corresponding arg names are placed into the `triton_meta["restore_value"]`. From there, those are added to the existing `mutated_arg_names` in the caching autotuner infra which already exists and leads to the listed argss being cloned. This achieves the equivalent effect to the native `restore_value`. ghstack-source-id: a027031 Pull Request resolved: #139851
This PR adds support for the `restore_value` argument of the `triton.autotune` for the user-defined Triton kernels in PT2. The `kernel.restore_idx` are extracted in the `ir.UserDefinedTritonKernel` and the corresponding arg names are placed into the `triton_meta["restore_value"]`. From there, those are added to the existing `mutated_arg_names` in the caching autotuner infra which already exists and leads to the listed argss being cloned. This achieves the equivalent effect to the native `restore_value`. ghstack-source-id: 03ea337 Pull Request resolved: #139851
| # move as many positional arguments from dicts to args as we | ||
| # can to circumvent the bug with the kwargs and pre_/post_hook: | ||
| # https://github.com/triton-lang/triton/issues/5082 | ||
| # TODO: remove this when the Triton issue above is fixed | ||
| args = [] |
There was a problem hiding this comment.
this feels really bad, and confusing. Maybe only do this when pre/post hook is used?
There was a problem hiding this comment.
At this stage of the Triton code, we can't tell when pre/post hook is used. Just submitted a PR so that we'll be able to tell: triton-lang/triton#5092. Can update then in a follow-up. Although, there is another PR fixing the mentioned issue that should make this whole thing unnecessary: triton-lang/triton#5083.
There was a problem hiding this comment.
Could you elaborate why it's bad, btw? Is this incorrect? I'm just moving the args around, but not skipping or duplicating any passed to the kernel? It may not always help the issue, but in most cases it should (i.e., when the args used in the hook are not coming from @triton.heuristic or smth).
| kwargs = kwargs.copy() | ||
| constant_args = constant_args.copy() | ||
| for name in kernel.arg_names: | ||
| if name in kwargs: | ||
| args.append(kwargs.pop(name)) | ||
| elif name in constant_args: | ||
| args.append(constant_args.pop(name)) | ||
| else: | ||
| break |
There was a problem hiding this comment.
you can write this for loop better, you dont need copy/pop, just read indices?
There was a problem hiding this comment.
I need to pop, because I'm passing the kwargs and constant_args together with the collected args, no? (That's also why I'm copying kwargs and constant_args, as I mutate the dicts). Otherwise, there will be duplicate argument passing in the line following this?
This PR adds support for the `restore_value` argument of the `triton.autotune` for the user-defined Triton kernels in PT2. The `kernel.restore_idx` are extracted in the `ir.UserDefinedTritonKernel` and the corresponding arg names are placed into the `triton_meta["restore_value"]`. From there, those are added to the existing `mutated_arg_names` in the caching autotuner infra which already exists and leads to the listed argss being cloned. This achieves the equivalent effect to the native `restore_value`. ghstack-source-id: 9f806fb Pull Request resolved: #139851
|
@pytorchbot merge |
Merge startedYour change will be merged once all checks pass (ETA 0-4 Hours). Learn more about merging in the wiki. Questions? Feedback? Please reach out to the PyTorch DevX Team |
|
The merge job was canceled or timed out. This most often happen if two merge requests were issued for the same PR, or if merge job was waiting for more than 6 hours for tests to finish. In later case, please do not hesitate to reissue the merge command |
|
@pytorchbot merge |
Merge startedYour change will be merged once all checks pass (ETA 0-4 Hours). Learn more about merging in the wiki. Questions? Feedback? Please reach out to the PyTorch DevX Team |
Merge failedReason: 1 jobs have failed, first few of them are: trunk / macos-py3-arm64 / build Details for Dev Infra teamRaised by workflow job |
|
@pytorchbot merge -f 'macos-py3-arm64 seems to be throwing spurious cancellations and cant possibly be a failure on this diff' |
Merge startedYour change will be merged immediately since you used the force (-f) flag, bypassing any CI checks (ETA: 1-5 minutes). Please use Learn more about merging in the wiki. Questions? Feedback? Please reach out to the PyTorch DevX Team |
…nels (pytorch#139851) This PR adds support for the `restore_value` argument of the `@triton.autotune` for the user-defined Triton kernels in PT2. The `kernel.restore_idx` are extracted in the `ir.UserDefinedTritonKernel` and the corresponding arg names are placed into the `triton_meta["restore_value"]`. From there, those are added to the existing `mutated_arg_names` in the caching autotuner infra which already exists and leads to the listed argss being cloned. This achieves the equivalent effect to the native `restore_value`. Pull Request resolved: pytorch#139851 Approved by: https://github.com/oulgen
Stack from ghstack (oldest at bottom):
This PR adds support for the
restore_valueargument of the@triton.autotunefor the user-defined Triton kernels in PT2.The
kernel.restore_idxare extracted in their.UserDefinedTritonKerneland the corresponding arg names areplaced into the
triton_meta["restore_value"]. From there, thoseare added to the existing
mutated_arg_namesin the caching autotunerinfra which already exists and leads to the listed argss being cloned.
This achieves the equivalent effect to the native
restore_value.cc @voznesenskym @penguinwu @EikanWang @jgong5 @Guobing-Chen @XiaobingSuper @zhuhaozhe @blzheng @wenzhe-nrv @jiayisunx @ipiszy @yf225 @chenyang78 @kadeng @muchulee8 @ColinPeppler @amjames @desertfire @chauhang