Skip to content

Commit b7524b0

Browse files
tugsbayasgalanpytorchmergebot
authored andcommitted
Make test_export training IR compatible (#138517)
In this PR, I make test_export to be compatible with training IR. The idea is that when we flip the IR to non-functional training IR, all these tests should be green. The changes involve reading through the test case, and add necessary decomposition etc to make sure the tests pass. For example, if the tests expect to see mutated buffers returned, we need to get them via running run_decomp. Differential Revision: [D64732360](https://our.internmc.facebook.com/intern/diff/D64732360) Pull Request resolved: #138517 Approved by: https://github.com/avikchaudhuri
1 parent 904816d commit b7524b0

1 file changed

Lines changed: 77 additions & 144 deletions

File tree

test/export/test_export.py

Lines changed: 77 additions & 144 deletions
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,13 @@
3333
from torch._higher_order_ops.hints_wrap import hints_wrapper
3434
from torch._inductor.compile_fx import split_const_gm
3535
from torch._subclasses import FakeTensorMode
36-
from torch.export import default_decompositions, Dim, export, unflatten
36+
from torch.export import (
37+
default_decompositions,
38+
Dim,
39+
export,
40+
export_for_training,
41+
unflatten,
42+
)
3743
from torch.export._trace import (
3844
_export,
3945
_export_to_torch_ir,
@@ -931,7 +937,7 @@ def forward(self, x):
931937
# z = 4
932938
return x + y + z + w2
933939

934-
ep = export(M(), (torch.randn(2, 3),), strict=False)
940+
ep = export(M(), (torch.randn(2, 3),), strict=False).run_decompositions({})
935941
self.assertEqual(list(ep.graph_signature.buffers_to_mutate.values()), ["buf"])
936942
self.assertTrue(
937943
torch.allclose(ep.module()(torch.ones(2, 3) + 1), torch.ones(2, 3) * 12)
@@ -980,7 +986,7 @@ def forward(self, x):
980986
# z = 3 + 3
981987
return x + y + z
982988

983-
ep = export(M(), (torch.randn(2, 3),), strict=False)
989+
ep = export(M(), (torch.randn(2, 3),), strict=False).run_decompositions({})
984990
self.assertEqual(
985991
list(ep.graph_signature.buffers_to_mutate.values()), ["buf_0", "buf_1"]
986992
)
@@ -1275,45 +1281,14 @@ def forward(self, x):
12751281
)
12761282
check_users_for_graph(ep.graph)
12771283

1278-
@unittest.skipIf(IS_FBCODE, "Broken in fbcode")
1279-
def test_export_predispatch_custom_ops_warnings(self):
1280-
@torch.library.custom_op("mylib::foo", mutates_args={})
1281-
def foo(x: torch.Tensor) -> torch.Tensor:
1282-
return x.sin()
1283-
1284-
@foo.register_fake
1285-
def _(x):
1286-
return torch.empty_like(x)
1287-
1288-
class Foo(torch.nn.Module):
1289-
def forward(self, x):
1290-
return foo(x)
1291-
1292-
x = torch.randn(3)
1293-
1294-
# Assert no warnings
1295-
with warnings.catch_warnings():
1296-
warnings.simplefilter("error")
1297-
torch.export.export(Foo(), (x,))
1298-
1284+
def test_export_custom_op_lib(self):
12991285
ops_registered_before = set(torch.ops.mylib)
13001286

13011287
# Assert warning for CompositeImplictAutograd op
13021288
with torch.library._scoped_library("mylib", "FRAGMENT") as lib:
13031289
lib.define("foo123(Tensor x) -> Tensor")
13041290
lib.impl("foo123", lambda x: x.sin(), "CompositeImplicitAutograd")
13051291

1306-
class Bar(torch.nn.Module):
1307-
def forward(self, x):
1308-
return torch.ops.mylib.foo123(x)
1309-
1310-
with self.assertWarnsRegex(
1311-
UserWarning, "CompositeImplicitAutograd and have functional schema"
1312-
):
1313-
with warnings.catch_warnings():
1314-
warnings.simplefilter("always")
1315-
torch.export.export(Bar(), (x,))
1316-
13171292
ops_registered_after = set(torch.ops.mylib)
13181293
self.assertEqual(ops_registered_after, ops_registered_before)
13191294

@@ -3006,7 +2981,7 @@ def forward(self, x):
30062981
return x.cos() + y.cos()
30072982

30082983
foo = Module()
3009-
gm = export(foo, (torch.tensor([2, 3, 5]),))
2984+
gm = export(foo, (torch.tensor([2, 3, 5]),)).run_decompositions({})
30102985

30112986
view_count = 0
30122987
for node in gm.graph.nodes:
@@ -4059,7 +4034,7 @@ class Module(torch.nn.Module):
40594034
def forward(self, x):
40604035
return x.to("cpu")
40614036

4062-
ep = export(Module(), (torch.tensor(1, device="cpu"),))
4037+
ep = export(Module(), (torch.tensor(1, device="cpu"),)).run_decompositions({})
40634038
ops = []
40644039
for node in ep.graph.nodes:
40654040
if node.op == "call_function":
@@ -4077,7 +4052,7 @@ def forward(self, x):
40774052
Module(),
40784053
(torch.tensor([1, 2], device="cpu"),),
40794054
dynamic_shapes={"x": {0: Dim("i")}},
4080-
)
4055+
).run_decompositions({})
40814056
ops = []
40824057
for node in ep.graph.nodes:
40834058
if node.op == "call_function":
@@ -4096,14 +4071,16 @@ def forward(self, x):
40964071
with self.assertRaisesRegex(
40974072
RuntimeError, "cannot mutate tensors with frozen storage"
40984073
):
4099-
export(Module(), (torch.tensor(1, device="cpu"),))
4074+
export(Module(), (torch.tensor(1, device="cpu"),)).run_decompositions({})
41004075

41014076
def test_float_conversion(self):
41024077
class Module(torch.nn.Module):
41034078
def forward(self, x):
41044079
return x.float()
41054080

4106-
ep = export(Module(), (torch.tensor(1, dtype=torch.float),))
4081+
ep = export(Module(), (torch.tensor(1, dtype=torch.float),)).run_decompositions(
4082+
{}
4083+
)
41074084
ops = []
41084085
for node in ep.graph.nodes:
41094086
if node.op == "call_function":
@@ -4122,7 +4099,9 @@ def forward(self, x):
41224099
with self.assertRaisesRegex(
41234100
RuntimeError, "cannot mutate tensors with frozen storage"
41244101
):
4125-
export(Module(), (torch.tensor(1, dtype=torch.float),))
4102+
export(Module(), (torch.tensor(1, dtype=torch.float),)).run_decompositions(
4103+
{}
4104+
)
41264105

41274106
def test_module(self):
41284107
class MyLinear(torch.nn.Module):
@@ -4260,7 +4239,6 @@ def forward(self, x):
42604239
"torch.ops.aten._assert_async.msg", 1, exactly=True
42614240
).run(ep.graph_module.code)
42624241

4263-
@testing.expectedFailureRetraceabilityNonStrict
42644242
def test_decomp_item_in_prim_after_decomposition(self):
42654243
class M(torch.nn.Module):
42664244
def forward(self, x):
@@ -4269,71 +4247,18 @@ def forward(self, x):
42694247

42704248
decomp_table = {**_decomp_table_to_post_autograd_aten(), **decomposition_table}
42714249

4272-
ep = export(M(), (torch.randn(2, 2),)).run_decompositions(decomp_table)
4273-
4274-
# The difference seems fine because export_for_training catches const tensor little differently.
4275-
# Training IR produces:
4276-
# graph():
4277-
# %c_lifted_tensor_0 : [num_users=1] = placeholder[target=c_lifted_tensor_0]
4278-
# %x : [num_users=1] = placeholder[target=x]
4279-
# %lift_fresh_copy : [num_users=1] = call_function[target=torch.ops.aten.lift_fresh_copy.default](args = (%c_lifted_tensor_0,), kwargs = {})
4280-
# %detach_ : [num_users=1] = call_function[target=torch.ops.aten.detach_.default](args = (%lift_fresh_copy,), kwargs = {})
4281-
# %_assert_async : [num_users=0] = call_function[target=torch.ops.aten._assert_async.msg](args = (%detach_, Fail), kwargs = {})
4282-
# return (x,)
4283-
#
4284-
# Pre-dispatch functionalization produces:
4285-
# graph():
4286-
# %c_lifted_tensor_0 : [num_users=1] = placeholder[target=c_lifted_tensor_0]
4287-
# %x : [num_users=1] = placeholder[target=x]
4288-
# %lift_fresh_copy : [num_users=1] = call_function[target=torch.ops.aten.lift_fresh_copy.default](args = (%c_lifted_tensor_0,), kwargs = {})
4289-
# %detach : [num_users=1] = call_function[target=torch.ops.aten.detach.default](args = (%lift_fresh_copy,), kwargs = {})
4290-
# %_assert_async : [num_users=0] = call_function[target=torch.ops.aten._assert_async.msg](args = (%detach, Fail), kwargs = {})
4291-
# return (x,)
4292-
#
4293-
# Retracing:
4294-
# graph():
4295-
# %c_lifted_tensor_0 : [num_users=1] = placeholder[target=c_lifted_tensor_0]
4296-
# %x : [num_users=1] = placeholder[target=x]
4297-
# %clone : [num_users=1] = call_function[target=torch.ops.aten.clone.default](args = (%c_lifted_tensor_0,), kwargs = {})
4298-
# %detach : [num_users=1] = call_function[target=torch.ops.aten.detach.default](args = (%clone,), kwargs = {})
4299-
# %_assert_async : [num_users=0] = call_function[target=torch.ops.aten._assert_async.msg](args = (%detach, Fail), kwargs = {})
4300-
# return (x,)
4301-
# The difference comes from the fact that prim has registration for aten.detach while not for aten.detach_.
4302-
# The diference in retracing comes from the fact that we retrace at pre-dispatch level while the usual flow
4303-
# traces to post-dispatch.
4304-
if is_training_ir_test(self._testMethodName):
4305-
self.assertExpectedInline(
4306-
str(ep.graph_module.code).strip(),
4307-
"""\
4250+
ep = export_for_training(M(), (torch.randn(2, 2),)).run_decompositions(
4251+
decomp_table
4252+
)
4253+
4254+
self.assertExpectedInline(
4255+
str(ep.graph_module.code).strip(),
4256+
"""\
43084257
def forward(self, c_lifted_tensor_0, x):
43094258
lift_fresh_copy = torch.ops.aten.lift_fresh_copy.default(c_lifted_tensor_0); c_lifted_tensor_0 = None
43104259
_assert_async = torch.ops.aten._assert_async.msg(lift_fresh_copy, 'Fail'); lift_fresh_copy = _assert_async = None
43114260
return (x,)""",
4312-
)
4313-
elif is_retracebility_test(self._testMethodName):
4314-
self.assertExpectedInline(
4315-
str(ep.graph_module.code).strip(),
4316-
"""\
4317-
def forward(self, c_lifted_tensor_0, x):
4318-
clone = torch.ops.prims.clone.default(c_lifted_tensor_0, memory_format = torch.preserve_format); c_lifted_tensor_0 = None
4319-
view_of = torch.ops.prims.view_of.default(clone); clone = None
4320-
view_of_1 = torch.ops.prims.view_of.default(view_of); view_of = None
4321-
view_of_2 = torch.ops.prims.view_of.default(view_of_1); view_of_1 = None
4322-
_assert_async = torch.ops.aten._assert_async.msg(view_of_2, 'Fail'); view_of_2 = _assert_async = None
4323-
return (x,)""",
4324-
)
4325-
else:
4326-
self.assertExpectedInline(
4327-
str(ep.graph_module.code).strip(),
4328-
"""\
4329-
def forward(self, c_lifted_tensor_0, x):
4330-
lift_fresh_copy = torch.ops.aten.lift_fresh_copy.default(c_lifted_tensor_0); c_lifted_tensor_0 = None
4331-
view_of = torch.ops.prims.view_of.default(lift_fresh_copy); lift_fresh_copy = None
4332-
view_of_1 = torch.ops.prims.view_of.default(view_of); view_of = None
4333-
view_of_2 = torch.ops.prims.view_of.default(view_of_1); view_of_1 = None
4334-
_assert_async = torch.ops.aten._assert_async.msg(view_of_2, 'Fail'); view_of_2 = _assert_async = None
4335-
return (x,)""",
4336-
)
4261+
)
43374262

43384263
def test_decomp_batch_norm_functional_predispatch(self):
43394264
class ConvBatchnorm(torch.nn.Module):
@@ -5034,7 +4959,7 @@ class M(torch.nn.Module):
50344959
def forward(self, x):
50354960
return torch.ops.aten.lift_fresh_copy(x)
50364961

5037-
ep = export(M(), (torch.ones(6, 4),))
4962+
ep = export(M(), (torch.ones(6, 4),)).run_decompositions({})
50384963
found = False
50394964

50404965
op = "torch.ops.aten.clone.default"
@@ -5650,7 +5575,6 @@ def forward(self, x):
56505575
unflattened = unflatten(ep)
56515576
self.assertTrue(torch.allclose(unflattened(*inps), M2()(*inps)))
56525577

5653-
@testing.expectedFailureRetraceability # Retracing tensor constants results in buffers
56545578
def test_nested_module_with_constant_buffer(self):
56555579
class M1(torch.nn.Module):
56565580
def __init__(self) -> None:
@@ -5666,37 +5590,22 @@ def forward(self, x):
56665590
return m(x) * x
56675591

56685592
inps = (torch.randn(3, 3),)
5669-
ep = export(M2(), inps)
5593+
ep = export_for_training(M2(), inps).run_decompositions({})
56705594
self.assertTrue(torch.allclose(ep.module()(*inps), M2()(*inps)))
56715595

56725596
self.assertEqual(len(ep.state_dict), 0)
56735597
self.assertEqual(len(ep.constants), 1)
5674-
5675-
if is_training_ir_test(self._testMethodName):
5676-
self.assertExpectedInline(
5677-
str(ep.graph).strip(),
5678-
"""\
5598+
self.assertExpectedInline(
5599+
str(ep.graph).strip(),
5600+
"""\
56795601
graph():
56805602
%c_lifted_tensor_0 : [num_users=1] = placeholder[target=c_lifted_tensor_0]
56815603
%x : [num_users=2] = placeholder[target=x]
56825604
%lift_fresh_copy : [num_users=1] = call_function[target=torch.ops.aten.lift_fresh_copy.default](args = (%c_lifted_tensor_0,), kwargs = {})
56835605
%add : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%x, %lift_fresh_copy), kwargs = {})
56845606
%mul : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%add, %x), kwargs = {})
56855607
return (mul,)""",
5686-
)
5687-
else:
5688-
self.assertExpectedInline(
5689-
str(ep.graph).strip(),
5690-
"""\
5691-
graph():
5692-
%c_lifted_tensor_0 : [num_users=1] = placeholder[target=c_lifted_tensor_0]
5693-
%x : [num_users=2] = placeholder[target=x]
5694-
%lift_fresh_copy : [num_users=1] = call_function[target=torch.ops.aten.lift_fresh_copy.default](args = (%c_lifted_tensor_0,), kwargs = {})
5695-
%detach : [num_users=1] = call_function[target=torch.ops.aten.detach.default](args = (%lift_fresh_copy,), kwargs = {})
5696-
%add : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%x, %detach), kwargs = {})
5697-
%mul : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%add, %x), kwargs = {})
5698-
return (mul,)""",
5699-
)
5608+
)
57005609

57015610
unflattened = unflatten(ep)
57025611
self.assertTrue(torch.allclose(unflattened(*inps), M2()(*inps)))
@@ -5718,7 +5627,7 @@ def forward(self, x):
57185627

57195628
inps = (torch.randn(3, 3),)
57205629
# Strict export segfaults (Issue #128109)
5721-
ep = torch.export.export(M2(), inps, strict=False)
5630+
ep = export_for_training(M2(), inps, strict=False).run_decompositions({})
57225631
self.assertTrue(torch.allclose(ep.module()(*inps), M2()(*inps)))
57235632

57245633
self.assertEqual(len(ep.state_dict), 0)
@@ -5732,10 +5641,13 @@ def forward(self, x):
57325641
%x : [num_users=2] = placeholder[target=x]
57335642
%ones : [num_users=1] = call_function[target=torch.ops.aten.ones.default](args = ([3, 3],), kwargs = {device: cpu, pin_memory: False})
57345643
%detach : [num_users=1] = call_function[target=torch.ops.aten.detach.default](args = (%ones,), kwargs = {})
5735-
%lift_fresh_copy : [num_users=1] = call_function[target=torch.ops.aten.lift_fresh_copy.default](args = (%c_lifted_tensor_0,), kwargs = {})
5736-
%detach_1 : [num_users=1] = call_function[target=torch.ops.aten.detach.default](args = (%lift_fresh_copy,), kwargs = {})
5644+
%detach_1 : [num_users=1] = call_function[target=torch.ops.aten.detach.default](args = (%detach,), kwargs = {})
57375645
%detach_2 : [num_users=1] = call_function[target=torch.ops.aten.detach.default](args = (%detach_1,), kwargs = {})
5738-
%mul : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%detach, %detach_2), kwargs = {})
5646+
%lift_fresh_copy : [num_users=1] = call_function[target=torch.ops.aten.lift_fresh_copy.default](args = (%c_lifted_tensor_0,), kwargs = {})
5647+
%detach_3 : [num_users=1] = call_function[target=torch.ops.aten.detach.default](args = (%lift_fresh_copy,), kwargs = {})
5648+
%detach_4 : [num_users=1] = call_function[target=torch.ops.aten.detach.default](args = (%detach_3,), kwargs = {})
5649+
%detach_5 : [num_users=1] = call_function[target=torch.ops.aten.detach.default](args = (%detach_4,), kwargs = {})
5650+
%mul : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%detach_2, %detach_5), kwargs = {})
57395651
%add : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%x, %mul), kwargs = {})
57405652
%mul_1 : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%add, %x), kwargs = {})
57415653
return (mul_1,)""",
@@ -6836,7 +6748,7 @@ def forward(self, x):
68366748
ep = export(
68376749
Foo(),
68386750
(torch.randn(4, 4),),
6839-
)
6751+
).run_decompositions({})
68406752
# check correct lines are in stack trace
68416753
trace_mul = [node for node in ep.graph.nodes if node.name == "mul"][0].meta.get(
68426754
"stack_trace", ""
@@ -7387,19 +7299,7 @@ def forward(self, x):
73877299

73887300
inps = (torch.ones(5),)
73897301

7390-
ep = torch.export.export(M(), inps)
7391-
self.assertExpectedInline(
7392-
str(ep.graph_module.code.strip()),
7393-
"""\
7394-
def forward(self, x):
7395-
cos = torch.ops.aten.cos.default(x)
7396-
auto_functionalized = torch.ops.higher_order.auto_functionalized(torch.ops.testlib.foo.default, x = x, z = cos); x = cos = None
7397-
getitem_3 = auto_functionalized[3]; auto_functionalized = None
7398-
cos_1 = torch.ops.aten.cos.default(getitem_3)
7399-
return (getitem_3, getitem_3, cos_1)""",
7400-
)
7401-
7402-
ep = torch.export._trace._export(M(), inps, pre_dispatch=True)
7302+
ep = export_for_training(M(), inps).run_decompositions({})
74037303
self.assertExpectedInline(
74047304
str(ep.graph_module.code.strip()),
74057305
"""\
@@ -8824,7 +8724,40 @@ def outer_body_fn(x, y):
88248724
x = torch.randn(2, 4)
88258725
y = torch.ones(4)
88268726

8827-
ep = export(M(), (x, y))
8727+
ep_for_training = torch.export.export_for_training(M(), (x, y))
8728+
self.assertExpectedInline(
8729+
normalize_gm(
8730+
ep_for_training.graph_module.print_readable(print_output=False)
8731+
),
8732+
"""\
8733+
class GraphModule(torch.nn.Module):
8734+
def forward(self, x: "f32[2, 4]", y: "f32[4]"):
8735+
add: "f32[2, 4]" = torch.ops.aten.add.Tensor(x, y); x = None
8736+
8737+
hints_wrapper_body_graph_0 = self.hints_wrapper_body_graph_0
8738+
hints_wrapper = torch.ops.higher_order.hints_wrapper(hints_wrapper_body_graph_0, (add, y), {}, hints = {'outer_body': True}); hints_wrapper_body_graph_0 = add = y = None
8739+
getitem: "f32[2, 4]" = hints_wrapper[0]; hints_wrapper = None
8740+
return (getitem,)
8741+
8742+
class hints_wrapper_body_graph_0(torch.nn.Module):
8743+
def forward(self, arg0_1: "f32[2, 4]", arg1_1: "f32[4]"):
8744+
hints_wrapper_body_graph_0 = self.hints_wrapper_body_graph_0
8745+
hints_wrapper = torch.ops.higher_order.hints_wrapper(hints_wrapper_body_graph_0, (arg0_1, arg1_1), {}, hints = {'inner_body': True}); hints_wrapper_body_graph_0 = arg0_1 = arg1_1 = None
8746+
getitem: "f32[2, 4]" = hints_wrapper[0]; hints_wrapper = None
8747+
8748+
abs_1: "f32[2, 4]" = torch.ops.aten.abs.default(getitem); getitem = None
8749+
return (abs_1,)
8750+
8751+
class hints_wrapper_body_graph_0(torch.nn.Module):
8752+
def forward(self, arg0_1: "f32[2, 4]", arg1_1: "f32[4]"):
8753+
relu: "f32[2, 4]" = torch.ops.aten.relu.default(arg0_1); arg0_1 = None
8754+
8755+
add: "f32[2, 4]" = torch.ops.aten.add.Tensor(relu, arg1_1); relu = arg1_1 = None
8756+
return (add,)
8757+
""",
8758+
)
8759+
8760+
ep = export(M(), (x, y)).run_decompositions({})
88288761
export_res = ep.module()(x, y)
88298762
ref_res = M()(x, y)
88308763
self.assertEqual(export_res, ref_res)

0 commit comments

Comments
 (0)