Skip to content

cublas runtime error when both bmm's arguments have been expanded #2022

@hezyin

Description

@hezyin

The torch.bmm function would raise a runtime error when both of its arguments have been expanded. The following code snippet can reproduce the error:

import torch

a = torch.randn(256, 256).cuda()
b = torch.randn(32, 256).cuda()

c = a.unsqueeze(0).expand(32, 256, 256)
d = b.unsqueeze(2).expand(32, 256, 20)
torch.bmm(c, d)

Error message:

RuntimeError: cublas runtime error : an invalid numeric value was used as an argument at /b/wheel/pytorch-src/torch/lib/THC/THCBlas.cu:344

I was able to get rid of the error by replacing expand with repeat, the following code would be fine:

import torch

a = torch.randn(256, 256).cuda()
b = torch.randn(32, 256).cuda()

c = a.unsqueeze(0).expand(32, 256, 256)
d = b.unsqueeze(2).repeat(1, 1, 20)
torch.bmm(c, d)

My environment: python 3.6 (anaconda), CUDA 8.0, pytorch 0.1.12_2

Metadata

Metadata

Assignees

No one assigned

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions