Improve interpolate() speed for channels_last CPU videos#90302
Improve interpolate() speed for channels_last CPU videos#90302NicolasHug wants to merge 2 commits intogh/NicolasHug/4/basefrom
Conversation
[ghstack-poisoned]
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/90302
Note: Links to docs will display an error until the docs builds have been completed. ✅ No FailuresAs of commit f5aaf3b: This comment was automatically generated by Dr. CI and updates every 15 minutes. |
vfdev-5
left a comment
There was a problem hiding this comment.
LGTM, thanks @NicolasHug
Just a nit to update the comment.
| // Similar to _use_vectorized_kernel_cond_2d() but for 3d resampling (e.g. videos) | ||
| const Tensor& output, | ||
| const Tensor& input) { | ||
| return ((input.is_contiguous(at::MemoryFormat::ChannelsLast3d)) && (input.size(1) > 3)); |
There was a problem hiding this comment.
If In this case output size does not matter, maybe, we can update above comment saying that. Otherwise, it seems like output size condition is missing...
There was a problem hiding this comment.
Good point - I can confirm that this doesn't seem to suffer from the weight pre-computation overhead for small output, unlike the 2D case. I updated the comment to reflect that.
This is the exact same PR as #86361, but on Videos (3D) instead of images (2D). For torchvision training use-cases (num_threads=1), the speed-ups range in 1X-2X. When num_threads>1 the speed-ups are a lot higher, up to ~30X Benchmarks details: <details > ``` main branch=c6942dbbfbf836450898aa9a0c08aefe437d0765 input shape output size mode dtype num_threads speed-up main PR (1, 3, 8, 256, 256) -> (16, 320, 320) linear float32 num_threads=1 1.0X 54.7ms vs 55.7ms (1, 3, 8, 256, 256) -> (16, 320, 320) nearest float32 num_threads=1 1.7X 40.5ms vs 24.4ms (1, 3, 8, 256, 256) -> (16, 320, 320) nearest uint8 num_threads=1 1.4X 33.1ms vs 23.7ms (1, 3, 8, 256, 256) -> (16, 320, 320) nearest-exact float32 num_threads=1 2.0X 47.5ms vs 24.3ms (1, 3, 8, 256, 256) -> (16, 320, 320) nearest-exact uint8 num_threads=1 1.7X 39.9ms vs 23.7ms (1, 3, 8, 256, 256) -> (16, 320, 320) linear float32 num_threads=2 2.2X 54.6ms vs 25.1ms (1, 3, 8, 256, 256) -> (16, 320, 320) nearest float32 num_threads=2 2.3X 21.2ms vs 9.3ms (1, 3, 8, 256, 256) -> (16, 320, 320) nearest uint8 num_threads=2 1.4X 16.5ms vs 12.0ms (1, 3, 8, 256, 256) -> (16, 320, 320) nearest-exact float32 num_threads=2 2.6X 24.3ms vs 9.3ms (1, 3, 8, 256, 256) -> (16, 320, 320) nearest-exact uint8 num_threads=2 1.7X 19.9ms vs 12.0ms (1, 3, 8, 256, 256) -> (16, 320, 320) linear float32 num_threads=12 10X 54.3ms vs 5.4ms (1, 3, 8, 256, 256) -> (16, 320, 320) nearest float32 num_threads=12 2.5X 4.1ms vs 1.6ms (1, 3, 8, 256, 256) -> (16, 320, 320) nearest uint8 num_threads=12 1.4X 2.9ms vs 2.1ms (1, 3, 8, 256, 256) -> (16, 320, 320) nearest-exact float32 num_threads=12 1.7X 4.8ms vs 2.8ms (1, 3, 8, 256, 256) -> (16, 320, 320) nearest-exact uint8 num_threads=12 1.7X 3.5ms vs 2.1ms (1, 3, 8, 256, 256) -> (16, 320, 320) linear float32 num_threads=32 20X 54.2ms vs 2.7ms (1, 3, 8, 256, 256) -> (16, 320, 320) nearest float32 num_threads=32 1.5X 2.2ms vs 1.5ms (1, 3, 8, 256, 256) -> (16, 320, 320) nearest uint8 num_threads=32 1.6X 1.3ms vs 0.8ms (1, 3, 8, 256, 256) -> (16, 320, 320) nearest-exact float32 num_threads=32 1.3X 1.8ms vs 1.4ms (1, 3, 8, 256, 256) -> (16, 320, 320) nearest-exact uint8 num_threads=32 1.7X 1.3ms vs 0.8ms (1, 3, 16, 320, 320) -> (8, 256, 256) linear float32 num_threads=1 1.0X 15.4ms vs 16.0ms (1, 3, 16, 320, 320) -> (8, 256, 256) nearest float32 num_threads=1 2.0X 12.3ms vs 6.0ms (1, 3, 16, 320, 320) -> (8, 256, 256) nearest uint8 num_threads=1 1.6X 12.0ms vs 7.7ms (1, 3, 16, 320, 320) -> (8, 256, 256) nearest-exact float32 num_threads=1 2.2X 13.1ms vs 6.0ms (1, 3, 16, 320, 320) -> (8, 256, 256) nearest-exact uint8 num_threads=1 1.7X 12.8ms vs 7.6ms (1, 3, 16, 320, 320) -> (8, 256, 256) linear float32 num_threads=2 1.9X 15.5ms vs 8.2ms (1, 3, 16, 320, 320) -> (8, 256, 256) nearest float32 num_threads=2 2.0X 6.1ms vs 3.1ms (1, 3, 16, 320, 320) -> (8, 256, 256) nearest uint8 num_threads=2 1.5X 6.0ms vs 3.9ms (1, 3, 16, 320, 320) -> (8, 256, 256) nearest-exact float32 num_threads=2 2.2X 6.6ms vs 3.0ms (1, 3, 16, 320, 320) -> (8, 256, 256) nearest-exact uint8 num_threads=2 1.7X 6.5ms vs 3.9ms (1, 3, 16, 320, 320) -> (8, 256, 256) linear float32 num_threads=12 11X 15.5ms vs 1.4ms (1, 3, 16, 320, 320) -> (8, 256, 256) nearest float32 num_threads=12 2.0X 1.1ms vs 0.5ms (1, 3, 16, 320, 320) -> (8, 256, 256) nearest uint8 num_threads=12 1.6X 1.1ms vs 0.7ms (1, 3, 16, 320, 320) -> (8, 256, 256) nearest-exact float32 num_threads=12 2.1X 1.2ms vs 0.5ms (1, 3, 16, 320, 320) -> (8, 256, 256) nearest-exact uint8 num_threads=12 1.5X 1.1ms vs 0.8ms (1, 3, 16, 320, 320) -> (8, 256, 256) linear float32 num_threads=32 15X 15.4ms vs 1.0ms (1, 3, 16, 320, 320) -> (8, 256, 256) nearest float32 num_threads=32 1.7X 0.7ms vs 0.4ms (1, 3, 16, 320, 320) -> (8, 256, 256) nearest uint8 num_threads=32 1.3X 0.7ms vs 0.5ms (1, 3, 16, 320, 320) -> (8, 256, 256) nearest-exact float32 num_threads=32 3X 0.7ms vs 0.2ms (1, 3, 16, 320, 320) -> (8, 256, 256) nearest-exact uint8 num_threads=32 2.6X 0.7ms vs 0.3ms (1, 3, 16, 320, 320) -> (32, 512, 512) linear float32 num_threads=1 1.0X 295.6ms vs 304.3ms (1, 3, 16, 320, 320) -> (32, 512, 512) nearest float32 num_threads=1 1.5X 223.2ms vs 144.3ms (1, 3, 16, 320, 320) -> (32, 512, 512) nearest uint8 num_threads=1 1.5X 177.7ms vs 121.0ms (1, 3, 16, 320, 320) -> (32, 512, 512) nearest-exact float32 num_threads=1 1.8X 258.6ms vs 145.3ms (1, 3, 16, 320, 320) -> (32, 512, 512) nearest-exact uint8 num_threads=1 1.6X 203.9ms vs 128.6ms (1, 3, 16, 320, 320) -> (32, 512, 512) linear float32 num_threads=2 1.8X 295.4ms vs 160.4ms (1, 3, 16, 320, 320) -> (32, 512, 512) nearest float32 num_threads=2 1.5X 119.0ms vs 80.2ms (1, 3, 16, 320, 320) -> (32, 512, 512) nearest uint8 num_threads=2 1.4X 84.8ms vs 60.6ms (1, 3, 16, 320, 320) -> (32, 512, 512) nearest-exact float32 num_threads=2 1.7X 136.1ms vs 80.1ms (1, 3, 16, 320, 320) -> (32, 512, 512) nearest-exact uint8 num_threads=2 1.7X 102.2ms vs 60.5ms (1, 3, 16, 320, 320) -> (32, 512, 512) linear float32 num_threads=12 9X 295.3ms vs 32.3ms (1, 3, 16, 320, 320) -> (32, 512, 512) nearest float32 num_threads=12 1.4X 25.2ms vs 18.7ms (1, 3, 16, 320, 320) -> (32, 512, 512) nearest uint8 num_threads=12 1.4X 16.5ms vs 11.9ms (1, 3, 16, 320, 320) -> (32, 512, 512) nearest-exact float32 num_threads=12 1.5X 28.1ms vs 18.8ms (1, 3, 16, 320, 320) -> (32, 512, 512) nearest-exact uint8 num_threads=12 1.7X 19.4ms vs 11.5ms (1, 3, 16, 320, 320) -> (32, 512, 512) linear float32 num_threads=32 18X 294.7ms vs 16.2ms (1, 3, 16, 320, 320) -> (32, 512, 512) nearest float32 num_threads=32 1.2X 14.4ms vs 12.5ms (1, 3, 16, 320, 320) -> (32, 512, 512) nearest uint8 num_threads=32 1.2X 5.9ms vs 4.8ms (1, 3, 16, 320, 320) -> (32, 512, 512) nearest-exact float32 num_threads=32 1.2X 14.5ms vs 12.5ms (1, 3, 16, 320, 320) -> (32, 512, 512) nearest-exact uint8 num_threads=32 1.4X 6.9ms vs 4.8ms (1, 3, 32, 512, 512) -> (16, 320, 320) linear float32 num_threads=1 0.9X 48.6ms vs 55.1ms (1, 3, 32, 512, 512) -> (16, 320, 320) nearest float32 num_threads=1 2.0X 38.8ms vs 19.2ms (1, 3, 32, 512, 512) -> (16, 320, 320) nearest uint8 num_threads=1 1.6X 37.6ms vs 23.8ms (1, 3, 32, 512, 512) -> (16, 320, 320) nearest-exact float32 num_threads=1 2.1X 41.2ms vs 19.2ms (1, 3, 32, 512, 512) -> (16, 320, 320) nearest-exact uint8 num_threads=1 1.7X 39.9ms vs 23.8ms (1, 3, 32, 512, 512) -> (16, 320, 320) linear float32 num_threads=2 1.9X 48.8ms vs 25.3ms (1, 3, 32, 512, 512) -> (16, 320, 320) nearest float32 num_threads=2 2.0X 19.2ms vs 9.5ms (1, 3, 32, 512, 512) -> (16, 320, 320) nearest uint8 num_threads=2 1.6X 18.8ms vs 12.0ms (1, 3, 32, 512, 512) -> (16, 320, 320) nearest-exact float32 num_threads=2 2.2X 20.5ms vs 9.5ms (1, 3, 32, 512, 512) -> (16, 320, 320) nearest-exact uint8 num_threads=2 1.7X 20.0ms vs 12.0ms (1, 3, 32, 512, 512) -> (16, 320, 320) linear float32 num_threads=12 11X 48.6ms vs 4.6ms (1, 3, 32, 512, 512) -> (16, 320, 320) nearest float32 num_threads=12 2.0X 3.4ms vs 1.7ms (1, 3, 32, 512, 512) -> (16, 320, 320) nearest uint8 num_threads=12 1.6X 3.3ms vs 2.1ms (1, 3, 32, 512, 512) -> (16, 320, 320) nearest-exact float32 num_threads=12 2.1X 3.6ms vs 1.7ms (1, 3, 32, 512, 512) -> (16, 320, 320) nearest-exact uint8 num_threads=12 1.7X 3.5ms vs 2.1ms (1, 3, 32, 512, 512) -> (16, 320, 320) linear float32 num_threads=32 27X 48.3ms vs 1.8ms (1, 3, 32, 512, 512) -> (16, 320, 320) nearest float32 num_threads=32 1.1X 2.2ms vs 2.0ms (1, 3, 32, 512, 512) -> (16, 320, 320) nearest uint8 num_threads=32 2.6X 2.1ms vs 0.8ms (1, 3, 32, 512, 512) -> (16, 320, 320) nearest-exact float32 num_threads=32 2.4X 2.3ms vs 0.9ms (1, 3, 32, 512, 512) -> (16, 320, 320) nearest-exact uint8 num_threads=32 2.6X 2.2ms vs 0.8ms ``` </details> Code: <details> ```py import operator_benchmark as op_bench import torch """Microbenchmarks for interpolate operator.""" class InterpolateBenchmark(op_bench.TorchBenchmarkBase): def init(self, input_size, output_size, channels_last=False, mode='linear', dtype=torch.float): input_image = torch.randint(0, 256, size=input_size, dtype=dtype, device='cpu', requires_grad=self.auto_set()) if channels_last: if input_image.ndim == 4: input_image = input_image.contiguous(memory_format=torch.channels_last) elif input_image.ndim == 5: input_image = input_image.contiguous(memory_format=torch.channels_last_3d) else: raise ValueError( f"Can not set channels_last to the input of {input_image.ndim} dims" ) align_corners = None if "nearest" in mode else False if mode == "linear": mode = { 3: 'linear', 4: 'bilinear', 5: 'trilinear', }[input_image.ndim] self.inputs = { "input_image": input_image, "output_size": output_size, "mode": mode, "align_corners": align_corners, } self.set_module_name("interpolate") def forward(self, input_image, output_size, mode, align_corners): return torch.nn.functional.interpolate(input_image, size=output_size, mode=mode, align_corners=align_corners) def make_config(): sizes = ( ((16, 320, 320), (8, 256, 256)), ((16, 320, 320), (32, 512, 512)), ) attrs = [] for (DHW1, DHW2) in sizes: attrs.append([(1, 3, *DHW1), DHW2]) attrs.append([(1, 3, *DHW2), DHW1]) config = op_bench.config_list( attr_names=["input_size", "output_size"], attrs=attrs, cross_product_configs={ 'channels_last': [True], 'mode': ["linear", "nearest", "nearest-exact"], 'dtype': [torch.float, torch.uint8] }, tags=["short"], ) # Need to remove instances with both torch.int and linear # Note: this is naaaasty def get_mode(l): for d in l: if "mode" in d: return d["mode"] def get_dtype(l): for d in l: if "dtype" in d: return d["dtype"] config = [l for l in config if not(get_mode(l) == "linear" and get_dtype(l) == torch.uint8)] return config config = make_config() op_bench.generate_pt_test(config, InterpolateBenchmark) if __name__ == "__main__": op_bench.benchmark_runner.main() ``` ```py import re import argparse parser = argparse.ArgumentParser() parser.add_argument("f3", nargs="?", default="main") parser.add_argument("f2", nargs="?", default="new") args = parser.parse_args() with open(args.f1) as f: main = f.readlines() with open(args.f2) as f: new = f.readlines() out = [] for main_line, new_line in zip(main, new): # num_threads=1 # TODO: remove if main_line.startswith("num_threads="): num_threads = int(main_line.split("=")[-1]) if main_line.startswith("# Input"): deets = f"{main_line.strip()}, {num_threads=}" if main_line.startswith("Forward"): main_time = float(main_line.split()[-1]) new_time = float(new_line.split()[-1]) ratio = main_time / new_time fmt = ".1f" if ratio < 3 else ".0f" improv = f"{ratio:{fmt}}X" time_fmt = ",.3f" if new_time < 100 else ",.1f" deets = deets.strip().replace("# Input: ", "") deets = deets.replace(": ", "=") deets = deets.replace("input_size=", "") deets = deets.replace(", output_size=", " -> ") deets = deets.replace("dtype=torch.", "") deets = deets.replace("mode=", "") deets = deets.replace("channels_last=True, ", "") split = deets.split(",") size = ','.join(split[:-3]) mode, dtype, threads = split[-3:] deets = f"{size:<30} {mode:<15} {dtype:<10} {threads:<15}" l = f"{deets} {improv:<5} {main_time / 1000:{time_fmt}}ms vs {new_time / 1000:{time_fmt}}ms" out.append(l) def key(s): # s = ''.join(s.split()[1:]) # remove "N.nX" part num_threads = (int(re.findall(r"num_threads=(\d+)", s)[0]),) input_shape, output_shape = re.findall("\(.*?\)", s) input_shape = input_shape[1:-1] # remove parenthesis input_HW = tuple(int(x) for x in input_shape.split(",")[-2:]) input_C = (-int(input_shape.split(",")[1]),) output_HW = tuple(int(x) for x in output_shape[1:-1].split(",")) is_downsample = (output_HW[0] < input_HW[0],) if "linear" in s: mode = "linear" elif "nearest-exact" in s: mode = "nearest-exact" else: assert "nearest" in s mode = "nearest" mode = (mode,) return is_downsample + input_HW + output_HW + num_threads + input_C + mode for i, l in enumerate(sorted(out, key=key)): if i % 5 == 0: print() # if i % 10 == 0 and i % 40 != 0: # print() # if i % 40 == 0: # print("-" * 100) print(l) ``` </details > cc VitalyFedyunin jgong5 mingfeima XiaobingSuper sanchitintel ashokei jingxu10 [ghstack-poisoned]
|
@pytorchmergebot merge -g |
Merge startedYour change will be merged once all checks on your PR pass since you used the green (-g) flag (ETA: 0-4 Hours). Learn more about merging in the wiki. Questions? Feedback? Please reach out to the PyTorch DevX Team |
Stack from ghstack (oldest at bottom):
This is the exact same PR as #86361, but on Videos (3D) instead of images (2D).
For torchvision training use-cases (num_threads=1), the speed-ups range in 1X-2X. When num_threads>1 the speed-ups are a lot higher, up to ~30X
Benchmarks details:
Details
Code:
Details
cc @VitalyFedyunin @jgong5 @mingfeima @XiaobingSuper @sanchitintel @ashokei @jingxu10