Skip to content

[inductor] Support autotune restore_value for user-defined Triton kernels#139851

Closed
aakhundov wants to merge 5 commits intogh/aakhundov/16/basefrom
gh/aakhundov/16/head
Closed

[inductor] Support autotune restore_value for user-defined Triton kernels#139851
aakhundov wants to merge 5 commits intogh/aakhundov/16/basefrom
gh/aakhundov/16/head

Conversation

@aakhundov
Copy link
Contributor

@aakhundov aakhundov commented Nov 6, 2024

Stack from ghstack (oldest at bottom):

This PR adds support for the restore_value argument of the
@triton.autotune for the user-defined Triton kernels in PT2.

The kernel.restore_idx are extracted in the
ir.UserDefinedTritonKernel and the corresponding arg names are
placed into the triton_meta["restore_value"]. From there, those
are added to the existing mutated_arg_names in the caching autotuner
infra which already exists and leads to the listed argss being cloned.
This achieves the equivalent effect to the native restore_value.

cc @voznesenskym @penguinwu @EikanWang @jgong5 @Guobing-Chen @XiaobingSuper @zhuhaozhe @blzheng @wenzhe-nrv @jiayisunx @ipiszy @yf225 @chenyang78 @kadeng @muchulee8 @ColinPeppler @amjames @desertfire @chauhang

[ghstack-poisoned]
@aakhundov aakhundov requested a review from zou3519 as a code owner November 6, 2024 04:53
@pytorch-bot
Copy link

pytorch-bot bot commented Nov 6, 2024

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/139851

Note: Links to docs will display an error until the docs builds have been completed.

❌ 1 Cancelled Job

As of commit 8191aa2 with merge base 157c18a (image):

CANCELLED JOB - The following job was cancelled. Please retry:

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@aakhundov aakhundov added the topic: not user facing topic category label Nov 6, 2024
@aakhundov aakhundov changed the title [inductor] Add restore_value to user-defined Triton kernel support [inductor] Support autotune restore_value for user-defined Triton kernels in PT2 Nov 6, 2024
@aakhundov aakhundov changed the title [inductor] Support autotune restore_value for user-defined Triton kernels in PT2 [inductor] Support autotune restore_value for user-defined Triton kernels Nov 6, 2024
[ghstack-poisoned]
aakhundov added a commit that referenced this pull request Nov 6, 2024
This PR adds support for the `restore_value` argument of the
`triton.autotune` for the user-defined Triton kernels in PT2.

The `kernel.restore_idx` are extracted in the
`ir.UserDefinedTritonKernel` and the corresponding arg names are
placed into the `triton_meta["restore_value"]`. From there, those
are added to the existing `mutated_arg_names` in the caching autotuner
infra which already exists and leads to the listed argss being cloned.
This achieves the equivalent effect to the native `restore_value`.

ghstack-source-id: a027031
Pull Request resolved: #139851
[ghstack-poisoned]
@aakhundov aakhundov requested a review from eellison November 7, 2024 05:31
[ghstack-poisoned]
aakhundov added a commit that referenced this pull request Nov 7, 2024
This PR adds support for the `restore_value` argument of the
`triton.autotune` for the user-defined Triton kernels in PT2.

The `kernel.restore_idx` are extracted in the
`ir.UserDefinedTritonKernel` and the corresponding arg names are
placed into the `triton_meta["restore_value"]`. From there, those
are added to the existing `mutated_arg_names` in the caching autotuner
infra which already exists and leads to the listed argss being cloned.
This achieves the equivalent effect to the native `restore_value`.

ghstack-source-id: 03ea337
Pull Request resolved: #139851
Comment on lines +713 to +717
# move as many positional arguments from dicts to args as we
# can to circumvent the bug with the kwargs and pre_/post_hook:
# https://github.com/triton-lang/triton/issues/5082
# TODO: remove this when the Triton issue above is fixed
args = []
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this feels really bad, and confusing. Maybe only do this when pre/post hook is used?

Copy link
Contributor Author

@aakhundov aakhundov Nov 7, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

