Skip to content

Commit 6ba44e5

Browse files
author
eellison
committed
Update on "[JIT] Make create autodiff subgraphs do in place updates to aliasDb"
Update alias db in-place instead of having to construct alias db from scratch on each change, causing O(n^2) behavior. Description from #37106 holds pretty well: """ Recomputing the aliasdb on every fusion iteration + in every subblock is hugely expensive. Instead, update it in-place when doing fusion. The graph fuser pass operates by pushing nodes into a fusion group. So we start with `x, y = f(a, b, c)` and end with: ``` x_out, y_out = prim::fusionGroup(a, b, c) x_in, y_in = f(a_in, b_in, c_in) -> x_in, y_in ``` We destroy the x and y Value*s in the process. This operation is easy to express as an update to the aliasDb--x_out just takes on all the aliasing information x used to have. In particular, since we know f and prim::fusionGroup are purely functional, we don't have to mess with any write information. """ The one difficulty here is mapping x, y to x_out, y_out is not trivial in merging nodes into the autodiff subgraph node. There are a few options: - attempt to make all subgraph utils & ir cloning logic update a map - mirror the subgraph utils implementation in create_autodiff_subgraph - uniquely map x, y and x_in, y_in so you can back out the correspondence. I went with the third option. This shouldn't affect the results of the pass at all. LMK if you think there's anything else I should be doing to test, I was thinking about maybe exposing an option to run create autodiff subgraphs without the post processor and check that the alias db was correctly updated. Differential Revision: [D22798377](https://our.internmc.facebook.com/intern/diff/D22798377) [ghstack-poisoned]
1 parent 56194da commit 6ba44e5

1 file changed

Lines changed: 2 additions & 2 deletions

File tree

test/jit/test_autodiff_subgraph_slicing.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -168,11 +168,11 @@ def fn(x, k):
168168
z = y * k
169169
return z, k
170170

171-
graph = self._perform_ad_subgraph_slicing(fn, 1, 1)
172171

172+
graph = self._perform_ad_subgraph_slicing(fn, 1, 1)
173173
# We should not have combined the two multiplications into
174174
# the same group; they should each be a separate DiffGraph
175-
self.assertGraphContainsExactly(graph, 'prim::DifferentiableGraph', 2)
175+
self.assertGraphContainsExactly(graph, 'prim::DifferentiableGraph', 3)
176176

177177

178178
def test_merge_respects_aliasing(self):

0 commit comments

Comments
 (0)