Skip to content

torch.einsum 400x slower than numpy.einsum on a simple contraction #10661

@fritzo

Description

@fritzo

Issue description

torch.einsum is 400x slower than numpy.einsum on a simple contraction. This is making some Pyro models very slow.

Code example

import torch, numpy, timeit
x = torch.randn((2, 2000))
y = torch.randn((2, 2, 2000))
equation = 'ac,abc->cb'

time0 = timeit.default_timer()
for _ in range(1000):
    _ = torch.einsum(equation, [x, y])

time1 = timeit.default_timer()
for _ in range(1000):
    _ = numpy.einsum(equation, x.numpy(), y.numpy())

time2 = timeit.default_timer()
print('torch: {}'.format(time1 - time0))
print('numpy: {}'.format(time2 - time1))
torch: 3.97460007668
numpy: 0.00863790512085

System Info

$ python collect_env.py
Collecting environment information...
PyTorch version: 0.4.0
Is debug build: No
CUDA used to build PyTorch: None

OS: Mac OSX 10.13.3
GCC version: Could not collect
CMake version: version 3.12.0

Python version: 2.7
Is CUDA available: No
CUDA runtime version: No CUDA
GPU models and configuration: No CUDA
Nvidia driver version: No CUDA
cuDNN version: No CUDA

Versions of relevant libraries:
[pip] Could not collect
[conda] torch                     0.4.0                     <pip>
[conda] torchfile                 0.1.0                     <pip>
[conda] torchvision               0.2.1                     <pip>

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions