3333from torch ._higher_order_ops .hints_wrap import hints_wrapper
3434from torch ._inductor .compile_fx import split_const_gm
3535from torch ._subclasses import FakeTensorMode
36- from torch .export import default_decompositions , Dim , export , unflatten
36+ from torch .export import (
37+ default_decompositions ,
38+ Dim ,
39+ export ,
40+ export_for_training ,
41+ unflatten ,
42+ )
3743from torch .export ._trace import (
3844 _export ,
3945 _export_to_torch_ir ,
@@ -931,7 +937,7 @@ def forward(self, x):
931937 # z = 4
932938 return x + y + z + w2
933939
934- ep = export (M (), (torch .randn (2 , 3 ),), strict = False )
940+ ep = export (M (), (torch .randn (2 , 3 ),), strict = False ). run_decompositions ({})
935941 self .assertEqual (list (ep .graph_signature .buffers_to_mutate .values ()), ["buf" ])
936942 self .assertTrue (
937943 torch .allclose (ep .module ()(torch .ones (2 , 3 ) + 1 ), torch .ones (2 , 3 ) * 12 )
@@ -980,7 +986,7 @@ def forward(self, x):
980986 # z = 3 + 3
981987 return x + y + z
982988
983- ep = export (M (), (torch .randn (2 , 3 ),), strict = False )
989+ ep = export (M (), (torch .randn (2 , 3 ),), strict = False ). run_decompositions ({})
984990 self .assertEqual (
985991 list (ep .graph_signature .buffers_to_mutate .values ()), ["buf_0" , "buf_1" ]
986992 )
@@ -1275,45 +1281,14 @@ def forward(self, x):
12751281 )
12761282 check_users_for_graph (ep .graph )
12771283
1278- @unittest .skipIf (IS_FBCODE , "Broken in fbcode" )
1279- def test_export_predispatch_custom_ops_warnings (self ):
1280- @torch .library .custom_op ("mylib::foo" , mutates_args = {})
1281- def foo (x : torch .Tensor ) -> torch .Tensor :
1282- return x .sin ()
1283-
1284- @foo .register_fake
1285- def _ (x ):
1286- return torch .empty_like (x )
1287-
1288- class Foo (torch .nn .Module ):
1289- def forward (self , x ):
1290- return foo (x )
1291-
1292- x = torch .randn (3 )
1293-
1294- # Assert no warnings
1295- with warnings .catch_warnings ():
1296- warnings .simplefilter ("error" )
1297- torch .export .export (Foo (), (x ,))
1298-
1284+ def test_export_custom_op_lib (self ):
12991285 ops_registered_before = set (torch .ops .mylib )
13001286
13011287 # Assert warning for CompositeImplictAutograd op
13021288 with torch .library ._scoped_library ("mylib" , "FRAGMENT" ) as lib :
13031289 lib .define ("foo123(Tensor x) -> Tensor" )
13041290 lib .impl ("foo123" , lambda x : x .sin (), "CompositeImplicitAutograd" )
13051291
1306- class Bar (torch .nn .Module ):
1307- def forward (self , x ):
1308- return torch .ops .mylib .foo123 (x )
1309-
1310- with self .assertWarnsRegex (
1311- UserWarning , "CompositeImplicitAutograd and have functional schema"
1312- ):
1313- with warnings .catch_warnings ():
1314- warnings .simplefilter ("always" )
1315- torch .export .export (Bar (), (x ,))
1316-
13171292 ops_registered_after = set (torch .ops .mylib )
13181293 self .assertEqual (ops_registered_after , ops_registered_before )
13191294
@@ -3006,7 +2981,7 @@ def forward(self, x):
30062981 return x .cos () + y .cos ()
30072982
30082983 foo = Module ()
3009- gm = export (foo , (torch .tensor ([2 , 3 , 5 ]),))
2984+ gm = export (foo , (torch .tensor ([2 , 3 , 5 ]),)). run_decompositions ({})
30102985
30112986 view_count = 0
30122987 for node in gm .graph .nodes :
@@ -4059,7 +4034,7 @@ class Module(torch.nn.Module):
40594034 def forward (self , x ):
40604035 return x .to ("cpu" )
40614036
4062- ep = export (Module (), (torch .tensor (1 , device = "cpu" ),))
4037+ ep = export (Module (), (torch .tensor (1 , device = "cpu" ),)). run_decompositions ({})
40634038 ops = []
40644039 for node in ep .graph .nodes :
40654040 if node .op == "call_function" :
@@ -4077,7 +4052,7 @@ def forward(self, x):
40774052 Module (),
40784053 (torch .tensor ([1 , 2 ], device = "cpu" ),),
40794054 dynamic_shapes = {"x" : {0 : Dim ("i" )}},
4080- )
4055+ ). run_decompositions ({})
40814056 ops = []
40824057 for node in ep .graph .nodes :
40834058 if node .op == "call_function" :
@@ -4096,14 +4071,16 @@ def forward(self, x):
40964071 with self .assertRaisesRegex (
40974072 RuntimeError , "cannot mutate tensors with frozen storage"
40984073 ):
4099- export (Module (), (torch .tensor (1 , device = "cpu" ),))
4074+ export (Module (), (torch .tensor (1 , device = "cpu" ),)). run_decompositions ({})
41004075
41014076 def test_float_conversion (self ):
41024077 class Module (torch .nn .Module ):
41034078 def forward (self , x ):
41044079 return x .float ()
41054080
4106- ep = export (Module (), (torch .tensor (1 , dtype = torch .float ),))
4081+ ep = export (Module (), (torch .tensor (1 , dtype = torch .float ),)).run_decompositions (
4082+ {}
4083+ )
41074084 ops = []
41084085 for node in ep .graph .nodes :
41094086 if node .op == "call_function" :
@@ -4122,7 +4099,9 @@ def forward(self, x):
41224099 with self .assertRaisesRegex (
41234100 RuntimeError , "cannot mutate tensors with frozen storage"
41244101 ):
4125- export (Module (), (torch .tensor (1 , dtype = torch .float ),))
4102+ export (Module (), (torch .tensor (1 , dtype = torch .float ),)).run_decompositions (
4103+ {}
4104+ )
41264105
41274106 def test_module (self ):
41284107 class MyLinear (torch .nn .Module ):
@@ -4260,7 +4239,6 @@ def forward(self, x):
42604239 "torch.ops.aten._assert_async.msg" , 1 , exactly = True
42614240 ).run (ep .graph_module .code )
42624241
4263- @testing .expectedFailureRetraceabilityNonStrict
42644242 def test_decomp_item_in_prim_after_decomposition (self ):
42654243 class M (torch .nn .Module ):
42664244 def forward (self , x ):
@@ -4269,71 +4247,18 @@ def forward(self, x):
42694247
42704248 decomp_table = {** _decomp_table_to_post_autograd_aten (), ** decomposition_table }
42714249
4272- ep = export (M (), (torch .randn (2 , 2 ),)).run_decompositions (decomp_table )
4273-
4274- # The difference seems fine because export_for_training catches const tensor little differently.
4275- # Training IR produces:
4276- # graph():
4277- # %c_lifted_tensor_0 : [num_users=1] = placeholder[target=c_lifted_tensor_0]
4278- # %x : [num_users=1] = placeholder[target=x]
4279- # %lift_fresh_copy : [num_users=1] = call_function[target=torch.ops.aten.lift_fresh_copy.default](args = (%c_lifted_tensor_0,), kwargs = {})
4280- # %detach_ : [num_users=1] = call_function[target=torch.ops.aten.detach_.default](args = (%lift_fresh_copy,), kwargs = {})
4281- # %_assert_async : [num_users=0] = call_function[target=torch.ops.aten._assert_async.msg](args = (%detach_, Fail), kwargs = {})
4282- # return (x,)
4283- #
4284- # Pre-dispatch functionalization produces:
4285- # graph():
4286- # %c_lifted_tensor_0 : [num_users=1] = placeholder[target=c_lifted_tensor_0]
4287- # %x : [num_users=1] = placeholder[target=x]
4288- # %lift_fresh_copy : [num_users=1] = call_function[target=torch.ops.aten.lift_fresh_copy.default](args = (%c_lifted_tensor_0,), kwargs = {})
4289- # %detach : [num_users=1] = call_function[target=torch.ops.aten.detach.default](args = (%lift_fresh_copy,), kwargs = {})
4290- # %_assert_async : [num_users=0] = call_function[target=torch.ops.aten._assert_async.msg](args = (%detach, Fail), kwargs = {})
4291- # return (x,)
4292- #
4293- # Retracing:
4294- # graph():
4295- # %c_lifted_tensor_0 : [num_users=1] = placeholder[target=c_lifted_tensor_0]
4296- # %x : [num_users=1] = placeholder[target=x]
4297- # %clone : [num_users=1] = call_function[target=torch.ops.aten.clone.default](args = (%c_lifted_tensor_0,), kwargs = {})
4298- # %detach : [num_users=1] = call_function[target=torch.ops.aten.detach.default](args = (%clone,), kwargs = {})
4299- # %_assert_async : [num_users=0] = call_function[target=torch.ops.aten._assert_async.msg](args = (%detach, Fail), kwargs = {})
4300- # return (x,)
4301- # The difference comes from the fact that prim has registration for aten.detach while not for aten.detach_.
4302- # The diference in retracing comes from the fact that we retrace at pre-dispatch level while the usual flow
4303- # traces to post-dispatch.
4304- if is_training_ir_test (self ._testMethodName ):
4305- self .assertExpectedInline (
4306- str (ep .graph_module .code ).strip (),
4307- """\
4250+ ep = export_for_training (M (), (torch .randn (2 , 2 ),)).run_decompositions (
4251+ decomp_table
4252+ )
4253+
4254+ self .assertExpectedInline (
4255+ str (ep .graph_module .code ).strip (),
4256+ """\
43084257 def forward(self, c_lifted_tensor_0, x):
43094258 lift_fresh_copy = torch.ops.aten.lift_fresh_copy.default(c_lifted_tensor_0); c_lifted_tensor_0 = None
43104259 _assert_async = torch.ops.aten._assert_async.msg(lift_fresh_copy, 'Fail'); lift_fresh_copy = _assert_async = None
43114260 return (x,)""" ,
4312- )
4313- elif is_retracebility_test (self ._testMethodName ):
4314- self .assertExpectedInline (
4315- str (ep .graph_module .code ).strip (),
4316- """\
4317- def forward(self, c_lifted_tensor_0, x):
4318- clone = torch.ops.prims.clone.default(c_lifted_tensor_0, memory_format = torch.preserve_format); c_lifted_tensor_0 = None
4319- view_of = torch.ops.prims.view_of.default(clone); clone = None
4320- view_of_1 = torch.ops.prims.view_of.default(view_of); view_of = None
4321- view_of_2 = torch.ops.prims.view_of.default(view_of_1); view_of_1 = None
4322- _assert_async = torch.ops.aten._assert_async.msg(view_of_2, 'Fail'); view_of_2 = _assert_async = None
4323- return (x,)""" ,
4324- )
4325- else :
4326- self .assertExpectedInline (
4327- str (ep .graph_module .code ).strip (),
4328- """\
4329- def forward(self, c_lifted_tensor_0, x):
4330- lift_fresh_copy = torch.ops.aten.lift_fresh_copy.default(c_lifted_tensor_0); c_lifted_tensor_0 = None
4331- view_of = torch.ops.prims.view_of.default(lift_fresh_copy); lift_fresh_copy = None
4332- view_of_1 = torch.ops.prims.view_of.default(view_of); view_of = None
4333- view_of_2 = torch.ops.prims.view_of.default(view_of_1); view_of_1 = None
4334- _assert_async = torch.ops.aten._assert_async.msg(view_of_2, 'Fail'); view_of_2 = _assert_async = None
4335- return (x,)""" ,
4336- )
4261+ )
43374262
43384263 def test_decomp_batch_norm_functional_predispatch (self ):
43394264 class ConvBatchnorm (torch .nn .Module ):
@@ -5034,7 +4959,7 @@ class M(torch.nn.Module):
50344959 def forward (self , x ):
50354960 return torch .ops .aten .lift_fresh_copy (x )
50364961
5037- ep = export (M (), (torch .ones (6 , 4 ),))
4962+ ep = export (M (), (torch .ones (6 , 4 ),)). run_decompositions ({})
50384963 found = False
50394964
50404965 op = "torch.ops.aten.clone.default"
@@ -5650,7 +5575,6 @@ def forward(self, x):
56505575 unflattened = unflatten (ep )
56515576 self .assertTrue (torch .allclose (unflattened (* inps ), M2 ()(* inps )))
56525577
5653- @testing .expectedFailureRetraceability # Retracing tensor constants results in buffers
56545578 def test_nested_module_with_constant_buffer (self ):
56555579 class M1 (torch .nn .Module ):
56565580 def __init__ (self ) -> None :
@@ -5666,37 +5590,22 @@ def forward(self, x):
56665590 return m (x ) * x
56675591
56685592 inps = (torch .randn (3 , 3 ),)
5669- ep = export (M2 (), inps )
5593+ ep = export_for_training (M2 (), inps ). run_decompositions ({} )
56705594 self .assertTrue (torch .allclose (ep .module ()(* inps ), M2 ()(* inps )))
56715595
56725596 self .assertEqual (len (ep .state_dict ), 0 )
56735597 self .assertEqual (len (ep .constants ), 1 )
5674-
5675- if is_training_ir_test (self ._testMethodName ):
5676- self .assertExpectedInline (
5677- str (ep .graph ).strip (),
5678- """\
5598+ self .assertExpectedInline (
5599+ str (ep .graph ).strip (),
5600+ """\
56795601 graph():
56805602 %c_lifted_tensor_0 : [num_users=1] = placeholder[target=c_lifted_tensor_0]
56815603 %x : [num_users=2] = placeholder[target=x]
56825604 %lift_fresh_copy : [num_users=1] = call_function[target=torch.ops.aten.lift_fresh_copy.default](args = (%c_lifted_tensor_0,), kwargs = {})
56835605 %add : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%x, %lift_fresh_copy), kwargs = {})
56845606 %mul : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%add, %x), kwargs = {})
56855607 return (mul,)""" ,
5686- )
5687- else :
5688- self .assertExpectedInline (
5689- str (ep .graph ).strip (),
5690- """\
5691- graph():
5692- %c_lifted_tensor_0 : [num_users=1] = placeholder[target=c_lifted_tensor_0]
5693- %x : [num_users=2] = placeholder[target=x]
5694- %lift_fresh_copy : [num_users=1] = call_function[target=torch.ops.aten.lift_fresh_copy.default](args = (%c_lifted_tensor_0,), kwargs = {})
5695- %detach : [num_users=1] = call_function[target=torch.ops.aten.detach.default](args = (%lift_fresh_copy,), kwargs = {})
5696- %add : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%x, %detach), kwargs = {})
5697- %mul : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%add, %x), kwargs = {})
5698- return (mul,)""" ,
5699- )
5608+ )
57005609
57015610 unflattened = unflatten (ep )
57025611 self .assertTrue (torch .allclose (unflattened (* inps ), M2 ()(* inps )))
@@ -5718,7 +5627,7 @@ def forward(self, x):
57185627
57195628 inps = (torch .randn (3 , 3 ),)
57205629 # Strict export segfaults (Issue #128109)
5721- ep = torch . export . export (M2 (), inps , strict = False )
5630+ ep = export_for_training (M2 (), inps , strict = False ). run_decompositions ({} )
57225631 self .assertTrue (torch .allclose (ep .module ()(* inps ), M2 ()(* inps )))
57235632
57245633 self .assertEqual (len (ep .state_dict ), 0 )
@@ -5732,10 +5641,13 @@ def forward(self, x):
57325641 %x : [num_users=2] = placeholder[target=x]
57335642 %ones : [num_users=1] = call_function[target=torch.ops.aten.ones.default](args = ([3, 3],), kwargs = {device: cpu, pin_memory: False})
57345643 %detach : [num_users=1] = call_function[target=torch.ops.aten.detach.default](args = (%ones,), kwargs = {})
5735- %lift_fresh_copy : [num_users=1] = call_function[target=torch.ops.aten.lift_fresh_copy.default](args = (%c_lifted_tensor_0,), kwargs = {})
5736- %detach_1 : [num_users=1] = call_function[target=torch.ops.aten.detach.default](args = (%lift_fresh_copy,), kwargs = {})
5644+ %detach_1 : [num_users=1] = call_function[target=torch.ops.aten.detach.default](args = (%detach,), kwargs = {})
57375645 %detach_2 : [num_users=1] = call_function[target=torch.ops.aten.detach.default](args = (%detach_1,), kwargs = {})
5738- %mul : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%detach, %detach_2), kwargs = {})
5646+ %lift_fresh_copy : [num_users=1] = call_function[target=torch.ops.aten.lift_fresh_copy.default](args = (%c_lifted_tensor_0,), kwargs = {})
5647+ %detach_3 : [num_users=1] = call_function[target=torch.ops.aten.detach.default](args = (%lift_fresh_copy,), kwargs = {})
5648+ %detach_4 : [num_users=1] = call_function[target=torch.ops.aten.detach.default](args = (%detach_3,), kwargs = {})
5649+ %detach_5 : [num_users=1] = call_function[target=torch.ops.aten.detach.default](args = (%detach_4,), kwargs = {})
5650+ %mul : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%detach_2, %detach_5), kwargs = {})
57395651 %add : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%x, %mul), kwargs = {})
57405652 %mul_1 : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%add, %x), kwargs = {})
57415653 return (mul_1,)""" ,
@@ -6836,7 +6748,7 @@ def forward(self, x):
68366748 ep = export (
68376749 Foo (),
68386750 (torch .randn (4 , 4 ),),
6839- )
6751+ ). run_decompositions ({})
68406752 # check correct lines are in stack trace
68416753 trace_mul = [node for node in ep .graph .nodes if node .name == "mul" ][0 ].meta .get (
68426754 "stack_trace" , ""
@@ -7387,19 +7299,7 @@ def forward(self, x):
73877299
73887300 inps = (torch .ones (5 ),)
73897301
7390- ep = torch .export .export (M (), inps )
7391- self .assertExpectedInline (
7392- str (ep .graph_module .code .strip ()),
7393- """\
7394- def forward(self, x):
7395- cos = torch.ops.aten.cos.default(x)
7396- auto_functionalized = torch.ops.higher_order.auto_functionalized(torch.ops.testlib.foo.default, x = x, z = cos); x = cos = None
7397- getitem_3 = auto_functionalized[3]; auto_functionalized = None
7398- cos_1 = torch.ops.aten.cos.default(getitem_3)
7399- return (getitem_3, getitem_3, cos_1)""" ,
7400- )
7401-
7402- ep = torch .export ._trace ._export (M (), inps , pre_dispatch = True )
7302+ ep = export_for_training (M (), inps ).run_decompositions ({})
74037303 self .assertExpectedInline (
74047304 str (ep .graph_module .code .strip ()),
74057305 """\
@@ -8824,7 +8724,40 @@ def outer_body_fn(x, y):
88248724 x = torch .randn (2 , 4 )
88258725 y = torch .ones (4 )
88268726
8827- ep = export (M (), (x , y ))
8727+ ep_for_training = torch .export .export_for_training (M (), (x , y ))
8728+ self .assertExpectedInline (
8729+ normalize_gm (
8730+ ep_for_training .graph_module .print_readable (print_output = False )
8731+ ),
8732+ """\
8733+ class GraphModule(torch.nn.Module):
8734+ def forward(self, x: "f32[2, 4]", y: "f32[4]"):
8735+ add: "f32[2, 4]" = torch.ops.aten.add.Tensor(x, y); x = None
8736+
8737+ hints_wrapper_body_graph_0 = self.hints_wrapper_body_graph_0
8738+ hints_wrapper = torch.ops.higher_order.hints_wrapper(hints_wrapper_body_graph_0, (add, y), {}, hints = {'outer_body': True}); hints_wrapper_body_graph_0 = add = y = None
8739+ getitem: "f32[2, 4]" = hints_wrapper[0]; hints_wrapper = None
8740+ return (getitem,)
8741+
8742+ class hints_wrapper_body_graph_0(torch.nn.Module):
8743+ def forward(self, arg0_1: "f32[2, 4]", arg1_1: "f32[4]"):
8744+ hints_wrapper_body_graph_0 = self.hints_wrapper_body_graph_0
8745+ hints_wrapper = torch.ops.higher_order.hints_wrapper(hints_wrapper_body_graph_0, (arg0_1, arg1_1), {}, hints = {'inner_body': True}); hints_wrapper_body_graph_0 = arg0_1 = arg1_1 = None
8746+ getitem: "f32[2, 4]" = hints_wrapper[0]; hints_wrapper = None
8747+
8748+ abs_1: "f32[2, 4]" = torch.ops.aten.abs.default(getitem); getitem = None
8749+ return (abs_1,)
8750+
8751+ class hints_wrapper_body_graph_0(torch.nn.Module):
8752+ def forward(self, arg0_1: "f32[2, 4]", arg1_1: "f32[4]"):
8753+ relu: "f32[2, 4]" = torch.ops.aten.relu.default(arg0_1); arg0_1 = None
8754+
8755+ add: "f32[2, 4]" = torch.ops.aten.add.Tensor(relu, arg1_1); relu = arg1_1 = None
8756+ return (add,)
8757+ """ ,
8758+ )
8759+
8760+ ep = export (M (), (x , y )).run_decompositions ({})
88288761 export_res = ep .module ()(x , y )
88298762 ref_res = M ()(x , y )
88308763 self .assertEqual (export_res , ref_res )
0 commit comments