Skip to content

[dynamo] Fix dunder attr access on WrapperUserFunctionVariable (lru_cache, wraps)#176934

Closed
guilhermeleobas wants to merge 7 commits intogh/guilhermeleobas/310/basefrom
gh/guilhermeleobas/310/head
Closed

[dynamo] Fix dunder attr access on WrapperUserFunctionVariable (lru_cache, wraps)#176934
guilhermeleobas wants to merge 7 commits intogh/guilhermeleobas/310/basefrom
gh/guilhermeleobas/310/head

Conversation

@guilhermeleobas
Copy link
Copy Markdown
Collaborator

@guilhermeleobas guilhermeleobas commented Mar 9, 2026

Stack from ghstack (oldest at bottom):

WrapperUserFunctionVariable now inherits from BaseUserFunctionVariable
instead of VariableTracker, gaining the shared var_getattr implementation
that handles __name__, __qualname__, __doc__, __module__, __code__,
__dict__, __annotations__, and __type_params__.

This fixes functools.wraps applied to lru_cache-wrapped functions at trace
time — previously, accessing __name__, __dict__, etc. on the wrapper object
would graph-break.

Co-authored Claude Sonnet 4.6

cc @voznesenskym @penguinwu @EikanWang @jgong5 @Guobing-Chen @XiaobingSuper @zhuhaozhe @blzheng @wenzhe-nrv @jiayisunx @kadeng @chauhang @amjames @Lucaskabela @jataylo

[ghstack-poisoned]
@pytorch-bot
Copy link
Copy Markdown

pytorch-bot bot commented Mar 9, 2026

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/176934

Note: Links to docs will display an error until the docs builds have been completed.

⏳ No Failures, 83 Pending

As of commit 9bc8c84 with merge base 3fab23b (image):
💚 Looks good so far! There are no failures yet. 💚

UNSTABLE - The following jobs are marked as unstable, possibly due to flakiness on trunk:

This comment was automatically generated by Dr. CI and updates every 15 minutes.

guilhermeleobas added a commit that referenced this pull request Mar 9, 2026
…ache, wraps)

`WrapperUserFunctionVariable` now inherits from `BaseUserFunctionVariable`
instead of `VariableTracker`, gaining the shared `var_getattr` implementation
that handles `__name__`, `__qualname__`, `__doc__`, `__module__`, `__code__`,
`__dict__`, `__annotations__`, and `__type_params__`.

This fixes `functools.wraps` applied to `lru_cache`-wrapped functions at trace
time — previously, accessing `__name__`, `__dict__`, etc. on the wrapper object
graph-break.

Co-authored Claude Sonnet 4.6


ghstack-source-id: 9b69323
Pull-Request: #176934
@pytorch-bot
Copy link
Copy Markdown

pytorch-bot bot commented Mar 9, 2026

This PR needs a release notes: label

If your changes are user facing and intended to be a part of release notes, please use a label starting with release notes:.

If not, please add the topic: not user facing label.

To add a label, you can comment to pytorchbot, for example
@pytorchbot label "topic: not user facing"

For more information, see
https://github.com/pytorch/pytorch/wiki/PyTorch-AutoLabel-Bot#why-categorize-for-release-notes-and-how-does-it-work.

guilhermeleobas added a commit that referenced this pull request Mar 9, 2026
…ache, wraps)

`WrapperUserFunctionVariable` now inherits from `BaseUserFunctionVariable`
instead of `VariableTracker`, gaining the shared `var_getattr` implementation
that handles `__name__`, `__qualname__`, `__doc__`, `__module__`, `__code__`,
`__dict__`, `__annotations__`, and `__type_params__`.

This fixes `functools.wraps` applied to `lru_cache`-wrapped functions at trace
time — previously, accessing `__name__`, `__dict__`, etc. on the wrapper object
would graph-break.

Co-authored Claude Sonnet 4.6

ghstack-source-id: 9b69323
Pull-Request: #176934
@guilhermeleobas guilhermeleobas marked this pull request as ready for review March 9, 2026 23:39
@guilhermeleobas guilhermeleobas requested a review from zou3519 March 9, 2026 23:39
@pytorch-bot pytorch-bot bot added the ciflow/torchtitan Run TorchTitan integration tests label Mar 10, 2026
@zou3519 zou3519 requested a review from williamwen42 March 10, 2026 14:05
Copy link
Copy Markdown
Contributor

@Lucaskabela Lucaskabela left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Few minor comments

def var_getattr(self, tx: "InstructionTranslator", name: str):
fn_dict = self.get_dict_vt(tx)

# missing: __globals__, __closure__, __kwdefautls__, __defaults__
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this a TODO we plan on adding later? If so can we make this clear?

Or if these are not supported here inentionally we should call that out

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's intentional.

def var_getattr(self, tx: "InstructionTranslator", name: str) -> VariableTracker:
if name == "__dict__":
return self.get_dict_vt(tx)
if name in ("__dict__",):
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why a single check for in?

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Forgot to undo this change before committing. I'll undo

[ghstack-poisoned]
guilhermeleobas added a commit that referenced this pull request Mar 10, 2026
…ache, wraps)

`WrapperUserFunctionVariable` now inherits from `BaseUserFunctionVariable`
instead of `VariableTracker`, gaining the shared `var_getattr` implementation
that handles `__name__`, `__qualname__`, `__doc__`, `__module__`, `__code__`,
`__dict__`, `__annotations__`, and `__type_params__`.

This fixes `functools.wraps` applied to `lru_cache`-wrapped functions at trace
time — previously, accessing `__name__`, `__dict__`, etc. on the wrapper object
would graph-break.

Co-authored Claude Sonnet 4.6

ghstack-source-id: 88b4ba0
Pull-Request: #176934
Copy link
Copy Markdown
Contributor

@Lucaskabela Lucaskabela left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for fixing - LGTM

@guilhermeleobas
Copy link
Copy Markdown
Collaborator Author

@pytorchbot merge

@pytorch-bot pytorch-bot bot added the ciflow/trunk Trigger trunk jobs on your pull request label Mar 11, 2026
@pytorchmergebot
Copy link
Copy Markdown
Collaborator

Merge started

Your change will be merged once all checks pass (ETA 0-4 Hours).

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

@pytorchmergebot
Copy link
Copy Markdown
Collaborator

Merge failed

Reason: 1 jobs have failed, first few of them are: inductor / inductor-cpu-build / build

Details for Dev Infra team Raised by workflow job

@guilhermeleobas
Copy link
Copy Markdown
Collaborator Author

@pytorchbot merge -i

@pytorchmergebot
Copy link
Copy Markdown
Collaborator

Merge started

Your change will be merged while ignoring the following 3 checks: inductor / inductor-cpu-build / build, inductor / unit-test / inductor-test / test (inductor_cpp_wrapper, 1, 2, linux.g5.4xlarge.nvidia.gpu), inductor / unit-test / inductor-test / test (inductor_cpp_wrapper, 2, 2, linux.g5.4xlarge.nvidia.gpu)

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

@pytorchmergebot
Copy link
Copy Markdown
Collaborator

Merge started

Your change will be merged once all checks pass (ETA 0-4 Hours).

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

@yangw-dev
Copy link
Copy Markdown
Contributor

@pytorchbot revert -m "This is being reverted due to an internal build dependency issue.
The code changes themselves are correct and the tests pass. However, when importing this change internally, it triggers a dependency blocklist violation, please find a meta folks to help fix this internally,see D96281773" -c ghfirst

@pytorch-bot
Copy link
Copy Markdown

pytorch-bot bot commented Mar 16, 2026

❌ 🤖 pytorchbot command failed:

Got EOF while in a quoted string```
Try `@pytorchbot --help` for more info.

@yangw-dev
Copy link
Copy Markdown
Contributor

@pytorchbot revert -m "This is being reverted due to an internal build dependency issue.The code changes themselves are correct and the tests pass. However, when importing this change internally, it triggers a dependency blocklist violation, please find a meta folks to help fix this internally,see D96281773" -c ghfirst

@pytorchmergebot
Copy link
Copy Markdown
Collaborator

@pytorchbot successfully started a revert job. Check the current status here.
Questions? Feedback? Please reach out to the PyTorch DevX Team

@pytorchmergebot
Copy link
Copy Markdown
Collaborator

@guilhermeleobas your PR has been successfully reverted.

pytorchmergebot added a commit that referenced this pull request Mar 16, 2026
…e (lru_cache, wraps) (#176934)"

This reverts commit 609e0ed.

Reverted #176934 on behalf of https://github.com/yangw-dev due to This is being reverted due to an internal build dependency issue.The code changes themselves are correct and the tests pass. However, when importing this change internally, it triggers a dependency blocklist violation, please find a meta folks to help fix this internally,see D96281773 ([comment](#176934 (comment)))
@yangw-dev
Copy link
Copy Markdown
Contributor

yangw-dev commented Mar 16, 2026

i apologize, this should not be reverted. the diff train is a bit messed now today. remerge this, nothing need author to do

@yangw-dev
Copy link
Copy Markdown
Contributor

@pytorchbot merge -i

@pytorchmergebot
Copy link
Copy Markdown
Collaborator

Merge started

Your change will be merged while ignoring the following 1 checks: inductor / inductor-cpu-build / build

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

@yangw-dev
Copy link
Copy Markdown
Contributor

@pytorchbot merge -f 'accidentally revert this pr, due to misinfo from internal diff train'

@pytorchmergebot
Copy link
Copy Markdown
Collaborator

The merge job was canceled or timed out. This most often happen if two merge requests were issued for the same PR, or if merge job was waiting for more than 6 hours for tests to finish. In later case, please do not hesitate to reissue the merge command
For more information see pytorch-bot wiki.

@pytorchmergebot
Copy link
Copy Markdown
Collaborator

Merge started

Your change will be merged immediately since you used the force (-f) flag, bypassing any CI checks (ETA: 1-5 minutes). Please use -f as last resort and instead consider -i/--ignore-current to continue the merge ignoring current failures. This will allow currently pending tests to finish and report signal before the merge.

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

@guilhermeleobas
Copy link
Copy Markdown
Collaborator Author

@yangw-dev Do I need to do anything?

AaronWang04 pushed a commit to AaronWang04/pytorch that referenced this pull request Mar 24, 2026
…ache, wraps) (pytorch#176934)

`WrapperUserFunctionVariable` now inherits from `BaseUserFunctionVariable`
instead of `VariableTracker`, gaining the shared `var_getattr` implementation
that handles `__name__`, `__qualname__`, `__doc__`, `__module__`, `__code__`,
`__dict__`, `__annotations__`, and `__type_params__`.

This fixes `functools.wraps` applied to `lru_cache`-wrapped functions at trace
time — previously, accessing `__name__`, `__dict__`, etc. on the wrapper object
would graph-break.

Co-authored Claude Sonnet 4.6

Pull Request resolved: pytorch#176934
Approved by: https://github.com/Lucaskabela, https://github.com/williamwen42
EmanueleCoradin pushed a commit to EmanueleCoradin/pytorch that referenced this pull request Mar 30, 2026
…ache, wraps) (pytorch#176934)

`WrapperUserFunctionVariable` now inherits from `BaseUserFunctionVariable`
instead of `VariableTracker`, gaining the shared `var_getattr` implementation
that handles `__name__`, `__qualname__`, `__doc__`, `__module__`, `__code__`,
`__dict__`, `__annotations__`, and `__type_params__`.

This fixes `functools.wraps` applied to `lru_cache`-wrapped functions at trace
time — previously, accessing `__name__`, `__dict__`, etc. on the wrapper object
would graph-break.

Co-authored Claude Sonnet 4.6

Pull Request resolved: pytorch#176934
Approved by: https://github.com/Lucaskabela, https://github.com/williamwen42
ghstack dependencies: pytorch#176623
EmanueleCoradin pushed a commit to EmanueleCoradin/pytorch that referenced this pull request Mar 30, 2026
…ache, wraps) (pytorch#176934)

`WrapperUserFunctionVariable` now inherits from `BaseUserFunctionVariable`
instead of `VariableTracker`, gaining the shared `var_getattr` implementation
that handles `__name__`, `__qualname__`, `__doc__`, `__module__`, `__code__`,
`__dict__`, `__annotations__`, and `__type_params__`.

This fixes `functools.wraps` applied to `lru_cache`-wrapped functions at trace
time — previously, accessing `__name__`, `__dict__`, etc. on the wrapper object
would graph-break.

Co-authored Claude Sonnet 4.6

Pull Request resolved: pytorch#176934
Approved by: https://github.com/Lucaskabela, https://github.com/williamwen42
EmanueleCoradin pushed a commit to EmanueleCoradin/pytorch that referenced this pull request Mar 30, 2026
…e (lru_cache, wraps) (pytorch#176934)"

This reverts commit 609e0ed.

Reverted pytorch#176934 on behalf of https://github.com/yangw-dev due to This is being reverted due to an internal build dependency issue.The code changes themselves are correct and the tests pass. However, when importing this change internally, it triggers a dependency blocklist violation, please find a meta folks to help fix this internally,see D96281773 ([comment](pytorch#176934 (comment)))
EmanueleCoradin pushed a commit to EmanueleCoradin/pytorch that referenced this pull request Mar 30, 2026
…ache, wraps) (pytorch#176934)

`WrapperUserFunctionVariable` now inherits from `BaseUserFunctionVariable`
instead of `VariableTracker`, gaining the shared `var_getattr` implementation
that handles `__name__`, `__qualname__`, `__doc__`, `__module__`, `__code__`,
`__dict__`, `__annotations__`, and `__type_params__`.

This fixes `functools.wraps` applied to `lru_cache`-wrapped functions at trace
time — previously, accessing `__name__`, `__dict__`, etc. on the wrapper object
would graph-break.

Co-authored Claude Sonnet 4.6

Pull Request resolved: pytorch#176934
Approved by: https://github.com/Lucaskabela, https://github.com/williamwen42
AaronWang04 pushed a commit to AaronWang04/pytorch that referenced this pull request Mar 31, 2026
…e (lru_cache, wraps) (pytorch#176934)"

This reverts commit 609e0ed.

Reverted pytorch#176934 on behalf of https://github.com/yangw-dev due to This is being reverted due to an internal build dependency issue.The code changes themselves are correct and the tests pass. However, when importing this change internally, it triggers a dependency blocklist violation, please find a meta folks to help fix this internally,see D96281773 ([comment](pytorch#176934 (comment)))
AaronWang04 pushed a commit to AaronWang04/pytorch that referenced this pull request Mar 31, 2026
…ache, wraps) (pytorch#176934)

`WrapperUserFunctionVariable` now inherits from `BaseUserFunctionVariable`
instead of `VariableTracker`, gaining the shared `var_getattr` implementation
that handles `__name__`, `__qualname__`, `__doc__`, `__module__`, `__code__`,
`__dict__`, `__annotations__`, and `__type_params__`.

This fixes `functools.wraps` applied to `lru_cache`-wrapped functions at trace
time — previously, accessing `__name__`, `__dict__`, etc. on the wrapper object
would graph-break.

Co-authored Claude Sonnet 4.6

Pull Request resolved: pytorch#176934
Approved by: https://github.com/Lucaskabela, https://github.com/williamwen42
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

Projects

None yet

Development

Successfully merging this pull request may close these issues.

7 participants