Skip to content

Commit bcf6488

Browse files
committed
Put warning behind conditional; test for it at Python level
1 parent e51366f commit bcf6488

2 files changed

Lines changed: 29 additions & 14 deletions

File tree

test/test_image.py

Lines changed: 22 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import concurrent.futures
2+
import contextlib
23
import glob
34
import io
45
import os
@@ -935,17 +936,29 @@ def test_decode_webp(decode_fun, scripted):
935936

936937

937938
@pytest.mark.parametrize("decode_fun", (decode_webp, decode_image))
938-
def test_decode_webp_grayscale(decode_fun):
939+
def test_decode_webp_grayscale(decode_fun, capfd):
939940
encoded_bytes = read_file(next(get_images(FAKEDATA_DIR, ".webp")))
940941

941-
# Note that we warn at the C++ layer because dispatching for decode_image
942-
# doesn't happen until we hit C++. The C++ layer does not propagate
943-
# warnings up to Python, so we can't test for them.
944-
img = decode_fun(encoded_bytes, mode=ImageReadMode.GRAY)
945-
946-
# Note that because we do not support grayscale conversions, we expect
947-
# that the number of color channels is still 3.
948-
assert img.shape == (3, 100, 100)
942+
# We warn at the C++ layer because for decode_image(), we don't do the image
943+
# type dispatch until we get to the C++ version of decode_image(). We could
944+
# warn at the Python layer in decode_webp(), but then users would get a
945+
# double wanring: one from the Python layer and one from the C++ layer.
946+
#
947+
# Because we use the TORCH_WARN_ONCE macro, we need to do this dance to
948+
# temporarily always warn so we can test.
949+
@contextlib.contextmanager
950+
def set_always_warn():
951+
torch._C._set_warnAlways(True)
952+
yield
953+
torch._C._set_warnAlways(False)
954+
955+
with set_always_warn():
956+
img = decode_fun(encoded_bytes, mode=ImageReadMode.GRAY)
957+
assert "Webp does not support grayscale conversions" in capfd.readouterr().err
958+
959+
# Note that because we do not support grayscale conversions, we expect
960+
# that the number of color channels is still 3.
961+
assert img.shape == (3, 100, 100)
949962

950963

951964
# This test is skipped by default because it requires webp images that we're not

torchvision/csrc/io/image/cpu/decode_webp.cpp

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -32,11 +32,13 @@ torch::Tensor decode_webp(
3232
res == VP8_STATUS_OK, "WebPGetFeatures failed with error code ", res);
3333
TORCH_CHECK(
3434
!features.has_animation, "Animated webp files are not supported.");
35-
TORCH_WARN_ONCE(
36-
mode == IMAGE_READ_MODE_GRAY ||
37-
mode == IMAGE_READ_MODE_GRAY_ALPHA,
38-
"Webp does not support grayscale conversions. "
39-
"The returned tensor will be in the colorspace of the original image.");
35+
36+
if (mode == IMAGE_READ_MODE_GRAY ||
37+
mode == IMAGE_READ_MODE_GRAY_ALPHA) {
38+
TORCH_WARN_ONCE(
39+
"Webp does not support grayscale conversions. "
40+
"The returned tensor will be in the colorspace of the original image.");
41+
}
4042

4143
auto return_rgb =
4244
should_this_return_rgb_or_rgba_let_me_know_in_the_comments_down_below_guys_see_you_in_the_next_video(

0 commit comments

Comments
 (0)