Skip to content

Commit df4d5f0

Browse files
committed
Update on "Can we test dynamo forward hooks now?"
[ghstack-poisoned]
2 parents a448f9b + b63cfc1 commit df4d5f0

1 file changed

Lines changed: 16 additions & 3 deletions

File tree

test/nn/test_module_hooks.py

Lines changed: 16 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -642,12 +642,25 @@ def test_module_global_hooks(self):
642642
'backwards': 0
643643
}
644644

645+
"""
646+
Is anything actually wrong with dynamo impl for hooks currently?
647+
- is the return type differeing from eager?
648+
649+
Does this test just not work for dynamo bc asserting stuff in the hook
650+
causes graph breaks?
651+
"""
645652
def fw_hook(inc, h_module, input, output):
646653
self.assertIsInstance(input, tuple)
647-
self.assertTrue(isinstance(output, torch.Tensor))
648-
self.assertTrue(isinstance(h_module, module))
654+
# dynamo produces tuple output
655+
# self.assertTrue(isinstance(output, torch.Tensor), f"{type(output)}")
656+
657+
# dynamo returns a GraphModule
658+
# self.assertTrue(isinstance(h_module, module), f"{type(h_module)}")
659+
649660
self.assertEqual(input[0], torch.ones(5, 5))
650-
self.assertEqual(output, torch.empty(5, 5).fill_(1 / (1 + 1 / math.e)))
661+
662+
# dynamo output is an empty tuple!
663+
self.assertEqual(output, torch.empty(5, 5).fill_(1 / (1 + 1 / math.e)), f"{output}")
651664
counter['forwards'] += inc
652665

653666
def bw_hook(inc, h_module, grad_input, grad_output):

0 commit comments

Comments
 (0)