Remove getitem special handling in the partitioner#87073
Remove getitem special handling in the partitioner#87073IvanYashchuk wants to merge 15 commits intopytorch:masterfrom
Conversation
This special handling of getitem unnecessary splits fusions at functions with tuple outputs.
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/87073
Note: Links to docs will display an error until the docs builds have been completed. ✅ No FailuresAs of commit 9ce171f: This comment was automatically generated by Dr. CI and updates every 15 minutes. |
|
cc @SherlockNoMad, why was this special handling needed, and why does it split fusions? |
| # this is a no-op | ||
| maybe_merge_partition(self_id, other_id) | ||
|
|
||
| # post processing to re-assign "getitem" nodes into upstream partition |
There was a problem hiding this comment.
nitpick, should we remove this part?
This was supposed to fuse getitem to producer, which is a fusion, before we have getitem disabled in the line you commented above. With that logic removed, I believe this section would remove getitem node at the beginning of the fusion, if it comes from an unfused node.
There was a problem hiding this comment.
Keeping this code works with the example I added to the PR description, but I was hitting the assert in merge_single_node in one benchmark model. I should extract that failing graph portion for testing.
There was a problem hiding this comment.
Got ya. merge_single_node might be asserting too aggressively as well.
pytorch/torch/fx/passes/infra/partitioner.py
Line 108 in 8393213
This assert should be removed along with you update on getitem logic.
There was a problem hiding this comment.
Keeping this code python -m pytest test/test_fx_passes.py -k "test_partitioner_fn_" -vvv fails with the assert
Traceback (most recent call last):
File "/home/iyashchuk/dev/pytorch/master/test/test_fx_passes.py", line 224, in test_partitioner
partitions = partitioner.propose_partitions()
File "/home/iyashchuk/dev/pytorch/master/torch/fx/passes/infra/partitioner.py", line 162, in propose_partitions
merge_single_node(node, id)
File "/home/iyashchuk/dev/pytorch/master/torch/fx/passes/infra/partitioner.py", line 108, in merge_single_node
assert node not in assignment
AssertionErrorRemoving this code the test still fails (as in CI) but the failure is expected now:
____ TestFXGraphPasses.test_partitioner_fn_<function TestPartitionFunctions_forward13 at 0x7f08e8fdfd00>_expected_partition_[['add_2', 'add_1', 'add']] ____
Traceback (most recent call last):
File "/home/iyashchuk/dev/pytorch/master/test/test_fx_passes.py", line 229, in test_partitioner
assert set(partitions_name[i]) == set(expected_partition[i])
AssertionError: assert {'add', 'getitem_2', 'getitem', 'getitem_1', 'getitem_3', 'add_2', 'add_1'} == {'add_2', 'add_1', 'add'}
Extra items in the left set:
'getitem_2'
'getitem_3'
'getitem_1'
'getitem'
Full diff:
- {'add_2', 'add_1', 'add'}
+ {'add', 'getitem_2', 'getitem', 'getitem_1', 'getitem_3', 'add_2', 'add_1'}The failing test was added in https://github.com/pytorch/pytorch/pull/86713 😉
There was a problem hiding this comment.
I was suggesting keeping this code and remove the assert.
There was a problem hiding this comment.
Keeping the code and removing the assert gives the following partitions in that test:
[{getitem, add_1, add_2, getitem_2, getitem_3, getitem_1, add}, {getitem_3, getitem_2, getitem, getitem_1}]
There was a problem hiding this comment.
Oops, that's surprising to me. 😆
How can we have getitem spanning acorss multiple partition? I think there's just a small bug somewhere. Looks like it's attempting to clear getitem from the original partition... But somehow that wasn't done right. I can take a quick look at this afterwards.
torch/_prims/nvfuser_executor.py
Outdated
| if "getitem" in node.name: | ||
| # Check if the node unpacks a tuple from a supported node | ||
| node_to_unpack = node.args[0] | ||
| return self.is_node_supported(submodules, node_to_unpack) |
There was a problem hiding this comment.
I'm slightly leaning towards clean up getitem at the beginning of the graph in partitioner, instead of having more complicated logic in op support query. For now it works fine, since partitioner goes from consumer to producer and we don't really fuse anything as we propose partition group. But that implementation could change 😦
There was a problem hiding this comment.
It's of course a matter of taste. I think special casing getitem and processing it differently may cause bugs. There are two situations with getitem it's either unpack a node that's supported or not, and we can't unconditionally either reject or accept getitem. Previously we have already relied on the partitioner (always rejecting getitem) to do the right thing and it didn't.
test/test_fx_passes.py
Outdated
|
|
||
| # 5 getitem special case | ||
| (TestPartitionFunctions.forward13, [["add_2", "add_1", "add"]]), | ||
| (TestPartitionFunctions.forward13, [["add_2", "add_1", "add", "getitem", "getitem_1", "getitem_2", "getitem_3"]]), |
There was a problem hiding this comment.
Hi @IvanYashchuk, I think this is an unexpected change.
getitem node should always be partitioned together with its producer node.
Across the stack, we have an implicit assumption that module's input and output must be tensor type. This is why we have the special handling logic in the first place.
There was a problem hiding this comment.
This doesn't look right even for this PR. getitem is produced by split, which isn't supported by the fusion node?
|
Given the example, getitem_2, getitem_3 should be partitioned into fused_1. |
SherlockNoMad
left a comment
There was a problem hiding this comment.
getitem should always be partitioned together with its producer node.
@SherlockNoMad The problem here is that the original logic special case Changes in this PR is supposed to fix that. |
I agree with you, but the previous code was doing that too aggressively and incorrectly. Please take a look at the latest changes. |
On this topic, could you consider changes like this one? We are merging things as is provided by op support list, and at the end of fusion, we clean up Sorry that my previous refactor leaves the code in an ugly state. 😝 |
Seems like we are not the sole user of fx partitioner. #87007 if I'm not sure if the discussion here concerns you at all @wschin, would an |
Some |
Getitem patch
Placeholder nodes called "getitem_XXX" were incorrectly dropped from the graph.
|
@SherlockNoMad could you please take another look at the proposed changes? |
|
|
||
| assignment[node] = id | ||
| if id not in partitions_by_id: | ||
| if id is None: |
There was a problem hiding this comment.
Do we really have case where id is None?
Looking at the call site for this function, it doesn't seem to have any None case...
If this cannot be None, let assert id is not None.
There was a problem hiding this comment.
Yes, this happens naturally when getitem is marked as supported by backends.
Later during the special handling of getitem, where we merge each getitem calls to its producer, we could run into cases where the producer is not supported by backends, but we accidentally merged getitem into the fusion. We'll ended up pulling these nodes out and that's when id here would be None.
pytorch/torch/fx/passes/infra/partitioner.py
Line 167 in 7d3fbd7
This is checked in test case forward13
SherlockNoMad
left a comment
There was a problem hiding this comment.
LGTM, expect for the minor comment.
|
Thanks to @SherlockNoMad 's stamp. Failure seems to be on CI nodes and unrelated. Let's rebase to a stable commit and merge it! 🎉 🎉 🎉 |
|
@pytorchbot merge -g |
Merge startedYour change will be merged once all checks on your PR pass since you used the green (-g) flag (ETA: 0-4 Hours). Learn more about merging in the wiki. Questions? Feedback? Please reach out to the PyTorch DevX Team |
This special handling of getitem unnecessary splits fusions at functions with tuple outputs.
Example script:
```py
import torch
from torch.fx.passes.infra.partitioner import CapabilityBasedPartitioner
from torch._prims.nvfuser_executor import NvfuserPrimOperatorSupport
from torch.fx.experimental.proxy_tensor import make_fx
def func(x):
xx = torch.ops.nvprims.add(x, 1)
var, mean = torch.ops.nvprims.var_mean(x, correction=0)
var_cos = torch.ops.nvprims.cos(var)
mean_sin = torch.ops.nvprims.sin(mean)
return torch.ops.nvprims.add(var_cos, mean_sin)
a = torch.randn(5, 3, 3, device="cuda")
gm = make_fx(func)(a)
gm.graph.print_tabular()
supported_ops = NvfuserPrimOperatorSupport()
partitioner = CapabilityBasedPartitioner(
gm, supported_ops, allows_single_node_partition=False
)
partitions = partitioner.propose_partitions()
print(partitions)
partitioned_graph = partitioner.fuse_partitions(partitions)
partitioned_graph.graph.print_tabular()
```
Output on master:
```py
opcode name target args kwargs
------------- --------- --------------------------- ---------------- -----------------
placeholder x_1 x_1 () {}
call_function add nvprims.add.default (x_1, 1) {}
call_function var_mean nvprims.var_mean.main (x_1, [0, 1, 2]) {'correction': 0}
call_function getitem <built-in function getitem> (var_mean, 0) {}
call_function getitem_1 <built-in function getitem> (var_mean, 1) {}
call_function cos nvprims.cos.default (getitem,) {}
call_function sin nvprims.sin.default (getitem_1,) {}
call_function add_1 nvprims.add.default (cos, sin) {}
output output output (add_1,) {}
[{cos, sin, add_1}, {var_mean, add, getitem, getitem_1}]
opcode name target args kwargs
------------- --------- --------------------------- ---------------------- --------
placeholder x_1 x_1 () {}
call_module fused_1 fused_1 (x_1,) {}
call_function getitem_2 <built-in function getitem> (fused_1, 0) {}
call_function getitem_3 <built-in function getitem> (fused_1, 1) {}
call_module fused_0 fused_0 (getitem_2, getitem_3) {}
output output output (fused_0,) {}
```
Output with this PR:
```
[{var_mean, add_1, cos, sin, add, getitem_1, getitem}]
opcode name target args kwargs
----------- ------- -------- ---------- --------
placeholder x_1 x_1 () {}
call_module fused_0 fused_0 (x_1,) {}
output output output (fused_0,) {}
```
Pull Request resolved: pytorch#87073
Approved by: https://github.com/jjsjann123, https://github.com/SherlockNoMad
This special handling of getitem unnecessary splits fusions at functions with tuple outputs.
Example script:
```py
import torch
from torch.fx.passes.infra.partitioner import CapabilityBasedPartitioner
from torch._prims.nvfuser_executor import NvfuserPrimOperatorSupport
from torch.fx.experimental.proxy_tensor import make_fx
def func(x):
xx = torch.ops.nvprims.add(x, 1)
var, mean = torch.ops.nvprims.var_mean(x, correction=0)
var_cos = torch.ops.nvprims.cos(var)
mean_sin = torch.ops.nvprims.sin(mean)
return torch.ops.nvprims.add(var_cos, mean_sin)
a = torch.randn(5, 3, 3, device="cuda")
gm = make_fx(func)(a)
gm.graph.print_tabular()
supported_ops = NvfuserPrimOperatorSupport()
partitioner = CapabilityBasedPartitioner(
gm, supported_ops, allows_single_node_partition=False
)
partitions = partitioner.propose_partitions()
print(partitions)
partitioned_graph = partitioner.fuse_partitions(partitions)
partitioned_graph.graph.print_tabular()
```
Output on master:
```py
opcode name target args kwargs
------------- --------- --------------------------- ---------------- -----------------
placeholder x_1 x_1 () {}
call_function add nvprims.add.default (x_1, 1) {}
call_function var_mean nvprims.var_mean.main (x_1, [0, 1, 2]) {'correction': 0}
call_function getitem <built-in function getitem> (var_mean, 0) {}
call_function getitem_1 <built-in function getitem> (var_mean, 1) {}
call_function cos nvprims.cos.default (getitem,) {}
call_function sin nvprims.sin.default (getitem_1,) {}
call_function add_1 nvprims.add.default (cos, sin) {}
output output output (add_1,) {}
[{cos, sin, add_1}, {var_mean, add, getitem, getitem_1}]
opcode name target args kwargs
------------- --------- --------------------------- ---------------------- --------
placeholder x_1 x_1 () {}
call_module fused_1 fused_1 (x_1,) {}
call_function getitem_2 <built-in function getitem> (fused_1, 0) {}
call_function getitem_3 <built-in function getitem> (fused_1, 1) {}
call_module fused_0 fused_0 (getitem_2, getitem_3) {}
output output output (fused_0,) {}
```
Output with this PR:
```
[{var_mean, add_1, cos, sin, add, getitem_1, getitem}]
opcode name target args kwargs
----------- ------- -------- ---------- --------
placeholder x_1 x_1 () {}
call_module fused_0 fused_0 (x_1,) {}
output output output (fused_0,) {}
```
Pull Request resolved: pytorch#87073
Approved by: https://github.com/jjsjann123, https://github.com/SherlockNoMad
This special handling of getitem unnecessary splits fusions at functions with tuple outputs.
Example script:
Output on master:
Output with this PR: