Skip to content

[CUDA] Large tensor maxpool crash fix#165374

Closed
Isalia20 wants to merge 6 commits intopytorch:mainfrom
Isalia20:cuda-maxpool-int64
Closed

[CUDA] Large tensor maxpool crash fix#165374
Isalia20 wants to merge 6 commits intopytorch:mainfrom
Isalia20:cuda-maxpool-int64

Conversation

@Isalia20
Copy link
Collaborator

@Isalia20 Isalia20 commented Oct 13, 2025

@pytorch-bot pytorch-bot bot added the release notes: nn release notes category label Oct 13, 2025
@pytorch-bot
Copy link

pytorch-bot bot commented Oct 13, 2025

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/165374

Note: Links to docs will display an error until the docs builds have been completed.

⏳ No Failures, 1 Pending

As of commit 2ec7dd1 with merge base 01738a3 (image):
💚 Looks good so far! There are no failures yet. 💚

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@Isalia20 Isalia20 added topic: bug fixes topic category module: cuda Related to torch.cuda, and CUDA support in general labels Oct 13, 2025
@eqy eqy added the ciflow/trunk Trigger trunk jobs on your pull request label Oct 13, 2025

static __device__ inline int p_start(int size, int pad, int kernel, int dilation, int stride) {
return (size + pad < ((kernel - 1) * dilation + 1)) ? 0 : (size + pad - ((kernel - 1) * dilation + 1)) / stride + 1;
template <typename T>
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: index_t

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

updated

int64_t in_stride_n, int64_t in_stride_c,
int64_t in_stride_h, int64_t in_stride_w)
{
const int64_t int_max = std::numeric_limits<int>::max();
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

constexpr?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

updateed

test/test_nn.py Outdated
# https://github.com/pytorch/pytorch/issues/165297
N, C, H, W = 70, 64, 512, 960 # dims to extend > int32
device = torch.device("cuda")
x_cuda = torch.randn(N, C, H, W, device=device)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do memory requirements go down if a narrower dtype such as half is used?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, decreased needed memory and updated test to use float16. Initially I wanted to trigger the same illegal memory access that was in the issue with float32 but testing in float16 should be sufficient as well since we compare it to nchw format for correctness

@pytorch-bot pytorch-bot bot removed ciflow/trunk Trigger trunk jobs on your pull request ciflow/h100 labels Oct 14, 2025
@eqy eqy added the ciflow/trunk Trigger trunk jobs on your pull request label Oct 14, 2025
@eqy
Copy link
Collaborator

eqy commented Oct 14, 2025

@pytorchmergebot rebase

@pytorchmergebot
Copy link
Collaborator

@pytorchbot started a rebase job onto refs/remotes/origin/viable/strict. Check the current status here

@pytorchmergebot
Copy link
Collaborator

Tried to rebase and push PR #165374, but it was already up to date. Try rebasing against main by issuing:
@pytorchbot rebase -b main

@Isalia20
Copy link
Collaborator Author

@pytorchbot rebase -b main

@pytorchmergebot
Copy link
Collaborator

@pytorchbot started a rebase job onto refs/remotes/origin/main. Check the current status here

@pytorchmergebot
Copy link
Collaborator

Successfully rebased cuda-maxpool-int64 onto refs/remotes/origin/main, please pull locally before adding more changes (for example, via git checkout cuda-maxpool-int64 && git pull --rebase)

@pytorch-bot pytorch-bot bot removed the ciflow/trunk Trigger trunk jobs on your pull request label Oct 14, 2025
@Isalia20
Copy link
Collaborator Author

@pytorchbot merge

@pytorch-bot
Copy link

pytorch-bot bot commented Oct 15, 2025

Pull workflow has not been scheduled for the PR yet. It could be because author doesn't have permissions to run those or skip-checks keywords were added to PR/commits, aborting merge. Please get/give approval for the workflows and/or remove skip ci decorators before next merge attempt. If you think this is a mistake, please contact PyTorch Dev Infra.

@Isalia20
Copy link
Collaborator Author

@eqy Need an approval of workflow here and then we can merge I guess

Comment on lines +41 to +44
template <typename index_t>
__device__ inline index_t dmin(index_t a, index_t b) {
return a <= b ? a : b;
}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What's wrong with std::min?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

updated


template <typename index_t>
static __device__ inline index_t p_start(index_t size, int pad, int kernel, int dilation, int stride) {
const index_t kernel_extent = static_cast<index_t>((kernel - 1) * dilation + 1);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit

Suggested change
const index_t kernel_extent = static_cast<index_t>((kernel - 1) * dilation + 1);
const auto kernel_extent = static_cast<index_t>((kernel - 1) * dilation + 1);

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

updated

"fractional_max_pool2d requires output_ratio to either be a single Int or tuple of Ints."):
res = arg_class(*arg_3)

@unittest.skipIf(not TEST_CUDA, "CUDA not available")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This makes me sad, we really should use device and @onlyCUDA decorator

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

updated

Comment on lines +7504 to +7505
device = torch.device("cuda")
x_cuda = torch.randn(N, C, H, W, device=device, dtype=torch.float16)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit

Suggested change
device = torch.device("cuda")
x_cuda = torch.randn(N, C, H, W, device=device, dtype=torch.float16)
x_cuda = torch.randn(N, C, H, W, device="cuda", dtype=torch.float16)

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

updated

@Isalia20 Isalia20 added the ciflow/trunk Trigger trunk jobs on your pull request label Oct 15, 2025
@Isalia20
Copy link
Collaborator Author

@pytorchbot merge

@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged once all checks pass (ETA 0-4 Hours).

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

Chao1Han pushed a commit to Chao1Han/pytorch that referenced this pull request Oct 21, 2025
zhudada0120 pushed a commit to zhudada0120/pytorch that referenced this pull request Oct 22, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ciflow/trunk Trigger trunk jobs on your pull request Merged module: cuda Related to torch.cuda, and CUDA support in general open source release notes: nn release notes category topic: bug fixes topic category

Projects

None yet

Development

Successfully merging this pull request may close these issues.

[BUG] MaxPool2d with channels_last + bfloat16 on CUDA produces NaNs for large tensors

5 participants