Skip to content

Errors exporting model from PyTorch to Caffe2 #7650

@peastman

Description

@peastman

System Info

PyTorch 0.4.0 (installed with conda)
Caffe2 0.8.dev (installed with conda)
onnx-caffe2 1.0.0 (installed with pip)
macOS 10.13.4
Python 3.6.5
no CUDA

Issue description

I'm trying to build a model with PyTorch, export it as a Caffe2 model, then use it in a C++ program. I'm pretty sure the C++ code is correct. At any rate, it works correctly if I use a model built directly with Caffe2. But I run into various problems when using a PyTorch model. Here is the code I use to generate it:

import torch
import torch.nn as nn

class Compute(nn.Module):
    def forward(self, x):
        return torch.sum(x**2)

x = torch.rand(10, 3)
torch.onnx.export(Compute(), x, "test.onnx", verbose=True, input_names=['positions'], output_names=['energy'])

I convert it to a Caffe2 model with convert-onnx-to-caffe2, then try to execute it in my C++ program. It fails with this error:

exception: [enforce fail at tensor.h:495] IsType<T>(). Tensor type mismatch, caller expects elements to be float while tensor contains long long Error from operator: 
input: "positions" input: "1" output: "2" name: "" type: "Pow" device_option { device_type: 0 cuda_gpu_id: 0 }
** while accessing input: 1

The problem appears to be the operation x**2. PyTorch is recording the exponent as being a long long, but Caffe2 insists it must be a float.

As a temporary workaround, I tried eliminating the power operation by changing the line to return torch.sum(x*x). With that change I can run the model, but when I query the "energy" output, it's wrong. It ought to be a scalar containing the sum of squares of all the input elements. Instead, it comes out as a (10, 3) matrix containing the square of each element. That is, the sum operation is never getting run.

Metadata

Metadata

Assignees

No one assigned

    Labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions