|
13 | 13 | from math import sqrt |
14 | 14 | from pathlib import Path |
15 | 15 | from torch.multiprocessing import Process |
16 | | -from torch.fx import symbolic_trace, Proxy, Node, GraphModule, Tracer, Graph, wrap |
| 16 | +from torch.fx import symbolic_trace, Proxy, Node, GraphModule, Interpreter, Tracer, Transformer, Graph, wrap |
| 17 | +from torch.fx.node import Target |
17 | 18 | from torch.fx.experimental import shape_prop |
18 | 19 | from torch.fx.immutable_collections import immutable_dict, immutable_list |
19 | 20 | from copy import deepcopy |
@@ -957,6 +958,146 @@ def forward(self, x): |
957 | 958 | # Test shape propogation and make sure results match actual |
958 | 959 | self.assertEqual(output_shape, ref_out.shape) |
959 | 960 |
|
| 961 | + def test_interpreter(self): |
| 962 | + class MyModule(torch.nn.Module): |
| 963 | + def __init__(self): |
| 964 | + super().__init__() |
| 965 | + self.param = torch.nn.Parameter(torch.rand(3, 4)) |
| 966 | + self.linear = torch.nn.Linear(4, 5) |
| 967 | + |
| 968 | + def forward(self, x): |
| 969 | + return self.linear(x + self.param).clamp(min=0.0, max=1.0) |
| 970 | + |
| 971 | + m = MyModule() |
| 972 | + gm = torch.fx.symbolic_trace(m) |
| 973 | + |
| 974 | + interpreter = Interpreter(gm) |
| 975 | + input = torch.randn(3, 4) |
| 976 | + self.assertEqual(interpreter.run(input), gm(input)) |
| 977 | + self.assertEqual(interpreter.run(input), m(input)) |
| 978 | + |
| 979 | + def test_interpreter_run_node_override(self): |
| 980 | + class MyModule(torch.nn.Module): |
| 981 | + def __init__(self): |
| 982 | + super().__init__() |
| 983 | + self.param = torch.nn.Parameter(torch.rand(3, 4)) |
| 984 | + self.linear = torch.nn.Linear(4, 5) |
| 985 | + |
| 986 | + def forward(self, x): |
| 987 | + return self.linear(x + self.param).clamp(min=0.0, max=1.0) |
| 988 | + |
| 989 | + m = MyModule() |
| 990 | + gm = torch.fx.symbolic_trace(m) |
| 991 | + |
| 992 | + class RunNodeInterpreter(Interpreter): |
| 993 | + def __init__(self, module): |
| 994 | + super().__init__(module) |
| 995 | + |
| 996 | + def run_node(self, n : Node) -> Any: |
| 997 | + result = super().run_node(n) |
| 998 | + n.cached_value = result |
| 999 | + return result |
| 1000 | + |
| 1001 | + input = torch.randn(3, 4) |
| 1002 | + RunNodeInterpreter(gm).run(input) |
| 1003 | + for node in gm.graph.nodes: |
| 1004 | + assert hasattr(node, 'cached_value') |
| 1005 | + |
| 1006 | + def test_interpreter_onthefly_swap(self): |
| 1007 | + |
| 1008 | + def fn(x): |
| 1009 | + return torch.sigmoid(x).neg() |
| 1010 | + |
| 1011 | + gm = torch.fx.symbolic_trace(fn) |
| 1012 | + |
| 1013 | + class NegSigmSwapInterpreter(Interpreter): |
| 1014 | + def call_function(self, target : Target, args : Tuple, kwargs : Dict) -> Any: |
| 1015 | + if target == torch.sigmoid: |
| 1016 | + return torch.neg(*args, **kwargs) |
| 1017 | + return super().call_function(n) |
| 1018 | + |
| 1019 | + def call_method(self, target : Target, args : Tuple, kwargs : Dict) -> Any: |
| 1020 | + if target == 'neg': |
| 1021 | + call_self, *args_tail = args |
| 1022 | + return call_self.sigmoid(*args_tail, **kwargs) |
| 1023 | + return super().call_method(n) |
| 1024 | + |
| 1025 | + input = torch.randn(3, 4) |
| 1026 | + result = NegSigmSwapInterpreter(gm).run(input) |
| 1027 | + self.assertEqual(result, torch.neg(input).sigmoid()) |
| 1028 | + |
| 1029 | + def test_interpreter_partial_eval(self): |
| 1030 | + class MyModule(torch.nn.Module): |
| 1031 | + def __init__(self): |
| 1032 | + super().__init__() |
| 1033 | + self.param = torch.nn.Parameter(torch.rand(3, 4)) |
| 1034 | + self.linear = torch.nn.Linear(4, 5) |
| 1035 | + |
| 1036 | + def forward(self, x): |
| 1037 | + return self.linear(x + self.param).clamp(min=0.0, max=1.0) |
| 1038 | + |
| 1039 | + gm = torch.fx.symbolic_trace(MyModule()) |
| 1040 | + interp = Interpreter(gm) |
| 1041 | + env = {} |
| 1042 | + for node in gm.graph.nodes: |
| 1043 | + if node.op == 'call_module' and node.target == 'linear': |
| 1044 | + env[node] = torch.arange(0, 12, 1).reshape(3, 4) - 6.0 |
| 1045 | + break |
| 1046 | + assert len(env) == 1 |
| 1047 | + x = torch.randn(3, 4) |
| 1048 | + result = interp.run(x, initial_env=env) |
| 1049 | + self.assertEqual(result, (torch.arange(0, 12, 1).reshape(3, 4) - 6.0).clamp(0.0, 1.0)) |
| 1050 | + |
| 1051 | + def test_interpreter_star_args(self): |
| 1052 | + def with_star_args(x, *args): |
| 1053 | + return x + args[0] |
| 1054 | + |
| 1055 | + gm = torch.fx.symbolic_trace(with_star_args) |
| 1056 | + interp = Interpreter(gm) |
| 1057 | + result = interp.run(torch.ones(3, 4), torch.ones(3, 4), torch.rand(3, 4)) |
| 1058 | + self.assertEqual(result, torch.ones(3, 4) * 2.0) |
| 1059 | + |
| 1060 | + def test_transformer_noop(self): |
| 1061 | + class MyModule(torch.nn.Module): |
| 1062 | + def __init__(self): |
| 1063 | + super().__init__() |
| 1064 | + self.param = torch.nn.Parameter(torch.rand(3, 4)) |
| 1065 | + self.linear = torch.nn.Linear(4, 5) |
| 1066 | + |
| 1067 | + def forward(self, x): |
| 1068 | + return self.linear(x + self.param).clamp(min=0.0, max=1.0) |
| 1069 | + |
| 1070 | + m = MyModule() |
| 1071 | + gm = torch.fx.symbolic_trace(m) |
| 1072 | + |
| 1073 | + new_gm = Transformer(gm).transform() |
| 1074 | + |
| 1075 | + input = torch.randn(3, 4) |
| 1076 | + self.assertEqual(new_gm(input), gm(input)) |
| 1077 | + |
| 1078 | + def test_transformer_op_swap(self): |
| 1079 | + |
| 1080 | + def fn(x): |
| 1081 | + return torch.sigmoid(x).neg() |
| 1082 | + |
| 1083 | + gm = torch.fx.symbolic_trace(fn) |
| 1084 | + |
| 1085 | + class NegSigmSwapXformer(Transformer): |
| 1086 | + def call_function(self, target : Target, args : Tuple, kwargs : Dict) -> Any: |
| 1087 | + if target == torch.sigmoid: |
| 1088 | + return torch.neg(*args, **kwargs) |
| 1089 | + return super().call_function(n) |
| 1090 | + |
| 1091 | + def call_method(self, target : Target, args : Tuple, kwargs : Dict) -> Any: |
| 1092 | + if target == 'neg': |
| 1093 | + call_self, *args_tail = args |
| 1094 | + return call_self.sigmoid(*args_tail, **kwargs) |
| 1095 | + return super().call_method(n) |
| 1096 | + |
| 1097 | + transformed = NegSigmSwapXformer(gm).transform() |
| 1098 | + input = torch.randn(3, 4) |
| 1099 | + self.assertEqual(transformed(input), torch.neg(input).sigmoid()) |
| 1100 | + |
960 | 1101 | def test_fn_type_annotations(self): |
961 | 1102 | class Foo(torch.nn.Module): |
962 | 1103 | def forward(self, p : Pair, z : torch.Tensor, i : int) -> Dict[str, torch.Tensor]: |
|
0 commit comments