@@ -5179,6 +5179,11 @@ def max_pool2d_with_indices_backward(
51795179 is_channels_last = (x_stride is not None and x_stride [1 ] == 1 ) or (
51805180 gO_stride is not None and gO_stride [1 ] == 1
51815181 )
5182+ if any (d != 1 for d in dilation ):
5183+ # dilation NYI
5184+ return fallback_max_pool2d_with_indices_backward (
5185+ grad_output , x , kernel_size , stride , padding , dilation , ceil_mode , indices
5186+ )
51825187
51835188 * _batch , _height , width = x .get_size ()
51845189 * _ , pooled_height , pooled_width = grad_output .get_size ()
@@ -5187,17 +5192,13 @@ def max_pool2d_with_indices_backward(
51875192 grad_loader = grad_output .make_loader ()
51885193 new_size = list (x .get_size ())
51895194
5190- # Effective kernel size accounts for dilation
5191- effective_kh = (kernel_size [0 ] - 1 ) * dilation [0 ] + 1
5192- effective_kw = (kernel_size [1 ] - 1 ) * dilation [1 ] + 1
5193-
51945195 h_window_size = max (
5195- max (FloorDiv (h , stride [0 ]) - max (0 , FloorDiv (h - effective_kh , stride [0 ])), 1 )
5196- for h in range (effective_kh * 2 )
5196+ max (FloorDiv (h , stride [0 ]) - max (0 , FloorDiv (h - kernel_size [ 0 ] , stride [0 ])), 1 )
5197+ for h in range (kernel_size [ 0 ] * 2 )
51975198 )
51985199 w_window_size = max (
5199- max (FloorDiv (w , stride [1 ]) - max (0 , FloorDiv (w - effective_kw , stride [1 ])), 1 )
5200- for w in range (effective_kw * 2 )
5200+ max (FloorDiv (w , stride [1 ]) - max (0 , FloorDiv (w - kernel_size [ 1 ] , stride [1 ])), 1 )
5201+ for w in range (kernel_size [ 1 ] * 2 )
52015202 )
52025203
52035204 window_size = h_window_size * w_window_size
@@ -5216,10 +5217,10 @@ def fn(idx):
52165217 h = h + padding [0 ]
52175218 w = w + padding [1 ]
52185219 phstart = ops .index_expr (
5219- FloorDiv (h - effective_kh + stride [0 ], stride [0 ]), torch .int32
5220+ FloorDiv (h - kernel_size [ 0 ] + stride [0 ], stride [0 ]), torch .int32
52205221 )
52215222 pwstart = ops .index_expr (
5222- FloorDiv (w - effective_kw + stride [1 ], stride [1 ]), torch .int32
5223+ FloorDiv (w - kernel_size [ 1 ] + stride [1 ], stride [1 ]), torch .int32
52235224 )
52245225 phend = ops .index_expr (FloorDiv (h , stride [0 ]) + 1 , torch .int32 )
52255226 pwend = ops .index_expr (FloorDiv (w , stride [1 ]) + 1 , torch .int32 )
0 commit comments