|
1 | 1 | import concurrent.futures |
| 2 | +import contextlib |
2 | 3 | import glob |
3 | 4 | import io |
4 | 5 | import os |
@@ -935,17 +936,29 @@ def test_decode_webp(decode_fun, scripted): |
935 | 936 |
|
936 | 937 |
|
937 | 938 | @pytest.mark.parametrize("decode_fun", (decode_webp, decode_image)) |
938 | | -def test_decode_webp_grayscale(decode_fun): |
| 939 | +def test_decode_webp_grayscale(decode_fun, capfd): |
939 | 940 | encoded_bytes = read_file(next(get_images(FAKEDATA_DIR, ".webp"))) |
940 | 941 |
|
941 | | - # Note that we warn at the C++ layer because dispatching for decode_image |
942 | | - # doesn't happen until we hit C++. The C++ layer does not propagate |
943 | | - # warnings up to Python, so we can't test for them. |
944 | | - img = decode_fun(encoded_bytes, mode=ImageReadMode.GRAY) |
945 | | - |
946 | | - # Note that because we do not support grayscale conversions, we expect |
947 | | - # that the number of color channels is still 3. |
948 | | - assert img.shape == (3, 100, 100) |
| 942 | + # We warn at the C++ layer because for decode_image(), we don't do the image |
| 943 | + # type dispatch until we get to the C++ version of decode_image(). We could |
| 944 | + # warn at the Python layer in decode_webp(), but then users would get a |
| 945 | + # double wanring: one from the Python layer and one from the C++ layer. |
| 946 | + # |
| 947 | + # Because we use the TORCH_WARN_ONCE macro, we need to do this dance to |
| 948 | + # temporarily always warn so we can test. |
| 949 | + @contextlib.contextmanager |
| 950 | + def set_always_warn(): |
| 951 | + torch._C._set_warnAlways(True) |
| 952 | + yield |
| 953 | + torch._C._set_warnAlways(False) |
| 954 | + |
| 955 | + with set_always_warn(): |
| 956 | + img = decode_fun(encoded_bytes, mode=ImageReadMode.GRAY) |
| 957 | + assert "Webp does not support grayscale conversions" in capfd.readouterr().err |
| 958 | + |
| 959 | + # Note that because we do not support grayscale conversions, we expect |
| 960 | + # that the number of color channels is still 3. |
| 961 | + assert img.shape == (3, 100, 100) |
949 | 962 |
|
950 | 963 |
|
951 | 964 | # This test is skipped by default because it requires webp images that we're not |
|
0 commit comments