Skip to content

Improve interpolate() speed for channels_last CPU videos#90302

Closed
NicolasHug wants to merge 2 commits intogh/NicolasHug/4/basefrom
gh/NicolasHug/4/head
Closed

Improve interpolate() speed for channels_last CPU videos#90302
NicolasHug wants to merge 2 commits intogh/NicolasHug/4/basefrom
gh/NicolasHug/4/head

Conversation

@NicolasHug
Copy link
Copy Markdown
Member

@NicolasHug NicolasHug commented Dec 6, 2022

Stack from ghstack (oldest at bottom):

This is the exact same PR as #86361, but on Videos (3D) instead of images (2D).

For torchvision training use-cases (num_threads=1), the speed-ups range in 1X-2X. When num_threads>1 the speed-ups are a lot higher, up to ~30X

Benchmarks details:

Details
main branch=c6942dbbfbf836450898aa9a0c08aefe437d0765
input shape            output size      mode            dtype     num_threads  speed-up  main   PR
(1, 3, 8, 256, 256) -> (16, 320, 320)  linear          float32    num_threads=1   1.0X  54.7ms vs 55.7ms
(1, 3, 8, 256, 256) -> (16, 320, 320)  nearest         float32    num_threads=1   1.7X  40.5ms vs 24.4ms
(1, 3, 8, 256, 256) -> (16, 320, 320)  nearest         uint8      num_threads=1   1.4X  33.1ms vs 23.7ms
(1, 3, 8, 256, 256) -> (16, 320, 320)  nearest-exact   float32    num_threads=1   2.0X  47.5ms vs 24.3ms
(1, 3, 8, 256, 256) -> (16, 320, 320)  nearest-exact   uint8      num_threads=1   1.7X  39.9ms vs 23.7ms

(1, 3, 8, 256, 256) -> (16, 320, 320)  linear          float32    num_threads=2   2.2X  54.6ms vs 25.1ms
(1, 3, 8, 256, 256) -> (16, 320, 320)  nearest         float32    num_threads=2   2.3X  21.2ms vs 9.3ms
(1, 3, 8, 256, 256) -> (16, 320, 320)  nearest         uint8      num_threads=2   1.4X  16.5ms vs 12.0ms
(1, 3, 8, 256, 256) -> (16, 320, 320)  nearest-exact   float32    num_threads=2   2.6X  24.3ms vs 9.3ms
(1, 3, 8, 256, 256) -> (16, 320, 320)  nearest-exact   uint8      num_threads=2   1.7X  19.9ms vs 12.0ms

(1, 3, 8, 256, 256) -> (16, 320, 320)  linear          float32    num_threads=12  10X   54.3ms vs 5.4ms
(1, 3, 8, 256, 256) -> (16, 320, 320)  nearest         float32    num_threads=12  2.5X  4.1ms vs 1.6ms
(1, 3, 8, 256, 256) -> (16, 320, 320)  nearest         uint8      num_threads=12  1.4X  2.9ms vs 2.1ms
(1, 3, 8, 256, 256) -> (16, 320, 320)  nearest-exact   float32    num_threads=12  1.7X  4.8ms vs 2.8ms
(1, 3, 8, 256, 256) -> (16, 320, 320)  nearest-exact   uint8      num_threads=12  1.7X  3.5ms vs 2.1ms

(1, 3, 8, 256, 256) -> (16, 320, 320)  linear          float32    num_threads=32  20X   54.2ms vs 2.7ms
(1, 3, 8, 256, 256) -> (16, 320, 320)  nearest         float32    num_threads=32  1.5X  2.2ms vs 1.5ms
(1, 3, 8, 256, 256) -> (16, 320, 320)  nearest         uint8      num_threads=32  1.6X  1.3ms vs 0.8ms
(1, 3, 8, 256, 256) -> (16, 320, 320)  nearest-exact   float32    num_threads=32  1.3X  1.8ms vs 1.4ms
(1, 3, 8, 256, 256) -> (16, 320, 320)  nearest-exact   uint8      num_threads=32  1.7X  1.3ms vs 0.8ms

(1, 3, 16, 320, 320) -> (8, 256, 256)  linear          float32    num_threads=1   1.0X  15.4ms vs 16.0ms
(1, 3, 16, 320, 320) -> (8, 256, 256)  nearest         float32    num_threads=1   2.0X  12.3ms vs 6.0ms
(1, 3, 16, 320, 320) -> (8, 256, 256)  nearest         uint8      num_threads=1   1.6X  12.0ms vs 7.7ms
(1, 3, 16, 320, 320) -> (8, 256, 256)  nearest-exact   float32    num_threads=1   2.2X  13.1ms vs 6.0ms
(1, 3, 16, 320, 320) -> (8, 256, 256)  nearest-exact   uint8      num_threads=1   1.7X  12.8ms vs 7.6ms

(1, 3, 16, 320, 320) -> (8, 256, 256)  linear          float32    num_threads=2   1.9X  15.5ms vs 8.2ms
(1, 3, 16, 320, 320) -> (8, 256, 256)  nearest         float32    num_threads=2   2.0X  6.1ms vs 3.1ms
(1, 3, 16, 320, 320) -> (8, 256, 256)  nearest         uint8      num_threads=2   1.5X  6.0ms vs 3.9ms
(1, 3, 16, 320, 320) -> (8, 256, 256)  nearest-exact   float32    num_threads=2   2.2X  6.6ms vs 3.0ms
(1, 3, 16, 320, 320) -> (8, 256, 256)  nearest-exact   uint8      num_threads=2   1.7X  6.5ms vs 3.9ms

(1, 3, 16, 320, 320) -> (8, 256, 256)  linear          float32    num_threads=12  11X   15.5ms vs 1.4ms
(1, 3, 16, 320, 320) -> (8, 256, 256)  nearest         float32    num_threads=12  2.0X  1.1ms vs 0.5ms
(1, 3, 16, 320, 320) -> (8, 256, 256)  nearest         uint8      num_threads=12  1.6X  1.1ms vs 0.7ms
(1, 3, 16, 320, 320) -> (8, 256, 256)  nearest-exact   float32    num_threads=12  2.1X  1.2ms vs 0.5ms
(1, 3, 16, 320, 320) -> (8, 256, 256)  nearest-exact   uint8      num_threads=12  1.5X  1.1ms vs 0.8ms

(1, 3, 16, 320, 320) -> (8, 256, 256)  linear          float32    num_threads=32  15X   15.4ms vs 1.0ms
(1, 3, 16, 320, 320) -> (8, 256, 256)  nearest         float32    num_threads=32  1.7X  0.7ms vs 0.4ms
(1, 3, 16, 320, 320) -> (8, 256, 256)  nearest         uint8      num_threads=32  1.3X  0.7ms vs 0.5ms
(1, 3, 16, 320, 320) -> (8, 256, 256)  nearest-exact   float32    num_threads=32  3X    0.7ms vs 0.2ms
(1, 3, 16, 320, 320) -> (8, 256, 256)  nearest-exact   uint8      num_threads=32  2.6X  0.7ms vs 0.3ms

(1, 3, 16, 320, 320) -> (32, 512, 512)  linear          float32    num_threads=1   1.0X  295.6ms vs 304.3ms
(1, 3, 16, 320, 320) -> (32, 512, 512)  nearest         float32    num_threads=1   1.5X  223.2ms vs 144.3ms
(1, 3, 16, 320, 320) -> (32, 512, 512)  nearest         uint8      num_threads=1   1.5X  177.7ms vs 121.0ms
(1, 3, 16, 320, 320) -> (32, 512, 512)  nearest-exact   float32    num_threads=1   1.8X  258.6ms vs 145.3ms
(1, 3, 16, 320, 320) -> (32, 512, 512)  nearest-exact   uint8      num_threads=1   1.6X  203.9ms vs 128.6ms

(1, 3, 16, 320, 320) -> (32, 512, 512)  linear          float32    num_threads=2   1.8X  295.4ms vs 160.4ms
(1, 3, 16, 320, 320) -> (32, 512, 512)  nearest         float32    num_threads=2   1.5X  119.0ms vs 80.2ms
(1, 3, 16, 320, 320) -> (32, 512, 512)  nearest         uint8      num_threads=2   1.4X  84.8ms vs 60.6ms
(1, 3, 16, 320, 320) -> (32, 512, 512)  nearest-exact   float32    num_threads=2   1.7X  136.1ms vs 80.1ms
(1, 3, 16, 320, 320) -> (32, 512, 512)  nearest-exact   uint8      num_threads=2   1.7X  102.2ms vs 60.5ms

(1, 3, 16, 320, 320) -> (32, 512, 512)  linear          float32    num_threads=12  9X    295.3ms vs 32.3ms
(1, 3, 16, 320, 320) -> (32, 512, 512)  nearest         float32    num_threads=12  1.4X  25.2ms vs 18.7ms
(1, 3, 16, 320, 320) -> (32, 512, 512)  nearest         uint8      num_threads=12  1.4X  16.5ms vs 11.9ms
(1, 3, 16, 320, 320) -> (32, 512, 512)  nearest-exact   float32    num_threads=12  1.5X  28.1ms vs 18.8ms
(1, 3, 16, 320, 320) -> (32, 512, 512)  nearest-exact   uint8      num_threads=12  1.7X  19.4ms vs 11.5ms

(1, 3, 16, 320, 320) -> (32, 512, 512)  linear          float32    num_threads=32  18X   294.7ms vs 16.2ms
(1, 3, 16, 320, 320) -> (32, 512, 512)  nearest         float32    num_threads=32  1.2X  14.4ms vs 12.5ms
(1, 3, 16, 320, 320) -> (32, 512, 512)  nearest         uint8      num_threads=32  1.2X  5.9ms vs 4.8ms
(1, 3, 16, 320, 320) -> (32, 512, 512)  nearest-exact   float32    num_threads=32  1.2X  14.5ms vs 12.5ms
(1, 3, 16, 320, 320) -> (32, 512, 512)  nearest-exact   uint8      num_threads=32  1.4X  6.9ms vs 4.8ms

(1, 3, 32, 512, 512) -> (16, 320, 320)  linear          float32    num_threads=1   0.9X  48.6ms vs 55.1ms
(1, 3, 32, 512, 512) -> (16, 320, 320)  nearest         float32    num_threads=1   2.0X  38.8ms vs 19.2ms
(1, 3, 32, 512, 512) -> (16, 320, 320)  nearest         uint8      num_threads=1   1.6X  37.6ms vs 23.8ms
(1, 3, 32, 512, 512) -> (16, 320, 320)  nearest-exact   float32    num_threads=1   2.1X  41.2ms vs 19.2ms
(1, 3, 32, 512, 512) -> (16, 320, 320)  nearest-exact   uint8      num_threads=1   1.7X  39.9ms vs 23.8ms

(1, 3, 32, 512, 512) -> (16, 320, 320)  linear          float32    num_threads=2   1.9X  48.8ms vs 25.3ms
(1, 3, 32, 512, 512) -> (16, 320, 320)  nearest         float32    num_threads=2   2.0X  19.2ms vs 9.5ms
(1, 3, 32, 512, 512) -> (16, 320, 320)  nearest         uint8      num_threads=2   1.6X  18.8ms vs 12.0ms
(1, 3, 32, 512, 512) -> (16, 320, 320)  nearest-exact   float32    num_threads=2   2.2X  20.5ms vs 9.5ms
(1, 3, 32, 512, 512) -> (16, 320, 320)  nearest-exact   uint8      num_threads=2   1.7X  20.0ms vs 12.0ms

(1, 3, 32, 512, 512) -> (16, 320, 320)  linear          float32    num_threads=12  11X   48.6ms vs 4.6ms
(1, 3, 32, 512, 512) -> (16, 320, 320)  nearest         float32    num_threads=12  2.0X  3.4ms vs 1.7ms
(1, 3, 32, 512, 512) -> (16, 320, 320)  nearest         uint8      num_threads=12  1.6X  3.3ms vs 2.1ms
(1, 3, 32, 512, 512) -> (16, 320, 320)  nearest-exact   float32    num_threads=12  2.1X  3.6ms vs 1.7ms
(1, 3, 32, 512, 512) -> (16, 320, 320)  nearest-exact   uint8      num_threads=12  1.7X  3.5ms vs 2.1ms

(1, 3, 32, 512, 512) -> (16, 320, 320)  linear          float32    num_threads=32  27X   48.3ms vs 1.8ms
(1, 3, 32, 512, 512) -> (16, 320, 320)  nearest         float32    num_threads=32  1.1X  2.2ms vs 2.0ms
(1, 3, 32, 512, 512) -> (16, 320, 320)  nearest         uint8      num_threads=32  2.6X  2.1ms vs 0.8ms
(1, 3, 32, 512, 512) -> (16, 320, 320)  nearest-exact   float32    num_threads=32  2.4X  2.3ms vs 0.9ms
(1, 3, 32, 512, 512) -> (16, 320, 320)  nearest-exact   uint8      num_threads=32  2.6X  2.2ms vs 0.8ms

Code:

Details
import operator_benchmark as op_bench
import torch

"""Microbenchmarks for interpolate operator."""


class InterpolateBenchmark(op_bench.TorchBenchmarkBase):
    def init(self, input_size, output_size, channels_last=False, mode='linear', dtype=torch.float):

        input_image = torch.randint(0, 256, size=input_size, dtype=dtype, device='cpu',
                                    requires_grad=self.auto_set())
        if channels_last:
            if input_image.ndim == 4:
                input_image = input_image.contiguous(memory_format=torch.channels_last)
            elif input_image.ndim == 5:
                input_image = input_image.contiguous(memory_format=torch.channels_last_3d)
            else:
                raise ValueError(
                    f"Can not set channels_last to the input of {input_image.ndim} dims"
                )


        align_corners = None if "nearest" in mode else False

        if mode == "linear":
            mode = {
                3: 'linear',
                4: 'bilinear',
                5: 'trilinear',
            }[input_image.ndim]

        self.inputs = {
            "input_image": input_image,
            "output_size": output_size,
            "mode": mode,
            "align_corners": align_corners,
        }

        self.set_module_name("interpolate")

    def forward(self, input_image, output_size, mode, align_corners):
        return torch.nn.functional.interpolate(input_image, size=output_size, mode=mode,
                                               align_corners=align_corners)


