Skip to content

Commit eeda31f

Browse files
vfdev-5pytorchmergebot
authored andcommitted
Added antialias flag to interpolate (CUDA, bilinear and bicubic) (#70930)
Summary: Description: - Added antialias flag to interpolate (CUDA) - forward and backward for bicubic mode - added tests Previous PR for CPU bilinear, #65142 Previous PR for CPU bicubic, #68819 ### Benchmarks <details> <summary> Bilinear forward pass, PIL, PTH CPU and PTH CUDA </summary> Code: https://gist.github.com/vfdev-5/b173761a567f2283b3c649c3c0574112 ``` Torch version: 1.11.0a0+gitd032369 Torch config: PyTorch built with: - GCC 9.3 - C++ Version: 201402 - OpenMP 201511 (a.k.a. OpenMP 4.5) - CPU capability usage: AVX2 - CUDA Runtime 11.1 - NVCC architecture flags: -gencode;arch=compute_61,code=sm_61 - CuDNN 8.0.5 - Build settings: BUILD_TYPE=Release, CUDA_VERSION=11.1, CUDNN_VERSION=8.0.5, CXX_COMPILER=/usr/bin/c++, CXX_FLAGS= -Wno-deprecated -fvisibility-inlines-hidden -DUSE_PTHREADPOOL -fopenmp -DNDEBUG -DUSE_KINETO -DUSE_PYTORCH_QNNPACK -DSYMBOLICATE_MOBILE_DEBUG_HANDLE -DEDGE_PROFILER_USE_KINETO -O2 -fPIC -Wno-narrowing -Wall -Wextra -Werror=return-type -Wno-missing-field-initializers -Wno-type-limits -Wno-array-bounds -Wno-unknown-pragmas -Wno-sign-compare -Wno-unused-parameter -Wno-unused-function -Wno-unused-result -Wno-unused-local-typedefs -Wno-strict-overflow -Wno-strict-aliasing -Wno-error=deprecated-declarations -Wno-stringop-overflow -Wno-psabi -Wno-error=pedantic -Wno-error=redundant-decls -Wno-error=old-style-cast -fdiagnostics-color=always -faligned-new -Wno-unused-but-set-variable -Wno-maybe-uninitialized -fno-math-errno -fno-trapping-math -Werror=format -Werror=cast-function-type -Wno-stringop-overflow, PERF_WITH_AVX=1, PERF_WITH_AVX2=1, PERF_WITH_AVX512=1, TORCH_VERSION=1.11.0, USE_CUDA=1, USE_CUDNN=1, USE_EIGEN_FOR_BLAS=ON, USE_EXCEPTION_PTR=1, USE_GFLAGS=OFF, USE_GLOG=OFF, USE_MKL=OFF, USE_MKLDNN=OFF, USE_MPI=OFF, USE_NCCL=ON, USE_NNPACK=0, USE_OPENMP=ON, USE_ROCM=OFF, Num threads: 8 [----------------------------------- Downsampling (bilinear): torch.Size([1, 3, 906, 438]) -> (320, 196) -----------------------------------] | Reference, PIL 8.4.0, mode: RGB | 1.11.0a0+gitd032369 cpu | 1.11.0a0+gitd032369 cuda 8 threads: ---------------------------------------------------------------------------------------------------------------------------------- channels_first contiguous torch.float32 | 2851.2 | 874.1 | 57.1 channels_last non-contiguous torch.float32 | 2856.1 | 1155.8 | 130.6 Times are in microseconds (us). [----------------------------------- Downsampling (bilinear): torch.Size([1, 3, 906, 438]) -> (460, 220) -----------------------------------] | Reference, PIL 8.4.0, mode: RGB | 1.11.0a0+gitd032369 cpu | 1.11.0a0+gitd032369 cuda 8 threads: ---------------------------------------------------------------------------------------------------------------------------------- channels_first contiguous torch.float32 | 3705.9 | 1005.8 | 66.3 channels_last non-contiguous torch.float32 | 3742.9 | 1332.8 | 143.5 Times are in microseconds (us). [------------------------------------ Downsampling (bilinear): torch.Size([1, 3, 906, 438]) -> (120, 96) -----------------------------------] | Reference, PIL 8.4.0, mode: RGB | 1.11.0a0+gitd032369 cpu | 1.11.0a0+gitd032369 cuda 8 threads: ---------------------------------------------------------------------------------------------------------------------------------- channels_first contiguous torch.float32 | 1768.0 | 725.2 | 77.9 channels_last non-contiguous torch.float32 | 1753.7 | 942.5 | 144.0 Times are in microseconds (us). [----------------------------------- Downsampling (bilinear): torch.Size([1, 3, 906, 438]) -> (1200, 196) ----------------------------------] | Reference, PIL 8.4.0, mode: RGB | 1.11.0a0+gitd032369 cpu | 1.11.0a0+gitd032369 cuda 8 threads: ---------------------------------------------------------------------------------------------------------------------------------- channels_first contiguous torch.float32 | 9522.6 | 2593.8 | 157.8 channels_last non-contiguous torch.float32 | 9513.5 | 3622.7 | 241.5 Times are in microseconds (us). [----------------------------------- Downsampling (bilinear): torch.Size([1, 3, 906, 438]) -> (120, 1200) ----------------------------------] | Reference, PIL 8.4.0, mode: RGB | 1.11.0a0+gitd032369 cpu | 1.11.0a0+gitd032369 cuda 8 threads: ---------------------------------------------------------------------------------------------------------------------------------- channels_first contiguous torch.float32 | 2240.1 | 565.5 | 93.3 channels_last non-contiguous torch.float32 | 2244.2 | 972.7 | 170.8 Times are in microseconds (us). [------------------------- Downsampling (bilinear): torch.Size([1, 1, 906, 438]) -> (320, 196) --------------------------] | Reference, PIL 8.4.0, mode: F | 1.11.0a0+gitd032369 cpu | 1.11.0a0+gitd032369 cuda 8 threads: --------------------------------------------------------------------------------------------------------------- contiguous torch.float32 | 1441.3 | 386.1 | 22.3 Times are in microseconds (us). [------------------------- Downsampling (bilinear): torch.Size([1, 1, 906, 438]) -> (460, 220) --------------------------] | Reference, PIL 8.4.0, mode: F | 1.11.0a0+gitd032369 cpu | 1.11.0a0+gitd032369 cuda 8 threads: --------------------------------------------------------------------------------------------------------------- contiguous torch.float32 | 1815.2 | 376.8 | 27.8 Times are in microseconds (us). [-------------------------- Downsampling (bilinear): torch.Size([1, 1, 906, 438]) -> (120, 96) --------------------------] | Reference, PIL 8.4.0, mode: F | 1.11.0a0+gitd032369 cpu | 1.11.0a0+gitd032369 cuda 8 threads: --------------------------------------------------------------------------------------------------------------- contiguous torch.float32 | 962.3 | 400.0 | 29.4 Times are in microseconds (us). [------------------------- Downsampling (bilinear): torch.Size([1, 1, 906, 438]) -> (1200, 196) -------------------------] | Reference, PIL 8.4.0, mode: F | 1.11.0a0+gitd032369 cpu | 1.11.0a0+gitd032369 cuda 8 threads: --------------------------------------------------------------------------------------------------------------- contiguous torch.float32 | 4749.7 | 910.1 | 63.7 Times are in microseconds (us). [------------------------- Downsampling (bilinear): torch.Size([1, 1, 906, 438]) -> (120, 1200) -------------------------] | Reference, PIL 8.4.0, mode: F | 1.11.0a0+gitd032369 cpu | 1.11.0a0+gitd032369 cuda 8 threads: --------------------------------------------------------------------------------------------------------------- contiguous torch.float32 | 1098.1 | 272.0 | 36.4 Times are in microseconds (us). ``` </details> <details> <summary> Bicubic forward pass, PIL, PTH CPU and PTH CUDA </summary> Code: https://gist.github.com/vfdev-5/b173761a567f2283b3c649c3c0574112 ``` Torch version: 1.11.0a0+gitd032369 Torch config: PyTorch built with: - GCC 9.3 - C++ Version: 201402 - OpenMP 201511 (a.k.a. OpenMP 4.5) - CPU capability usage: AVX2 - CUDA Runtime 11.1 - NVCC architecture flags: -gencode;arch=compute_61,code=sm_61 - CuDNN 8.0.5 - Build settings: BUILD_TYPE=Release, CUDA_VERSION=11.1, CUDNN_VERSION=8.0.5, CXX_COMPILER=/usr/bin/c++, CXX_FLAGS= -Wno-deprecated -fvisibility-inlines-hidden -DUSE_PTHREADPOOL -fopenmp -DNDEBUG -DUSE_KINETO -DUSE_PYTORCH_QNNPACK -DSYMBOLICATE_MOBILE_DEBUG_HANDLE -DEDGE_PROFILER_USE_KINETO -O2 -fPIC -Wno-narrowing -Wall -Wextra -Werror=return-type -Wno-missing-field-initializers -Wno-type-limits -Wno-array-bounds -Wno-unknown-pragmas -Wno-sign-compare -Wno-unused-parameter -Wno-unused-function -Wno-unused-result -Wno-unused-local-typedefs -Wno-strict-overflow -Wno-strict-aliasing -Wno-error=deprecated-declarations -Wno-stringop-overflow -Wno-psabi -Wno-error=pedantic -Wno-error=redundant-decls -Wno-error=old-style-cast -fdiagnostics-color=always -faligned-new -Wno-unused-but-set-variable -Wno-maybe-uninitialized -fno-math-errno -fno-trapping-math -Werror=format -Werror=cast-function-type -Wno-stringop-overflow, PERF_WITH_AVX=1, PERF_WITH_AVX2=1, PERF_WITH_AVX512=1, TORCH_VERSION=1.11.0, USE_CUDA=1, USE_CUDNN=1, USE_EIGEN_FOR_BLAS=ON, USE_EXCEPTION_PTR=1, USE_GFLAGS=OFF, USE_GLOG=OFF, USE_MKL=OFF, USE_MKLDNN=OFF, USE_MPI=OFF, USE_NCCL=ON, USE_NNPACK=0, USE_OPENMP=ON, USE_ROCM=OFF, Num threads: 8 [------------------------------------ Downsampling (bicubic): torch.Size([1, 3, 906, 438]) -> (320, 196) -----------------------------------] | Reference, PIL 8.4.0, mode: RGB | 1.11.0a0+gitd032369 cpu | 1.11.0a0+gitd032369 cuda 8 threads: ---------------------------------------------------------------------------------------------------------------------------------- channels_first contiguous torch.float32 | 4522.4 | 1406.7 | 170.3 channels_last non-contiguous torch.float32 | 4530.0 | 1435.4 | 242.2 Times are in microseconds (us). [------------------------------------ Downsampling (bicubic): torch.Size([1, 3, 906, 438]) -> (460, 220) -----------------------------------] | Reference, PIL 8.4.0, mode: RGB | 1.11.0a0+gitd032369 cpu | 1.11.0a0+gitd032369 cuda 8 threads: ---------------------------------------------------------------------------------------------------------------------------------- channels_first contiguous torch.float32 | 5726.4 | 1628.6 | 164.0 channels_last non-contiguous torch.float32 | 5722.6 | 1665.6 | 234.7 Times are in microseconds (us). [------------------------------------ Downsampling (bicubic): torch.Size([1, 3, 906, 438]) -> (120, 96) ------------------------------------] | Reference, PIL 8.4.0, mode: RGB | 1.11.0a0+gitd032369 cpu | 1.11.0a0+gitd032369 cuda 8 threads: ---------------------------------------------------------------------------------------------------------------------------------- channels_first contiguous torch.float32 | 2909.1 | 1461.5 | 276.9 channels_last non-contiguous torch.float32 | 2892.9 | 1458.7 | 345.1 Times are in microseconds (us). [----------------------------------- Downsampling (bicubic): torch.Size([1, 3, 906, 438]) -> (1200, 196) -----------------------------------] | Reference, PIL 8.4.0, mode: RGB | 1.11.0a0+gitd032369 cpu | 1.11.0a0+gitd032369 cuda 8 threads: ---------------------------------------------------------------------------------------------------------------------------------- channels_first contiguous torch.float32 | 14699.2 | 4283.9 | 407.1 channels_last non-contiguous torch.float32 | 14711.3 | 4321.1 | 477.0 Times are in microseconds (us). [----------------------------------- Downsampling (bicubic): torch.Size([1, 3, 906, 438]) -> (120, 1200) -----------------------------------] | Reference, PIL 8.4.0, mode: RGB | 1.11.0a0+gitd032369 cpu | 1.11.0a0+gitd032369 cuda 8 threads: ---------------------------------------------------------------------------------------------------------------------------------- channels_first contiguous torch.float32 | 3467.0 | 980.0 | 339.2 channels_last non-contiguous torch.float32 | 3465.2 | 982.3 | 407.8 Times are in microseconds (us). [-------------------------- Downsampling (bicubic): torch.Size([1, 1, 906, 438]) -> (320, 196) --------------------------] | Reference, PIL 8.4.0, mode: F | 1.11.0a0+gitd032369 cpu | 1.11.0a0+gitd032369 cuda 8 threads: --------------------------------------------------------------------------------------------------------------- contiguous torch.float32 | 2396.7 | 877.8 | 68.1 Times are in microseconds (us). [-------------------------- Downsampling (bicubic): torch.Size([1, 1, 906, 438]) -> (460, 220) --------------------------] | Reference, PIL 8.4.0, mode: F | 1.11.0a0+gitd032369 cpu | 1.11.0a0+gitd032369 cuda 8 threads: --------------------------------------------------------------------------------------------------------------- contiguous torch.float32 | 3068.2 | 777.3 | 64.7 Times are in microseconds (us). [-------------------------- Downsampling (bicubic): torch.Size([1, 1, 906, 438]) -> (120, 96) ---------------------------] | Reference, PIL 8.4.0, mode: F | 1.11.0a0+gitd032369 cpu | 1.11.0a0+gitd032369 cuda 8 threads: --------------------------------------------------------------------------------------------------------------- contiguous torch.float32 | 1540.2 | 829.3 | 100.4 Times are in microseconds (us). [------------------------- Downsampling (bicubic): torch.Size([1, 1, 906, 438]) -> (1200, 196) --------------------------] | Reference, PIL 8.4.0, mode: F | 1.11.0a0+gitd032369 cpu | 1.11.0a0+gitd032369 cuda 8 threads: --------------------------------------------------------------------------------------------------------------- contiguous torch.float32 | 7919.5 | 1467.8 | 151.6 Times are in microseconds (us). [------------------------- Downsampling (bicubic): torch.Size([1, 1, 906, 438]) -> (120, 1200) --------------------------] | Reference, PIL 8.4.0, mode: F | 1.11.0a0+gitd032369 cpu | 1.11.0a0+gitd032369 cuda 8 threads: --------------------------------------------------------------------------------------------------------------- contiguous torch.float32 | 1695.7 | 631.2 | 117.7 Times are in microseconds (us). ``` </details> <details> <summary> Bilinear backward pass, PTH CPU and PTH CUDA </summary> Code: https://gist.github.com/vfdev-5/b173761a567f2283b3c649c3c0574112 ``` - Measure only backward op Torch version: 1.11.0a0+gitd032369 Torch config: PyTorch built with: - GCC 9.3 - C++ Version: 201402 - OpenMP 201511 (a.k.a. OpenMP 4.5) - CPU capability usage: AVX2 - CUDA Runtime 11.1 - NVCC architecture flags: -gencode;arch=compute_61,code=sm_61 - CuDNN 8.0.5 - Build settings: BUILD_TYPE=Release, CUDA_VERSION=11.1, CUDNN_VERSION=8.0.5, CXX_COMPILER=/usr/bin/c++, CXX_FLAGS= -Wno-deprecated -fvisibility-inlines-hidden -DUSE_PTHREADPOOL -fopenmp -DNDEBUG -DUSE_KINETO -DUSE_PYTORCH_QNNPACK -DSYMBOLICATE_MOBILE_DEBUG_HANDLE -DEDGE_PROFILER_USE_KINETO -O2 -fPIC -Wno-narrowing -Wall -Wextra -Werror=return-type -Wno-missing-field-initializers -Wno-type-limits -Wno-array-bounds -Wno-unknown-pragmas -Wno-sign-compare -Wno-unused-parameter -Wno-unused-function -Wno-unused-result -Wno-unused-local-typedefs -Wno-strict-overflow -Wno-strict-aliasing -Wno-error=deprecated-declarations -Wno-stringop-overflow -Wno-psabi -Wno-error=pedantic -Wno-error=redundant-decls -Wno-error=old-style-cast -fdiagnostics-color=always -faligned-new -Wno-unused-but-set-variable -Wno-maybe-uninitialized -fno-math-errno -fno-trapping-math -Werror=format -Werror=cast-function-type -Wno-stringop-overflow, PERF_WITH_AVX=1, PERF_WITH_AVX2=1, PERF_WITH_AVX512=1, TORCH_VERSION=1.11.0, USE_CUDA=1, USE_CUDNN=1, USE_EIGEN_FOR_BLAS=ON, USE_EXCEPTION_PTR=1, USE_GFLAGS=OFF, USE_GLOG=OFF, USE_MKL=OFF, USE_MKLDNN=OFF, USE_MPI=OFF, USE_NCCL=ON, USE_NNPACK=0, USE_OPENMP=ON, USE_ROCM=OFF, Num threads: 8 [------------- Downsampling backward (bilinear): torch.Size([1, 3, 906, 438]) -> (320, 196) ------------] | 1.11.0a0+gitd032369 cpu | 1.11.0a0+gitd032369 cuda 8 threads: ---------------------------------------------------------------------------------------------- channels_first contiguous torch.float32 | 4686.8 | 215.7 channels_last non-contiguous torch.float32 | 5101.1 | 220.5 Times are in microseconds (us). [------------- Downsampling backward (bilinear): torch.Size([1, 3, 906, 438]) -> (460, 220) ------------] | 1.11.0a0+gitd032369 cpu | 1.11.0a0+gitd032369 cuda 8 threads: ---------------------------------------------------------------------------------------------- channels_first contiguous torch.float32 | 6011.2 | 204.4 channels_last non-contiguous torch.float32 | 6396.0 | 210.0 Times are in microseconds (us). [------------- Downsampling backward (bilinear): torch.Size([1, 3, 906, 438]) -> (120, 96) -------------] | 1.11.0a0+gitd032369 cpu | 1.11.0a0+gitd032369 cuda 8 threads: ---------------------------------------------------------------------------------------------- channels_first contiguous torch.float32 | 2035.6 | 250.2 channels_last non-contiguous torch.float32 | 1589.6 | 252.5 Times are in microseconds (us). [------------ Downsampling backward (bilinear): torch.Size([1, 3, 906, 438]) -> (1200, 196) ------------] | 1.11.0a0+gitd032369 cpu | 1.11.0a0+gitd032369 cuda 8 threads: ---------------------------------------------------------------------------------------------- channels_first contiguous torch.float32 | 11392.5 | 256.5 channels_last non-contiguous torch.float32 | 11640.2 | 263.9 Times are in microseconds (us). [------------ Downsampling backward (bilinear): torch.Size([1, 3, 906, 438]) -> (120, 1200) ------------] | 1.11.0a0+gitd032369 cpu | 1.11.0a0+gitd032369 cuda 8 threads: ---------------------------------------------------------------------------------------------- channels_first contiguous torch.float32 | 11769.6 | 465.9 channels_last non-contiguous torch.float32 | 12407.0 | 474.4 Times are in microseconds (us). [---- Downsampling backward (bilinear): torch.Size([1, 1, 906, 438]) -> (320, 196) ----] | 1.11.0a0+gitd032369 cpu | 1.11.0a0+gitd032369 cuda 8 threads: ----------------------------------------------------------------------------- contiguous torch.float32 | 3931.0 | 133.3 Times are in microseconds (us). [---- Downsampling backward (bilinear): torch.Size([1, 1, 906, 438]) -> (460, 220) ----] | 1.11.0a0+gitd032369 cpu | 1.11.0a0+gitd032369 cuda 8 threads: ----------------------------------------------------------------------------- contiguous torch.float32 | 5594.8 | 133.9 Times are in microseconds (us). [---- Downsampling backward (bilinear): torch.Size([1, 1, 906, 438]) -> (120, 96) -----] | 1.11.0a0+gitd032369 cpu | 1.11.0a0+gitd032369 cuda 8 threads: ----------------------------------------------------------------------------- contiguous torch.float32 | 1272.6 | 133.0 Times are in microseconds (us). [--- Downsampling backward (bilinear): torch.Size([1, 1, 906, 438]) -> (1200, 196) ----] | 1.11.0a0+gitd032369 cpu | 1.11.0a0+gitd032369 cuda 8 threads: ----------------------------------------------------------------------------- contiguous torch.float32 | 10618.1 | 134.0 Times are in microseconds (us). [--- Downsampling backward (bilinear): torch.Size([1, 1, 906, 438]) -> (120, 1200) ----] | 1.11.0a0+gitd032369 cpu | 1.11.0a0+gitd032369 cuda 8 threads: ----------------------------------------------------------------------------- contiguous torch.float32 | 11082.2 | 154.6 Times are in microseconds (us). ``` </details> <details> <summary> Bicubic backward pass, PTH CPU and PTH CUDA </summary> Code: https://gist.github.com/vfdev-5/b173761a567f2283b3c649c3c0574112 ``` - Measure only backward op Torch version: 1.11.0a0+gitd032369 Torch config: PyTorch built with: - GCC 9.3 - C++ Version: 201402 - OpenMP 201511 (a.k.a. OpenMP 4.5) - CPU capability usage: AVX2 - CUDA Runtime 11.1 - NVCC architecture flags: -gencode;arch=compute_61,code=sm_61 - CuDNN 8.0.5 - Build settings: BUILD_TYPE=Release, CUDA_VERSION=11.1, CUDNN_VERSION=8.0.5, CXX_COMPILER=/usr/bin/c++, CXX_FLAGS= -Wno-deprecated -fvisibility-inlines-hidden -DUSE_PTHREADPOOL -fopenmp -DNDEBUG -DUSE_KINETO -DUSE_PYTORCH_QNNPACK -DSYMBOLICATE_MOBILE_DEBUG_HANDLE -DEDGE_PROFILER_USE_KINETO -O2 -fPIC -Wno-narrowing -Wall -Wextra -Werror=return-type -Wno-missing-field-initializers -Wno-type-limits -Wno-array-bounds -Wno-unknown-pragmas -Wno-sign-compare -Wno-unused-parameter -Wno-unused-function -Wno-unused-result -Wno-unused-local-typedefs -Wno-strict-overflow -Wno-strict-aliasing -Wno-error=deprecated-declarations -Wno-stringop-overflow -Wno-psabi -Wno-error=pedantic -Wno-error=redundant-decls -Wno-error=old-style-cast -fdiagnostics-color=always -faligned-new -Wno-unused-but-set-variable -Wno-maybe-uninitialized -fno-math-errno -fno-trapping-math -Werror=format -Werror=cast-function-type -Wno-stringop-overflow, PERF_WITH_AVX=1, PERF_WITH_AVX2=1, PERF_WITH_AVX512=1, TORCH_VERSION=1.11.0, USE_CUDA=1, USE_CUDNN=1, USE_EIGEN_FOR_BLAS=ON, USE_EXCEPTION_PTR=1, USE_GFLAGS=OFF, USE_GLOG=OFF, USE_MKL=OFF, USE_MKLDNN=OFF, USE_MPI=OFF, USE_NCCL=ON, USE_NNPACK=0, USE_OPENMP=ON, USE_ROCM=OFF, Num threads: 8 [------------- Downsampling backward (bicubic): torch.Size([1, 3, 906, 438]) -> (320, 196) -------------] | 1.11.0a0+gitd032369 cpu | 1.11.0a0+gitd032369 cuda 8 threads: ---------------------------------------------------------------------------------------------- channels_first contiguous torch.float32 | 6791.2 | 618.9 channels_last non-contiguous torch.float32 | 7125.2 | 622.9 Times are in microseconds (us). [------------- Downsampling backward (bicubic): torch.Size([1, 3, 906, 438]) -> (460, 220) -------------] | 1.11.0a0+gitd032369 cpu | 1.11.0a0+gitd032369 cuda 8 threads: ---------------------------------------------------------------------------------------------- channels_first contiguous torch.float32 | 8806.2 | 600.3 channels_last non-contiguous torch.float32 | 9167.6 | 607.5 Times are in microseconds (us). [-------------- Downsampling backward (bicubic): torch.Size([1, 3, 906, 438]) -> (120, 96) -------------] | 1.11.0a0+gitd032369 cpu | 1.11.0a0+gitd032369 cuda 8 threads: ---------------------------------------------------------------------------------------------- channels_first contiguous torch.float32 | 3683.6 | 693.8 channels_last non-contiguous torch.float32 | 3617.4 | 695.0 Times are in microseconds (us). [------------- Downsampling backward (bicubic): torch.Size([1, 3, 906, 438]) -> (1200, 196) ------------] | 1.11.0a0+gitd032369 cpu | 1.11.0a0+gitd032369 cuda 8 threads: ---------------------------------------------------------------------------------------------- channels_first contiguous torch.float32 | 17548.2 | 779.4 channels_last non-contiguous torch.float32 | 17966.2 | 786.5 Times are in microseconds (us). [------------- Downsampling backward (bicubic): torch.Size([1, 3, 906, 438]) -> (120, 1200) ------------] | 1.11.0a0+gitd032369 cpu | 1.11.0a0+gitd032369 cuda 8 threads: ---------------------------------------------------------------------------------------------- channels_first contiguous torch.float32 | 28.4 | 1.6 channels_last non-contiguous torch.float32 | 28.4 | 1.6 Times are in milliseconds (ms). [---- Downsampling backward (bicubic): torch.Size([1, 1, 906, 438]) -> (320, 196) -----] | 1.11.0a0+gitd032369 cpu | 1.11.0a0+gitd032369 cuda 8 threads: ----------------------------------------------------------------------------- contiguous torch.float32 | 6266.1 | 208.5 Times are in microseconds (us). [---- Downsampling backward (bicubic): torch.Size([1, 1, 906, 438]) -> (460, 220) -----] | 1.11.0a0+gitd032369 cpu | 1.11.0a0+gitd032369 cuda 8 threads: ----------------------------------------------------------------------------- contiguous torch.float32 | 8218.3 | 200.8 Times are in microseconds (us). [----- Downsampling backward (bicubic): torch.Size([1, 1, 906, 438]) -> (120, 96) -----] | 1.11.0a0+gitd032369 cpu | 1.11.0a0+gitd032369 cuda 8 threads: ----------------------------------------------------------------------------- contiguous torch.float32 | 3458.9 | 231.9 Times are in microseconds (us). [---- Downsampling backward (bicubic): torch.Size([1, 1, 906, 438]) -> (1200, 196) ----] | 1.11.0a0+gitd032369 cpu | 1.11.0a0+gitd032369 cuda 8 threads: ----------------------------------------------------------------------------- contiguous torch.float32 | 15729.3 | 261.6 Times are in microseconds (us). [---- Downsampling backward (bicubic): torch.Size([1, 1, 906, 438]) -> (120, 1200) ----] | 1.11.0a0+gitd032369 cpu | 1.11.0a0+gitd032369 cuda 8 threads: ----------------------------------------------------------------------------- contiguous torch.float32 | 26279.8 | 547.0 Times are in microseconds (us). ``` </details> Code is moved from torchvision: pytorch/vision#4211 and optimized Pull Request resolved: #70930 Reviewed By: zou3519 Differential Revision: D33817902 Pulled By: jbschlosser fbshipit-source-id: d63a620f8972ff36b63841f0bc6c820466f58f69 (cherry picked from commit d358cfd)
1 parent 567c2bb commit eeda31f

4 files changed

Lines changed: 590 additions & 46 deletions

File tree

aten/src/ATen/native/cuda/UpSample.cuh

Lines changed: 108 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -147,7 +147,7 @@ __device__ __forceinline__ static int nearest_neighbor_exact_compute_source_inde
147147
// input_index = round(index_f32)
148148
// Same as Pillow and Scikit-Image/Scipy ndi.zoom
149149
const int src_index =
150-
min(static_cast<int>(floorf((dst_index + 0.5) * scale)), input_size - 1);
150+
min(static_cast<int>(floorf((dst_index + static_cast<float>(0.5)) * scale)), input_size - 1);
151151
return src_index;
152152
}
153153

@@ -171,7 +171,7 @@ __device__ __forceinline__ static int nearest_neighbor_exact_bw_compute_source_i
171171
int output_size) {
172172
// Equivalent to Pillow and Scikit-Image/Scipy ndi.zoom
173173
const int src_index =
174-
min(static_cast<int>(ceilf(dst_index * scale - 0.5)), output_size);
174+
min(static_cast<int>(ceilf(dst_index * scale - static_cast<float>(0.5))), output_size);
175175
return src_index;
176176
}
177177

@@ -255,5 +255,111 @@ __device__ __forceinline__ static accscalar_t cubic_interp1d(
255255
return x0 * coeffs[0] + x1 * coeffs[1] + x2 * coeffs[2] + x3 * coeffs[3];
256256
}
257257

258+
namespace upsample_antialias {
259+
260+
// taken from
261+
// https://github.com/python-pillow/Pillow/blob/6812205f18ca4ef54372e87e1a13ce4a859434df/
262+
// src/libImaging/Resample.c#L20-L29
263+
struct BilinearFilterFunctor {
264+
265+
template <typename accscalar_t>
266+
__device__ accscalar_t operator()(accscalar_t x) const {
267+
if (x < 0) {
268+
x = -x;
269+
}
270+
if (x < 1) {
271+
return 1 - x;
272+
}
273+
return 0;
274+
}
275+
276+
static const int size = 2;
277+
};
278+
279+
// taken from
280+
// https://github.com/python-pillow/Pillow/blob/6812205f18ca4ef54372e87e1a13ce4a859434df/
281+
// src/libImaging/Resample.c#L46-L62
282+
struct BicubicFilterFunctor {
283+
284+
template <typename accscalar_t>
285+
__device__ accscalar_t operator()(accscalar_t x) const {
286+
// https://en.wikipedia.org/wiki/Bicubic_interpolation#Bicubic_convolution_algorithm
287+
const accscalar_t a = -0.5;
288+
if (x < 0) {
289+
x = -x;
290+
}
291+
if (x < 1) {
292+
return ((a + 2) * x - (a + 3)) * x * x + 1;
293+
}
294+
if (x < 2) {
295+
return (((x - 5) * x + 8) * x - 4) * a;
296+
}
297+
return 0;
298+
}
299+
300+
static const int size = 4;
301+
};
302+
303+
template <typename accscalar_t>
304+
__device__ __forceinline__ static void _compute_weights_span(
305+
const int i,
306+
const int input_size,
307+
const accscalar_t scale,
308+
const accscalar_t support,
309+
int& xmin,
310+
int& xsize,
311+
accscalar_t& center) {
312+
center = scale * (i + static_cast<accscalar_t>(0.5));
313+
xmin = max(static_cast<int>(center - support + static_cast<accscalar_t>(0.5)), static_cast<int>(0));
314+
xsize = min(static_cast<int>(center + support + static_cast<accscalar_t>(0.5)), input_size) - xmin;
315+
}
316+
317+
template <typename scalar_t, typename accscalar_t, typename interp_filter_t>
318+
__device__ __forceinline__ static void _compute_weights(
319+
scalar_t* wt_ptr,
320+
const accscalar_t scale,
321+
int interp_size,
322+
const interp_filter_t& interp_filter,
323+
accscalar_t xmin_m_center,
324+
int xsize) {
325+
326+
accscalar_t invscale = (scale >= 1.0) ? 1.0 / scale : 1.0;
327+
accscalar_t total_w = 0.0;
328+
int j = 0;
329+
for (j = 0; j < xsize; j++) {
330+
accscalar_t w = interp_filter((j + xmin_m_center + static_cast<accscalar_t>(0.5)) * invscale);
331+
wt_ptr[j] = static_cast<scalar_t>(w);
332+
total_w += w;
333+
}
334+
for (j = 0; j < xsize; j++) {
335+
if (total_w != 0.0) {
336+
wt_ptr[j] /= total_w;
337+
}
338+
}
339+
for (; j < interp_size; j++) {
340+
wt_ptr[j] = static_cast<scalar_t>(0.0);
341+
}
342+
}
343+
344+
template <typename scalar_t, typename accscalar_t>
345+
__device__ __forceinline__ static accscalar_t interpolate_aa_single_dim(
346+
const scalar_t* src,
347+
const scalar_t* weights,
348+
int size) {
349+
scalar_t t = static_cast<accscalar_t>(*src);
350+
scalar_t wts = static_cast<accscalar_t>(weights[0]);
351+
accscalar_t output = t * wts;
352+
353+
int j = 1;
354+
for (; j < size; j++) {
355+
wts = static_cast<accscalar_t>(weights[j]);
356+
t = static_cast<accscalar_t>(*(src + j));
357+
output += t * wts;
358+
}
359+
return output;
360+
}
361+
362+
}
363+
258364
} // namespace native
259365
} // namespace at

0 commit comments

Comments
 (0)