The torch.bmm function would raise a runtime error when both of its arguments have been expanded. The following code snippet can reproduce the error:
import torch
a = torch.randn(256, 256).cuda()
b = torch.randn(32, 256).cuda()
c = a.unsqueeze(0).expand(32, 256, 256)
d = b.unsqueeze(2).expand(32, 256, 20)
torch.bmm(c, d)
Error message:
RuntimeError: cublas runtime error : an invalid numeric value was used as an argument at /b/wheel/pytorch-src/torch/lib/THC/THCBlas.cu:344
I was able to get rid of the error by replacing expand with repeat, the following code would be fine:
import torch
a = torch.randn(256, 256).cuda()
b = torch.randn(32, 256).cuda()
c = a.unsqueeze(0).expand(32, 256, 256)
d = b.unsqueeze(2).repeat(1, 1, 20)
torch.bmm(c, d)
My environment: python 3.6 (anaconda), CUDA 8.0, pytorch 0.1.12_2
The
torch.bmmfunction would raise a runtime error when both of its arguments have been expanded. The following code snippet can reproduce the error:Error message:
I was able to get rid of the error by replacing
expandwithrepeat, the following code would be fine:My environment: python 3.6 (anaconda), CUDA 8.0, pytorch 0.1.12_2