Skip to content

Commit 1b45111

Browse files
williamwen42pytorchmergebot
authored andcommitted
[dynamo, nested graph breaks] fix nested step_graph_break bug where parent stack gets corrupted (#177408)
Pull Request resolved: #177408 Approved by: https://github.com/Lucaskabela ghstack dependencies: #176906, #177090, #177155, #177195
1 parent c5dcefd commit 1b45111

2 files changed

Lines changed: 29 additions & 2 deletions

File tree

test/dynamo/test_nested_graph_breaks.py

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1324,6 +1324,34 @@ def gn(x):
13241324
self.assertEqual(gn(inp), inp + 3)
13251325
self.assertEqual(cnts.frame_count, 2)
13261326

1327+
def test_step_graph_break_frame_values_not_corrupted(self):
1328+
"""Bytecode generation bug in step_graph_break corrupted parent frame
1329+
locals when the parent had a non-empty operand stack (num_stack > 0).
1330+
"""
1331+
1332+
def inner(x):
1333+
x = x + 1
1334+
x = x + 1
1335+
torch._dynamo.step_unsupported()
1336+
return x
1337+
1338+
cnts = torch._dynamo.testing.CompileCounter()
1339+
1340+
@torch.compile(backend=cnts)
1341+
def fn(x):
1342+
x = x + 1
1343+
y = (x, inner(x))
1344+
return x, y
1345+
1346+
x = torch.tensor([1.0, 2.0])
1347+
result = fn(x)
1348+
self.assertEqual(result[0], torch.tensor([2.0, 3.0]))
1349+
self.assertEqual(
1350+
result[1], (torch.tensor([2.0, 3.0]), torch.tensor([4.0, 5.0]))
1351+
)
1352+
self.assertEqual(cnts.frame_count, 1)
1353+
self.assertEqual(cnts.op_count, 3)
1354+
13271355
def test_contextmanager_graph_break_in_init(self):
13281356
"""Graph break in _GeneratorContextManager.__init__ when the generator
13291357
function is @torch._disable_dynamo (the DDP pattern)."""

torch/_dynamo/symbolic_convert.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1605,9 +1605,8 @@ def step_graph_break(self, continue_inst: Instruction) -> None:
16051605
*create_copy(2),
16061606
cg.create_load_const(0),
16071607
cg.create_binary_subscr(),
1608-
create_dup_top(),
16091608
*create_binary_slice(num_stack, None),
1610-
*create_swap(2),
1609+
*create_copy(3),
16111610
cg.create_load_const(0),
16121611
create_instruction("STORE_SUBSCR"),
16131612
]

0 commit comments

Comments
 (0)