Skip to content

Avoid fallback for avg_pool -#6409

Merged
qihqi merged 2 commits intomasterfrom
qihqi/core_aten_ops
Jan 31, 2024
Merged

Avoid fallback for avg_pool -#6409
qihqi merged 2 commits intomasterfrom
qihqi/core_aten_ops

Conversation

@qihqi
Copy link
Copy Markdown
Collaborator

@qihqi qihqi commented Jan 29, 2024

By supporting divisor overrides and ceil_mode + count_include_pad property.

When count_include_pad is True, the number of paddings also appear in the denominator. But, if ceil_mode is also true, then, when we round, we can introduce extra padding. ** these paddings are NOT counted in the denominator** Therefore, when ceil_mode is true, we need to manually pad to distinguish padding that should count for denominator and those that shouldnt (i.e. those introduced by ceil_mode)

@qihqi qihqi requested a review from wonjoo-wj January 29, 2024 22:08
@qihqi qihqi force-pushed the qihqi/core_aten_ops branch 3 times, most recently from 49fb09a to 791fa27 Compare January 30, 2024 00:12
By supporting divisor overrides properly.
@qihqi qihqi force-pushed the qihqi/core_aten_ops branch from 791fa27 to 1edf387 Compare January 30, 2024 00:14
@wonjoo-wj
Copy link
Copy Markdown
Collaborator

Changes LGTM, but seems like CI failed:

======================================================================
FAIL: test_aten_avg_pool2d_3 (__main__.AtenOpTest) [torch_xla_diff:0.001]
----------------------------------------------------------------------
Traceback (most recent call last):
  File "/tmp/pytorch/xla/test/test_core_aten_ops.py", line 72, in run_export_and_compare
    diff_output(
  File "/tmp/pytorch/xla/test/test_core_aten_ops.py", line 33, in diff_output
    testcase.assertTrue(
AssertionError: False is not true

Maybe just need to adjust rtol/atol?

@qihqi qihqi force-pushed the qihqi/core_aten_ops branch 3 times, most recently from 39e4282 to 3a15b4a Compare January 31, 2024 05:10
@qihqi qihqi force-pushed the qihqi/core_aten_ops branch from 3a15b4a to 2e1d1b4 Compare January 31, 2024 17:13
@qihqi qihqi merged commit a0bae82 into master Jan 31, 2024
@qihqi qihqi deleted the qihqi/core_aten_ops branch January 31, 2024 22:23
@wonjoo-wj
Copy link
Copy Markdown
Collaborator

wonjoo-wj commented Jan 31, 2024

@qihqi, thanks! Could you link the github issues that this PR fixes, if there are any?

yeounoh added a commit that referenced this pull request Feb 1, 2024
yeounoh added a commit that referenced this pull request Feb 1, 2024
yeounoh added a commit that referenced this pull request Feb 1, 2024
yeounoh added a commit that referenced this pull request Feb 1, 2024
@qihqi
Copy link
Copy Markdown
Collaborator Author

qihqi commented Feb 2, 2024

It looks like from Jan 31 -> Feb 1 we have improved performance of some vision models, this might has helped (as it avoids CPU fallback):

image

amithrm pushed a commit to amithrm/xla that referenced this pull request Mar 1, 2024
By supporting divisor overrides and ceil_mode + count_include_pad property.

When count_include_pad is True, the number of paddings also appear in the denominator. But, if ceil_mode is also true, then, when we round, we can introduce extra padding. ** these paddings are NOT counted in the denominator** Therefore, when ceil_mode is true, we need to manually pad to distinguish padding that should count for denominator and those that shouldnt (i.e. those introduced by ceil_mode)
bhavya01 pushed a commit that referenced this pull request Apr 22, 2024
By supporting divisor overrides and ceil_mode + count_include_pad property.

When count_include_pad is True, the number of paddings also appear in the denominator. But, if ceil_mode is also true, then, when we round, we can introduce extra padding. ** these paddings are NOT counted in the denominator** Therefore, when ceil_mode is true, we need to manually pad to distinguish padding that should count for denominator and those that shouldnt (i.e. those introduced by ceil_mode)
LPanosTT added a commit to tenstorrent/tt-mlir that referenced this pull request Jun 18, 2025
…tir.full` (#3826)

**This changes/additions in this PR pertains to the effort to use xla
for compiling PyTorch models in tt-torch**

### Ticket
No tt-mlir issue. There is a metal issue I've filed but this is not
solely a workaround to an op which will not run in tt-metal
- Metal issue: tenstorrent/tt-metal#23617
- Another metal issue:
tenstorrent/tt-metal#23581

### Problem description
torch-xla will decompose avg_pool2d into a sum-pool on the input tensor,
and a divisor with the size of the window. **However**, the denominator
is not a constant. Instead it is calculated to be the result of another
sum-pool. This sum-pool is applied to a constant tensor containing only
`1.0` The sum pool ends up making a tensor with the same spatial
dimensions as the activation, and an element wise division is performed
to get the same end result:

```
%60 = "ttir.full" (full tensor of 1.0 with shape 56x56)
...
%704 = "ttir.pooling"(%702, %703) <{base_dilations = array<i64: 1, 1, 1, 1>, operandSegmentSizes = array<i32: 1, 1>, padding = array<i64: 0, 0, 0, 0, 0, 0, 0, 0>, pooling_method = #ttir<pooling_method Sum>, window_dilations = array<i64: 1, 1, 1, 1>, window_dimensions = array<i64: 1, 1, 2, 2>, window_strides = array<i64: 1, 1, 2, 2>}> : (tensor<1x128x56x56xbf16>, tensor<1x128x28x28xbf16>) -> tensor<1x128x28x28xbf16>
%705 = ttir.empty() : tensor<28x28xbf16>
%706 = "ttir.pooling"(%60, %705) <{base_dilations = array<i64: 1, 1>, operandSegmentSizes = array<i32: 1, 1>, padding = array<i64: 0, 0, 0, 0>, pooling_method = #ttir<pooling_method Sum>, window_dilations = array<i64: 1, 1>, window_dimensions = array<i64: 2, 2>, window_strides = array<i64: 2, 2>}> : (tensor<56x56xbf16>, tensor<28x28xbf16>) -> tensor<28x28xbf16>
... reshape the denominator (to unsqueeze)
... broadcast the denominator (so the channel dim is identical)
%712 = "ttir.div" (%704, %706) -> ...
```

The reason torch-xla does this is seemingly to handle an edge case where
the kwargs `count_include_pad = True` and `ceil_mode = True`. However it
is actually applied across the board.

- torch-xla PR where this was made:
pytorch/xla#6409

**Futhermore**, this sum-pool itself is not even a valid pooling
operation in PyTorch or ttnn as it the input tensor is 2D. PyTorch
expects at least channels dim, and ttnn expects a channel dim and a
batch dim. So if we were to instead rely on ttnn to compute this
properly, and consteval the result. The lowering pattern for this
`ttir.pooling` op in `TTIRToTTIRDecomposition` would require reshapes to
be placed on the input and output. Which isn't necessarily a blocker.
However a future fusing pattern to convert `div(sum_pool, const)` to
`avg_pool` would be needlessly more complex if we also needed to match
the case where the denominator is the result of another `sum_pool` of a
constant which has reshapes on the input and output.

### What's changed
Added a TTIRToTTIRDecomposition pattern for `ttir.pooling` to replace
the operations results with `ttir.full` containing the correct values.
No computation is required as the result of such a pattern is
straightforward

### Checklist
- [X] New/Existing tests provide coverage for changes
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants