@@ -245,6 +245,7 @@ def verify_aot_autograd(
245245 * ,
246246 test_mutation : bool = False ,
247247 decompositions : Optional [Dict ] = None ,
248+ dynamic : bool = False ,
248249 ):
249250 for keep_input_mutations in [True , False ]:
250251 # Some tests pass in a callable for inp, to generate the inputs
@@ -294,15 +295,17 @@ def verify_aot_autograd(
294295 fw_compiler = partial (extract_graph , graph_cell = fw_graph_cell ),
295296 bw_compiler = nop ,
296297 decompositions = decompositions ,
297- keep_inference_input_mutations = keep_input_mutations
298+ keep_inference_input_mutations = keep_input_mutations ,
299+ dynamic = dynamic
298300 )
299301 else :
300302 compiled_f = aot_function (
301303 f ,
302304 fw_compiler = partial (extract_graph , graph_cell = fw_graph_cell ),
303305 bw_compiler = nop ,
304306 decompositions = decompositions ,
305- keep_inference_input_mutations = keep_input_mutations
307+ keep_inference_input_mutations = keep_input_mutations ,
308+ dynamic = dynamic
306309 )
307310 ref_out , ref_grad = outs_and_grads (f , graph_inps , inp )
308311 test_out , test_grad = outs_and_grads (compiled_f , graph_inps_copy , inp_copy )
@@ -366,21 +369,17 @@ def f(a, b):
366369 self .verify_aot_autograd (f , inp )
367370
368371 # Test for bug occurring at the intersection of fake tensors & functionalization.
369- @patch ("torch._functorch.config.use_dynamic_shapes" , True )
370- @patch ("torch._functorch.config.use_fake_tensor" , True )
371372 def test_squeeze_mutation (self ):
372373 def f (a ):
373374 b = a .clone ().squeeze (- 1 )
374375 b .add_ (1. )
375376 return a + b
376377
377378 inp = [torch .randn (3 , 1 , requires_grad = True )]
378- self .verify_aot_autograd (f , inp )
379+ self .verify_aot_autograd (f , inp , dynamic = True )
379380 inp = [torch .randn (3 , 1 , requires_grad = False )]
380- self .verify_aot_autograd (f , inp )
381+ self .verify_aot_autograd (f , inp , dynamic = True )
381382
382- @patch ("torch._functorch.config.use_dynamic_shapes" , True )
383- @patch ("torch._functorch.config.use_fake_tensor" , True )
384383 def test_embedding_bag_view (self ):
385384 # Backwards pass tries to wrap a sparse tensor in a FunctionalTensorWrapper;
386385 # test that this works even though the sparse tensor has no storage.
@@ -395,7 +394,7 @@ def forward(self, x, y):
395394
396395 x = torch .arange (3 )
397396 y = torch .arange (3 )
398- self .verify_aot_autograd (F (), [x , y ])
397+ self .verify_aot_autograd (F (), [x , y ], dynamic = True )
399398
400399 @patch ("functorch.compile.config.use_fake_tensor" , True )
401400 def test_input_mutation_simple (self ):
@@ -1714,8 +1713,6 @@ def bn(x):
17141713 for a , b in zip (ref , res ):
17151714 assert torch .allclose (a , b )
17161715
1717- @patch ("functorch.compile.config.use_dynamic_shapes" , True )
1718- @patch ("functorch.compile.config.use_fake_tensor" , True )
17191716 def test_output_op_depending_on_symint (self ):
17201717 """
17211718 It won't be obvious from reading this test what it's testing for. We should probably make it into a more
@@ -1738,12 +1735,10 @@ def f(x):
17381735 # TODO: assert outputs of fwd graph trace to correct symint
17391736
17401737 # e2e test that fails without symint clone fix
1741- af = aot_function (f , nop , partition_fn = partial (min_cut_rematerialization_partition , compiler = "inductor" ))
1738+ af = aot_function (f , nop , partition_fn = partial (min_cut_rematerialization_partition , compiler = "inductor" ), dynamic = True )
17421739 out = af (inp )
17431740 self .assertEqual (out , f (inp ))
17441741
1745- @patch ("functorch.compile.config.use_dynamic_shapes" , True )
1746- @patch ("functorch.compile.config.use_fake_tensor" , True )
17471742 def test_default_partitioner_saves_symints_not_tensors_for_bw (self ):
17481743 """
17491744 In this test, the important thing is that primals_1 is **only** needed in the backward
@@ -1764,7 +1759,7 @@ def f(a):
17641759 d = b .masked_fill_ (c , 0 )
17651760 return d
17661761
1767- compiled_f = aot_function (f , nop )
1762+ compiled_f = aot_function (f , nop , dynamic = True )
17681763 inp_ref = torch .ones (2 , 2 , requires_grad = True )
17691764 inp_test = torch .ones (2 , 2 , requires_grad = True )
17701765
@@ -1859,14 +1854,15 @@ def get_num_ins_outs(fx_g):
18591854 return tuple (len (i ) for i in get_ins_outs (fx_g ))
18601855
18611856
1862- def get_fw_bw_graph (f , inps , partitioner = min_cut_rematerialization_partition ):
1857+ def get_fw_bw_graph (f , inps , partitioner = min_cut_rematerialization_partition , dynamic = False ):
18631858 fw_graph_cell = [None ]
18641859 bw_graph_cell = [None ]
18651860 aot_function (f ,
18661861 fw_compiler = partial (extract_graph , graph_cell = fw_graph_cell ),
18671862 bw_compiler = partial (extract_graph , graph_cell = bw_graph_cell ),
18681863 partition_fn = partitioner ,
1869- decompositions = default_decompositions )(* inps ).sum ().backward ()
1864+ decompositions = default_decompositions ,
1865+ dynamic = dynamic )(* inps ).sum ().backward ()
18701866 return (fw_graph_cell [0 ], bw_graph_cell [0 ])
18711867
18721868
@@ -1933,8 +1929,6 @@ def f(x, mod_weight, mod_bias):
19331929 self .assertEqual (get_num_ins_outs (fw_graph ), (3 , 6 ))
19341930 self .assertEqual (get_num_ins_outs (bw_graph ), (6 , 3 ))
19351931
1936- @patch ("functorch.compile.config.use_dynamic_shapes" , True )
1937- @patch ("functorch.compile.config.use_fake_tensor" , True )
19381932 @unittest .skipIf (not USE_NETWORKX , "networkx not available" )
19391933 def test_min_cut_partitioner_save_shape (self ):
19401934
@@ -1943,7 +1937,7 @@ def f(x):
19431937 return s
19441938
19451939 inp = [torch .ones ([10 , 10 ], requires_grad = True )]
1946- fw_graph , bw_graph = get_fw_bw_graph (f , inp )
1940+ fw_graph , bw_graph = get_fw_bw_graph (f , inp , dynamic = True )
19471941 _ , fw_output = get_ins_outs (fw_graph )
19481942 self .assertEqual (get_num_ins_outs (fw_graph ), (1 , 3 ))
19491943 self .assertEqual (get_num_ins_outs (bw_graph ), (3 , 1 ))
@@ -1968,14 +1962,12 @@ def f(a, b, c):
19681962 x = sb [0 ] + sc [0 ]
19691963 a_sz = (x , a .size (0 ))
19701964 return torch .cat ([a .expand (a_sz ), b , c ])
1971- fw_graph , bw_graph = get_fw_bw_graph (f , inp )
1965+ fw_graph , bw_graph = get_fw_bw_graph (f , inp , dynamic = True )
19721966 self .assertEqual (get_num_ins_outs (fw_graph ), (3 , 5 ))
19731967 self .assertEqual (get_num_ins_outs (bw_graph ), (5 , 3 ))
19741968 _ , outs = get_ins_outs (fw_graph )
19751969 self .assertTrue (all ([is_sym_node (n ) for n in outs [1 :]]))
19761970
1977- @patch ("functorch.compile.config.use_dynamic_shapes" , True )
1978- @patch ("functorch.compile.config.use_fake_tensor" , True )
19791971 def test_default_partitioner_output_tensor_shape_tensor (self ):
19801972
19811973 inp = [
@@ -2004,7 +1996,8 @@ def f(a, b, c, d):
20041996 fw_compiler = partial (extract_graph , graph_cell = fw_graph_cell ),
20051997 bw_compiler = partial (extract_graph , graph_cell = bw_graph_cell ),
20061998 partition_fn = default_partition ,
2007- decompositions = default_decompositions )(* inp )
1999+ decompositions = default_decompositions ,
2000+ dynamic = True )(* inp )
20082001 fw_graph = fw_graph_cell [0 ]
20092002 (compiled_outs [0 ].sum () + compiled_outs [2 ].sum ()).backward ()
20102003 bw_graph = bw_graph_cell [0 ]
@@ -2037,8 +2030,6 @@ def f(a, b, c, d):
20372030 # TODO(whc) we should learn to return torch.Sizes
20382031 self .assertFalse (isinstance (compiled_outs [1 ], torch .Size ))
20392032
2040- @patch ("functorch.compile.config.use_dynamic_shapes" , True )
2041- @patch ("functorch.compile.config.use_fake_tensor" , True )
20422033 @unittest .skipIf (not USE_NETWORKX , "networkx not available" )
20432034 def test_min_cut_partitioner_output_tensor_shape_tensor (self ):
20442035
@@ -2068,7 +2059,8 @@ def f(a, b, c, d):
20682059 fw_compiler = partial (extract_graph , graph_cell = fw_graph_cell ),
20692060 bw_compiler = partial (extract_graph , graph_cell = bw_graph_cell ),
20702061 partition_fn = min_cut_rematerialization_partition ,
2071- decompositions = default_decompositions )(* inp )
2062+ decompositions = default_decompositions ,
2063+ dynamic = True )(* inp )
20722064 fw_graph = fw_graph_cell [0 ]
20732065 (compiled_outs [0 ].sum () + compiled_outs [2 ].sum ()).backward ()
20742066 bw_graph = bw_graph_cell [0 ]
@@ -2617,7 +2609,7 @@ def create_new_arg(x):
26172609 except DynamicOutputShapeException :
26182610 self .skipTest ("Dynamic output shape operation in trace" )
26192611
2620- def _test_aot_autograd_helper (self , device , dtype , op ):
2612+ def _test_aot_autograd_helper (self , device , dtype , op , dynamic = False ):
26212613 if not op .supports_autograd :
26222614 self .skipTest ("Op does not support autograd" )
26232615
@@ -2639,7 +2631,7 @@ def f(args):
26392631 c_args , c_kwargs = pytree .tree_unflatten (cur_flat_args , args_spec )
26402632 return op .op (* c_args , ** c_kwargs )
26412633
2642- compiled_f = compiled_function (f , nop , nop )
2634+ compiled_f = compiled_function (f , nop , nop , dynamic = dynamic )
26432635 try :
26442636 _test_aot_autograd_forwards_backwards_helper (self , f , compiled_f , args )
26452637 except GuardOnDataDependentSymNode :
@@ -2651,7 +2643,7 @@ def f(args):
26512643 else :
26522644 raise
26532645
2654- def _test_aot_autograd_module_helper (self , device , dtype , training , module_info ):
2646+ def _test_aot_autograd_module_helper (self , device , dtype , training , module_info , * , dynamic = False ):
26552647 module_cls = module_info .module_cls
26562648 module_inputs = module_info .module_inputs_func (module_info , device = device , dtype = dtype ,
26572649 requires_grad = True , training = training )
@@ -2696,7 +2688,7 @@ def f(params_buffers_args):
26962688 named_params = dict (m .named_parameters (remove_duplicate = False ))
26972689 named_buffers = dict (m .named_buffers (remove_duplicate = False ))
26982690 num_params_buffers = len (named_params ) + len (named_buffers )
2699- compiled_f = aot_function (f , nop , num_params_buffers = num_params_buffers )
2691+ compiled_f = aot_function (f , nop , num_params_buffers = num_params_buffers , dynamic = dynamic )
27002692 params_buffers_args = [named_params , named_buffers , args ]
27012693 _test_aot_autograd_forwards_backwards_helper (self , f , compiled_f , params_buffers_args )
27022694
@@ -2708,13 +2700,11 @@ def test_aot_autograd_exhaustive(self, device, dtype, op):
27082700 _test_aot_autograd_helper (self , device , dtype , op )
27092701
27102702 @ops (op_db , allowed_dtypes = (torch .float ,))
2711- @patch ("functorch.compile.config.use_dynamic_shapes" , True )
2712- @patch ("functorch.compile.config.use_fake_tensor" , True )
27132703 @patch ("functorch.compile.config.use_functionalize" , True )
27142704 @skipOps ('TestEagerFusionOpInfo' , 'test_aot_autograd_symbolic_exhaustive' ,
27152705 aot_autograd_failures | symbolic_aot_autograd_failures )
27162706 def test_aot_autograd_symbolic_exhaustive (self , device , dtype , op ):
2717- _test_aot_autograd_helper (self , device , dtype , op )
2707+ _test_aot_autograd_helper (self , device , dtype , op , dynamic = True )
27182708
27192709
27202710aot_autograd_module_failures = set ({
@@ -2754,13 +2744,11 @@ def test_aot_autograd_module_exhaustive(self, device, dtype, training, module_in
27542744 _test_aot_autograd_module_helper (self , device , dtype , training , module_info )
27552745
27562746 @modules (module_db , allowed_dtypes = (torch .float ,))
2757- @patch ("functorch.compile.config.use_dynamic_shapes" , True )
2758- @patch ("functorch.compile.config.use_fake_tensor" , True )
27592747 @patch ("functorch.compile.config.use_functionalize" , True )
27602748 @decorateForModules (unittest .expectedFailure ,
27612749 aot_autograd_module_failures | symbolic_aot_autograd_module_failures )
27622750 def test_aot_autograd_symbolic_module_exhaustive (self , device , dtype , training , module_info ):
2763- _test_aot_autograd_module_helper (self , device , dtype , training , module_info )
2751+ _test_aot_autograd_module_helper (self , device , dtype , training , module_info , dynamic = True )
27642752
27652753
27662754only_for = ("cpu" )
0 commit comments