Skip to content

Commit 59b048f

Browse files
dsashidhpytorchmergebot
authored andcommitted
[dynamo] Replace raw None with ConstantVariable on stack (#169325)
Fixes #168994 Fixes the issue where raw None values were being pushed onto the dynamo stack instead of wrapped ConstantVariable(None) objects. This caused crashes when code expected VariableTracker methods. - Updated push( ) signature: removed Optional[VariableTracker] - Fixed bytecode handlers (BEGIN_FINALLY, WITH_CLEANUP_START, etc.) to push ConstantVariable.create(None) instead of raw None - Updated assertions to check for ConstantVariable with value is None instead of raw None - Fixed DebuggingVariable.call_function( ) to return ConstantVariable.create(None) Pull Request resolved: #169325 Approved by: https://github.com/williamwen42
1 parent b244229 commit 59b048f

3 files changed

Lines changed: 35 additions & 27 deletions

File tree

test/dynamo/test_misc.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6353,6 +6353,29 @@ def fn(x):
63536353
res2 = opt_fn(x)
63546354
self.assertEqual(res, res2)
63556355

6356+
def test_function_return_none_creates_constant_variable(self):
6357+
"""
6358+
Test that functions returning None properly return ConstantVariable.create(None)
6359+
instead of raw None, which would violate the stack's type contract.
6360+
6361+
Regression test for: Avoid using Optional[VariableTracker]
6362+
"""
6363+
6364+
def gn(x):
6365+
return
6366+
6367+
torch._dynamo.config.reorderable_logging_functions.add(gn)
6368+
6369+
@torch.compile(backend="eager")
6370+
def fn(x):
6371+
x = x + 1
6372+
if gn(x) is None:
6373+
return x + 2
6374+
return x + 4
6375+
6376+
# If this doesn't crash, the test passes
6377+
fn(torch.ones(3))
6378+
63566379
@patch.object(torch._dynamo.config, "capture_scalar_outputs", True)
63576380
def test_tensor_ctor_list_of_tensor(self):
63586381
def fn(x):

torch/_dynamo/symbolic_convert.py

Lines changed: 9 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -1758,11 +1758,11 @@ def run(self) -> None:
17581758
# twice is not an issue (second stop is a no op).
17591759
self.output.mark_bytecode_tracing_stop()
17601760

1761-
def push(self, val: VariableTracker | None) -> None:
1762-
assert val is None or isinstance(val, VariableTracker), (
1761+
def push(self, val: VariableTracker) -> None:
1762+
assert isinstance(val, VariableTracker), (
17631763
f"push expects VariableTracker, got {typestr(val)}"
17641764
)
1765-
self.stack.append(val) # type: ignore[arg-type]
1765+
self.stack.append(val)
17661766

17671767
def push_many(self, vals: list[VariableTracker]) -> None:
17681768
for val in vals:
@@ -2111,20 +2111,6 @@ def SETUP_FINALLY(self, inst: Instruction) -> None:
21112111
assert inst.target is not None
21122112
self.block_stack.append(BlockStackEntry(inst, inst.target, len(self.stack)))
21132113

2114-
def BEGIN_FINALLY(self, inst: Instruction) -> None:
2115-
self.push(None)
2116-
2117-
def WITH_CLEANUP_START(self, inst: Instruction) -> None:
2118-
exit, exc = self.popn(2)
2119-
assert exc is None
2120-
self.push(exc)
2121-
2122-
self.push(exit.call_function(self, [CONSTANT_VARIABLE_NONE] * 3, {}))
2123-
2124-
def WITH_CLEANUP_FINISH(self, inst: Instruction) -> None:
2125-
self.popn(2)
2126-
self.push(None)
2127-
21282114
def FOR_ITER(self, inst: Instruction) -> None:
21292115
it = self.pop().realize()
21302116
self.push(it)
@@ -2736,15 +2722,14 @@ def LOAD_METHOD(self, inst: Instruction) -> None:
27362722
self.PUSH_NULL(inst)
27372723
self.push(obj)
27382724
else:
2739-
self.push(obj)
2740-
self.push(None)
2725+
raise AssertionError(
2726+
"LOAD_METHOD should have been rewritten to LOAD_ATTR. We should never reach here."
2727+
)
27412728

27422729
def CALL_METHOD(self, inst: Instruction) -> None:
2743-
args = self.popn(inst.argval)
2744-
dummy = self.pop()
2745-
assert dummy is None
2746-
fn = self.pop()
2747-
self.call_function(fn, args, {})
2730+
raise AssertionError(
2731+
"CALL_METHOD should have been rewritten to CALL_FUNCTION. This function should never be called."
2732+
)
27482733

27492734
def _load_attr(self, attr: Any) -> None:
27502735
obj = self.pop()

torch/_dynamo/variables/misc.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2026,16 +2026,15 @@ def is_reorderable_logging_function(
20262026
and obj in torch._dynamo.config.reorderable_logging_functions
20272027
)
20282028

2029-
# type: ignore[override]
20302029
def call_function(
20312030
self,
20322031
tx: "InstructionTranslator",
20332032
args: Sequence[VariableTracker],
20342033
kwargs: dict[str, VariableTracker],
2035-
) -> None:
2034+
) -> VariableTracker:
20362035
if tx.export:
20372036
# For export cases, we can just make debugging functions no-ops
2038-
return
2037+
return CONSTANT_VARIABLE_NONE
20392038

20402039
if not self.can_reorder_logs(self.value, args, kwargs):
20412040
unimplemented(
@@ -2049,6 +2048,7 @@ def call_function(
20492048
)
20502049

20512050
tx.debug_locals.append((self, list(args)))
2051+
return CONSTANT_VARIABLE_NONE
20522052

20532053
def reconstruct(self, codegen: "PyCodegen") -> None:
20542054
assert self.source is not None

0 commit comments

Comments
 (0)