[jit] Handled cases when input list to cat is mutated after cat using AliasDb#60774
[jit] Handled cases when input list to cat is mutated after cat using AliasDb#60774navahgar wants to merge 7 commits intogh/navahgar/24/basefrom
Conversation
… AliasDb [ghstack-poisoned]
💊 CI failures summary and remediationsAs of commit d06305b (more details on the Dr. CI page and at hud.pytorch.org/pr/60774): 💚 💚 Looks good so far! There are no failures yet. 💚 💚 Preview docs built from this PR This comment was automatically generated by Dr. CI (expand for details).Follow this link to opt-out of these comments for your Pull Requests.Please report bugs/suggestions to the (internal) Dr. CI Users group. |
|
@navahgar has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator. |
…r cat using AliasDb" Differential Revision: [D29406100](https://our.internmc.facebook.com/intern/diff/D29406100) [ghstack-poisoned]
|
@navahgar has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator. |
| auto aliasDb = getOrCreateAliasDb(); | ||
| for (auto use : list_uses) { | ||
| if (aliasDb->isMutable(use.user)) { | ||
| if (aliasDb->couldMoveBeforeTopologically(use.user, cat)) { |
There was a problem hiding this comment.
I don't know if this is the right logic.
If you have
x = [a, a]
x.append([3])
torch.cat(x)
...
torch.cat(x)
You will not be able to move the append before the second cat, however nonetheless there is a mutation before cat.
You might see something similar if you have a loop as well.
Can I ask what the specific use case is here that we're trying to address?
There was a problem hiding this comment.
I probably had the wrong assumption regarding couldMoveBeforeTopologically then. I thought that since the append is already "before" the second cat, couldMoveBeforeTopologically would return true here. IIUC, you are saying that it would try to move just before that destination point and fail here.
I just want to catch all cases where the list (that is input to this cat) is mutated before this cat. In your example, I want to know that both cat ops are with mutated list.
Another case is:
x = [a, a]
torch.cat(x)
...
x.append([3])
torch.cat(x)
In this example, I want to know that the first cat is not with mutated list, but the second cat uses a mutated list.
I need this so that, this pass only changes those cat ops without mutated list to prim::Concat. The subsequent calls to RemoveListMutation (added in another PR) will remove the mutation involved.
There was a problem hiding this comment.
How about changing that condition to the following?
if (use.user->isBefore(cat) || aliasDb->couldMoveBeforeTopologically(use.user, cat)) {
...
For your example, this should return true (list is modified) for both cat ops since the append is "before" both cat ops.
For my example above, this should return false for the first cat and true for the second one.
IIUC, for nodes in a block isBefore check should suffice. But if we have multiple blocks involved (say with some conditionals), then we might need couldMoveBeforeTopologically.
Am I right about this usage?
There was a problem hiding this comment.
IIUC, you are saying that it would try to move just before that destination point and fail here.
Yes, that is correct.
I don't know that isBefore is a sufficient condition.
li = [x]
for _ in range(4):
print(torch.cat(li)
li.append(x)
So, a couple things we could do here:
we could try to augment alias analysis to add a function to test for this case, basically couldMoveBeforeTopologicallyValid except we are ignoring topological dependencies and just look at write/side effectful ones.
However,
I think this problem we're looking here is better solved by iteratively trying to remove mutation on list and also replace cats with a variadic version. Isnt that we're doing #60776, and after that pass doesn't this problem we're trying to solve just become:
x = [a, a]
torch.cat(x,)
...
x.append([3])
torch.cat(x)
->
torch.cat(a, a)
torch.cat(a, a, 3)
There was a problem hiding this comment.
Good point regarding isBefore.
I think this problem we're looking here is better solved by iteratively trying to remove mutation on list and also replace cats with a variadic version.
It will solve some cases but not all. For example:
x = [a, a]
x.append([a])
torch.cat(x)
...
x.append([a])
torch.cat(x)
After one round of RemoveListMutation, the first append will be eliminated but the second one will still be there.
x = [a, a, a]
torch.cat(x)
...
x.append([a])
torch.cat(x)
Now, we can replace the first cat with prim::Concat but we still can't do the same for the second one. We need a check to restrict such cases. These are the cases I am trying to catch with this check.
There was a problem hiding this comment.
Can we just make the pass iteratively try to remove both appends and cats, so that the pass would remove all instances of x in one pass ?
Even in that case, the pass to remove cats should only remove those that have not been mutated. And that needs the same check again. I still don't see a way to get around that check.
There was a problem hiding this comment.
It doesn't need the same check that those have not been mutated, it just needs the check that you can merge the instantiation of the list with its use
There was a problem hiding this comment.
Sure, that should work as well. How is that any better? Is there a different API to check that?
There was a problem hiding this comment.
let's VC tomorrow to takl about it
There was a problem hiding this comment.
Changed this to check if the input list could be moved before cat, as we discussed. PTAL.
…r cat using AliasDb" Differential Revision: [D29406100](https://our.internmc.facebook.com/intern/diff/D29406100) [ghstack-poisoned]
|
@navahgar has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator. |
…r cat using AliasDb" Differential Revision: [D29406100](https://our.internmc.facebook.com/intern/diff/D29406100) [ghstack-poisoned]
|
@navahgar has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator. |
…r cat using AliasDb" Differential Revision: [D29406100](https://our.internmc.facebook.com/intern/diff/D29406100) [ghstack-poisoned]
|
@navahgar has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator. |
…r cat using AliasDb" Differential Revision: [D29406100](https://our.internmc.facebook.com/intern/diff/D29406100) [ghstack-poisoned]
|
@navahgar has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator. |
…r cat using AliasDb" Differential Revision: [D29406100](https://our.internmc.facebook.com/intern/diff/D29406100) [ghstack-poisoned]
|
@navahgar has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator. |
… AliasDb (pytorch#60774) Summary: Pull Request resolved: pytorch#60774 Test Plan: Imported from OSS Reviewed By: mrshenli Differential Revision: D29406100 Pulled By: navahgar fbshipit-source-id: af6afca65881c18c51b482eb63898a0f1c94d591
Stack from ghstack:
Differential Revision: D29406100