Skip to content

improve noop elimination for view#151095

Closed
BoyuanFeng wants to merge 11 commits intomainfrom
bf/noop-elimination
Closed

improve noop elimination for view#151095
BoyuanFeng wants to merge 11 commits intomainfrom
bf/noop-elimination

Conversation

@BoyuanFeng
Copy link
Copy Markdown
Contributor

@BoyuanFeng BoyuanFeng commented Apr 11, 2025

This PR improves noop elimination.

View Noop

>>> torch.Size([1,2,3]) == [1,2,3]
False
>>> torch.Size([1,2,3]) == (1,2,3)
True

So we add tuple(size) in view_noop.

Example:

import torch

@torch.compile()
def f(x):
    batch_size = x.shape[0]
    x = x.transpose(1, 2) # (batch_size, 2, 3)
    x = x.reshape(batch_size, 2, 3) # noop
    return x

x = torch.randn((2,3,2))
f(x)

x = torch.randn((4,3,2))
f(x)

Before:
image

After:
image

cc @H-Huang @awgu @wanchaol @fegin @fduwjj @wz337 @wconstab @d4l3k @voznesenskym @penguinwu @EikanWang @jgong5 @Guobing-Chen @XiaobingSuper @zhuhaozhe @blzheng @wenzhe-nrv @jiayisunx @ipiszy @chenyang78 @kadeng @muchulee8 @amjames @chauhang @aakhundov

@BoyuanFeng BoyuanFeng added ciflow/trunk Trigger trunk jobs on your pull request topic: not user facing topic category module: inductor labels Apr 11, 2025
@pytorch-bot
Copy link
Copy Markdown

pytorch-bot bot commented Apr 11, 2025

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/151095

Note: Links to docs will display an error until the docs builds have been completed.

❌ 1 New Failure, 1 Pending, 14 Unrelated Failures

As of commit 1c70eae with merge base 1f29190 (image):

NEW FAILURE - The following job has failed:

FLAKY - The following jobs failed but were likely due to flakiness present on trunk:

BROKEN TRUNK - The following jobs failed but were present on the merge base:

👉 Rebase onto the `viable/strict` branch to avoid these failures

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@BoyuanFeng BoyuanFeng marked this pull request as draft April 11, 2025 16:37
@pytorch-bot pytorch-bot bot added the oncall: distributed Add this issue/PR to distributed oncall triage queue label Apr 12, 2025
self.assertEqual(
all_gathers[2].res_node.target,
torch.ops.aten.view.dtype,
torch.ops._c10d_functional.wait_tensor.default,
Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

view.dtype is noop eliminated cc @weifengpy

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

offline discussed with @weifengpy and the change makes sense

@BoyuanFeng BoyuanFeng changed the title improve noop elimination improve noop elimination for view Apr 13, 2025
b_inp = functools.partial(torch.empty, (1, 1, 8, 8), device=device)
m_inp = functools.partial(torch.empty, (2, 1, 1, 4), device=device)
# need 2d attn_mask to generate patterns with view op
m_inp_2d = functools.partial(torch.empty, (2, 4), device=device)
Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

#113004 added _sfdp_pattern_15, which takes a 4d m_inp (shape: (2,1,1,4)) as attn_mask and calls (attn_mask == 0).view((bs, 1, 1, k_len)). After this pr, the view op will be noop eliminated so the search pattern will not have that view op. In _test_sdpa_rewriter_15, the graph still takes a 2d mask so the view op will not be noop eliminated. This difference prevents pattern match.

The change here still passes a 2d attn_mask so the view op is not noop eliminated in the search pattern.

cc @Valentine233 @jgong5 @leslie-fang-intel

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM. Thanks!

@BoyuanFeng BoyuanFeng marked this pull request as ready for review April 16, 2025 03:51
@BoyuanFeng BoyuanFeng requested review from eellison and zou3519 April 16, 2025 16:42
Copy link
Copy Markdown
Contributor

@eellison eellison left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks good

Comment on lines +886 to +888
def view_default_noop(arg, size):
return arg.shape == tuple(size)

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These should use statically_known_true here. We should not add a guard on arg[0] == size[0] if then arg[1] != size[1].

would be nice to have recursive api for this.

but just

return len(arg.shape) == len(size) and all(statically_known_true(a == b) for zip(arg.shape, size))

works for now.

cc @laithsakka

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

apparently statically_known_true(sym_eq(ls1,ls2)) does this - thx u @laithsakka

return arg.shape == size
@register_noop_decomp(aten.view.default)
def view_default_noop(arg, size):
return statically_known_true(sym_eq(arg.shape, tuple(size)))
Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

arg.shape is a tuple but not a list. sym_eq returns False if one input is tuple and another input is list.

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Did this used to specialize on size?

@BoyuanFeng
Copy link
Copy Markdown
Contributor Author

@pytorchbot merge -f "skip unrelated export failure"

@pytorchmergebot
Copy link
Copy Markdown
Collaborator

Merge started

Your change will be merged immediately since you used the force (-f) flag, bypassing any CI checks (ETA: 1-5 minutes). Please use -f as last resort and instead consider -i/--ignore-current to continue the merge ignoring current failures. This will allow currently pending tests to finish and report signal before the merge.

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

amathewc pushed a commit to amathewc/pytorch that referenced this pull request Apr 17, 2025
This PR improves noop elimination.

### View Noop

```python
>>> torch.Size([1,2,3]) == [1,2,3]
False
>>> torch.Size([1,2,3]) == (1,2,3)
True
```
So we add `tuple(size)` in `view_noop`.

Example:
```python
import torch

@torch.compile()
def f(x):
    batch_size = x.shape[0]
    x = x.transpose(1, 2) # (batch_size, 2, 3)
    x = x.reshape(batch_size, 2, 3) # noop
    return x

x = torch.randn((2,3,2))
f(x)

x = torch.randn((4,3,2))
f(x)
```

Before:
![image](https://github.com/user-attachments/assets/be488881-6c99-43a9-b088-fa481f675775)

After:
![image](https://github.com/user-attachments/assets/6d93be3d-128b-44d4-ad6a-d3d18e272329)

Pull Request resolved: pytorch#151095
Approved by: https://github.com/eellison
@github-actions github-actions bot deleted the bf/noop-elimination branch May 28, 2025 02:18
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ciflow/inductor ciflow/trunk Trigger trunk jobs on your pull request Merged module: inductor oncall: distributed Add this issue/PR to distributed oncall triage queue topic: not user facing topic category

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants