Skip to content

Nvfuser guard patch#75016

Closed
jjsjann123 wants to merge 5 commits intopytorch:masterfrom
jjsjann123:nvfuser_guard_patch
Closed

Nvfuser guard patch#75016
jjsjann123 wants to merge 5 commits intopytorch:masterfrom
jjsjann123:nvfuser_guard_patch

Conversation

@jjsjann123
Copy link
Copy Markdown
Collaborator

@jjsjann123 jjsjann123 commented Mar 31, 2022

Fixes issue where CudaFusionGuard would return false on backward graph because requires_grad flag doesn't match.

This is due to the fact that autodiff uses GradMode switch to turn on/off requires_grad, which is not taken into consideration by nvfuser guard. We verified the implementation under TensorType::matchTensor.

  • Add python test to verify no fallback is observed

@facebook-github-bot
Copy link
Copy Markdown
Contributor

facebook-github-bot commented Mar 31, 2022

🔗 Helpful links

💊 CI failures summary and remediations

As of commit 1b0b917 (more details on the Dr. CI page):


💚 💚 Looks good so far! There are no failures yet. 💚 💚


This comment was automatically generated by Dr. CI (expand for details).

Please report bugs/suggestions to the (internal) Dr. CI Users group.

Click here to manually regenerate this comment.

@facebook-github-bot facebook-github-bot added the oncall: jit Add this issue/PR to JIT oncall triage queue label Mar 31, 2022
@jjsjann123 jjsjann123 requested a review from Krovatkin March 31, 2022 14:31
Copy link
Copy Markdown
Contributor

@eellison eellison left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

(guard_tensor_type->device().value() != tensor.device())) ||
(guard_tensor_type->requiresGrad().has_value() &&
guard_tensor_type->requiresGrad().value() != tensor.requires_grad())) {
guard_tensor_type->requiresGrad().value() !=
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I made this mistake as well :/ aa99df5

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for making me feel better~

@eellison
Copy link
Copy Markdown
Contributor

have some failing tests

@facebook-github-bot
Copy link
Copy Markdown
Contributor

@eellison has imported this pull request. If you are a Meta employee, you can view this diff on Phabricator.

@eellison
Copy link
Copy Markdown
Contributor

eellison commented Apr 1, 2022

@pytorchbot merge this please

@github-actions
Copy link
Copy Markdown
Contributor

github-actions Bot commented Apr 1, 2022

Hey @jjsjann123.
You've committed this PR, but it does not have both a 'release notes: ...' and 'topics: ...' label. Please add one of each to the PR. The 'release notes: ...' label should represent the part of PyTorch that this PR changes (fx, autograd, distributed, etc) and the 'topics: ...' label should represent the kind of PR it is (not user facing, new feature, bug fix, perf improvement, etc). The list of valid labels can be found here for the 'release notes: ...' and here for the 'topics: ...'.
For changes that are 'topic: not user facing' there is no need for a release notes label.

@jjsjann123 jjsjann123 deleted the nvfuser_guard_patch branch April 1, 2022 19:35
davidberard98 added a commit that referenced this pull request Apr 1, 2022
fixes merge issue between #75016 and #73322

[ghstack-poisoned]
davidberard98 added a commit that referenced this pull request Apr 1, 2022
fixes merge issue between #75016 and #73322

[ghstack-poisoned]
davidberard98 added a commit that referenced this pull request Apr 1, 2022
fixes merge issue between #75016 and #73322

ghstack-source-id: a79f74a
Pull Request resolved: #75134
@malfet
Copy link
Copy Markdown
Contributor

malfet commented Apr 1, 2022

@pytorchbot revert this

@malfet
Copy link
Copy Markdown
Contributor

malfet commented Apr 1, 2022

Temporary reverting in order to resolve land race between this PR and 96c8799
cc: @atalman , @eellison

pytorchmergebot added a commit that referenced this pull request Apr 1, 2022
This reverts commit d86181f.

Reverted #75016 on behalf of https://github.com/malfet
@davidberard98
Copy link
Copy Markdown
Contributor

sorry didn't notice this when I landed the ci stuff, @jjsjann123 could you rebase this and also change RUN_CUDA -> RUN_NVFUSER to avoid rocm failures?

@malfet
Copy link
Copy Markdown
Contributor

malfet commented Apr 1, 2022

Oh, so it needed to be reverted anyway ;)

@davidberard98
Copy link
Copy Markdown
Contributor

yeah, either this or the CI changes would have needed to be reverted

@Chillee
Copy link
Copy Markdown
Collaborator

Chillee commented Apr 6, 2022

Will reland in #75303 if @jjsjann123 doesn't mind :)

@jjsjann123
Copy link
Copy Markdown
Collaborator Author

Will reland in #75303 if @jjsjann123 doesn't mind :)

ohohoh, somehow haven't got back to this one. Thanks for taking care of it.

@Chillee
Copy link
Copy Markdown
Collaborator

Chillee commented Apr 6, 2022

np - useful for me :)

pytorchmergebot pushed a commit that referenced this pull request Apr 6, 2022
Reland of #75016 with `USE_CUDA` => `USE_NVFUSER`
Pull Request resolved: #75303
Approved by: https://github.com/jjsjann123, https://github.com/davidberard98
facebook-github-bot pushed a commit that referenced this pull request Apr 7, 2022
Summary:
Reland of #75016 with `USE_CUDA` => `USE_NVFUSER`

Pull Request resolved: #75303
Approved by: https://github.com/jjsjann123, https://github.com/davidberard98

Test Plan: contbuild & OSS CI, see https://hud.pytorch.org/commit/pytorch/pytorch/5994d684840a6b3be37bfa033af92c891ba257a6

Reviewed By: b0noI, davidberard98

Differential Revision: D35420569

Pulled By: Chillee

fbshipit-source-id: e25a26b85f5056ba7b5b73448b03aa9926ce00df
jjsjann123 added a commit to jjsjann123/nvfuser that referenced this pull request Oct 29, 2022
Fixes issue where CudaFusionGuard would return false on backward graph because `requires_grad` flag doesn't match.

This is due to the fact that autodiff uses GradMode switch to turn on/off requires_grad, which is not taken into consideration by nvfuser guard. We verified the implementation under `TensorType::matchTensor`.

- [x] Add python test to verify no fallback is observed
Pull Request resolved: pytorch/pytorch#75016
Approved by: https://github.com/eellison
jjsjann123 pushed a commit to jjsjann123/nvfuser that referenced this pull request Oct 29, 2022
jjsjann123 pushed a commit to jjsjann123/nvfuser that referenced this pull request Oct 29, 2022
jjsjann123 added a commit to jjsjann123/nvfuser that referenced this pull request Nov 10, 2022
Fixes issue where CudaFusionGuard would return false on backward graph because `requires_grad` flag doesn't match.

This is due to the fact that autodiff uses GradMode switch to turn on/off requires_grad, which is not taken into consideration by nvfuser guard. We verified the implementation under `TensorType::matchTensor`.

- [x] Add python test to verify no fallback is observed
Pull Request resolved: pytorch/pytorch#75016
Approved by: https://github.com/eellison
jjsjann123 pushed a commit to jjsjann123/nvfuser that referenced this pull request Nov 10, 2022
jjsjann123 pushed a commit to jjsjann123/nvfuser that referenced this pull request Nov 10, 2022
laurentdupin pushed a commit to laurentdupin/pytorch that referenced this pull request Apr 25, 2026
Fixes issue where CudaFusionGuard would return false on backward graph because `requires_grad` flag doesn't match.

This is due to the fact that autodiff uses GradMode switch to turn on/off requires_grad, which is not taken into consideration by nvfuser guard. We verified the implementation under `TensorType::matchTensor`.

- [x] Add python test to verify no fallback is observed
Pull Request resolved: pytorch#75016
Approved by: https://github.com/eellison
laurentdupin pushed a commit to laurentdupin/pytorch that referenced this pull request Apr 25, 2026
laurentdupin pushed a commit to laurentdupin/pytorch that referenced this pull request Apr 25, 2026
Reland of pytorch#75016 with `USE_CUDA` => `USE_NVFUSER`
Pull Request resolved: pytorch#75303
Approved by: https://github.com/jjsjann123, https://github.com/davidberard98
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

cla signed oncall: jit Add this issue/PR to JIT oncall triage queue open source Reverted

Projects

None yet

Development

Successfully merging this pull request may close these issues.

8 participants