Skip to content

Commit 609e0ed

Browse files
guilhermeleobaspytorchmergebot
authored andcommitted
[dynamo] Fix dunder attr access on WrapperUserFunctionVariable (lru_cache, wraps) (#176934)
`WrapperUserFunctionVariable` now inherits from `BaseUserFunctionVariable` instead of `VariableTracker`, gaining the shared `var_getattr` implementation that handles `__name__`, `__qualname__`, `__doc__`, `__module__`, `__code__`, `__dict__`, `__annotations__`, and `__type_params__`. This fixes `functools.wraps` applied to `lru_cache`-wrapped functions at trace time — previously, accessing `__name__`, `__dict__`, etc. on the wrapper object would graph-break. Co-authored Claude Sonnet 4.6 Pull Request resolved: #176934 Approved by: https://github.com/Lucaskabela, https://github.com/williamwen42
1 parent 004e7e6 commit 609e0ed

3 files changed

Lines changed: 242 additions & 52 deletions

File tree

test/dynamo/test_functions.py

Lines changed: 159 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4295,6 +4295,165 @@ def fn(x):
42954295
opt_fn = torch.compile(fn, backend="eager", fullgraph=True)
42964296
self.assertEqual(fn(x), opt_fn(x))
42974297

4298+
def test_wrapper_user_function_hasattr(self):
4299+
# WrapperUserFunctionVariable (e.g. lru_cache-wrapped fn) passed to a
4300+
# decorator that calls functools.wraps at tracing time should not graph
4301+
# break on hasattr(fn, '__dict__').
4302+
@functools.lru_cache
4303+
def cached_fn(x):
4304+
return x * 2
4305+
4306+
def retry(func):
4307+
@functools.wraps(func)
4308+
def wrapper(*args, **kwargs):
4309+
return func(*args, **kwargs)
4310+
4311+
return wrapper
4312+
4313+
def fn(x):
4314+
return retry(cached_fn)(x)
4315+
4316+
x = torch.tensor(2.0, device="cpu")
4317+
opt_fn = torch.compile(fn, backend="eager", fullgraph=True)
4318+
self.assertEqual(fn(x), opt_fn(x))
4319+
4320+
def test_wraps_on_lru_cache_preserves_name(self):
4321+
# functools.wraps copies __name__ and __qualname__ from the wrapped fn;
4322+
# when applied to an lru_cache-wrapped function at trace time,
4323+
# WrapperUserFunctionVariable must expose those attributes.
4324+
@functools.lru_cache
4325+
def my_op(x):
4326+
"""my docstring"""
4327+
return x + 1
4328+
4329+
def apply_wraps(func):
4330+
@functools.wraps(func)
4331+
def inner(*args, **kwargs):
4332+
return func(*args, **kwargs)
4333+
4334+
return inner
4335+
4336+
def fn(x):
4337+
wrapped = apply_wraps(my_op)
4338+
if wrapped.__name__ != "my_op":
4339+
raise AssertionError(f"Expected 'my_op', got {wrapped.__name__!r}")
4340+
if wrapped.__qualname__ != my_op.__qualname__:
4341+
raise AssertionError(
4342+
f"Expected {my_op.__qualname__!r}, got {wrapped.__qualname__!r}"
4343+
)
4344+
if wrapped.__doc__ != "my docstring":
4345+
raise AssertionError(
4346+
f"Expected 'my docstring', got {wrapped.__doc__!r}"
4347+
)
4348+
return wrapped(x)
4349+
4350+
x = torch.tensor(1.0)
4351+
opt_fn = torch.compile(fn, backend="eager", fullgraph=True)
4352+
self.assertEqual(fn(x), opt_fn(x))
4353+
4354+
def test_wraps_on_lru_cache_copies_annotations(self):
4355+
# functools.wraps should copy __annotations__ from an lru_cache-wrapped fn.
4356+
@functools.lru_cache
4357+
def annotated_fn(x: torch.Tensor) -> torch.Tensor:
4358+
return x * 3
4359+
4360+
def decorator(func):
4361+
@functools.wraps(func)
4362+
def inner(*args, **kwargs):
4363+
return func(*args, **kwargs)
4364+
4365+
return inner
4366+
4367+
def fn(x):
4368+
return decorator(annotated_fn)(x)
4369+
4370+
x = torch.tensor(2.0)
4371+
opt_fn = torch.compile(fn, backend="eager", fullgraph=True)
4372+
self.assertEqual(fn(x), opt_fn(x))
4373+
4374+
def test_wraps_on_nested_fn(self):
4375+
# functools.wraps applied to a locally defined (NestedUserFunctionVariable)
4376+
# function should also work without graph breaks.
4377+
def fn(x):
4378+
def inner_op(t):
4379+
"""inner doc"""
4380+
return t * 2
4381+
4382+
@functools.wraps(inner_op)
4383+
def wrapper(*args, **kwargs):
4384+
return inner_op(*args, **kwargs)
4385+
4386+
if wrapper.__name__ != "inner_op":
4387+
raise AssertionError(f"Expected 'inner_op', got {wrapper.__name__!r}")
4388+
if wrapper.__doc__ != "inner doc":
4389+
raise AssertionError(f"Expected 'inner doc', got {wrapper.__doc__!r}")
4390+
return wrapper(x)
4391+
4392+
x = torch.tensor(3.0)
4393+
opt_fn = torch.compile(fn, backend="eager", fullgraph=True)
4394+
self.assertEqual(fn(x), opt_fn(x))
4395+
4396+
def test_wraps_stacked_on_lru_cache(self):
4397+
# Stacking two functools.wraps layers over an lru_cache-wrapped fn.
4398+
@functools.lru_cache
4399+
def base_fn(x):
4400+
return x - 1
4401+
4402+
def outer_decorator(func):
4403+
@functools.wraps(func)
4404+
def middle(*args, **kwargs):
4405+
return func(*args, **kwargs)
4406+
4407+
return middle
4408+
4409+
def inner_decorator(func):
4410+
@functools.wraps(func)
4411+
def innermost(*args, **kwargs):
4412+
return func(*args, **kwargs)
4413+
4414+
return innermost
4415+
4416+
def fn(x):
4417+
return inner_decorator(outer_decorator(base_fn))(x)
4418+
4419+
x = torch.tensor(5.0)
4420+
opt_fn = torch.compile(fn, backend="eager", fullgraph=True)
4421+
self.assertEqual(fn(x), opt_fn(x))
4422+
4423+
def test_lru_cache_dunder_name_access(self):
4424+
# Accessing __name__ on an lru_cache-wrapped function during tracing
4425+
# should return the original function's name as a constant.
4426+
@functools.lru_cache
4427+
def compute(x):
4428+
return x + 10
4429+
4430+
def fn(x):
4431+
name = compute.__name__
4432+
if name != "compute":
4433+
raise AssertionError(f"Expected 'compute', got {name!r}")
4434+
return compute(x)
4435+
4436+
x = torch.tensor(1.0)
4437+
opt_fn = torch.compile(fn, backend="eager", fullgraph=True)
4438+
self.assertEqual(fn(x), opt_fn(x))
4439+
4440+
def test_lru_cache_dunder_doc_access(self):
4441+
# Accessing __doc__ on an lru_cache-wrapped function during tracing.
4442+
@functools.lru_cache
4443+
def documented_fn(x):
4444+
"""returns x squared"""
4445+
return x**2
4446+
4447+
def fn(x):
4448+
doc = documented_fn.__doc__
4449+
if doc != "returns x squared":
4450+
raise AssertionError(f"Expected 'returns x squared', got {doc!r}")
4451+
return documented_fn(x)
4452+
4453+
x = torch.tensor(3.0)
4454+
opt_fn = torch.compile(fn, backend="eager", fullgraph=True)
4455+
self.assertEqual(fn(x), opt_fn(x))
4456+
42984457
def test_functools_cache_guard(self):
42994458
class Foo:
43004459
@functools.lru_cache # noqa: B019

torch/_dynamo/variables/builder.py

Lines changed: 14 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1392,11 +1392,19 @@ def build_key_value(
13921392
elif inspect.getattr_static(value, "__script_if_tracing_wrapper", False):
13931393
self.install_guards(GuardBuilder.TYPE_MATCH)
13941394
return WrapperUserFunctionVariable(
1395-
value, "__original_fn", source=self.source
1395+
value,
1396+
"__original_fn",
1397+
source=self.source,
1398+
mutation_type=AttributeMutationExisting(),
13961399
)
13971400
elif is_lru_cache_wrapped_function(value):
13981401
self.install_guards(GuardBuilder.TYPE_MATCH)
1399-
return WrapperUserFunctionVariable(value, "__wrapped__", source=self.source)
1402+
return WrapperUserFunctionVariable(
1403+
value,
1404+
"__wrapped__",
1405+
source=self.source,
1406+
mutation_type=AttributeMutationExisting(),
1407+
)
14001408
elif value is sys.exc_info or (
14011409
sys.version_info >= (3, 11) and value is sys.exception
14021410
):
@@ -1406,7 +1414,10 @@ def build_key_value(
14061414
):
14071415
self.install_guards(GuardBuilder.TYPE_MATCH)
14081416
return WrapperUserFunctionVariable(
1409-
value, "_torchdynamo_inline", source=self.source
1417+
value,
1418+
"_torchdynamo_inline",
1419+
source=self.source,
1420+
mutation_type=AttributeMutationExisting(),
14101421
)
14111422
elif value is collections.namedtuple:
14121423
self.install_guards(GuardBuilder.ID_MATCH)

torch/_dynamo/variables/functions.py

Lines changed: 69 additions & 49 deletions
Original file line numberDiff line numberDiff line change
@@ -395,12 +395,14 @@ def get_source(self) -> Source | None:
395395
def get_dict_vt(self, tx: "InstructionTranslator") -> "DunderDictVariable":
396396
if self.dict_vt is None:
397397
dict_proxy: dict[str, VariableTracker] = {}
398-
if hasattr(self, "fn"): # Use `.get_function()` instead?
398+
399+
if not istype(self, NestedUserFunctionVariable):
400+
fn = self.get_function()
399401
dict_proxy = {
400402
name: VariableTracker.build(
401403
tx, value, source=self.source and AttrSource(self.source, name)
402404
)
403-
for name, value in self.fn.__dict__.items()
405+
for name, value in fn.__dict__.items()
404406
}
405407
self.dict_vt = variables.DunderDictVariable.create(tx, self, dict_proxy)
406408
return self.dict_vt
@@ -445,9 +447,47 @@ def get_code(self) -> types.CodeType:
445447
def has_self(self) -> bool:
446448
raise NotImplementedError
447449

450+
def get_function(self) -> types.FunctionType:
451+
raise NotImplementedError
452+
448453
def get_module(self) -> str:
449454
return self.get_globals()["__name__"]
450455

456+
def var_getattr(self, tx: "InstructionTranslator", name: str):
457+
fn_dict = self.get_dict_vt(tx)
458+
459+
# missing: __globals__, __closure__, __kwdefautls__, __defaults__
460+
if name in ("__name__", "__qualname__", "__doc__", "__module__", "__code__"):
461+
val = getattr(self, f"get_{name[2:-2]}")()
462+
if fn_dict.contains(name):
463+
return fn_dict.getitem(name)
464+
return ConstantVariable.create(
465+
val, source=self.source and AttrSource(self.source, name)
466+
)
467+
elif name == "__dict__":
468+
return fn_dict
469+
elif name == "__annotations__":
470+
return fn_dict.getitem_or_default(
471+
name,
472+
lambda: variables.ConstDictVariable(
473+
{},
474+
mutation_type=ValueMutationNew(),
475+
),
476+
)
477+
elif name == "__type_params__":
478+
return fn_dict.getitem_or_default(
479+
name,
480+
lambda: variables.TupleVariable(
481+
[],
482+
mutation_type=ValueMutationNew(),
483+
),
484+
)
485+
else:
486+
if fn_dict.contains(name):
487+
return fn_dict.getitem(name)
488+
else:
489+
raise_observed_exception(AttributeError, tx)
490+
451491
def call_function(
452492
self,
453493
tx: "InstructionTranslator",
@@ -636,7 +676,7 @@ def bind_args(
636676

637677
def var_getattr(self, tx: "InstructionTranslator", name: str) -> VariableTracker:
638678
if name == "__dict__":
639-
return self.get_dict_vt(tx)
679+
return super().var_getattr(tx, name)
640680
elif name in cmp_name_to_op_mapping:
641681
return variables.GetAttrVariable(self, name)
642682
source = self.get_source()
@@ -1858,55 +1898,24 @@ def _get_function_impl(self, _converting: set[int]) -> types.FunctionType:
18581898
return func
18591899

18601900
def var_getattr(self, tx: "InstructionTranslator", name: str) -> VariableTracker:
1861-
fn_dict = self.get_dict_vt(tx)
1862-
1863-
# Some dunder attributes (__name__, __doc__, etc) are stored in the C
1864-
# field slot. I guess it won't be too bad if we store them in the
1865-
# __dict__ field.
1866-
1867-
# annotations should be stored in the __dict__ field
1868-
if name == "__annotations__":
1869-
return self.get_dict_vt(tx).getitem_or_default(
1870-
name,
1871-
lambda: variables.ConstDictVariable(
1872-
{},
1873-
source=self.source and AttrSource(self.source, "__annotations__"),
1874-
mutation_type=ValueMutationNew(),
1875-
),
1876-
)
1877-
elif name == "__code__":
1878-
return self.code
1879-
elif name == "__defaults__":
1901+
if name in (
1902+
"__annotations__",
1903+
"__dict__",
1904+
"__doc__",
1905+
"__code__",
1906+
"__module__",
1907+
"__name__",
1908+
"__qualname__",
1909+
"__type_params__",
1910+
):
1911+
return super().var_getattr(tx, name)
1912+
if name == "__defaults__":
18801913
d = getattr(self, "defaults", None)
18811914
return d.as_python_constant() if d else ConstantVariable.create(None)
1882-
elif name == "__dict__":
1883-
return self.get_dict_vt(tx)
1884-
elif name == "__type_params__":
1885-
return fn_dict.getitem_or_default(
1886-
name,
1887-
lambda: variables.TupleVariable(
1888-
[],
1889-
source=self.source and AttrSource(self.source, "__type_params__"),
1890-
),
1891-
)
1892-
elif name in ("__name__", "__qualname__", "__doc__", "__module__"):
1893-
val = getattr(self, f"get_{name[2:-2]}")()
1894-
return fn_dict.getitem_or_default(
1895-
name,
1896-
lambda: ConstantVariable.create(
1897-
val, source=self.source and AttrSource(self.source, name)
1898-
),
1899-
)
19001915
elif name in cmp_name_to_op_mapping:
19011916
return variables.GetAttrVariable(self, name)
19021917
else:
1903-
if fn_dict.contains(name):
1904-
return fn_dict.getitem(name)
1905-
else:
1906-
# should `var_getattr` raise AttributeError if not found?
1907-
# I'm wondering if this method is a helper that it is faster
1908-
# than going through BuiltinVariable(getattr).call_function(...)
1909-
raise_observed_exception(AttributeError, tx)
1918+
return super().var_getattr(tx, name)
19101919

19111920
def has_closure(self) -> bool:
19121921
return self.closure is not None
@@ -2335,7 +2344,7 @@ def reconstruct(self, codegen: "PyCodegen") -> None:
23352344
codegen.extend_output(create_call_function(1, False))
23362345

23372346

2338-
class WrapperUserFunctionVariable(VariableTracker):
2347+
class WrapperUserFunctionVariable(BaseUserFunctionVariable):
23392348
"""
23402349
Used to represent a wrapper object that contains the actual callable as an
23412350
attribute. For example, torch.jit.script/trace have the original function at
@@ -2348,14 +2357,25 @@ def __init__(self, wrapper_obj: Any, attr_to_trace: str, **kwargs: Any) -> None:
23482357
self.wrapper_obj = wrapper_obj
23492358
self.attr_to_trace = attr_to_trace
23502359

2360+
def get_module(self) -> str:
2361+
return self.wrapper_obj.__module__
2362+
2363+
def get_name(self) -> str:
2364+
return self.wrapper_obj.__name__
2365+
2366+
def get_code(self) -> types.CodeType:
2367+
return self.get_function().__code__
2368+
23512369
def var_getattr(self, tx: "InstructionTranslator", name: str) -> VariableTracker:
23522370
if name == self.attr_to_trace:
23532371
val = getattr(self.wrapper_obj, self.attr_to_trace)
23542372
source = self.source and AttrSource(self.source, name)
23552373
return VariableTracker.build(tx, val, source)
2356-
23572374
return super().var_getattr(tx, name)
23582375

2376+
def get_function(self):
2377+
return getattr(self.wrapper_obj, self.attr_to_trace)
2378+
23592379
def self_args(self) -> list[VariableTracker]:
23602380
return []
23612381

0 commit comments

Comments
 (0)