Skip to content

Remove getitem special handling in the partitioner#87073

Closed
IvanYashchuk wants to merge 15 commits intopytorch:masterfrom
IvanYashchuk:getitem-partitioner
Closed

Remove getitem special handling in the partitioner#87073
IvanYashchuk wants to merge 15 commits intopytorch:masterfrom
IvanYashchuk:getitem-partitioner

Conversation

@IvanYashchuk
Copy link
Collaborator

@IvanYashchuk IvanYashchuk commented Oct 17, 2022

This special handling of getitem unnecessary splits fusions at functions with tuple outputs.

Example script:

import torch
from torch.fx.passes.infra.partitioner import CapabilityBasedPartitioner
from torch._prims.nvfuser_executor import NvfuserPrimOperatorSupport
from torch.fx.experimental.proxy_tensor import make_fx

def func(x):
    xx = torch.ops.nvprims.add(x, 1)
    var, mean = torch.ops.nvprims.var_mean(x, correction=0)
    var_cos = torch.ops.nvprims.cos(var)
    mean_sin = torch.ops.nvprims.sin(mean)
    return torch.ops.nvprims.add(var_cos, mean_sin)

a = torch.randn(5, 3, 3, device="cuda")
gm = make_fx(func)(a)
gm.graph.print_tabular()

supported_ops = NvfuserPrimOperatorSupport()
partitioner = CapabilityBasedPartitioner(
    gm, supported_ops, allows_single_node_partition=False
)
partitions = partitioner.propose_partitions()
print(partitions)
partitioned_graph = partitioner.fuse_partitions(partitions)
partitioned_graph.graph.print_tabular()

Output on master:

opcode         name       target                       args              kwargs
-------------  ---------  ---------------------------  ----------------  -----------------
placeholder    x_1        x_1                          ()                {}
call_function  add        nvprims.add.default          (x_1, 1)          {}
call_function  var_mean   nvprims.var_mean.main        (x_1, [0, 1, 2])  {'correction': 0}
call_function  getitem    <built-in function getitem>  (var_mean, 0)     {}
call_function  getitem_1  <built-in function getitem>  (var_mean, 1)     {}
call_function  cos        nvprims.cos.default          (getitem,)        {}
call_function  sin        nvprims.sin.default          (getitem_1,)      {}
call_function  add_1      nvprims.add.default          (cos, sin)        {}
output         output     output                       (add_1,)          {}
[{cos, sin, add_1}, {var_mean, add, getitem, getitem_1}]
opcode         name       target                       args                    kwargs
-------------  ---------  ---------------------------  ----------------------  --------
placeholder    x_1        x_1                          ()                      {}
call_module    fused_1    fused_1                      (x_1,)                  {}
call_function  getitem_2  <built-in function getitem>  (fused_1, 0)            {}
call_function  getitem_3  <built-in function getitem>  (fused_1, 1)            {}
call_module    fused_0    fused_0                      (getitem_2, getitem_3)  {}
output         output     output                       (fused_0,)              {}

Output with this PR:

[{var_mean, add_1, cos, sin, add, getitem_1, getitem}]
opcode       name     target    args        kwargs
-----------  -------  --------  ----------  --------
placeholder  x_1      x_1       ()          {}
call_module  fused_0  fused_0   (x_1,)      {}
output       output   output    (fused_0,)  {}