def make_config():
    sizes = (
        ((16, 320, 320), (8, 256, 256)),
        ((16, 320, 320), (32, 512, 512)),
    )

    attrs = []
    for (DHW1, DHW2) in sizes:
        attrs.append([(1, 3, *DHW1), DHW2])
        attrs.append([(1, 3, *DHW2), DHW1])


    config = op_bench.config_list(
        attr_names=["input_size", "output_size"],
        attrs=attrs,
        cross_product_configs={
            'channels_last': [True],
            'mode': ["linear", "nearest", "nearest-exact"],
            'dtype': [torch.float, torch.uint8]
        },
        tags=["short"],
    )

    # Need to remove instances with both torch.int and linear
    # Note: this is naaaasty
    def get_mode(l):
        for d in l:
            if "mode" in d:
                return d["mode"]
    def get_dtype(l):
        for d in l:
            if "dtype" in d:
                return d["dtype"]
    config = [l for l in config if not(get_mode(l) == "linear" and get_dtype(l) == torch.uint8)]
    return config

config = make_config()
op_bench.generate_pt_test(config, InterpolateBenchmark)


if __name__ == "__main__":
    op_bench.benchmark_runner.main()
import re
import argparse

parser = argparse.ArgumentParser()
parser.add_argument("f3", nargs="?", default="main")
parser.add_argument("f2", nargs="?", default="new")
args = parser.parse_args()

with open(args.f1) as f:
    main = f.readlines()
with open(args.f2) as f:
    new = f.readlines()

out = []

for main_line, new_line in zip(main, new):
    # num_threads=1  # TODO: remove
    if main_line.startswith("num_threads="):
        num_threads = int(main_line.split("=")[-1])
    if main_line.startswith("# Input"):
        deets = f"{main_line.strip()}, {num_threads=}"
    if main_line.startswith("Forward"):
        main_time = float(main_line.split()[-1])
        new_time = float(new_line.split()[-1])
        ratio = main_time / new_time
        fmt = ".1f" if ratio < 3 else ".0f"
        improv = f"{ratio:{fmt}}X"
        time_fmt = ",.3f" if new_time < 100 else ",.1f"
        deets = deets.strip().replace("# Input: ", "")
        deets = deets.replace(": ", "=")
        deets = deets.replace("input_size=", "")
        deets = deets.replace(", output_size=", " -> ")
        deets = deets.replace("dtype=torch.", "")
        deets = deets.replace("mode=", "")
        deets = deets.replace("channels_last=True, ", "")
        split = deets.split(",")
        size = ','.join(split[:-3])
        mode, dtype, threads = split[-3:]
        deets = f"{size:<30} {mode:<15} {dtype:<10} {threads:<15}"

        l = f"{deets}  {improv:<5} {main_time / 1000:{time_fmt}}ms vs {new_time / 1000:{time_fmt}}ms"
        out.append(l)


def key(s):
    # s = ''.join(s.split()[1:]) # remove "N.nX" part
    num_threads = (int(re.findall(r"num_threads=(\d+)", s)[0]),)

    input_shape, output_shape = re.findall("\(.*?\)", s)
    input_shape = input_shape[1:-1]  # remove parenthesis
    input_HW = tuple(int(x) for x in input_shape.split(",")[-2:])
    input_C = (-int(input_shape.split(",")[1]),)

    output_HW = tuple(int(x) for x in output_shape[1:-1].split(","))
    is_downsample = (output_HW[0] < input_HW[0],)
    if "linear" in s:
        mode = "linear"
    elif "nearest-exact" in s:
        mode = "nearest-exact"
    else:
        assert "nearest" in s
        mode = "nearest"
    mode = (mode,)
    return is_downsample + input_HW + output_HW + num_threads + input_C + mode

for i, l in enumerate(sorted(out, key=key)):
    if i % 5 == 0:
        print()
    # if i % 10 == 0 and i % 40 != 0:
    #     print()
    # if i % 40 == 0:
    #     print("-" * 100)
    print(l)

cc @VitalyFedyunin @jgong5 @mingfeima @XiaobingSuper @sanchitintel @ashokei @jingxu10

@pytorch-bot
Copy link
Copy Markdown

pytorch-bot bot commented Dec 6, 2022

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/90302

Note: Links to docs will display an error until the docs builds have been completed.

✅ No Failures

As of commit f5aaf3b:
💚 Looks good so far! There are no failures yet. 💚

This comment was automatically generated by Dr. CI and updates every 15 minutes.

NicolasHug added a commit that referenced this pull request Dec 6, 2022
@github-actions github-actions bot added the module: cpu CPU specific problem (e.g., perf, algorithm) label Dec 6, 2022
@NicolasHug NicolasHug added triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module release notes: nn release notes category topic: performance topic category topic: not user facing topic category and removed module: cpu CPU specific problem (e.g., perf, algorithm) labels Dec 6, 2022
@NicolasHug NicolasHug requested review from fmassa and vfdev-5 and removed request for fmassa December 6, 2022 18:34
Copy link
Copy Markdown
Contributor

@vfdev-5 vfdev-5 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, thanks @NicolasHug
Just a nit to update the comment.

// Similar to _use_vectorized_kernel_cond_2d() but for 3d resampling (e.g. videos)
const Tensor& output,
const Tensor& input) {
return ((input.is_contiguous(at::MemoryFormat::ChannelsLast3d)) && (input.size(1) > 3));
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If In this case output size does not matter, maybe, we can update above comment saying that. Otherwise, it seems like output size condition is missing...

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good point - I can confirm that this doesn't seem to suffer from the weight pre-computation overhead for small output, unlike the 2D case. I updated the comment to reflect that.

Copy link
Copy Markdown
Member

@fmassa fmassa left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, thanks!

This is the exact same PR as #86361, but on Videos (3D) instead of images (2D).

For torchvision training use-cases (num_threads=1), the speed-ups range in 1X-2X.  When num_threads>1 the speed-ups are a lot higher, up to ~30X

Benchmarks details:
<details >

```
main branch=c6942dbbfbf836450898aa9a0c08aefe437d0765
input shape            output size      mode            dtype     num_threads  speed-up  main   PR
(1, 3, 8, 256, 256) -> (16, 320, 320)  linear          float32    num_threads=1   1.0X  54.7ms vs 55.7ms
(1, 3, 8, 256, 256) -> (16, 320, 320)  nearest         float32    num_threads=1   1.7X  40.5ms vs 24.4ms
(1, 3, 8, 256, 256) -> (16, 320, 320)  nearest         uint8      num_threads=1   1.4X  33.1ms vs 23.7ms
(1, 3, 8, 256, 256) -> (16, 320, 320)  nearest-exact   float32    num_threads=1   2.0X  47.5ms vs 24.3ms
(1, 3, 8, 256, 256) -> (16, 320, 320)  nearest-exact   uint8      num_threads=1   1.7X  39.9ms vs 23.7ms

(1, 3, 8, 256, 256) -> (16, 320, 320)  linear          float32    num_threads=2   2.2X  54.6ms vs 25.1ms
(1, 3, 8, 256, 256) -> (16, 320, 320)  nearest         float32    num_threads=2   2.3X  21.2ms vs 9.3ms
(1, 3, 8, 256, 256) -> (16, 320, 320)  nearest         uint8      num_threads=2   1.4X  16.5ms vs 12.0ms
(1, 3, 8, 256, 256) -> (16, 320, 320)  nearest-exact   float32    num_threads=2   2.6X  24.3ms vs 9.3ms
(1, 3, 8, 256, 256) -> (16, 320, 320)  nearest-exact   uint8      num_threads=2   1.7X  19.9ms vs 12.0ms

(1, 3, 8, 256, 256) -> (16, 320, 320)  linear          float32    num_threads=12  10X   54.3ms vs 5.4ms
(1, 3, 8, 256, 256) -> (16, 320, 320)  nearest         float32    num_threads=12  2.5X  4.1ms vs 1.6ms
(1, 3, 8, 256, 256) -> (16, 320, 320)  nearest         uint8      num_threads=12  1.4X  2.9ms vs 2.1ms
(1, 3, 8, 256, 256) -> (16, 320, 320)  nearest-exact   float32    num_threads=12  1.7X  4.8ms vs 2.8ms
(1, 3, 8, 256, 256) -> (16, 320, 320)  nearest-exact   uint8      num_threads=12  1.7X  3.5ms vs 2.1ms

(1, 3, 8, 256, 256) -> (16, 320, 320)  linear          float32    num_threads=32  20X   54.2ms vs 2.7ms
(1, 3, 8, 256, 256) -> (16, 320, 320)  nearest         float32    num_threads=32  1.5X  2.2ms vs 1.5ms
(1, 3, 8, 256, 256) -> (16, 320, 320)  nearest         uint8      num_threads=32  1.6X  1.3ms vs 0.8ms
(1, 3, 8, 256, 256) -> (16, 320, 320)  nearest-exact   float32    num_threads=32  1.3X  1.8ms vs 1.4ms
(1, 3, 8, 256, 256) -> (16, 320, 320)  nearest-exact   uint8      num_threads=32  1.7X  1.3ms vs 0.8ms

(1, 3, 16, 320, 320) -> (8, 256, 256)  linear          float32    num_threads=1   1.0X  15.4ms vs 16.0ms
(1, 3, 16, 320, 320) -> (8, 256, 256)  nearest         float32    num_threads=1   2.0X  12.3ms vs 6.0ms
(1, 3, 16, 320, 320) -> (8, 256, 256)  nearest         uint8      num_threads=1   1.6X  12.0ms vs 7.7ms
(1, 3, 16, 320, 320) -> (8, 256, 256)  nearest-exact   float32    num_threads=1   2.2X  13.1ms vs 6.0ms
(1, 3, 16, 320, 320) -> (8, 256, 256)  nearest-exact   uint8      num_threads=1   1.7X  12.8ms vs 7.6ms

(1, 3, 16, 320, 320) -> (8, 256, 256)  linear          float32    num_threads=2   1.9X  15.5ms vs 8.2ms
(1, 3, 16, 320, 320) -> (8, 256, 256)  nearest         float32    num_threads=2   2.0X  6.1ms vs 3.1ms
(1, 3, 16, 320, 320) -> (8, 256, 256)  nearest         uint8      num_threads=2   1.5X  6.0ms vs 3.9ms
(1, 3, 16, 320, 320) -> (8, 256, 256)  nearest-exact   float32    num_threads=2   2.2X  6.6ms vs 3.0ms
(1, 3, 16, 320, 320) -> (8, 256, 256)  nearest-exact   uint8      num_threads=2   1.7X  6.5ms vs 3.9ms

(1, 3, 16, 320, 320) -> (8, 256, 256)  linear          float32    num_threads=12  11X   15.5ms vs 1.4ms
(1, 3, 16, 320, 320) -> (8, 256, 256)  nearest         float32    num_threads=12  2.0X  1.1ms vs 0.5ms
(1, 3, 16, 320, 320) -> (8, 256, 256)  nearest         uint8      num_threads=12  1.6X  1.1ms vs 0.7ms
(1, 3, 16, 320, 320) -> (8, 256, 256)  nearest-exact   float32    num_threads=12  2.1X  1.2ms vs 0.5ms
(1, 3, 16, 320, 320) -> (8, 256, 256)  nearest-exact   uint8      num_threads=12  1.5X  1.1ms vs 0.8ms

(1, 3, 16, 320, 320) -> (8, 256, 256)  linear          float32    num_threads=32  15X   15.4ms vs 1.0ms
(1, 3, 16, 320, 320) -> (8, 256, 256)  nearest         float32    num_threads=32  1.7X  0.7ms vs 0.4ms
(1, 3, 16, 320, 320) -> (8, 256, 256)  nearest         uint8      num_threads=32  1.3X  0.7ms vs 0.5ms
(1, 3, 16, 320, 320) -> (8, 256, 256)  nearest-exact   float32    num_threads=32  3X    0.7ms vs 0.2ms
(1, 3, 16, 320, 320) -> (8, 256, 256)  nearest-exact   uint8      num_threads=32  2.6X  0.7ms vs 0.3ms

(1, 3, 16, 320, 320) -> (32, 512, 512)  linear          float32    num_threads=1   1.0X  295.6ms vs 304.3ms
(1, 3, 16, 320, 320) -> (32, 512, 512)  nearest         float32    num_threads=1   1.5X  223.2ms vs 144.3ms
(1, 3, 16, 320, 320) -> (32, 512, 512)  nearest         uint8      num_threads=1   1.5X  177.7ms vs 121.0ms
(1, 3, 16, 320, 320) -> (32, 512, 512)  nearest-exact   float32    num_threads=1   1.8X  258.6ms vs 145.3ms
(1, 3, 16, 320, 320) -> (32, 512, 512)  nearest-exact   uint8      num_threads=1   1.6X  203.9ms vs 128.6ms

(1, 3, 16, 320, 320) -> (32, 512, 512)  linear          float32    num_threads=2   1.8X  295.4ms vs 160.4ms
(1, 3, 16, 320, 320) -> (32, 512, 512)  nearest         float32    num_threads=2   1.5X  119.0ms vs 80.2ms
(1, 3, 16, 320, 320) -> (32, 512, 512)  nearest         uint8      num_threads=2   1.4X  84.8ms vs 60.6ms
(1, 3, 16, 320, 320) -> (32, 512, 512)  nearest-exact   float32    num_threads=2   1.7X  136.1ms vs 80.1ms
(1, 3, 16, 320, 320) -> (32, 512, 512)  nearest-exact   uint8      num_threads=2   1.7X  102.2ms vs 60.5ms

(1, 3, 16, 320, 320) -> (32, 512, 512)  linear          float32    num_threads=12  9X    295.3ms vs 32.3ms
(1, 3, 16, 320, 320) -> (32, 512, 512)  nearest         float32    num_threads=12  1.4X  25.2ms vs 18.7ms
(1, 3, 16, 320, 320) -> (32, 512, 512)  nearest         uint8      num_threads=12  1.4X  16.5ms vs 11.9ms
(1, 3, 16, 320, 320) -> (32, 512, 512)  nearest-exact   float32    num_threads=12  1.5X  28.1ms vs 18.8ms
(1, 3, 16, 320, 320) -> (32, 512, 512)  nearest-exact   uint8      num_threads=12  1.7X  19.4ms vs 11.5ms

(1, 3, 16, 320, 320) -> (32, 512, 512)  linear          float32    num_threads=32  18X   294.7ms vs 16.2ms
(1, 3, 16, 320, 320) -> (32, 512, 512)  nearest         float32    num_threads=32  1.2X  14.4ms vs 12.5ms
(1, 3, 16, 320, 320) -> (32, 512, 512)  nearest         uint8      num_threads=32  1.2X  5.9ms vs 4.8ms
(1, 3, 16, 320, 320) -> (32, 512, 512)  nearest-exact   float32    num_threads=32  1.2X  14.5ms vs 12.5ms
(1, 3, 16, 320, 320) -> (32, 512, 512)  nearest-exact   uint8      num_threads=32  1.4X  6.9ms vs 4.8ms

(1, 3, 32, 512, 512) -> (16, 320, 320)  linear          float32    num_threads=1   0.9X  48.6ms vs 55.1ms
(1, 3, 32, 512, 512) -> (16, 320, 320)  nearest         float32    num_threads=1   2.0X  38.8ms vs 19.2ms
(1, 3, 32, 512, 512) -> (16, 320, 320)  nearest         uint8      num_threads=1   1.6X  37.6ms vs 23.8ms
(1, 3, 32, 512, 512) -> (16, 320, 320)  nearest-exact   float32    num_threads=1   2.1X  41.2ms vs 19.2ms
(1, 3, 32, 512, 512) -> (16, 320, 320)  nearest-exact   uint8      num_threads=1   1.7X  39.9ms vs 23.8ms

(1, 3, 32, 512, 512) -> (16, 320, 320)  linear          float32    num_threads=2   1.9X  48.8ms vs 25.3ms
(1, 3, 32, 512, 512) -> (16, 320, 320)  nearest         float32    num_threads=2   2.0X  19.2ms vs 9.5ms
(1, 3, 32, 512, 512) -> (16, 320, 320)  nearest         uint8      num_threads=2   1.6X  18.8ms vs 12.0ms
(1, 3, 32, 512, 512) -> (16, 320, 320)  nearest-exact   float32    num_threads=2   2.2X  20.5ms vs 9.5ms
(1, 3, 32, 512, 512) -> (16, 320, 320)  nearest-exact   uint8      num_threads=2   1.7X  20.0ms vs 12.0ms

(1, 3, 32, 512, 512) -> (16, 320, 320)  linear          float32    num_threads=12  11X   48.6ms vs 4.6ms
(1, 3, 32, 512, 512) -> (16, 320, 320)  nearest         float32    num_threads=12  2.0X  3.4ms vs 1.7ms
(1, 3, 32, 512, 512) -> (16, 320, 320)  nearest         uint8      num_threads=12  1.6X  3.3ms vs 2.1ms
(1, 3, 32, 512, 512) -> (16, 320, 320)  nearest-exact   float32    num_threads=12  2.1X  3.6ms vs 1.7ms
(1, 3, 32, 512, 512) -> (16, 320, 320)  nearest-exact   uint8      num_threads=12  1.7X  3.5ms vs 2.1ms

(1, 3, 32, 512, 512) -> (16, 320, 320)  linear          float32    num_threads=32  27X   48.3ms vs 1.8ms
(1, 3, 32, 512, 512) -> (16, 320, 320)  nearest         float32    num_threads=32  1.1X  2.2ms vs 2.0ms
(1, 3, 32, 512, 512) -> (16, 320, 320)  nearest         uint8      num_threads=32  2.6X  2.1ms vs 0.8ms
(1, 3, 32, 512, 512) -> (16, 320, 320)  nearest-exact   float32    num_threads=32  2.4X  2.3ms vs 0.9ms
(1, 3, 32, 512, 512) -> (16, 320, 320)  nearest-exact   uint8      num_threads=32  2.6X  2.2ms vs 0.8ms

```


</details>

Code:

<details>


```py
import operator_benchmark as op_bench
import torch

"""Microbenchmarks for interpolate operator."""


class InterpolateBenchmark(op_bench.TorchBenchmarkBase):
    def init(self, input_size, output_size, channels_last=False, mode='linear', dtype=torch.float):

        input_image = torch.randint(0, 256, size=input_size, dtype=dtype, device='cpu',
                                    requires_grad=self.auto_set())
        if channels_last:
            if input_image.ndim == 4:
                input_image = input_image.contiguous(memory_format=torch.channels_last)
            elif input_image.ndim == 5:
                input_image = input_image.contiguous(memory_format=torch.channels_last_3d)
            else:
                raise ValueError(
                    f"Can not set channels_last to the input of {input_image.ndim} dims"
                )


        align_corners = None if "nearest" in mode else False

        if mode == "linear":
            mode = {
                3: 'linear',
                4: 'bilinear',
                5: 'trilinear',
            }[input_image.ndim]

        self.inputs = {
            "input_image": input_image,
            "output_size": output_size,
            "mode": mode,
            "align_corners": align_corners,
        }

        self.set_module_name("interpolate")

    def forward(self, input_image, output_size, mode, align_corners):
        return torch.nn.functional.interpolate(input_image, size=output_size, mode=mode,
                                               align_corners=align_corners)


def make_config():
    sizes = (
        ((16, 320, 320), (8, 256, 256)),
        ((16, 320, 320), (32, 512, 512)),
    )

    attrs = []
    for (DHW1, DHW2) in sizes:
        attrs.append([(1, 3, *DHW1), DHW2])
        attrs.append([(1, 3, *DHW2), DHW1])


    config = op_bench.config_list(
        attr_names=["input_size", "output_size"],
        attrs=attrs,
        cross_product_configs={
            'channels_last': [True],
            'mode': ["linear", "nearest", "nearest-exact"],
            'dtype': [torch.float, torch.uint8]
        },
        tags=["short"],
    )

    # Need to remove instances with both torch.int and linear
    # Note: this is naaaasty
    def get_mode(l):
        for d in l:
            if "mode" in d:
                return d["mode"]
    def get_dtype(l):
        for d in l:
            if "dtype" in d:
                return d["dtype"]
    config = [l for l in config if not(get_mode(l) == "linear" and get_dtype(l) == torch.uint8)]
    return config

config = make_config()
op_bench.generate_pt_test(config, InterpolateBenchmark)


if __name__ == "__main__":
    op_bench.benchmark_runner.main()
```

```py
import re
import argparse

parser = argparse.ArgumentParser()
parser.add_argument("f3", nargs="?", default="main")
parser.add_argument("f2", nargs="?", default="new")
args = parser.parse_args()

with open(args.f1) as f:
    main = f.readlines()
with open(args.f2) as f:
    new = f.readlines()

out = []

for main_line, new_line in zip(main, new):
    # num_threads=1  # TODO: remove
    if main_line.startswith("num_threads="):
        num_threads = int(main_line.split("=")[-1])
    if main_line.startswith("# Input"):
        deets = f"{main_line.strip()}, {num_threads=}"
    if main_line.startswith("Forward"):
        main_time = float(main_line.split()[-1])
        new_time = float(new_line.split()[-1])
        ratio = main_time / new_time
        fmt = ".1f" if ratio < 3 else ".0f"
        improv = f"{ratio:{fmt}}X"
        time_fmt = ",.3f" if new_time < 100 else ",.1f"
        deets = deets.strip().replace("# Input: ", "")
        deets = deets.replace(": ", "=")
        deets = deets.replace("input_size=", "")
        deets = deets.replace(", output_size=", " -> ")
        deets = deets.replace("dtype=torch.", "")
        deets = deets.replace("mode=", "")
        deets = deets.replace("channels_last=True, ", "")
        split = deets.split(",")
        size = ','.join(split[:-3])
        mode, dtype, threads = split[-3:]
        deets = f"{size:<30} {mode:<15} {dtype:<10} {threads:<15}"

        l = f"{deets}  {improv:<5} {main_time / 1000:{time_fmt}}ms vs {new_time / 1000:{time_fmt}}ms"
        out.append(l)


def key(s):
    # s = ''.join(s.split()[1:]) # remove "N.nX" part
    num_threads = (int(re.findall(r"num_threads=(\d+)", s)[0]),)

    input_shape, output_shape = re.findall("\(.*?\)", s)
    input_shape = input_shape[1:-1]  # remove parenthesis
    input_HW = tuple(int(x) for x in input_shape.split(",")[-2:])
    input_C = (-int(input_shape.split(",")[1]),)

    output_HW = tuple(int(x) for x in output_shape[1:-1].split(","))
    is_downsample = (output_HW[0] < input_HW[0],)
    if "linear" in s:
        mode = "linear"
    elif "nearest-exact" in s:
        mode = "nearest-exact"
    else:
        assert "nearest" in s
        mode = "nearest"
    mode = (mode,)
    return is_downsample + input_HW + output_HW + num_threads + input_C + mode

for i, l in enumerate(sorted(out, key=key)):
    if i % 5 == 0:
        print()
    # if i % 10 == 0 and i % 40 != 0:
    #     print()
    # if i % 40 == 0:
    #     print("-" * 100)
    print(l)
```


</details >

cc VitalyFedyunin jgong5 mingfeima XiaobingSuper sanchitintel ashokei jingxu10

[ghstack-poisoned]
NicolasHug added a commit that referenced this pull request Dec 14, 2022
@github-actions github-actions bot added the module: cpu CPU specific problem (e.g., perf, algorithm) label Dec 14, 2022
@NicolasHug
Copy link
Copy Markdown
Member Author

@pytorchmergebot merge -g

@pytorch-bot pytorch-bot bot added the ciflow/trunk Trigger trunk jobs on your pull request label Dec 14, 2022
@pytorchmergebot
Copy link
Copy Markdown
Collaborator

Merge started

Your change will be merged once all checks on your PR pass since you used the green (-g) flag (ETA: 0-4 Hours).

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

@facebook-github-bot facebook-github-bot deleted the gh/NicolasHug/4/head branch June 8, 2023 14:45
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ciflow/trunk Trigger trunk jobs on your pull request Merged module: cpu CPU specific problem (e.g., perf, algorithm) release notes: nn release notes category topic: not user facing topic category topic: performance topic category triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants