Skip to content

Commit 609f76f

Browse files
James Reedfacebook-github-bot
authored andcommitted
[WIP][FX] Add Interpreter and Transformer (#50420)
Summary: Pull Request resolved: #50420 Test Plan: Imported from OSS Reviewed By: zdevito Differential Revision: D25880330 Pulled By: jamesr66a fbshipit-source-id: 27d34888e36e39924821fed891d79f969237a104
1 parent 0831984 commit 609f76f

6 files changed

Lines changed: 548 additions & 45 deletions

File tree

docs/source/fx.rst

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -322,3 +322,9 @@ API Reference
322322
:members:
323323

324324
.. autoclass:: torch.fx.Proxy
325+
326+
.. autoclass:: torch.fx.Interpreter
327+
:members:
328+
329+
.. autoclass:: torch.fx.Transformer
330+
:members:

test/test_fx.py

Lines changed: 142 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,8 @@
1313
from math import sqrt
1414
from pathlib import Path
1515
from torch.multiprocessing import Process
16-
from torch.fx import symbolic_trace, Proxy, Node, GraphModule, Tracer, Graph, wrap
16+
from torch.fx import symbolic_trace, Proxy, Node, GraphModule, Interpreter, Tracer, Transformer, Graph, wrap
17+
from torch.fx.node import Target
1718
from torch.fx.experimental import shape_prop
1819
from torch.fx.immutable_collections import immutable_dict, immutable_list
1920
from copy import deepcopy
@@ -957,6 +958,146 @@ def forward(self, x):
957958
# Test shape propogation and make sure results match actual
958959
self.assertEqual(output_shape, ref_out.shape)
959960

961+
def test_interpreter(self):
962+
class MyModule(torch.nn.Module):
963+
def __init__(self):
964+
super().__init__()
965+
self.param = torch.nn.Parameter(torch.rand(3, 4))
966+
self.linear = torch.nn.Linear(4, 5)
967+
968+
def forward(self, x):
969+
return self.linear(x + self.param).clamp(min=0.0, max=1.0)
970+
971+
m = MyModule()
972+
gm = torch.fx.symbolic_trace(m)
973+
974+
interpreter = Interpreter(gm)
975+
input = torch.randn(3, 4)
976+
self.assertEqual(interpreter.run(input), gm(input))
977+
self.assertEqual(interpreter.run(input), m(input))
978+
979+
def test_interpreter_run_node_override(self):
980+
class MyModule(torch.nn.Module):
981+
def __init__(self):
982+
super().__init__()
983+
self.param = torch.nn.Parameter(torch.rand(3, 4))
984+
self.linear = torch.nn.Linear(4, 5)
985+
986+
def forward(self, x):
987+
return self.linear(x + self.param).clamp(min=0.0, max=1.0)
988+
989+
m = MyModule()
990+
gm = torch.fx.symbolic_trace(m)
991+
992+
class RunNodeInterpreter(Interpreter):
993+
def __init__(self, module):
994+
super().__init__(module)
995+
996+
def run_node(self, n : Node) -> Any:
997+
result = super().run_node(n)
998+
n.cached_value = result
999+
return result
1000+
1001+
input = torch.randn(3, 4)
1002+
RunNodeInterpreter(gm).run(input)
1003+
for node in gm.graph.nodes:
1004+
assert hasattr(node, 'cached_value')
1005+
1006+
def test_interpreter_onthefly_swap(self):
1007+
1008+
def fn(x):
1009+
return torch.sigmoid(x).neg()
1010+
1011+
gm = torch.fx.symbolic_trace(fn)
1012+
1013+
class NegSigmSwapInterpreter(Interpreter):
1014+
def call_function(self, target : Target, args : Tuple, kwargs : Dict) -> Any:
1015+
if target == torch.sigmoid:
1016+
return torch.neg(*args, **kwargs)
1017+
return super().call_function(n)
1018+
1019+
def call_method(self, target : Target, args : Tuple, kwargs : Dict) -> Any:
1020+
if target == 'neg':
1021+
call_self, *args_tail = args
1022+
return call_self.sigmoid(*args_tail, **kwargs)
1023+
return super().call_method(n)
1024+
1025+
input = torch.randn(3, 4)
1026+
result = NegSigmSwapInterpreter(gm).run(input)
1027+
self.assertEqual(result, torch.neg(input).sigmoid())
1028+
1029+
def test_interpreter_partial_eval(self):
1030+
class MyModule(torch.nn.Module):
1031+
def __init__(self):
1032+
super().__init__()
1033+
self.param = torch.nn.Parameter(torch.rand(3, 4))
1034+
self.linear = torch.nn.Linear(4, 5)
1035+
1036+
def forward(self, x):
1037+
return self.linear(x + self.param).clamp(min=0.0, max=1.0)
1038+
1039+
gm = torch.fx.symbolic_trace(MyModule())
1040+
interp = Interpreter(gm)
1041+
env = {}
1042+
for node in gm.graph.nodes:
1043+
if node.op == 'call_module' and node.target == 'linear':
1044+
env[node] = torch.arange(0, 12, 1).reshape(3, 4) - 6.0
1045+
break
1046+
assert len(env) == 1
1047+
x = torch.randn(3, 4)
1048+
result = interp.run(x, initial_env=env)
1049+
self.assertEqual(result, (torch.arange(0, 12, 1).reshape(3, 4) - 6.0).clamp(0.0, 1.0))
1050+
1051+
def test_interpreter_star_args(self):
1052+
def with_star_args(x, *args):
1053+
return x + args[0]
1054+
1055+
gm = torch.fx.symbolic_trace(with_star_args)
1056+
interp = Interpreter(gm)
1057+
result = interp.run(torch.ones(3, 4), torch.ones(3, 4), torch.rand(3, 4))
1058+
self.assertEqual(result, torch.ones(3, 4) * 2.0)
1059+
1060+
def test_transformer_noop(self):
1061+
class MyModule(torch.nn.Module):
1062+
def __init__(self):
1063+
super().__init__()
1064+
self.param = torch.nn.Parameter(torch.rand(3, 4))
1065+
self.linear = torch.nn.Linear(4, 5)
1066+
1067+
def forward(self, x):
1068+
return self.linear(x + self.param).clamp(min=0.0, max=1.0)
1069+
1070+
m = MyModule()
1071+
gm = torch.fx.symbolic_trace(m)
1072+
1073+
new_gm = Transformer(gm).transform()
1074+
1075+
input = torch.randn(3, 4)
1076+
self.assertEqual(new_gm(input), gm(input))
1077+
1078+
def test_transformer_op_swap(self):
1079+
1080+
def fn(x):
1081+
return torch.sigmoid(x).neg()
1082+
1083+
gm = torch.fx.symbolic_trace(fn)
1084+
1085+
class NegSigmSwapXformer(Transformer):
1086+
def call_function(self, target : Target, args : Tuple, kwargs : Dict) -> Any:
1087+
if target == torch.sigmoid:
1088+
return torch.neg(*args, **kwargs)
1089+
return super().call_function(n)
1090+
1091+
def call_method(self, target : Target, args : Tuple, kwargs : Dict) -> Any:
1092+
if target == 'neg':
1093+
call_self, *args_tail = args
1094+
return call_self.sigmoid(*args_tail, **kwargs)
1095+
return super().call_method(n)
1096+
1097+
transformed = NegSigmSwapXformer(gm).transform()
1098+
input = torch.randn(3, 4)
1099+
self.assertEqual(transformed(input), torch.neg(input).sigmoid())
1100+
9601101
def test_fn_type_annotations(self):
9611102
class Foo(torch.nn.Module):
9621103
def forward(self, p : Pair, z : torch.Tensor, i : int) -> Dict[str, torch.Tensor]:

torch/fx/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -82,3 +82,4 @@ def forward(self, x):
8282
from .graph import Graph
8383
from .node import Node, map_arg
8484
from .proxy import Proxy
85+
from .interpreter import Interpreter as Interpreter, Transformer as Transformer

torch/fx/__init__.pyi

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,3 +3,4 @@ from .graph_module import GraphModule as GraphModule
33
from .node import Node as Node, map_arg as map_arg
44
from .proxy import Proxy as Proxy
55
from .symbolic_trace import Tracer as Tracer, symbolic_trace as symbolic_trace, wrap as wrap
6+
from .interpreter import Interpreter as Interpreter, Transformer as Transformer
Lines changed: 10 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -1,51 +1,17 @@
11
import torch
22
import torch.fx
33
from torch.fx.node import Node
4+
from typing import Any
45

5-
from typing import Dict
6+
class ShapeProp(torch.fx.Interpreter):
7+
def run_node(self, n : Node) -> Any:
8+
result = super().run_node(n)
69

7-
class ShapeProp:
8-
def __init__(self, mod):
9-
self.mod = mod
10-
self.graph = mod.graph
11-
self.modules = dict(self.mod.named_modules())
10+
if isinstance(result, torch.Tensor):
11+
n.shape = result.shape # type: ignore
12+
n.dtype = result.dtype # type: ignore
1213

13-
def propagate(self, *args):
14-
args_iter = iter(args)
15-
env : Dict[str, Node] = {}
16-
17-
def load_arg(a):
18-
return torch.fx.node.map_arg(a, lambda n: env[n.name])
19-
20-
def fetch_attr(target : str):
21-
target_atoms = target.split('.')
22-
attr_itr = self.mod
23-
for i, atom in enumerate(target_atoms):
24-
if not hasattr(attr_itr, atom):
25-
raise RuntimeError(f"Node referenced nonexistant target {'.'.join(target_atoms[:i])}")
26-
attr_itr = getattr(attr_itr, atom)
27-
return attr_itr
28-
29-
for node in self.graph.nodes:
30-
if node.op == 'placeholder':
31-
result = next(args_iter)
32-
elif node.op == 'get_attr':
33-
result = fetch_attr(node.target)
34-
elif node.op == 'call_function':
35-
result = node.target(*load_arg(node.args), **load_arg(node.kwargs))
36-
elif node.op == 'call_method':
37-
self_obj, *args = load_arg(node.args)
38-
kwargs = load_arg(node.kwargs)
39-
result = getattr(self_obj, node.target)(*args, **kwargs)
40-
elif node.op == 'call_module':
41-
result = self.modules[node.target](*load_arg(node.args), **load_arg(node.kwargs))
42-
elif node.op == 'output':
43-
return load_arg(node.args[0])
14+
return result
4415

45-
if isinstance(result, torch.Tensor):
46-
node.shape = result.shape
47-
node.dtype = result.dtype
48-
49-
env[node.name] = result
50-
51-
return None
16+
def propagate(self, *args):
17+
return super().run(*args)

0 commit comments

Comments
 (0)