Skip to content

Commit 62ab0c3

Browse files
committed
Update
[ghstack-poisoned]
2 parents a9bc045 + 593c31f commit 62ab0c3

5 files changed

Lines changed: 17 additions & 16 deletions

File tree

aten/src/ATen/native/cuda/SortingRadixSelect.cuh

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -521,6 +521,7 @@ __device__ __forceinline__ void countRadixAggregateCounts(
521521
for (uint32_t i = 0; i < RadixSize; ++i) {
522522
counts[i] = smem[buffer_offset + i];
523523
}
524+
__syncthreads(); // Wait for all threads to finish reading the final counts.
524525
}
525526

526527
// This function counts the distribution of all input values in a

test/inductor/pallas_expected_failures/CpuTests.test_max_pool2d_with_indices_backward6_cpu

Lines changed: 0 additions & 4 deletions
This file was deleted.

test/inductor/test_torchinductor.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10354,7 +10354,7 @@ def fn(a, b, c):
1035410354

1035510355
# From https://github.com/pytorch/pytorch/issues/93384
1035610356
def test_max_pool2d_with_indices_backward6(self):
10357-
# dilation is not 1. Should still generate kernels.
10357+
# dilation is not 1. Should fallback
1035810358
def fn(a, b, c):
1035910359
return aten.max_pool2d_with_indices_backward(
1036010360
a, b, [3, 2], [2, 1], [1, 1], [1, 2], False, c
@@ -10378,7 +10378,7 @@ def fn(a, b, c):
1037810378
indices,
1037910379
],
1038010380
)
10381-
assertGeneratedKernelCountEqual(self, 1)
10381+
assertGeneratedKernelCountEqual(self, 0)
1038210382

1038310383
def test_issue102546(self):
1038410384
def fn(x):

test/inductor/test_torchinductor_codegen_dynamic_shapes.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -219,6 +219,9 @@ def run(*ex, **kwargs):
219219
"test_max_pool2d_with_indices_backward5_dynamic_shapes": TestFailure(
220220
("cpu", "cuda")
221221
),
222+
"test_max_pool2d_with_indices_backward6_dynamic_shapes": TestFailure(
223+
("cpu", "cuda", "xpu")
224+
),
222225
"test_misaligned_address_issue1_dynamic_shapes": TestFailure(("cpu",)),
223226
"test_mm_views_dynamic_shapes": TestFailure(("cpu", "cuda", "xpu")),
224227
"test_new_empty_dynamic_shapes": TestFailure(("cpu", "cuda", "xpu")),

torch/_inductor/lowering.py

Lines changed: 11 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -5179,6 +5179,11 @@ def max_pool2d_with_indices_backward(
51795179
is_channels_last = (x_stride is not None and x_stride[1] == 1) or (
51805180
gO_stride is not None and gO_stride[1] == 1
51815181
)
5182+
if any(d != 1 for d in dilation):
5183+
# dilation NYI
5184+
return fallback_max_pool2d_with_indices_backward(
5185+
grad_output, x, kernel_size, stride, padding, dilation, ceil_mode, indices
5186+
)
51825187

51835188
*_batch, _height, width = x.get_size()
51845189
*_, pooled_height, pooled_width = grad_output.get_size()
@@ -5187,17 +5192,13 @@ def max_pool2d_with_indices_backward(
51875192
grad_loader = grad_output.make_loader()
51885193
new_size = list(x.get_size())
51895194

5190-
# Effective kernel size accounts for dilation
5191-
effective_kh = (kernel_size[0] - 1) * dilation[0] + 1
5192-
effective_kw = (kernel_size[1] - 1) * dilation[1] + 1
5193-
51945195
h_window_size = max(
5195-
max(FloorDiv(h, stride[0]) - max(0, FloorDiv(h - effective_kh, stride[0])), 1)
5196-
for h in range(effective_kh * 2)
5196+
max(FloorDiv(h, stride[0]) - max(0, FloorDiv(h - kernel_size[0], stride[0])), 1)
5197+
for h in range(kernel_size[0] * 2)
51975198
)
51985199
w_window_size = max(
5199-
max(FloorDiv(w, stride[1]) - max(0, FloorDiv(w - effective_kw, stride[1])), 1)
5200-
for w in range(effective_kw * 2)
5200+
max(FloorDiv(w, stride[1]) - max(0, FloorDiv(w - kernel_size[1], stride[1])), 1)
5201+
for w in range(kernel_size[1] * 2)
52015202
)
52025203

52035204
window_size = h_window_size * w_window_size
@@ -5216,10 +5217,10 @@ def fn(idx):
52165217
h = h + padding[0]
52175218
w = w + padding[1]
52185219
phstart = ops.index_expr(
5219-
FloorDiv(h - effective_kh + stride[0], stride[0]), torch.int32
5220+
FloorDiv(h - kernel_size[0] + stride[0], stride[0]), torch.int32
52205221
)
52215222
pwstart = ops.index_expr(
5222-
FloorDiv(w - effective_kw + stride[1], stride[1]), torch.int32
5223+
FloorDiv(w - kernel_size[1] + stride[1], stride[1]), torch.int32
52235224
)
52245225
phend = ops.index_expr(FloorDiv(h, stride[0]) + 1, torch.int32)
52255226
pwend = ops.index_expr(FloorDiv(w, stride[1]) + 1, torch.int32)

0 commit comments

Comments
 (0)