improve perf on convert_image_dtype and add tests#6795
Conversation
| @pytest.mark.parametrize( | ||
| ("info", "args_kwargs"), | ||
| make_info_args_kwargs_params( | ||
| next(info for info in KERNEL_INFOS if info.kernel is F.convert_image_dtype), | ||
| args_kwargs_fn=lambda info: info.sample_inputs_fn(), | ||
| ), | ||
| ) |
There was a problem hiding this comment.
This is rather convoluted to get the sample inputs for a single kernel. I'll refactor later since this is low priority right now.
There was a problem hiding this comment.
Another round of benchmarks after the new commits. Benchmark removed CUDA test and only tested one thread. On the flip side, measurements are now running longer to reduce the noise.
benchmark script
import pathlib
import pickle
import torch
from torch.utils import benchmark
import functools
from torchvision.prototype.transforms import functional as F
description = "PR" # "main", "PR"
def make_inputs(*, input_dtype, output_dtype, shape=(3, 512, 512)):
if input_dtype.is_floating_point:
image = torch.rand(shape, dtype=input_dtype)
else:
image = torch.randint(0, torch.iinfo(input_dtype).max + 1, shape, dtype=input_dtype)
return image, output_dtype
sub_labels_and_input_fns = [
("float to float", functools.partial(make_inputs, input_dtype=torch.float32, output_dtype=torch.float64)),
("float to int", functools.partial(make_inputs, input_dtype=torch.float32, output_dtype=torch.uint8)),
(" int to float", functools.partial(make_inputs, input_dtype=torch.uint8, output_dtype=torch.float32)),
(" int to int (down)", functools.partial(make_inputs, input_dtype=torch.int32, output_dtype=torch.uint8)),
(" int to int (up)", functools.partial(make_inputs, input_dtype=torch.uint8, output_dtype=torch.int32)),
]
timers = [
benchmark.Timer(
stmt="convert_image_dtype(*inputs)",
globals=dict(
convert_image_dtype=F.convert_image_dtype,
inputs=inputs_fn(),
),
label="convert_image_dtype perf improvements",
sub_label=sub_label,
description=description,
num_threads=1,
)
for sub_label, inputs_fn in sub_labels_and_input_fns
]
measurements = [timer.blocked_autorange(min_run_time=15) for timer in timers]
with open(f"{description}.measurements", "wb") as fh:
pickle.dump(measurements, fh)
measurements = []
for file in pathlib.Path(".").glob("*.measurements"):
with open(file, "rb") as fh:
measurements.extend(pickle.load(fh))
comparison = benchmark.Compare(measurements)
comparison.trim_significant_figures()
comparison.print()[- convert_image_dtype perf improvements --]
| main | PR
1 threads: ---------------------------------
float to float | 90 | 83
float to int | 380 | 380
int to float | 138 | 134
int to int (down) | 1100 | 402
int to int (up) | 127 | 92
Times are in microseconds (us).
- "float to float", "int to float", "int to int (up)" did not change from the last benchmarks and are still faster
- "int to int (down)" now uses bit shifts and is 3x faster
| if output_dtype.is_floating_point: | ||
| return value | ||
| else: | ||
| return int(decimal.Decimal(value) * torch.iinfo(output_dtype).max) |
There was a problem hiding this comment.
This gives us arbitrary floating point precision for the intermediate calculations, which is what we want for the reference function. You can see from the xfails I needed to add below, that we need this in some cases.
| condition=lambda args_kwargs: ( | ||
| args_kwargs.args[0].dtype in {torch.float16, torch.bfloat16} | ||
| and not args_kwargs.kwargs["dtype"].is_floating_point | ||
| ) | ||
| or ( | ||
| args_kwargs.args[0].dtype in {torch.float16, torch.bfloat16} | ||
| and args_kwargs.kwargs["dtype"] == torch.int64 | ||
| ) | ||
| or ( | ||
| args_kwargs.args[0].dtype in {torch.int32, torch.int64} | ||
| and args_kwargs.kwargs["dtype"] == torch.float16 | ||
| ), |
There was a problem hiding this comment.
I'm going to open an issue soon detailing what is happening in these cases and how we could mitigate it.
| # The bitshift kernel is not vectorized | ||
| # https://github.com/pytorch/pytorch/blob/703c19008df4700b6a522b0ae5c4b6d5ffc0906f/aten/src/ATen/native/cpu/BinaryOpsKernel.cpp#L315-L322 | ||
| # This results in the multiplication actually being faster. | ||
| # TODO: If the bitshift kernel is optimized in core, replace the computation below with | ||
| # `image.to(dtype).bitwise_left_shift_(num_value_bits_output - num_value_bits_input)` |
There was a problem hiding this comment.
Per comment. The same applies to the bitwise_right_shift kernel in the branch above, but that is still much faster than the division we had before.
Yes, I am. I've added reference tests just to make sure I'm not introducing anything here. If you look at them, in there I'm actually using the old idiom of multiplying or dividing by the factors. |
| if input_max_value > output_max_value: | ||
| factor = (input_max_value + 1) // (output_max_value + 1) | ||
| return value // factor | ||
| else: | ||
| factor = (output_max_value + 1) // (input_max_value + 1) | ||
| return value * factor |
There was a problem hiding this comment.
Pointer for my comment above.
Summary: * improve perf on convert_image_dtype and add tests * add reference tests * use bitshifts for int to int * revert bitshifts for int to int upscale * fix warning ignore Reviewed By: YosuaMichael Differential Revision: D40588162 fbshipit-source-id: 4f1c564f94f75ff37979c123a416b043b4c9ec14
The improvements come from using inplace operations where possible.
benchmark script
The branches that are improved are
Of these float to int is the most interesting for us, since we regularly to
torch.uint8totorch.float32before we normalize. With this patch, we get the following diff when profiling with @vfdev-5's benchmark scriptscc @vfdev-5 @datumbox @bjuncek