1818
1919class Partition :
2020 def __init__ (
21- self , id : Optional [int ] = None , nodes : Optional [Iterable [Node ]] = None
21+ self ,
22+ id : Optional [int ] = None ,
23+ nodes : Optional [Iterable [tuple [Node , Optional [int ]]]] = None ,
2224 ):
2325 self .id = id
24- self .nodes = dict . fromkeys (nodes ) if nodes is not None else {}
26+ self .nodes = dict (nodes ) if nodes is not None else {}
2527
2628 def __repr__ (self ) -> str :
2729 return str (self .nodes )
2830
29- def add_node (self , node : Node ):
30- self .nodes .update ({node : None })
31+ def add_node (self , node : Node , node_order : Optional [ int ] = None ):
32+ self .nodes .update ({node : node_order })
3133
3234 def remove_node (self , node : Node ):
3335 del self .nodes [node ]
@@ -172,7 +174,7 @@ def dfs_iter_find_cycle(all_user_nodes: set[Node]):
172174
173175 return merge_id , True
174176
175- def merge_single_node (node : Node , id : Optional [int ]):
177+ def merge_single_node (node : Node , node_order : Optional [ int ], id : Optional [int ]):
176178 def _update_partition_map (node : Node , id : int ):
177179 # Iterate through all the users of this node and update the partition map to indicate
178180 # that there is a path from the partition id of this node to the target partition id.
@@ -189,16 +191,16 @@ def _update_partition_map(node: Node, id: int):
189191 assignment .pop (node )
190192 elif id not in partitions_by_id :
191193 assignment [node ] = id
192- partitions_by_id [id ] = Partition (id = id , nodes = [node ])
194+ partitions_by_id [id ] = Partition (id = id , nodes = [( node , node_order ) ])
193195 partition_users [id ] = set (node .users )
194196 _update_partition_map (node , id )
195197 else :
196198 assignment [node ] = id
197- partitions_by_id [id ].add_node (node )
199+ partitions_by_id [id ].add_node (node , node_order )
198200
199201 logger .debug ("Proposing partitions..." )
200202
201- for node in reversed (self .graph_module .graph .nodes ):
203+ for node_order , node in enumerate ( reversed (self .graph_module .graph .nodes ) ):
202204 # use Dict as an ordered set to ensure deterministic partitioning result, don't care value
203205 merge_candidates : dict [int , None ] = {}
204206
@@ -211,7 +213,7 @@ def _update_partition_map(node: Node, id: int):
211213 partition_id = next (new_partition_id )
212214 nodes_order [node ] = partition_id
213215 partitions_order [partition_id ] = partition_id
214- merge_single_node (node , partition_id )
216+ merge_single_node (node , node_order , partition_id )
215217 merge_candidates [partition_id ] = None
216218
217219 # merge all possible partitions
@@ -228,6 +230,14 @@ def _update_partition_map(node: Node, id: int):
228230 # in the graph, otherwise, this is a no-op
229231 self_id , _ = maybe_merge_partition (self_id , other_id )
230232
233+ # sort partition nodes based on descending node order
234+ for partition in partitions_by_id .values ():
235+ partition .nodes = dict (
236+ sorted (
237+ partition .nodes .items (), key = operator .itemgetter (1 ), reverse = True
238+ )
239+ )
240+
231241 # post processing to re-assign "getitem" nodes into upstream partition
232242 logger .debug ("Reassigning getitem nodes to its producer node's partition..." )
233243 nodes_reassignment : dict [Node , int ] = {}
@@ -248,7 +258,7 @@ def _update_partition_map(node: Node, id: int):
248258 if assignment .get (user , None ) != id : # type: ignore[arg-type]
249259 nodes_reassignment [user ] = id # type: ignore[assignment]
250260 for node , id in nodes_reassignment .items ():
251- merge_single_node (node , id )
261+ merge_single_node (node , None , id )
252262
253263 # filter out single node partitions
254264 if not self .allows_single_node_partition :
0 commit comments