Skip to content

Commit 233b949

Browse files
bdhirshfacebook-github-bot
authored andcommitted
fix channels_last bug in upsample kernels (#53535)
Summary: Pull Request resolved: #53535 During the port to structured kernels for upsample kernels, I missed that a subset of them explicitly pass `memory_format` information from the input to the output tensors. Note 1: I added the logic into the `meta` function of each op, which feels morally correct since this logic affects the output shape/metadata. One consequence is that all backend implementations will get the logic. I synced with fmassa that this seems reasonable. Note 2: This logic used to happen in the following operators, which this PR fixes: - upsample_nearest3d - upsample_trilinear3d - upsample_nearest2d - upsample_bilinear2d I explicitly didn't patch the other upsample kernels, which look like they never forwarded memory_format information: - `upsample_bicubic2d` (maybe this should though? `UpSampleBicubic2d.cpp` isn't currently written to do anything different for `channels_last` tensors) - All of the `upsample_{mode}1d` operators. Probably because, afaik, channels_last isn't supported for 3d tensors - The corresponding backwards operator for every upsample op. Note 3: I'm also wondering why memory_format isn't just directly a part of the `tensor::options()` method, which would cause all ops to universally forward memory_format information from input to output tensors, rather than just the upsample ops. My guess is: - BC-breakage. I'm not sure whether this would really *break* people, but it's an API change - performance. `tensor::options()` is called everywhere, and adding a call to `suggest_memory_format()` would probably noticeably hit microbenchmarks. We could probably deal with that by making `memory_format` a precomputed field on the tensor? Test Plan: Imported from OSS Reviewed By: H-Huang Differential Revision: D26891540 Pulled By: bdhirsh fbshipit-source-id: b3845f4dd5646b88bf738b9e41fe829be6b0e5cf
1 parent a346521 commit 233b949

5 files changed

Lines changed: 25 additions & 4 deletions

File tree

aten/src/ATen/native/UpSampleBilinear2d.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@ TORCH_META_FUNC(upsample_bilinear2d) (
1919
"Non-empty 4D data tensor expected but got a tensor with sizes ",
2020
input.sizes());
2121

22-
set_output(full_output_size, input.options());
22+
set_output(full_output_size, input.options().memory_format(input.suggest_memory_format()));
2323
}
2424

2525
TORCH_META_FUNC(upsample_bilinear2d_backward) (

aten/src/ATen/native/UpSampleNearest2d.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@ TORCH_META_FUNC(upsample_nearest2d) (
1717
"Non-empty 4D data tensor expected but got a tensor with sizes ",
1818
input.sizes());
1919

20-
set_output(full_output_size, input.options());
20+
set_output(full_output_size, input.options().memory_format(input.suggest_memory_format()));
2121
}
2222

2323
TORCH_META_FUNC(upsample_nearest2d_backward) (

aten/src/ATen/native/UpSampleNearest3d.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@ TORCH_META_FUNC(upsample_nearest3d) (
2020
"Non-empty 5D data tensor expected but got a tensor with sizes ",
2121
input.sizes());
2222

23-
set_output(full_output_size, input.options());
23+
set_output(full_output_size, input.options().memory_format(input.suggest_memory_format()));
2424
}
2525

2626
TORCH_META_FUNC(upsample_nearest3d_backward) (

aten/src/ATen/native/UpSampleTrilinear3d.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@ TORCH_META_FUNC(upsample_trilinear3d) (
2424
"Non-empty 5D data tensor expected but got a tensor with sizes ",
2525
input.sizes());
2626

27-
set_output(full_output_size, input.options());
27+
set_output(full_output_size, input.options().memory_format(input.suggest_memory_format()));
2828
}
2929

3030
TORCH_META_FUNC(upsample_trilinear3d_backward) (

test/test_torch.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5548,6 +5548,27 @@ def test_memory_format_preserved_after_permute(self, device):
55485548
y = ndhwc.permute(0, 1, 4, 3, 2).permute(0, 1, 4, 3, 2)
55495549
self.assertTrue(y.is_contiguous(memory_format=torch.channels_last_3d))
55505550

5551+
def test_memory_format_preserved_after_upsample(self, device):
5552+
x = torch.randn(4, 3, 8, 8, device=device)
5553+
nhwc = x.contiguous(memory_format=torch.channels_last)
5554+
y = torch._C._nn.upsample_nearest2d(nhwc, (2, 2))
5555+
self.assertTrue(y.is_contiguous(memory_format=torch.channels_last))
5556+
5557+
x = torch.randn(4, 3, 8, 8, device=device)
5558+
nhwc = x.contiguous(memory_format=torch.channels_last)
5559+
y = torch._C._nn.upsample_bilinear2d(nhwc, (2, 2), True)
5560+
self.assertTrue(y.is_contiguous(memory_format=torch.channels_last))
5561+
5562+
x = torch.randn(4, 3, 8, 8, 8, device=device)
5563+
nhwc = x.contiguous(memory_format=torch.channels_last_3d)
5564+
y = torch._C._nn.upsample_nearest3d(nhwc, (2, 2, 2))
5565+
self.assertTrue(y.is_contiguous(memory_format=torch.channels_last_3d))
5566+
5567+
x = torch.randn(4, 3, 8, 8, 8, device=device)
5568+
nhwc = x.contiguous(memory_format=torch.channels_last_3d)
5569+
y = torch._C._nn.upsample_trilinear3d(nhwc, (2, 2, 2), True)
5570+
self.assertTrue(y.is_contiguous(memory_format=torch.channels_last_3d))
5571+
55515572
def test_memory_format_propagation_rules(self, device):
55525573

55535574
contiguous = torch.rand(10, 3, 5, 5, device=device)

0 commit comments

Comments
 (0)