Skip to content

Different behavior when trace model. #19349

@ofexe

Description

@ofexe

🐛 Bug

When I use torch.jit.trace() to trace a model, it gives me different results with same code.

To Reproduce

I write a simple test example

import torch.nn as nn
import torch


class Net(nn.Module):

    def __init__(self):
        super(Net, self).__init__()
        print()

    def set_test_name(self, name):
        self.test_name = name

    def forward(self, input):
        print(self.test_name, ':', input.shape[2], '\n')
        output = input
        return output


if __name__ == '__main__':
    model = Net()
    model.eval()
    inputs = torch.ones(1, 1, 32, 160)

    model.set_test_name('predict')
    output = model(inputs)

    model.set_test_name('trace')
    traced = torch.jit.trace(model, inputs)
    traced.save('stm.pt')

Expected behavior

it should print this

predict : 32
trace : 32
trace : 32
trace : 32

yet it gives me :

C:\ProgramData\Anaconda3\lib\site-packages\torch\tensor.py:427: RuntimeWarning: Iterating over a tensor might cause the trace to be incorrect. Passing a tensor of different shape won't change the number of iterations executed (and might lead to errors or silently give incorrect results).
predict : 32 

trace : tensor(32) 

  'incorrect results).', category=RuntimeWarning)
trace : tensor(32) 

trace : 32 

As you can see, the type of input.shape[2] is different.

Environment

  • PyTorch Version (e.g., 1.0): 1.0.1
  • OS (e.g., Linux): Win10
  • How you installed PyTorch (conda, pip, source): conda
  • Python version: 3.7.1

cc @suo

Metadata

Metadata

Assignees

No one assigned

    Labels

    oncall: jitAdd this issue/PR to JIT oncall triage queuetriagedThis issue has been looked at a team member, and triaged and prioritized into an appropriate module

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions