Skip to content

[torchax] Fix functional.max_pool#8814

Merged
qihqi merged 1 commit intopytorch:masterfrom
dvhg:maxpool
Mar 12, 2025
Merged

[torchax] Fix functional.max_pool#8814
qihqi merged 1 commit intopytorch:masterfrom
dvhg:maxpool

Conversation

@dvhg
Copy link
Copy Markdown
Contributor

@dvhg dvhg commented Mar 10, 2025

Fix #8086.

Also possibly related to #8241.

To summarize the changes to max pool:

  1. Dilation now correctly accounted for, including in padding calculation following the formula for output shape from here
  2. Default stride is equal to kernel size
  3. Indices are computed per-batch instead of across the entire input
  4. Ties break towards higher indices (i.e. pick the higher index if there is more than 1 max value in a window). This is arbitrary but matches Torch's behavior on CPU.

Copy link
Copy Markdown
Collaborator

@ManfeiBai ManfeiBai left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks, LGTM

@qihqi qihqi merged commit 0aac10e into pytorch:master Mar 12, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

[hard] nn.functional.max_pool2d and nn.functional.max_pool3d

3 participants