This special handling of getitem unnecessary splits fusions at functions
with tuple outputs.
@IvanYashchuk IvanYashchuk added the module: fx.passes Optimization passes written in FX (don't forget to select a more specific label) label Oct 17, 2022
@pytorch-bot
Copy link

pytorch-bot bot commented Oct 17, 2022

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/87073

Note: Links to docs will display an error until the docs builds have been completed.

✅ No Failures

As of commit 9ce171f:
💚 Looks good so far! There are no failures yet. 💚

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@pytorch-bot pytorch-bot bot added the release notes: fx release notes category label Oct 17, 2022
@IvanYashchuk IvanYashchuk marked this pull request as draft October 17, 2022 14:04
@ngimel
Copy link
Collaborator

ngimel commented Oct 17, 2022

cc @SherlockNoMad, why was this special handling needed, and why does it split fusions?

Copy link
Collaborator

@jjsjann123 jjsjann123 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM.

# this is a no-op
maybe_merge_partition(self_id, other_id)

# post processing to re-assign "getitem" nodes into upstream partition
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nitpick, should we remove this part?

This was supposed to fuse getitem to producer, which is a fusion, before we have getitem disabled in the line you commented above. With that logic removed, I believe this section would remove getitem node at the beginning of the fusion, if it comes from an unfused node.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Keeping this code works with the example I added to the PR description, but I was hitting the assert in merge_single_node in one benchmark model. I should extract that failing graph portion for testing.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Got ya. merge_single_node might be asserting too aggressively as well.

assert node not in assignment

This assert should be removed along with you update on getitem logic.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Keeping this code python -m pytest test/test_fx_passes.py -k "test_partitioner_fn_" -vvv fails with the assert

Traceback (most recent call last):
  File "/home/iyashchuk/dev/pytorch/master/test/test_fx_passes.py", line 224, in test_partitioner
    partitions = partitioner.propose_partitions()
  File "/home/iyashchuk/dev/pytorch/master/torch/fx/passes/infra/partitioner.py", line 162, in propose_partitions
    merge_single_node(node, id)
  File "/home/iyashchuk/dev/pytorch/master/torch/fx/passes/infra/partitioner.py", line 108, in merge_single_node
    assert node not in assignment
AssertionError

Removing this code the test still fails (as in CI) but the failure is expected now:

____ TestFXGraphPasses.test_partitioner_fn_<function TestPartitionFunctions_forward13 at 0x7f08e8fdfd00>_expected_partition_[['add_2', 'add_1', 'add']] ____
Traceback (most recent call last):
  File "/home/iyashchuk/dev/pytorch/master/test/test_fx_passes.py", line 229, in test_partitioner
    assert set(partitions_name[i]) == set(expected_partition[i])
AssertionError: assert {'add', 'getitem_2', 'getitem', 'getitem_1', 'getitem_3', 'add_2', 'add_1'} == {'add_2', 'add_1', 'add'}
  Extra items in the left set:
  'getitem_2'
  'getitem_3'
  'getitem_1'
  'getitem'
  Full diff:
  - {'add_2', 'add_1', 'add'}
  + {'add', 'getitem_2', 'getitem', 'getitem_1', 'getitem_3', 'add_2', 'add_1'}

The failing test was added in https://github.com/pytorch/pytorch/pull/86713 😉

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I was suggesting keeping this code and remove the assert.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Keeping the code and removing the assert gives the following partitions in that test:

[{getitem, add_1, add_2, getitem_2, getitem_3, getitem_1, add}, {getitem_3, getitem_2, getitem, getitem_1}]

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oops, that's surprising to me. 😆

How can we have getitem spanning acorss multiple partition? I think there's just a small bug somewhere. Looks like it's attempting to clear getitem from the original partition... But somehow that wasn't done right. I can take a quick look at this afterwards.

@pytorch-bot pytorch-bot bot added the ciflow/trunk Trigger trunk jobs on your pull request label Oct 17, 2022
@IvanYashchuk IvanYashchuk marked this pull request as ready for review October 17, 2022 18:36
if "getitem" in node.name:
# Check if the node unpacks a tuple from a supported node
node_to_unpack = node.args[0]
return self.is_node_supported(submodules, node_to_unpack)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm slightly leaning towards clean up getitem at the beginning of the graph in partitioner, instead of having more complicated logic in op support query. For now it works fine, since partitioner goes from consumer to producer and we don't really fuse anything as we propose partition group. But that implementation could change 😦

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's of course a matter of taste. I think special casing getitem and processing it differently may cause bugs. There are two situations with getitem it's either unpack a node that's supported or not, and we can't unconditionally either reject or accept getitem. Previously we have already relied on the partitioner (always rejecting getitem) to do the right thing and it didn't.


# 5 getitem special case
(TestPartitionFunctions.forward13, [["add_2", "add_1", "add"]]),
(TestPartitionFunctions.forward13, [["add_2", "add_1", "add", "getitem", "getitem_1", "getitem_2", "getitem_3"]]),
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi @IvanYashchuk, I think this is an unexpected change.
getitem node should always be partitioned together with its producer node.
Across the stack, we have an implicit assumption that module's input and output must be tensor type. This is why we have the special handling logic in the first place.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This doesn't look right even for this PR. getitem is produced by split, which isn't supported by the fusion node?

@SherlockNoMad
Copy link
Contributor

Given the example, getitem_2, getitem_3 should be partitioned into fused_1.

Copy link
Contributor

@SherlockNoMad SherlockNoMad left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

getitem should always be partitioned together with its producer node.

@jjsjann123
Copy link
Collaborator

getitem should always be partitioned together with its producer node.

@SherlockNoMad The problem here is that the original logic special case getitem after fusion partition has been proposed. Which resulted in us always segment the graph across getitem nodes.
i.e. if you have a var_mean where it's output is used by another fusion-supported op, the getitem node that's used to unpack the tuple output from var_mean stops us from fusing the two nodes.

Changes in this PR is supposed to fix that.

@mikaylagawarecki mikaylagawarecki added the triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module label Oct 18, 2022
@IvanYashchuk
Copy link
Collaborator Author

getitem should always be partitioned together with its producer node.

I agree with you, but the previous code was doing that too aggressively and incorrectly. getitem shouldn't be unconditionally accepted nor rejected, it should be tied to its producer when deciding whether it's supported or not.

Please take a look at the latest changes.

@jjsjann123
Copy link
Collaborator

getitem should always be partitioned together with its producer node.

I agree with you, but the previous code was doing that too aggressively and incorrectly. getitem shouldn't be unconditionally accepted nor rejected, it should be tied to its producer when deciding whether it's supported or not.

Please take a look at the latest changes.

On this topic, could you consider changes like this one?
csarofeen@9e9934e

We are merging things as is provided by op support list, and at the end of fusion, we clean up getitem node and keep them with their producer.

Sorry that my previous refactor leaves the code in an ugly state. 😝

@jjsjann123
Copy link
Collaborator

On this topic, could you consider changes like this one? csarofeen@9e9934e

Seems like we are not the sole user of fx partitioner. #87007

if getitem is relied by multiple backends, maybe it's a better that we keep the logic on merging getitem in op support simpler and have a common post processing pass to clean it up. I'm shamelessly promoting the patch I have in the commit above again :)

I'm not sure if the discussion here concerns you at all @wschin, would an getitem node at the beginning of partition break your parser?

@wschin
Copy link
Collaborator

wschin commented Oct 18, 2022

On this topic, could you consider changes like this one? csarofeen@9e9934e

Seems like we are not the sole user of fx partitioner. #87007

if getitem is relied by multiple backends, maybe it's a better that we keep the logic on merging getitem in op support simpler and have a common post processing pass to clean it up. I'm shamelessly promoting the patch I have in the commit above again :)

I'm not sure if the discussion here concerns you at all @wschin, would an getitem node at the beginning of partition break your parser?

Some getitems in the beginning and at the end should be fine. I just hope most of the computation can be partitioned into a single torch.fx.GraphModule. Having non-computation ops before and after the major computation should be fine; of course, those non-computation ops should be as less as possible. Many thanks!

@IvanYashchuk
Copy link
Collaborator Author

@SherlockNoMad could you please take another look at the proposed changes?


assignment[node] = id
if id not in partitions_by_id:
if id is None:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we really have case where id is None?
Looking at the call site for this function, it doesn't seem to have any None case...

If this cannot be None, let assert id is not None.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, this happens naturally when getitem is marked as supported by backends.

Later during the special handling of getitem, where we merge each getitem calls to its producer, we could run into cases where the producer is not supported by backends, but we accidentally merged getitem into the fusion. We'll ended up pulling these nodes out and that's when id here would be None.

merge_single_node(node, id)

This is checked in test case forward13

Copy link
Contributor

@SherlockNoMad SherlockNoMad left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, expect for the minor comment.

@jjsjann123
Copy link
Collaborator

Thanks to @SherlockNoMad 's stamp.

Failure seems to be on CI nodes and unrelated. Let's rebase to a stable commit and merge it! 🎉 🎉 🎉

@IvanYashchuk
Copy link
Collaborator Author

@pytorchbot merge -g

@IvanYashchuk IvanYashchuk added the topic: not user facing topic category label Oct 26, 2022
@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged once all checks on your PR pass since you used the green (-g) flag (ETA: 0-4 Hours).

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

@IvanYashchuk IvanYashchuk deleted the getitem-partitioner branch October 26, 2022 16:13
kulinseth pushed a commit to kulinseth/pytorch that referenced this pull request Nov 5, 2022
This special handling of getitem unnecessary splits fusions at functions with tuple outputs.

Example script:
```py
import torch
from torch.fx.passes.infra.partitioner import CapabilityBasedPartitioner
from torch._prims.nvfuser_executor import NvfuserPrimOperatorSupport
from torch.fx.experimental.proxy_tensor import make_fx

def func(x):
    xx = torch.ops.nvprims.add(x, 1)
    var, mean = torch.ops.nvprims.var_mean(x, correction=0)
    var_cos = torch.ops.nvprims.cos(var)
    mean_sin = torch.ops.nvprims.sin(mean)
    return torch.ops.nvprims.add(var_cos, mean_sin)

a = torch.randn(5, 3, 3, device="cuda")
gm = make_fx(func)(a)
gm.graph.print_tabular()

supported_ops = NvfuserPrimOperatorSupport()
partitioner = CapabilityBasedPartitioner(
    gm, supported_ops, allows_single_node_partition=False
)
partitions = partitioner.propose_partitions()
print(partitions)
partitioned_graph = partitioner.fuse_partitions(partitions)
partitioned_graph.graph.print_tabular()
```
Output on master:
```py
opcode         name       target                       args              kwargs
-------------  ---------  ---------------------------  ----------------  -----------------
placeholder    x_1        x_1                          ()                {}
call_function  add        nvprims.add.default          (x_1, 1)          {}
call_function  var_mean   nvprims.var_mean.main        (x_1, [0, 1, 2])  {'correction': 0}
call_function  getitem    <built-in function getitem>  (var_mean, 0)     {}
call_function  getitem_1  <built-in function getitem>  (var_mean, 1)     {}
call_function  cos        nvprims.cos.default          (getitem,)        {}
call_function  sin        nvprims.sin.default          (getitem_1,)      {}
call_function  add_1      nvprims.add.default          (cos, sin)        {}
output         output     output                       (add_1,)          {}
[{cos, sin, add_1}, {var_mean, add, getitem, getitem_1}]
opcode         name       target                       args                    kwargs
-------------  ---------  ---------------------------  ----------------------  --------
placeholder    x_1        x_1                          ()                      {}
call_module    fused_1    fused_1                      (x_1,)                  {}
call_function  getitem_2  <built-in function getitem>  (fused_1, 0)            {}
call_function  getitem_3  <built-in function getitem>  (fused_1, 1)            {}
call_module    fused_0    fused_0                      (getitem_2, getitem_3)  {}
output         output     output                       (fused_0,)              {}
```
Output with this PR:
```
[{var_mean, add_1, cos, sin, add, getitem_1, getitem}]
opcode       name     target    args        kwargs
-----------  -------  --------  ----------  --------
placeholder  x_1      x_1       ()          {}
call_module  fused_0  fused_0   (x_1,)      {}
output       output   output    (fused_0,)  {}
```

Pull Request resolved: pytorch#87073
Approved by: https://github.com/jjsjann123, https://github.com/SherlockNoMad
kulinseth pushed a commit to kulinseth/pytorch that referenced this pull request Dec 10, 2022
This special handling of getitem unnecessary splits fusions at functions with tuple outputs.

Example script:
```py
import torch
from torch.fx.passes.infra.partitioner import CapabilityBasedPartitioner
from torch._prims.nvfuser_executor import NvfuserPrimOperatorSupport
from torch.fx.experimental.proxy_tensor import make_fx

def func(x):
    xx = torch.ops.nvprims.add(x, 1)
    var, mean = torch.ops.nvprims.var_mean(x, correction=0)
    var_cos = torch.ops.nvprims.cos(var)
    mean_sin = torch.ops.nvprims.sin(mean)
    return torch.ops.nvprims.add(var_cos, mean_sin)

a = torch.randn(5, 3, 3, device="cuda")
gm = make_fx(func)(a)
gm.graph.print_tabular()

supported_ops = NvfuserPrimOperatorSupport()
partitioner = CapabilityBasedPartitioner(
    gm, supported_ops, allows_single_node_partition=False
)
partitions = partitioner.propose_partitions()
print(partitions)
partitioned_graph = partitioner.fuse_partitions(partitions)
partitioned_graph.graph.print_tabular()
```
Output on master:
```py
opcode         name       target                       args              kwargs
-------------  ---------  ---------------------------  ----------------  -----------------
placeholder    x_1        x_1                          ()                {}
call_function  add        nvprims.add.default          (x_1, 1)          {}
call_function  var_mean   nvprims.var_mean.main        (x_1, [0, 1, 2])  {'correction': 0}
call_function  getitem    <built-in function getitem>  (var_mean, 0)     {}
call_function  getitem_1  <built-in function getitem>  (var_mean, 1)     {}
call_function  cos        nvprims.cos.default          (getitem,)        {}
call_function  sin        nvprims.sin.default          (getitem_1,)      {}
call_function  add_1      nvprims.add.default          (cos, sin)        {}
output         output     output                       (add_1,)          {}
[{cos, sin, add_1}, {var_mean, add, getitem, getitem_1}]
opcode         name       target                       args                    kwargs
-------------  ---------  ---------------------------  ----------------------  --------
placeholder    x_1        x_1                          ()                      {}
call_module    fused_1    fused_1                      (x_1,)                  {}
call_function  getitem_2  <built-in function getitem>  (fused_1, 0)            {}
call_function  getitem_3  <built-in function getitem>  (fused_1, 1)            {}
call_module    fused_0    fused_0                      (getitem_2, getitem_3)  {}
output         output     output                       (fused_0,)              {}
```
Output with this PR:
```
[{var_mean, add_1, cos, sin, add, getitem_1, getitem}]
opcode       name     target    args        kwargs
-----------  -------  --------  ----------  --------
placeholder  x_1      x_1       ()          {}
call_module  fused_0  fused_0   (x_1,)      {}
output       output   output    (fused_0,)  {}
```

Pull Request resolved: pytorch#87073
Approved by: https://github.com/jjsjann123, https://github.com/SherlockNoMad
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ciflow/trunk Trigger trunk jobs on your pull request Merged module: fx.passes Optimization passes written in FX (don't forget to select a more specific label) open source release notes: fx release notes category topic: not user facing topic category triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module

Projects

None yet

Development

Successfully merging this pull request may close these issues.

8 participants