Skip to content

Commit 2d1e923

Browse files
xwu-intelpytorchmergebot
authored andcommitted
Partitioner: Fix to align partition node order with original graph (#157892)
Fixes #157891 Pull Request resolved: #157892 Approved by: https://github.com/ezyang
1 parent 399c89e commit 2d1e923

3 files changed

Lines changed: 35 additions & 14 deletions

File tree

test/fx/test_partitioner_order.py

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@ def __init__(self, graph_module: torch.fx.GraphModule):
2424
)
2525

2626

27+
# original graph node order is: ['x', 'add', 'add_1', 'output']
2728
class AddModule(torch.nn.Module):
2829
def forward(self, x):
2930
y = torch.add(x, x)
@@ -32,8 +33,18 @@ def forward(self, x):
3233

3334

3435
class TestPartitionerOrder(TestCase):
35-
# partitoner test to check graph node order
36-
def test_partitioner_order(self):
36+
# partitoner test to check graph node order remains the same with the original graph after partitioning
37+
def test_partitioner_graph_node_order(self):
38+
m = AddModule()
39+
traced_m = torch.fx.symbolic_trace(m)
40+
origin_node_order = [n.name for n in traced_m.graph.nodes]
41+
partions = DummyPartitioner(traced_m).propose_partitions()
42+
partion_nodes = [list(partition.nodes) for partition in partions]
43+
partition_node_order = [n.name for n in partion_nodes[0]]
44+
self.assertTrue(partition_node_order == origin_node_order)
45+
46+
# partitoner test to check graph node order remains the same during multiple runs
47+
def test_partitioner_multiple_runs_order(self):
3748
m = AddModule()
3849
traced_m = torch.fx.symbolic_trace(m)
3950
partitions = DummyPartitioner(traced_m).propose_partitions()

torch/fx/passes/infra/partitioner.py

Lines changed: 20 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -18,16 +18,18 @@
1818

1919
class Partition:
2020
def __init__(
21-
self, id: Optional[int] = None, nodes: Optional[Iterable[Node]] = None
21+
self,
22+
id: Optional[int] = None,
23+
nodes: Optional[Iterable[tuple[Node, Optional[int]]]] = None,
2224
):
2325
self.id = id
24-
self.nodes = dict.fromkeys(nodes) if nodes is not None else {}
26+
self.nodes = dict(nodes) if nodes is not None else {}
2527

2628
def __repr__(self) -> str:
2729
return str(self.nodes)
2830

29-
def add_node(self, node: Node):
30-
self.nodes.update({node: None})
31+
def add_node(self, node: Node, node_order: Optional[int] = None):
32+
self.nodes.update({node: node_order})
3133

3234
def remove_node(self, node: Node):
3335
del self.nodes[node]
@@ -172,7 +174,7 @@ def dfs_iter_find_cycle(all_user_nodes: set[Node]):
172174

173175
return merge_id, True
174176

175-
def merge_single_node(node: Node, id: Optional[int]):
177+
def merge_single_node(node: Node, node_order: Optional[int], id: Optional[int]):
176178
def _update_partition_map(node: Node, id: int):
177179
# Iterate through all the users of this node and update the partition map to indicate
178180
# that there is a path from the partition id of this node to the target partition id.
@@ -189,16 +191,16 @@ def _update_partition_map(node: Node, id: int):
189191
assignment.pop(node)
190192
elif id not in partitions_by_id:
191193
assignment[node] = id
192-
partitions_by_id[id] = Partition(id=id, nodes=[node])
194+
partitions_by_id[id] = Partition(id=id, nodes=[(node, node_order)])
193195
partition_users[id] = set(node.users)
194196
_update_partition_map(node, id)
195197
else:
196198
assignment[node] = id
197-
partitions_by_id[id].add_node(node)
199+
partitions_by_id[id].add_node(node, node_order)
198200

199201
logger.debug("Proposing partitions...")
200202

201-
for node in reversed(self.graph_module.graph.nodes):
203+
for node_order, node in enumerate(reversed(self.graph_module.graph.nodes)):
202204
# use Dict as an ordered set to ensure deterministic partitioning result, don't care value
203205
merge_candidates: dict[int, None] = {}
204206

@@ -211,7 +213,7 @@ def _update_partition_map(node: Node, id: int):
211213
partition_id = next(new_partition_id)
212214
nodes_order[node] = partition_id
213215
partitions_order[partition_id] = partition_id
214-
merge_single_node(node, partition_id)
216+
merge_single_node(node, node_order, partition_id)
215217
merge_candidates[partition_id] = None
216218

217219
# merge all possible partitions
@@ -228,6 +230,14 @@ def _update_partition_map(node: Node, id: int):
228230
# in the graph, otherwise, this is a no-op
229231
self_id, _ = maybe_merge_partition(self_id, other_id)
230232

233+
# sort partition nodes based on descending node order
234+
for partition in partitions_by_id.values():
235+
partition.nodes = dict(
236+
sorted(
237+
partition.nodes.items(), key=operator.itemgetter(1), reverse=True
238+
)
239+
)
240+
231241
# post processing to re-assign "getitem" nodes into upstream partition
232242
logger.debug("Reassigning getitem nodes to its producer node's partition...")
233243
nodes_reassignment: dict[Node, int] = {}
@@ -248,7 +258,7 @@ def _update_partition_map(node: Node, id: int):
248258
if assignment.get(user, None) != id: # type: ignore[arg-type]
249259
nodes_reassignment[user] = id # type: ignore[assignment]
250260
for node, id in nodes_reassignment.items():
251-
merge_single_node(node, id)
261+
merge_single_node(node, None, id)
252262

253263
# filter out single node partitions
254264
if not self.allows_single_node_partition:

torch/fx/passes/utils/fuser_utils.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -96,7 +96,7 @@ def fuse_as_graphmodule(
9696
gm: GraphModule,
9797
nodes: NodeList,
9898
module_name: str,
99-
partition_lookup_table: _Optional[dict[Node, None]] = None,
99+
partition_lookup_table: _Optional[dict[Node, _Optional[int]]] = None,
100100
*,
101101
always_return_tuple: bool = False,
102102
) -> tuple[GraphModule, tuple[Node, ...], tuple[Node, ...]]:
@@ -249,7 +249,7 @@ def erase_nodes(gm: GraphModule, nodes: NodeList) -> None:
249249
@compatibility(is_backward_compatible=False)
250250
def fuse_by_partitions(
251251
gm: GraphModule,
252-
partitions: list[dict[Node, None]],
252+
partitions: list[dict[Node, _Optional[int]]],
253253
prefix: str = "fused_",
254254
always_return_tuple: bool = False,
255255
) -> GraphModule:

0 commit comments

Comments
 (0)