At this stage of the Triton code, we can't tell when pre/post hook is used. Just submitted a PR so that we'll be able to tell: triton-lang/triton#5092. Can update then in a follow-up. Although, there is another PR fixing the mentioned issue that should make this whole thing unnecessary: triton-lang/triton#5083.

Copy link
Contributor Author

@aakhundov aakhundov Nov 7, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you elaborate why it's bad, btw? Is this incorrect? I'm just moving the args around, but not skipping or duplicating any passed to the kernel? It may not always help the issue, but in most cases it should (i.e., when the args used in the hook are not coming from @triton.heuristic or smth).

Comment on lines +718 to +726
kwargs = kwargs.copy()
constant_args = constant_args.copy()
for name in kernel.arg_names:
if name in kwargs:
args.append(kwargs.pop(name))
elif name in constant_args:
args.append(constant_args.pop(name))
else:
break
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

you can write this for loop better, you dont need copy/pop, just read indices?

Copy link
Contributor Author

@aakhundov aakhundov Nov 7, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I need to pop, because I'm passing the kwargs and constant_args together with the collected args, no? (That's also why I'm copying kwargs and constant_args, as I mutate the dicts). Otherwise, there will be duplicate argument passing in the line following this?

[ghstack-poisoned]
aakhundov added a commit that referenced this pull request Nov 7, 2024
This PR adds support for the `restore_value` argument of the
`triton.autotune` for the user-defined Triton kernels in PT2.

The `kernel.restore_idx` are extracted in the
`ir.UserDefinedTritonKernel` and the corresponding arg names are
placed into the `triton_meta["restore_value"]`. From there, those
are added to the existing `mutated_arg_names` in the caching autotuner
infra which already exists and leads to the listed argss being cloned.
This achieves the equivalent effect to the native `restore_value`.

ghstack-source-id: 9f806fb
Pull Request resolved: #139851
@aakhundov
Copy link
Contributor Author

@pytorchbot merge

@pytorch-bot pytorch-bot bot added the ciflow/trunk Trigger trunk jobs on your pull request label Nov 7, 2024
@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged once all checks pass (ETA 0-4 Hours).

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

@pytorchmergebot
Copy link
Collaborator

The merge job was canceled or timed out. This most often happen if two merge requests were issued for the same PR, or if merge job was waiting for more than 6 hours for tests to finish. In later case, please do not hesitate to reissue the merge command
For more information see pytorch-bot wiki.

@bertmaher
Copy link
Contributor

@pytorchbot merge

@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged once all checks pass (ETA 0-4 Hours).

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

@pytorchmergebot
Copy link
Collaborator

Merge failed

Reason: 1 jobs have failed, first few of them are: trunk / macos-py3-arm64 / build

Details for Dev Infra team Raised by workflow job

@bertmaher
Copy link
Contributor

@pytorchbot merge -f 'macos-py3-arm64 seems to be throwing spurious cancellations and cant possibly be a failure on this diff'

@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged immediately since you used the force (-f) flag, bypassing any CI checks (ETA: 1-5 minutes). Please use -f as last resort and instead consider -i/--ignore-current to continue the merge ignoring current failures. This will allow currently pending tests to finish and report signal before the merge.

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

pobin6 pushed a commit to pobin6/pytorch that referenced this pull request Dec 5, 2024
…nels (pytorch#139851)

This PR adds support for the `restore_value` argument of the
`@triton.autotune` for the user-defined Triton kernels in PT2.

The `kernel.restore_idx` are extracted in the
`ir.UserDefinedTritonKernel` and the corresponding arg names are
placed into the `triton_meta["restore_value"]`. From there, those
are added to the existing `mutated_arg_names` in the caching autotuner
infra which already exists and leads to the listed argss being cloned.
This achieves the equivalent effect to the native `restore_value`.

Pull Request resolved: pytorch#139851
Approved by: https://github.com/oulgen
@github-actions github-actions bot deleted the gh/aakhundov/16/head branch December 9, 2024 02:14
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants