🐛 Describe the bug
When a property getter reads an attribute via self.__dict__.get("attr", default) instead of self.attr, dynamo doesn't track the mutation. If the property is read from a @torch.compiler.disable-decorated method later in the same compiled frame, it returns the stale (pre-mutation) value.
This is a regression — it passes on 2.11.0.dev20260116 and fails on 2.12.0.dev20260305.
Minimal repro
import torch
from torch.compiler import disable
class Container:
def __init__(self):
self._len_value = 0
@property
def _len(self):
# Using __dict__.get instead of self._len_value triggers the bug
return self.__dict__.get("_len_value", 0)
@_len.setter
def _len(self, value):
self._len_value = value
def add(self, n):
self._len = self._len + n
@disable()
def __len__(self):
return self._len
c = Container()
@torch.compile(backend="eager")
def f(x):
c.add(x.shape[0]) # mutates c._len_value via property setter
return len(c) # reads c._len via @disable -> property getter -> __dict__.get
result = f(torch.randn(5))
assert result == 5, f"Expected 5, got {result}"
Note:
The issue is specifically __dict__.get() in the property getter. Replacing it with direct attribute access (return self._len_value) or getattr makes the test pass.
cc @chauhang @penguinwu @kurtamohler
Versions
- Passes:
torch 2.10.0, torch 2.11.0.dev20260116
- Fails:
torch 2.12.0.dev20260305
🐛 Describe the bug
When a property getter reads an attribute via
self.__dict__.get("attr", default)instead ofself.attr, dynamo doesn't track the mutation. If the property is read from a@torch.compiler.disable-decorated method later in the same compiled frame, it returns the stale (pre-mutation) value.This is a regression — it passes on
2.11.0.dev20260116and fails on2.12.0.dev20260305.Minimal repro
Note:
The issue is specifically
__dict__.get()in the property getter. Replacing it with direct attribute access (return self._len_value) orgetattrmakes the test pass.cc @chauhang @penguinwu @kurtamohler
Versions
torch 2.10.0,torch 2.11.0.dev20260116torch 2.12.0.dev20260305