Skip to content

[dynamo] Fix tracing partially initialized tensor subclass during dispatch#175397

Closed
azahed98 wants to merge 3 commits intogh/azahed98/5/basefrom
gh/azahed98/5/head
Closed

[dynamo] Fix tracing partially initialized tensor subclass during dispatch#175397
azahed98 wants to merge 3 commits intogh/azahed98/5/basefrom
gh/azahed98/5/head

Conversation

@azahed98
Copy link
Copy Markdown
Contributor

@azahed98 azahed98 commented Feb 20, 2026

Stack from ghstack (oldest at bottom):

Fixes an edge-case found with Diffusers+TorchAo+Dynamo where a Tensor Subclass can have it's __init__ traced, which calls __tensor_flatten__ prior to init. In this case. attributes used in __tensor_flatten__ result in an error since they are not yet initialized.

This PR instead adds an escape hatch to skip faking at the start of VariableBuilder.wrap_tensor in this case.

Test Plan: The original error can be reproduced with this script. I was unable to reproduce the error end-to-end without Diffusers or TorchAO dependency, so I instead added a unit test that checks that the escape hatch is taken with mocks.

cc @voznesenskym @penguinwu @EikanWang @jgong5 @Guobing-Chen @XiaobingSuper @zhuhaozhe @blzheng @wenzhe-nrv @jiayisunx @kadeng @chauhang @amjames @Lucaskabela @jataylo

@pytorch-bot
Copy link
Copy Markdown

pytorch-bot Bot commented Feb 20, 2026

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/175397

Note: Links to docs will display an error until the docs builds have been completed.

✅ You can merge normally! (3 Unrelated Failures)

As of commit 545bf4c with merge base f72a552 (image):

FLAKY - The following job failed but was likely due to flakiness present on trunk:

BROKEN TRUNK - The following job failed but was present on the merge base:

👉 Rebase onto the `viable/strict` branch to avoid these failures

UNSTABLE - The following job is marked as unstable, possibly due to flakiness on trunk:

This comment was automatically generated by Dr. CI and updates every 15 minutes.

azahed98 added a commit that referenced this pull request Feb 20, 2026
@pytorch-bot
Copy link
Copy Markdown

pytorch-bot Bot commented Feb 20, 2026

This PR needs a release notes: label

If your changes are user facing and intended to be a part of release notes, please use a label starting with release notes:.

If not, please add the topic: not user facing label.

To add a label, you can comment to pytorchbot, for example
@pytorchbot label "topic: not user facing"

For more information, see
https://github.com/pytorch/pytorch/wiki/PyTorch-AutoLabel-Bot#why-categorize-for-release-notes-and-how-does-it-work.

Copy link
Copy Markdown
Contributor

@anijain2305 anijain2305 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you compare it with other tensor subclasses like DTensor etc? Most of the hesitation stems from not understanding how tensor subclass init is handled in Dynamo. My understanding was we just graph break, but I might be wrong.

@azahed98
Copy link
Copy Markdown
Contributor Author

azahed98 commented Feb 25, 2026

Can you compare it with other tensor subclasses like DTensor etc? Most of the hesitation stems from not understanding how tensor subclass init is handled in Dynamo. My understanding was we just graph break, but I might be wrong.

@anijain2305 Did some digging and what I found is DTensor graph breaks on __init__ because it has an explicit @torch._disable_dynamo. TorchAO doesn't have this so we end up tracing through __init__.

I also realized that this issue is specific to compiling __init__ as a root frame, since this wouldn't be an issue if there's already a VariableTracker for self. However, now that I'm looking closer this escape hatch might result in skipping guards or capturing tensor ops in __init__?

In that case I'm thinking our best options are

  1. Default graph break on tensor subclass __init__ (and maybe have a config flag to change that behavior)
  2. Just update TorchAO to have the disable around it's subclasses' __init__ methods, but then we are still open to this issue with other user subclasses (albeit an edge case).

Edit 1:
I tried some small tests, and ops on the data tensor are captured but ops on self result in an Unsupported error from Dynamo.

@azahed98
Copy link
Copy Markdown
Contributor Author

@anijain2305 I tried adding a skip for root frame subclass __init__ in CatchErrorsWrapper

--- a/torch/_dynamo/convert_frame.py
+++ b/torch/_dynamo/convert_frame.py
@@ -2287,6 +2288,20 @@ class CatchErrorsWrapper:
         ):
             # nametuple constructor/_make
             return ConvertFrameReturn()
+
+        if (
+            frame.f_code.co_name == "__init__"
+            and frame.f_code.co_argcount > 0
+            and frame.f_code.co_varnames
+            and is_traceable_wrapper_subclass(
+                frame.f_locals.get(frame.f_code.co_varnames[0])
+            )
+        ):
+            # Skip tracing __init__ of traceable wrapper subclasses: self is
+            # partially initialized at this point (attributes set by __init__
+            # don't exist yet), so faking it would call __tensor_flatten__ and
+            # crash. Run eagerly instead, matching @torch._disable_dynamo behavior.
+            return ConvertFrameReturn()
         if torch._dynamo.utils.get_optimize_ddp_mode() == "ddp_optimizer":
             ddp_module = DistributedDataParallel._get_active_ddp_module()
+    is_traceable_wrapper_subclass,

This resolves the issue by skipping the frame instead, which should be fine since we realistically will only encounter this issue if __init__ is immediately after a graph break region. Shall I change this PR to this diff?

@anijain2305
Copy link
Copy Markdown
Contributor

@anijain2305 I tried adding a skip for root frame subclass __init__ in CatchErrorsWrapper

--- a/torch/_dynamo/convert_frame.py
+++ b/torch/_dynamo/convert_frame.py
@@ -2287,6 +2288,20 @@ class CatchErrorsWrapper:
         ):
             # nametuple constructor/_make
             return ConvertFrameReturn()
+
+        if (
+            frame.f_code.co_name == "__init__"
+            and frame.f_code.co_argcount > 0
+            and frame.f_code.co_varnames
+            and is_traceable_wrapper_subclass(
+                frame.f_locals.get(frame.f_code.co_varnames[0])
+            )
+        ):
+            # Skip tracing __init__ of traceable wrapper subclasses: self is
+            # partially initialized at this point (attributes set by __init__
+            # don't exist yet), so faking it would call __tensor_flatten__ and
+            # crash. Run eagerly instead, matching @torch._disable_dynamo behavior.
+            return ConvertFrameReturn()
         if torch._dynamo.utils.get_optimize_ddp_mode() == "ddp_optimizer":
             ddp_module = DistributedDataParallel._get_active_ddp_module()
+    is_traceable_wrapper_subclass,

This resolves the issue by skipping the frame instead, which should be fine since we realistically will only encounter this issue if __init__ is immediately after a graph break region. Shall I change this PR to this diff?

Yes, this makes sense. At the time, we might not have had this is_traceable_wrapper_subclass util. Now, it makes sense.

@sayakpaul
Copy link
Copy Markdown

@azahed98 any ETA on landing this? 👀

… during dispatch"


Fixes an edge-case found with Diffusers+TorchAo+Dynamo where a Tensor Subclass can have it's `__init__` traced, which calls `__tensor_flatten__` prior to init. In this case. attributes used in `__tensor_flatten__` result in an error since they are not yet initialized.

This PR instead adds an escape hatch to skip faking at the start of `VariableBuilder.wrap_tensor` in this case.

**Test Plan:** The original error can be reproduced with [this script](https://gist.github.com/sayakpaul/929678132809874c5dbf9c5215460d33). I was unable to reproduce the error end-to-end without Diffusers or TorchAO dependency, so I instead added a unit test that checks that the escape hatch is taken with mocks.

cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx kadeng chauhang amjames Lucaskabela jataylo

[ghstack-poisoned]
@azahed98
Copy link
Copy Markdown
Contributor Author

@sayakpaul I'll start the merge of the stack today -- just need to fix some CI fails from changes to the linter.

… during dispatch"


Fixes an edge-case found with Diffusers+TorchAo+Dynamo where a Tensor Subclass can have it's `__init__` traced, which calls `__tensor_flatten__` prior to init. In this case. attributes used in `__tensor_flatten__` result in an error since they are not yet initialized.

This PR instead adds an escape hatch to skip faking at the start of `VariableBuilder.wrap_tensor` in this case.

**Test Plan:** The original error can be reproduced with [this script](https://gist.github.com/sayakpaul/929678132809874c5dbf9c5215460d33). I was unable to reproduce the error end-to-end without Diffusers or TorchAO dependency, so I instead added a unit test that checks that the escape hatch is taken with mocks.

cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx kadeng chauhang amjames Lucaskabela jataylo

[ghstack-poisoned]
@azahed98
Copy link
Copy Markdown
Contributor Author

Ok looks like I resolved the ghstack dupe issue. Re-requesting review to unblock merge.

@azahed98 azahed98 requested a review from anijain2305 March 12, 2026 22:56
@azahed98 azahed98 added the topic: not user facing topic category label Mar 13, 2026
@pytorchmergebot
Copy link
Copy Markdown
Collaborator

Starting merge as part of PR stack under #175660

pytorchmergebot pushed a commit that referenced this pull request Mar 13, 2026
Fixes an error with the TENSOR_SUBCLASS_METADATA_MATCH guard when the tensor subclass has a SymInt in its metadata. In this scenario, `deepcopy` of the metadata propagates through the SymInt down to the ShapeEnv, FakeMode, and then FakeTensors, causing an error due to no data pointer.

This PR replaces SymInts in the metadata with an `_AnyCompare` object that always returns `True` for equals checks. This assumes dynamic shapes checks will handle correctness.

**Test Plan:** The original error can be reproduced with [this script](https://gist.github.com/sayakpaul/929678132809874c5dbf9c5215460d33) (if ran on the previous commit from this stack). This PR adds a regression test with a manually injected SymInt into the metadata, then compiles with `full_graph=True` and checks for no recompiles.

Pull Request resolved: #175596
Approved by: https://github.com/anijain2305
ghstack dependencies: #175397
pytorchmergebot pushed a commit that referenced this pull request Mar 13, 2026
…sure refcycle (#175660)

Fixes a potential reference cycle that can block `swap_tensors` during or after compile. This reference cycle comes from a closure of a `MetaConverter` object within the `_empty_create_subclass` defined in `MetaConverter.empty_create_subclass`.

This PR moves `_empty_create_subclass` to be a method of `MetaConverter` instead, adding additional arguments and moving imports as needed.

**Test Plan:** The original error can be reproduced with [this script](https://gist.github.com/sayakpaul/929678132809874c5dbf9c5215460d33) (if ran on the previous commit from this stack). This PR adds a unit test that checks that weakrefs created by `MetaConverter` are cleaned up when it is manually deleted even if garbage collection is disabled.

Pull Request resolved: #175660
Approved by: https://github.com/anijain2305
ghstack dependencies: #175397, #175596
@sayakpaul
Copy link
Copy Markdown

Thanks for landing it. I will try this out next week and get back!

@sayakpaul
Copy link
Copy Markdown

@azahed98 I bring bad news I am afraid.

I tried it out but seems like it's contingent on pytorch/ao#4088.

@sayakpaul
Copy link
Copy Markdown

sayakpaul commented Mar 23, 2026

huggingface/diffusers#13276 -- opened a PR. Hopefully, this gets resolved. @lordaarush do you want to test it as well?

EmanueleCoradin pushed a commit to EmanueleCoradin/pytorch that referenced this pull request Mar 30, 2026
…patch (pytorch#175397)

Fixes an edge-case found with Diffusers+TorchAo+Dynamo where a Tensor Subclass can have it's `__init__` traced, which calls `__tensor_flatten__` prior to init. In this case. attributes used in `__tensor_flatten__` result in an error since they are not yet initialized.

This PR instead adds an escape hatch to skip faking at the start of `VariableBuilder.wrap_tensor` in this case.

**Test Plan:** The original error can be reproduced with [this script](https://gist.github.com/sayakpaul/929678132809874c5dbf9c5215460d33). I was unable to reproduce the error end-to-end without Diffusers or TorchAO dependency, so I instead added a unit test that checks that the escape hatch is taken with mocks.

Pull Request resolved: pytorch#175397
Approved by: https://github.com/anijain2305, https://github.com/williamwen42
EmanueleCoradin pushed a commit to EmanueleCoradin/pytorch that referenced this pull request Mar 30, 2026
…75596)

Fixes an error with the TENSOR_SUBCLASS_METADATA_MATCH guard when the tensor subclass has a SymInt in its metadata. In this scenario, `deepcopy` of the metadata propagates through the SymInt down to the ShapeEnv, FakeMode, and then FakeTensors, causing an error due to no data pointer.

This PR replaces SymInts in the metadata with an `_AnyCompare` object that always returns `True` for equals checks. This assumes dynamic shapes checks will handle correctness.

**Test Plan:** The original error can be reproduced with [this script](https://gist.github.com/sayakpaul/929678132809874c5dbf9c5215460d33) (if ran on the previous commit from this stack). This PR adds a regression test with a manually injected SymInt into the metadata, then compiles with `full_graph=True` and checks for no recompiles.

Pull Request resolved: pytorch#175596
Approved by: https://github.com/anijain2305
ghstack dependencies: pytorch#175397
EmanueleCoradin pushed a commit to EmanueleCoradin/pytorch that referenced this pull request Mar 30, 2026
…sure refcycle (pytorch#175660)

Fixes a potential reference cycle that can block `swap_tensors` during or after compile. This reference cycle comes from a closure of a `MetaConverter` object within the `_empty_create_subclass` defined in `MetaConverter.empty_create_subclass`.

This PR moves `_empty_create_subclass` to be a method of `MetaConverter` instead, adding additional arguments and moving imports as needed.

**Test Plan:** The original error can be reproduced with [this script](https://gist.github.com/sayakpaul/929678132809874c5dbf9c5215460d33) (if ran on the previous commit from this stack). This PR adds a unit test that checks that weakrefs created by `MetaConverter` are cleaned up when it is manually deleted even if garbage collection is disabled.

Pull Request resolved: pytorch#175660
Approved by: https://github.com/anijain2305
ghstack dependencies: pytorch#175397, pytorch#175596
AaronWang04 pushed a commit to AaronWang04/pytorch that referenced this pull request Mar 31, 2026
…patch (pytorch#175397)

Fixes an edge-case found with Diffusers+TorchAo+Dynamo where a Tensor Subclass can have it's `__init__` traced, which calls `__tensor_flatten__` prior to init. In this case. attributes used in `__tensor_flatten__` result in an error since they are not yet initialized.

This PR instead adds an escape hatch to skip faking at the start of `VariableBuilder.wrap_tensor` in this case.

**Test Plan:** The original error can be reproduced with [this script](https://gist.github.com/sayakpaul/929678132809874c5dbf9c5215460d33). I was unable to reproduce the error end-to-end without Diffusers or TorchAO dependency, so I instead added a unit test that checks that the escape hatch is taken with mocks.

Pull Request resolved: pytorch#175397
Approved by: https://github.com/anijain2305, https://github.com/williamwen42
AaronWang04 pushed a commit to AaronWang04/pytorch that referenced this pull request Mar 31, 2026
…75596)

Fixes an error with the TENSOR_SUBCLASS_METADATA_MATCH guard when the tensor subclass has a SymInt in its metadata. In this scenario, `deepcopy` of the metadata propagates through the SymInt down to the ShapeEnv, FakeMode, and then FakeTensors, causing an error due to no data pointer.

This PR replaces SymInts in the metadata with an `_AnyCompare` object that always returns `True` for equals checks. This assumes dynamic shapes checks will handle correctness.

**Test Plan:** The original error can be reproduced with [this script](https://gist.github.com/sayakpaul/929678132809874c5dbf9c5215460d33) (if ran on the previous commit from this stack). This PR adds a regression test with a manually injected SymInt into the metadata, then compiles with `full_graph=True` and checks for no recompiles.

Pull Request resolved: pytorch#175596
Approved by: https://github.com/anijain2305
ghstack dependencies: pytorch#175397
AaronWang04 pushed a commit to AaronWang04/pytorch that referenced this pull request Mar 31, 2026
…sure refcycle (pytorch#175660)

Fixes a potential reference cycle that can block `swap_tensors` during or after compile. This reference cycle comes from a closure of a `MetaConverter` object within the `_empty_create_subclass` defined in `MetaConverter.empty_create_subclass`.

This PR moves `_empty_create_subclass` to be a method of `MetaConverter` instead, adding additional arguments and moving imports as needed.

**Test Plan:** The original error can be reproduced with [this script](https://gist.github.com/sayakpaul/929678132809874c5dbf9c5215460d33) (if ran on the previous commit from this stack). This PR adds a unit test that checks that weakrefs created by `MetaConverter` are cleaned up when it is manually deleted even if garbage collection is disabled.

Pull Request resolved: pytorch#175660
Approved by: https://github.com/anijain2305
ghstack dependencies: pytorch#175397, pytorch#175596
@sayakpaul
Copy link
Copy Markdown

huggingface/diffusers#13276 was merged and this is working really well now! Thanks @azahed98

@github-actions github-actions Bot deleted the gh/azahed98/5/head branch May 7, 2026 02:25
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants