Skip to content

Bilinear interpolation behavior inconsistent with TF, CoreML and Caffe #10604

@libfun

Description

@libfun

Issue description

Trying to compare and transfer models between Caffe, TF and Pytorch found difference in output of bilinear interpolations between all. Caffe is using depthwise transposed convolutions instead of straightforward resize, so it's easy to reimplement both in TF and Pytorch.
However, there is difference between output for TF and Pytorch with align_corners=False, which is default for both.

Code example

img = cv2.resize(cv2.imread('./lenna.png')[:, :, ::-1], (256, 256))
img = img.reshape(1, 256, 256, 3).astype('float32') / 255.
img = tf.convert_to_tensor(img)
output_size = [512, 512]
output = tf.image.resize_bilinear(img, output_size, align_corners=True)
with tf.Session() as sess:
        values = sess.run([output])
out_tf = values[0].astype('float32')[0]

img = img.reshape(1, 256, 256, 3).transpose(0, 3, 1, 2).astype('float32') / 255.
out_pt = nn.functional.interpolate(torch.from_numpy(nimg), 
                                   scale_factor=2, 
                                   mode='bilinear', 
                                   align_corners=True)
out_pt = out_pt.data.numpy().transpose(0, 2, 3, 1)[0]

print(np.max(np.abs(out_pt - out_tf)))
# output 5.6624413e-06

But

img = cv2.resize(cv2.imread('./lenna.png')[:, :, ::-1], (256, 256))
img = img.reshape(1, 256, 256, 3).astype('float32') / 255.
img = tf.convert_to_tensor(img)
output_size = [512, 512]
output = tf.image.resize_bilinear(img, output_size, align_corners=False)
with tf.Session() as sess:
        values = sess.run([output])
out_tf = values[0].astype('float32')[0]

img = img.reshape(1, 256, 256, 3).transpose(0, 3, 1, 2).astype('float32') / 255.
out_pt = nn.functional.interpolate(torch.from_numpy(nimg), 
                                   scale_factor=2, 
                                   mode='bilinear', 
                                   align_corners=False)
out_pt = out_pt.data.numpy().transpose(0, 2, 3, 1)[0]

print(np.max(np.abs(out_pt - out_tf)))
# output 0.22745097

Output diff * 10:
image

Output of CoreML is consistent with TF, so it seems that there is a bug with implementation of bilinear interpolation with align_corners=False in Pytorch.

Diff is reproducible both on cpu and cuda with cudnn 7.1, cuda 9.1.

Metadata

Metadata

Assignees

No one assigned

    Labels

    module: interpolationtriagedThis issue has been looked at a team member, and triaged and prioritized into an appropriate module

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions