Skip to content

reuse consant from jit#49916

Closed
cccclai wants to merge 33 commits intogh/cccclai/13/basefrom
gh/cccclai/13/head
Closed

reuse consant from jit#49916
cccclai wants to merge 33 commits intogh/cccclai/13/basefrom
gh/cccclai/13/head

Conversation

@cccclai
Copy link
Copy Markdown
Contributor

@cccclai cccclai commented Dec 29, 2020

Summary

Jit will generate constant tensor value, and it locates in the constant folder after unzip model.ptl. Bytecode generated by lite interpreter also includes constant tensor, which are almost the same with the constant tensor value from jit. This pr reuses the constant tensor from jit. The implementation is:

  1. In export_module.cpp, store all constant tensor value from jit in an unordered_map constants_from_jit, where the tensor value use tensor string as a hash. When writing bytecode (writeByteCode function), the map constants_from_jit will also be passed. If the constant in bytecode exist in constants_from_jit, it will be replaced by a tuple in this format ('tensor_jit_index', 4), where 4 is the index in tensor table from jit.
  2. In import.cpp, it will also read from constants folder and loads the constant tensor. Then it scans the constants tensor from bytecode. If there exists tuple in this format ('tensor_jit_index', 4) under constants fields, it will fetch the constant from the tensor table from jit and replace with it accordingly.

Previous python -m torch.utils.show_pickle bytecode.pkl

...
   ('constants',
    (0.1,
     1e-05,
     None,
     2,
     1,
     False,
     0,
     True,
     False,
     1,
     2,
     3,
     torch._utils._rebuild_tensor_v2(pers.obj(('storage', torch.FloatStorage, '0', 'cpu', 90944),),
       0,
       (1, 116, 28, 28),
       (90944, 784, 28, 1),
       False,
       collections.OrderedDict()),
     torch._utils._rebuild_tensor_v2(pers.obj(('storage', torch.FloatStorage, '1', 'cpu', 90944),),
       0,
       (1, 116, 28, 28),
       (90944, 784, 28, 1),
       False,
       collections.OrderedDict()),
     torch._utils._rebuild_tensor_v2(pers.obj(('storage', torch.FloatStorage, '2', 'cpu', 90944),),
       0,
       (1, 116, 28, 28),
       (90944, 784, 28, 1),
       False,
       collections.OrderedDict()),
     0.1,
     1e-05,
     None,
     2,
     1,
     False,
...
)
...

Now python -m torch.utils.show_pickle bytecode.pkl

     ('constants',
    (0.1,
     1e-05,
     None,
     2,
     1,
     False,
     0,
     True,
     False,
     1,
     2,
     3,
     ('tensor_jit_index', 4),
     ('tensor_jit_index', 3),
     ('tensor_jit_index', 2),
     0.1,
     1e-05,
     None,
     2,
     1,
     False,
     0,
     24,
     True,
     58,
     ('tensor_jit_index', 0),
     3,
     ('tensor_jit_index', 1),
     -1,
     ('tensor_jit_index', 12),
     ('tensor_jit_index', 11),
     ('tensor_jit_index', 10),
     ('tensor_jit_index', 9),
     ('tensor_jit_index', 8),
     ('tensor_jit_index', 7),
     ('tensor_jit_index', 6),
     0.1,
     1e-05,
     None,
     2,
     1,
     False,
)
...

Test

  1. Build pytorch locally. MACOSX_DEPLOYMENT_TARGET=10.9 CC=clang CXX=clang++ USE_CUDA=0 DEBUG=1 MAX_JOBS=16 python setup.py develop
  2. Run python save_lite.py
import torch

# ~/Documents/pytorch/data/dog.jpg
model = torch.hub.load('pytorch/vision:v0.6.0', 'shufflenet_v2_x1_0', pretrained=True)
model.eval()

# sample execution (requires torchvision)
from PIL import Image
from torchvision import transforms
import pathlib
import tempfile
import torch.utils.mobile_optimizer

input_image = Image.open('~/Documents/pytorch/data/dog.jpg')
preprocess = transforms.Compose([
    transforms.Resize(256),
    transforms.CenterCrop(224),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])
input_tensor = preprocess(input_image)
input_batch = input_tensor.unsqueeze(0) # create a mini-batch as expected by the model

# move the input and model to GPU for speed if available
if torch.cuda.is_available():
    input_batch = input_batch.to('cuda')
    model.to('cuda')

with torch.no_grad():
    output = model(input_batch)
# Tensor of shape 1000, with confidence scores over Imagenet's 1000 classes
print(output[0])
# The output has unnormalized scores. To get probabilities, you can run a softmax on it.
print(torch.nn.functional.softmax(output[0], dim=0))

traced = torch.jit.trace(model, input_batch)
sum(p.numel() * p.element_size() for p in traced.parameters())
tf = pathlib.Path('~/Documents/pytorch/data/data/example_debug_map_with_tensorkey.ptl')

torch.jit.save(traced, tf.name)
print(pathlib.Path(tf.name).stat().st_size)
traced._save_for_lite_interpreter(tf.name)
print(pathlib.Path(tf.name).stat().st_size)
print(tf.name)

  1. Run python test_lite.py
import torch
from torch.jit.mobile import _load_for_lite_interpreter
# sample execution (requires torchvision)
from PIL import Image
from torchvision import transforms

input_image = Image.open('~/Documents/pytorch/data/dog.jpg')
preprocess = transforms.Compose([
    transforms.Resize(256),
    transforms.CenterCrop(224),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])
input_tensor = preprocess(input_image)
input_batch = input_tensor.unsqueeze(0) # create a mini-batch as expected by the model
reload_lite_model = _load_for_lite_interpreter('~/Documents/pytorch/experiment/example_debug_map_with_tensorkey.ptl')

with torch.no_grad():
    output_lite = reload_lite_model(input_batch)
# Tensor of shape 1000, with confidence scores over Imagenet's 1000 classes
print(output_lite[0])
# The output has unnormalized scores. To get probabilities, you can run a softmax on it.
print(torch.nn.functional.softmax(output_lite[0], dim=0))

  1. Compare the result with pytorch in master and pytorch built locally with this change, and see the same output.
  2. The model size was 16.1 MB and becomes 12.9 with this change.

Stack from ghstack:

Differential Revision: D25731596

[ghstack-poisoned]
cccclai added a commit that referenced this pull request Dec 29, 2020
ghstack-source-id: 9624beb
Pull Request resolved: #49916
@facebook-github-bot facebook-github-bot added cla signed oncall: jit Add this issue/PR to JIT oncall triage queue labels Dec 29, 2020
@facebook-github-bot
Copy link
Copy Markdown
Contributor

facebook-github-bot commented Dec 29, 2020

💊 CI failures summary and remediations

As of commit 01db6a0 (more details on the Dr. CI page):


  • 2/2 failures possibly* introduced in this PR
    • 1/2 non-CircleCI failure(s)

🕵️ 1 new failure recognized by patterns

The following CI failures do not appear to be due to upstream breakages:

See CircleCI build pytorch_macos_10_13_py3_test (1/1)

Step: "Test" (full log | diagnosis details | 🔁 rerun)

Jan 09 04:32:03 AssertionError: False is not true : Scalars failed to compare as equal! Comparing -6 and 0 gives a difference of 6, but the allowed difference with rtol=0 and atol=0 is only 0!
Jan 09 04:32:03 ----------------------------------------------------------------------
Jan 09 04:32:03 Traceback (most recent call last):
Jan 09 04:32:03   File "/Users/distiller/workspace/miniconda3/lib/python3.7/site-packages/torch/testing/_internal/common_distributed.py", line 280, in wrapper
Jan 09 04:32:03     self._join_processes(fn)
Jan 09 04:32:03   File "/Users/distiller/workspace/miniconda3/lib/python3.7/site-packages/torch/testing/_internal/common_distributed.py", line 397, in _join_processes
Jan 09 04:32:03     self._check_return_codes(elapsed_time)
Jan 09 04:32:03   File "/Users/distiller/workspace/miniconda3/lib/python3.7/site-packages/torch/testing/_internal/common_distributed.py", line 443, in _check_return_codes
Jan 09 04:32:03     i, first_process.exitcode, p.exitcode
Jan 09 04:32:03   File "/Users/distiller/workspace/miniconda3/lib/python3.7/site-packages/torch/testing/_internal/common_utils.py", line 1225, in assertEqual
Jan 09 04:32:03     super().assertTrue(result, msg=self._get_assert_msg(msg, debug_msg=debug_msg))
Jan 09 04:32:03 AssertionError: False is not true : Scalars failed to compare as equal! Comparing -6 and 0 gives a difference of 6, but the allowed difference with rtol=0 and atol=0 is only 0!
Jan 09 04:32:03 Expect process 3 exit code to match Process 0 exit code of 0, but got -6
Jan 09 04:32:03 
Jan 09 04:32:03 ----------------------------------------------------------------------
Jan 09 04:32:03 Ran 362 tests in 616.117s
Jan 09 04:32:03 
Jan 09 04:32:03 FAILED (failures=1, skipped=18)
Jan 09 04:32:03 
Jan 09 04:32:03 Generating XML reports...
Jan 09 04:32:03 Generated XML report: test-reports/dist-gloo/TEST-ProcessGroupDdpComparisonTestWithSpawn-20210109042147.xml
Jan 09 04:32:03 Generated XML report: test-reports/dist-gloo/TEST-ProcessGroupDdpUnderDistAutogradTestWithSpawn-20210109042147.xml

This comment was automatically generated by Dr. CI (expand for details).Follow this link to opt-out of these comments for your Pull Requests.

Please report bugs/suggestions to the (internal) Dr. CI Users group.

This comment has been revised 128 times.

cccclai added a commit that referenced this pull request Dec 29, 2020
ghstack-source-id: 59ae48b
Pull Request resolved: #49916
cccclai added a commit that referenced this pull request Dec 30, 2020
ghstack-source-id: 400b946
Pull Request resolved: #49916
@iseeyuan
Copy link
Copy Markdown
Contributor

iseeyuan commented Jan 7, 2021

  1. There exists duplicate constant tensor in mobile. I am not sure the reason yet.
  2. There exists non tensor type in mobile, which doesn't exist in jit.

@iseeyuan any thought about the reason?

I'm not aware of 1. For 2 it's likely that there's constant propagation in jit to remove the constants. Maybe the same reason for 1.

## Summary
Jit will generate constant tensor value, and it locates in the constant folder after unzip model.ptl. Bytecode generated by lite interpreter also includes constant tensor, which are almost the same with the constant tensor value from jit. This pr reuses the constant tensor from jit. The implementation is:

1. In `export_module.cpp`, store all constant tensor value from jit in an `unordered_map constants_from_jit`, where the tensor value use tensor string as a hash. When writing bytecode (`writeByteCode` function), the map `constants_from_jit` will also be passed. If the constant in bytecode exist in `constants_from_jit`, it will be replaced by a tuple in this format `('tensor_jit_index', 4)`, where 4 is the index in tensor table from jit.
2. In `import.cpp`, it will also read from `constants` folder and loads the constant tensor. Then it scans the constants tensor from bytecode. If there exists tuple in this format `('tensor_jit_index', 4)` under `constants` fields, it will fetch the constant from the tensor table from jit and replace with it accordingly.

Previous `python -m torch.utils.show_pickle bytecode.pkl`
```
...
   ('constants',
    (0.1,
     1e-05,
     None,
     2,
     1,
     False,
     0,
     True,
     False,
     1,
     2,
     3,
     torch._utils._rebuild_tensor_v2(pers.obj(('storage', torch.FloatStorage, '0', 'cpu', 90944),),
       0,
       (1, 116, 28, 28),
       (90944, 784, 28, 1),
       False,
       collections.OrderedDict()),
     torch._utils._rebuild_tensor_v2(pers.obj(('storage', torch.FloatStorage, '1', 'cpu', 90944),),
       0,
       (1, 116, 28, 28),
       (90944, 784, 28, 1),
       False,
       collections.OrderedDict()),
     torch._utils._rebuild_tensor_v2(pers.obj(('storage', torch.FloatStorage, '2', 'cpu', 90944),),
       0,
       (1, 116, 28, 28),
       (90944, 784, 28, 1),
       False,
       collections.OrderedDict()),
     0.1,
     1e-05,
     None,
     2,
     1,
     False,
...
)
...
```
Now `python -m torch.utils.show_pickle bytecode.pkl`
```
     ('constants',
    (0.1,
     1e-05,
     None,
     2,
     1,
     False,
     0,
     True,
     False,
     1,
     2,
     3,
     ('tensor_jit_index', 4),
     ('tensor_jit_index', 3),
     ('tensor_jit_index', 2),
     0.1,
     1e-05,
     None,
     2,
     1,
     False,
     0,
     24,
     True,
     58,
     ('tensor_jit_index', 0),
     3,
     ('tensor_jit_index', 1),
     -1,
     ('tensor_jit_index', 12),
     ('tensor_jit_index', 11),
     ('tensor_jit_index', 10),
     ('tensor_jit_index', 9),
     ('tensor_jit_index', 8),
     ('tensor_jit_index', 7),
     ('tensor_jit_index', 6),
     0.1,
     1e-05,
     None,
     2,
     1,
     False,
)
...
```

## Test 

1. Build pytorch locally. `MACOSX_DEPLOYMENT_TARGET=10.9 CC=clang CXX=clang++ USE_CUDA=0 DEBUG=1 MAX_JOBS=16 python setup.py develop`
2. Run `python save_lite.py`
```
import torch

# ~/Documents/pytorch/data/dog.jpg
model = torch.hub.load('pytorch/vision:v0.6.0', 'shufflenet_v2_x1_0', pretrained=True)
model.eval()

# sample execution (requires torchvision)
from PIL import Image
from torchvision import transforms
import pathlib
import tempfile
import torch.utils.mobile_optimizer

input_image = Image.open('~/Documents/pytorch/data/dog.jpg')
preprocess = transforms.Compose([
    transforms.Resize(256),
    transforms.CenterCrop(224),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])
input_tensor = preprocess(input_image)
input_batch = input_tensor.unsqueeze(0) # create a mini-batch as expected by the model

# move the input and model to GPU for speed if available
if torch.cuda.is_available():
    input_batch = input_batch.to('cuda')
    model.to('cuda')

with torch.no_grad():
    output = model(input_batch)
# Tensor of shape 1000, with confidence scores over Imagenet's 1000 classes
print(output[0])
# The output has unnormalized scores. To get probabilities, you can run a softmax on it.
print(torch.nn.functional.softmax(output[0], dim=0))

traced = torch.jit.trace(model, input_batch)
sum(p.numel() * p.element_size() for p in traced.parameters())
tf = pathlib.Path('~/Documents/pytorch/data/data/example_debug_map_with_tensorkey.ptl')

torch.jit.save(traced, tf.name)
print(pathlib.Path(tf.name).stat().st_size)
traced._save_for_lite_interpreter(tf.name)
print(pathlib.Path(tf.name).stat().st_size)
print(tf.name)

```

3. Run `python test_lite.py`
```
import torch
from torch.jit.mobile import _load_for_lite_interpreter
# sample execution (requires torchvision)
from PIL import Image
from torchvision import transforms

input_image = Image.open('~/Documents/pytorch/data/dog.jpg')
preprocess = transforms.Compose([
    transforms.Resize(256),
    transforms.CenterCrop(224),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])
input_tensor = preprocess(input_image)
input_batch = input_tensor.unsqueeze(0) # create a mini-batch as expected by the model
reload_lite_model = _load_for_lite_interpreter('~/Documents/pytorch/experiment/example_debug_map_with_tensorkey.ptl')

with torch.no_grad():
    output_lite = reload_lite_model(input_batch)
# Tensor of shape 1000, with confidence scores over Imagenet's 1000 classes
print(output_lite[0])
# The output has unnormalized scores. To get probabilities, you can run a softmax on it.
print(torch.nn.functional.softmax(output_lite[0], dim=0))

```
4. Compare the result with pytorch in master and pytorch built locally with this change, and see the same output.
5. The model size was 16.1 MB and becomes 12.9 with this change.



Differential Revision: [D25731596](https://our.internmc.facebook.com/intern/diff/D25731596)

[ghstack-poisoned]
cccclai added a commit that referenced this pull request Jan 7, 2021
ghstack-source-id: f5596a2
Pull Request resolved: #49916
const Function& func,
bool save_mobile_debug_info) {
const std::
unordered_map<at::Tensor, int, tensor_value_hash, tensor_value_equal>&
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we need to create hash by converting tensor to string? Can we not use underlying storage ptr? I dont know if we have aliasing in constant where two constant tensors might refer to the same underlying memory but we can safeguard against that by using tensor sizes on top of storage ptr.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmm I don't think I use hash with ptr. The customized hash tensor_value_hash is using string and comes from here (https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/core/Formatting.cpp#L230). Here is one example:

     torch._utils._rebuild_tensor_v2(pers.obj(('storage', torch.FloatStorage, '0', 'cpu', 90944),),
       0,
       (1, 116, 28, 28),
       (90944, 784, 28, 1),
       False,
       collections.OrderedDict())

Do I misunderstand anything?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@kimishpatel do you have any other concern for this change?

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry on PTO today. What I meant to say is: Why are we converting tensor to str and using that as hash?
What are the other options? Can we not just do std::hash<void*>(tensor.data_ptr())? or mix tensor.data_ptr() with tensor.sizes()? Stringifying tensor for hashing seems like overkill.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oops sorry. I didn't notice you are on PTO today. It's actually intended to use string as the hash. The reason is that the key of constants_from_jit is tensor, and it's from the constants from jit, which has different ptr than the constants from mobile. I used the IValue hash function at the beginning, which uses ptr as hash for tensor type under the hood, instead of this customized hash function. It didn't work so I switched to string as hash input.

Side note: IValue hash function is defined here (https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/core/ivalue.cpp#L323-L334). For tensor type, ptr is part of the hash input.

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

OK sounds good.

## Summary
Jit will generate constant tensor value, and it locates in the constant folder after unzip model.ptl. Bytecode generated by lite interpreter also includes constant tensor, which are almost the same with the constant tensor value from jit. This pr reuses the constant tensor from jit. The implementation is:

1. In `export_module.cpp`, store all constant tensor value from jit in an `unordered_map constants_from_jit`, where the tensor value use tensor string as a hash. When writing bytecode (`writeByteCode` function), the map `constants_from_jit` will also be passed. If the constant in bytecode exist in `constants_from_jit`, it will be replaced by a tuple in this format `('tensor_jit_index', 4)`, where 4 is the index in tensor table from jit.
2. In `import.cpp`, it will also read from `constants` folder and loads the constant tensor. Then it scans the constants tensor from bytecode. If there exists tuple in this format `('tensor_jit_index', 4)` under `constants` fields, it will fetch the constant from the tensor table from jit and replace with it accordingly.

Previous `python -m torch.utils.show_pickle bytecode.pkl`
```
...
   ('constants',
    (0.1,
     1e-05,
     None,
     2,
     1,
     False,
     0,
     True,
     False,
     1,
     2,
     3,
     torch._utils._rebuild_tensor_v2(pers.obj(('storage', torch.FloatStorage, '0', 'cpu', 90944),),
       0,
       (1, 116, 28, 28),
       (90944, 784, 28, 1),
       False,
       collections.OrderedDict()),
     torch._utils._rebuild_tensor_v2(pers.obj(('storage', torch.FloatStorage, '1', 'cpu', 90944),),
       0,
       (1, 116, 28, 28),
       (90944, 784, 28, 1),
       False,
       collections.OrderedDict()),
     torch._utils._rebuild_tensor_v2(pers.obj(('storage', torch.FloatStorage, '2', 'cpu', 90944),),
       0,
       (1, 116, 28, 28),
       (90944, 784, 28, 1),
       False,
       collections.OrderedDict()),
     0.1,
     1e-05,
     None,
     2,
     1,
     False,
...
)
...
```
Now `python -m torch.utils.show_pickle bytecode.pkl`
```
     ('constants',
    (0.1,
     1e-05,
     None,
     2,
     1,
     False,
     0,
     True,
     False,
     1,
     2,
     3,
     ('tensor_jit_index', 4),
     ('tensor_jit_index', 3),
     ('tensor_jit_index', 2),
     0.1,
     1e-05,
     None,
     2,
     1,
     False,
     0,
     24,
     True,
     58,
     ('tensor_jit_index', 0),
     3,
     ('tensor_jit_index', 1),
     -1,
     ('tensor_jit_index', 12),
     ('tensor_jit_index', 11),
     ('tensor_jit_index', 10),
     ('tensor_jit_index', 9),
     ('tensor_jit_index', 8),
     ('tensor_jit_index', 7),
     ('tensor_jit_index', 6),
     0.1,
     1e-05,
     None,
     2,
     1,
     False,
)
...
```

## Test 

1. Build pytorch locally. `MACOSX_DEPLOYMENT_TARGET=10.9 CC=clang CXX=clang++ USE_CUDA=0 DEBUG=1 MAX_JOBS=16 python setup.py develop`
2. Run `python save_lite.py`
```
import torch

# ~/Documents/pytorch/data/dog.jpg
model = torch.hub.load('pytorch/vision:v0.6.0', 'shufflenet_v2_x1_0', pretrained=True)
model.eval()

# sample execution (requires torchvision)
from PIL import Image
from torchvision import transforms
import pathlib
import tempfile
import torch.utils.mobile_optimizer

input_image = Image.open('~/Documents/pytorch/data/dog.jpg')
preprocess = transforms.Compose([
    transforms.Resize(256),
    transforms.CenterCrop(224),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])
input_tensor = preprocess(input_image)
input_batch = input_tensor.unsqueeze(0) # create a mini-batch as expected by the model

# move the input and model to GPU for speed if available
if torch.cuda.is_available():
    input_batch = input_batch.to('cuda')
    model.to('cuda')

with torch.no_grad():
    output = model(input_batch)
# Tensor of shape 1000, with confidence scores over Imagenet's 1000 classes
print(output[0])
# The output has unnormalized scores. To get probabilities, you can run a softmax on it.
print(torch.nn.functional.softmax(output[0], dim=0))

traced = torch.jit.trace(model, input_batch)
sum(p.numel() * p.element_size() for p in traced.parameters())
tf = pathlib.Path('~/Documents/pytorch/data/data/example_debug_map_with_tensorkey.ptl')

torch.jit.save(traced, tf.name)
print(pathlib.Path(tf.name).stat().st_size)
traced._save_for_lite_interpreter(tf.name)
print(pathlib.Path(tf.name).stat().st_size)
print(tf.name)

```

3. Run `python test_lite.py`
```
import torch
from torch.jit.mobile import _load_for_lite_interpreter
# sample execution (requires torchvision)
from PIL import Image
from torchvision import transforms

input_image = Image.open('~/Documents/pytorch/data/dog.jpg')
preprocess = transforms.Compose([
    transforms.Resize(256),
    transforms.CenterCrop(224),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])
input_tensor = preprocess(input_image)
input_batch = input_tensor.unsqueeze(0) # create a mini-batch as expected by the model
reload_lite_model = _load_for_lite_interpreter('~/Documents/pytorch/experiment/example_debug_map_with_tensorkey.ptl')

with torch.no_grad():
    output_lite = reload_lite_model(input_batch)
# Tensor of shape 1000, with confidence scores over Imagenet's 1000 classes
print(output_lite[0])
# The output has unnormalized scores. To get probabilities, you can run a softmax on it.
print(torch.nn.functional.softmax(output_lite[0], dim=0))

```
4. Compare the result with pytorch in master and pytorch built locally with this change, and see the same output.
5. The model size was 16.1 MB and becomes 12.9 with this change.



Differential Revision: [D25731596](https://our.internmc.facebook.com/intern/diff/D25731596)

[ghstack-poisoned]
## Summary
Jit will generate constant tensor value, and it locates in the constant folder after unzip model.ptl. Bytecode generated by lite interpreter also includes constant tensor, which are almost the same with the constant tensor value from jit. This pr reuses the constant tensor from jit. The implementation is:

1. In `export_module.cpp`, store all constant tensor value from jit in an `unordered_map constants_from_jit`, where the tensor value use tensor string as a hash. When writing bytecode (`writeByteCode` function), the map `constants_from_jit` will also be passed. If the constant in bytecode exist in `constants_from_jit`, it will be replaced by a tuple in this format `('tensor_jit_index', 4)`, where 4 is the index in tensor table from jit.
2. In `import.cpp`, it will also read from `constants` folder and loads the constant tensor. Then it scans the constants tensor from bytecode. If there exists tuple in this format `('tensor_jit_index', 4)` under `constants` fields, it will fetch the constant from the tensor table from jit and replace with it accordingly.

Previous `python -m torch.utils.show_pickle bytecode.pkl`
```
...
   ('constants',
    (0.1,
     1e-05,
     None,
     2,
     1,
     False,
     0,
     True,
     False,
     1,
     2,
     3,
     torch._utils._rebuild_tensor_v2(pers.obj(('storage', torch.FloatStorage, '0', 'cpu', 90944),),
       0,
       (1, 116, 28, 28),
       (90944, 784, 28, 1),
       False,
       collections.OrderedDict()),
     torch._utils._rebuild_tensor_v2(pers.obj(('storage', torch.FloatStorage, '1', 'cpu', 90944),),
       0,
       (1, 116, 28, 28),
       (90944, 784, 28, 1),
       False,
       collections.OrderedDict()),
     torch._utils._rebuild_tensor_v2(pers.obj(('storage', torch.FloatStorage, '2', 'cpu', 90944),),
       0,
       (1, 116, 28, 28),
       (90944, 784, 28, 1),
       False,
       collections.OrderedDict()),
     0.1,
     1e-05,
     None,
     2,
     1,
     False,
...
)
...
```
Now `python -m torch.utils.show_pickle bytecode.pkl`
```
     ('constants',
    (0.1,
     1e-05,
     None,
     2,
     1,
     False,
     0,
     True,
     False,
     1,
     2,
     3,
     ('tensor_jit_index', 4),
     ('tensor_jit_index', 3),
     ('tensor_jit_index', 2),
     0.1,
     1e-05,
     None,
     2,
     1,
     False,
     0,
     24,
     True,
     58,
     ('tensor_jit_index', 0),
     3,
     ('tensor_jit_index', 1),
     -1,
     ('tensor_jit_index', 12),
     ('tensor_jit_index', 11),
     ('tensor_jit_index', 10),
     ('tensor_jit_index', 9),
     ('tensor_jit_index', 8),
     ('tensor_jit_index', 7),
     ('tensor_jit_index', 6),
     0.1,
     1e-05,
     None,
     2,
     1,
     False,
)
...
```

## Test 

1. Build pytorch locally. `MACOSX_DEPLOYMENT_TARGET=10.9 CC=clang CXX=clang++ USE_CUDA=0 DEBUG=1 MAX_JOBS=16 python setup.py develop`
2. Run `python save_lite.py`
```
import torch

# ~/Documents/pytorch/data/dog.jpg
model = torch.hub.load('pytorch/vision:v0.6.0', 'shufflenet_v2_x1_0', pretrained=True)
model.eval()

# sample execution (requires torchvision)
from PIL import Image
from torchvision import transforms
import pathlib
import tempfile
import torch.utils.mobile_optimizer

input_image = Image.open('~/Documents/pytorch/data/dog.jpg')
preprocess = transforms.Compose([
    transforms.Resize(256),
    transforms.CenterCrop(224),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])
input_tensor = preprocess(input_image)
input_batch = input_tensor.unsqueeze(0) # create a mini-batch as expected by the model

# move the input and model to GPU for speed if available
if torch.cuda.is_available():
    input_batch = input_batch.to('cuda')
    model.to('cuda')

with torch.no_grad():
    output = model(input_batch)
# Tensor of shape 1000, with confidence scores over Imagenet's 1000 classes
print(output[0])
# The output has unnormalized scores. To get probabilities, you can run a softmax on it.
print(torch.nn.functional.softmax(output[0], dim=0))

traced = torch.jit.trace(model, input_batch)
sum(p.numel() * p.element_size() for p in traced.parameters())
tf = pathlib.Path('~/Documents/pytorch/data/data/example_debug_map_with_tensorkey.ptl')

torch.jit.save(traced, tf.name)
print(pathlib.Path(tf.name).stat().st_size)
traced._save_for_lite_interpreter(tf.name)
print(pathlib.Path(tf.name).stat().st_size)
print(tf.name)

```

3. Run `python test_lite.py`
```
import torch
from torch.jit.mobile import _load_for_lite_interpreter
# sample execution (requires torchvision)
from PIL import Image
from torchvision import transforms

input_image = Image.open('~/Documents/pytorch/data/dog.jpg')
preprocess = transforms.Compose([
    transforms.Resize(256),
    transforms.CenterCrop(224),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])
input_tensor = preprocess(input_image)
input_batch = input_tensor.unsqueeze(0) # create a mini-batch as expected by the model
reload_lite_model = _load_for_lite_interpreter('~/Documents/pytorch/experiment/example_debug_map_with_tensorkey.ptl')

with torch.no_grad():
    output_lite = reload_lite_model(input_batch)
# Tensor of shape 1000, with confidence scores over Imagenet's 1000 classes
print(output_lite[0])
# The output has unnormalized scores. To get probabilities, you can run a softmax on it.
print(torch.nn.functional.softmax(output_lite[0], dim=0))

```
4. Compare the result with pytorch in master and pytorch built locally with this change, and see the same output.
5. The model size was 16.1 MB and becomes 12.9 with this change.



Differential Revision: [D25731596](https://our.internmc.facebook.com/intern/diff/D25731596)

[ghstack-poisoned]
## Summary
Jit will generate constant tensor value, and it locates in the constant folder after unzip model.ptl. Bytecode generated by lite interpreter also includes constant tensor, which are almost the same with the constant tensor value from jit. This pr reuses the constant tensor from jit. The implementation is:

1. In `export_module.cpp`, store all constant tensor value from jit in an `unordered_map constants_from_jit`, where the tensor value use tensor string as a hash. When writing bytecode (`writeByteCode` function), the map `constants_from_jit` will also be passed. If the constant in bytecode exist in `constants_from_jit`, it will be replaced by a tuple in this format `('tensor_jit_index', 4)`, where 4 is the index in tensor table from jit.
2. In `import.cpp`, it will also read from `constants` folder and loads the constant tensor. Then it scans the constants tensor from bytecode. If there exists tuple in this format `('tensor_jit_index', 4)` under `constants` fields, it will fetch the constant from the tensor table from jit and replace with it accordingly.

Previous `python -m torch.utils.show_pickle bytecode.pkl`
```
...
   ('constants',
    (0.1,
     1e-05,
     None,
     2,
     1,
     False,
     0,
     True,
     False,
     1,
     2,
     3,
     torch._utils._rebuild_tensor_v2(pers.obj(('storage', torch.FloatStorage, '0', 'cpu', 90944),),
       0,
       (1, 116, 28, 28),
       (90944, 784, 28, 1),
       False,
       collections.OrderedDict()),
     torch._utils._rebuild_tensor_v2(pers.obj(('storage', torch.FloatStorage, '1', 'cpu', 90944),),
       0,
       (1, 116, 28, 28),
       (90944, 784, 28, 1),
       False,
       collections.OrderedDict()),
     torch._utils._rebuild_tensor_v2(pers.obj(('storage', torch.FloatStorage, '2', 'cpu', 90944),),
       0,
       (1, 116, 28, 28),
       (90944, 784, 28, 1),
       False,
       collections.OrderedDict()),
     0.1,
     1e-05,
     None,
     2,
     1,
     False,
...
)
...
```
Now `python -m torch.utils.show_pickle bytecode.pkl`
```
     ('constants',
    (0.1,
     1e-05,
     None,
     2,
     1,
     False,
     0,
     True,
     False,
     1,
     2,
     3,
     ('tensor_jit_index', 4),
     ('tensor_jit_index', 3),
     ('tensor_jit_index', 2),
     0.1,
     1e-05,
     None,
     2,
     1,
     False,
     0,
     24,
     True,
     58,
     ('tensor_jit_index', 0),
     3,
     ('tensor_jit_index', 1),
     -1,
     ('tensor_jit_index', 12),
     ('tensor_jit_index', 11),
     ('tensor_jit_index', 10),
     ('tensor_jit_index', 9),
     ('tensor_jit_index', 8),
     ('tensor_jit_index', 7),
     ('tensor_jit_index', 6),
     0.1,
     1e-05,
     None,
     2,
     1,
     False,
)
...
```

## Test 

1. Build pytorch locally. `MACOSX_DEPLOYMENT_TARGET=10.9 CC=clang CXX=clang++ USE_CUDA=0 DEBUG=1 MAX_JOBS=16 python setup.py develop`
2. Run `python save_lite.py`
```
import torch

# ~/Documents/pytorch/data/dog.jpg
model = torch.hub.load('pytorch/vision:v0.6.0', 'shufflenet_v2_x1_0', pretrained=True)
model.eval()

# sample execution (requires torchvision)
from PIL import Image
from torchvision import transforms
import pathlib
import tempfile
import torch.utils.mobile_optimizer

input_image = Image.open('~/Documents/pytorch/data/dog.jpg')
preprocess = transforms.Compose([
    transforms.Resize(256),
    transforms.CenterCrop(224),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])
input_tensor = preprocess(input_image)
input_batch = input_tensor.unsqueeze(0) # create a mini-batch as expected by the model

# move the input and model to GPU for speed if available
if torch.cuda.is_available():
    input_batch = input_batch.to('cuda')
    model.to('cuda')

with torch.no_grad():
    output = model(input_batch)
# Tensor of shape 1000, with confidence scores over Imagenet's 1000 classes
print(output[0])
# The output has unnormalized scores. To get probabilities, you can run a softmax on it.
print(torch.nn.functional.softmax(output[0], dim=0))

traced = torch.jit.trace(model, input_batch)
sum(p.numel() * p.element_size() for p in traced.parameters())
tf = pathlib.Path('~/Documents/pytorch/data/data/example_debug_map_with_tensorkey.ptl')

torch.jit.save(traced, tf.name)
print(pathlib.Path(tf.name).stat().st_size)
traced._save_for_lite_interpreter(tf.name)
print(pathlib.Path(tf.name).stat().st_size)
print(tf.name)

```

3. Run `python test_lite.py`
```
import torch
from torch.jit.mobile import _load_for_lite_interpreter
# sample execution (requires torchvision)
from PIL import Image
from torchvision import transforms

input_image = Image.open('~/Documents/pytorch/data/dog.jpg')
preprocess = transforms.Compose([
    transforms.Resize(256),
    transforms.CenterCrop(224),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])
input_tensor = preprocess(input_image)
input_batch = input_tensor.unsqueeze(0) # create a mini-batch as expected by the model
reload_lite_model = _load_for_lite_interpreter('~/Documents/pytorch/experiment/example_debug_map_with_tensorkey.ptl')

with torch.no_grad():
    output_lite = reload_lite_model(input_batch)
# Tensor of shape 1000, with confidence scores over Imagenet's 1000 classes
print(output_lite[0])
# The output has unnormalized scores. To get probabilities, you can run a softmax on it.
print(torch.nn.functional.softmax(output_lite[0], dim=0))

```
4. Compare the result with pytorch in master and pytorch built locally with this change, and see the same output.
5. The model size was 16.1 MB and becomes 12.9 with this change.



Differential Revision: [D25731596](https://our.internmc.facebook.com/intern/diff/D25731596)

[ghstack-poisoned]
cccclai added a commit that referenced this pull request Jan 8, 2021
ghstack-source-id: 766fa27
Pull Request resolved: #49916
## Summary
Jit will generate constant tensor value, and it locates in the constant folder after unzip model.ptl. Bytecode generated by lite interpreter also includes constant tensor, which are almost the same with the constant tensor value from jit. This pr reuses the constant tensor from jit. The implementation is:

1. In `export_module.cpp`, store all constant tensor value from jit in an `unordered_map constants_from_jit`, where the tensor value use tensor string as a hash. When writing bytecode (`writeByteCode` function), the map `constants_from_jit` will also be passed. If the constant in bytecode exist in `constants_from_jit`, it will be replaced by a tuple in this format `('tensor_jit_index', 4)`, where 4 is the index in tensor table from jit.
2. In `import.cpp`, it will also read from `constants` folder and loads the constant tensor. Then it scans the constants tensor from bytecode. If there exists tuple in this format `('tensor_jit_index', 4)` under `constants` fields, it will fetch the constant from the tensor table from jit and replace with it accordingly.

Previous `python -m torch.utils.show_pickle bytecode.pkl`
```
...
   ('constants',
    (0.1,
     1e-05,
     None,
     2,
     1,
     False,
     0,
     True,
     False,
     1,
     2,
     3,
     torch._utils._rebuild_tensor_v2(pers.obj(('storage', torch.FloatStorage, '0', 'cpu', 90944),),
       0,
       (1, 116, 28, 28),
       (90944, 784, 28, 1),
       False,
       collections.OrderedDict()),
     torch._utils._rebuild_tensor_v2(pers.obj(('storage', torch.FloatStorage, '1', 'cpu', 90944),),
       0,
       (1, 116, 28, 28),
       (90944, 784, 28, 1),
       False,
       collections.OrderedDict()),
     torch._utils._rebuild_tensor_v2(pers.obj(('storage', torch.FloatStorage, '2', 'cpu', 90944),),
       0,
       (1, 116, 28, 28),
       (90944, 784, 28, 1),
       False,
       collections.OrderedDict()),
     0.1,
     1e-05,
     None,
     2,
     1,
     False,
...
)
...
```
Now `python -m torch.utils.show_pickle bytecode.pkl`
```
     ('constants',
    (0.1,
     1e-05,
     None,
     2,
     1,
     False,
     0,
     True,
     False,
     1,
     2,
     3,
     ('tensor_jit_index', 4),
     ('tensor_jit_index', 3),
     ('tensor_jit_index', 2),
     0.1,
     1e-05,
     None,
     2,
     1,
     False,
     0,
     24,
     True,
     58,
     ('tensor_jit_index', 0),
     3,
     ('tensor_jit_index', 1),
     -1,
     ('tensor_jit_index', 12),
     ('tensor_jit_index', 11),
     ('tensor_jit_index', 10),
     ('tensor_jit_index', 9),
     ('tensor_jit_index', 8),
     ('tensor_jit_index', 7),
     ('tensor_jit_index', 6),
     0.1,
     1e-05,
     None,
     2,
     1,
     False,
)
...
```

## Test 

1. Build pytorch locally. `MACOSX_DEPLOYMENT_TARGET=10.9 CC=clang CXX=clang++ USE_CUDA=0 DEBUG=1 MAX_JOBS=16 python setup.py develop`
2. Run `python save_lite.py`
```
import torch

# ~/Documents/pytorch/data/dog.jpg
model = torch.hub.load('pytorch/vision:v0.6.0', 'shufflenet_v2_x1_0', pretrained=True)
model.eval()

# sample execution (requires torchvision)
from PIL import Image
from torchvision import transforms
import pathlib
import tempfile
import torch.utils.mobile_optimizer

input_image = Image.open('~/Documents/pytorch/data/dog.jpg')
preprocess = transforms.Compose([
    transforms.Resize(256),
    transforms.CenterCrop(224),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])
input_tensor = preprocess(input_image)
input_batch = input_tensor.unsqueeze(0) # create a mini-batch as expected by the model

# move the input and model to GPU for speed if available
if torch.cuda.is_available():
    input_batch = input_batch.to('cuda')
    model.to('cuda')

with torch.no_grad():
    output = model(input_batch)
# Tensor of shape 1000, with confidence scores over Imagenet's 1000 classes
print(output[0])
# The output has unnormalized scores. To get probabilities, you can run a softmax on it.
print(torch.nn.functional.softmax(output[0], dim=0))

traced = torch.jit.trace(model, input_batch)
sum(p.numel() * p.element_size() for p in traced.parameters())
tf = pathlib.Path('~/Documents/pytorch/data/data/example_debug_map_with_tensorkey.ptl')

torch.jit.save(traced, tf.name)
print(pathlib.Path(tf.name).stat().st_size)
traced._save_for_lite_interpreter(tf.name)
print(pathlib.Path(tf.name).stat().st_size)
print(tf.name)

```

3. Run `python test_lite.py`
```
import torch
from torch.jit.mobile import _load_for_lite_interpreter
# sample execution (requires torchvision)
from PIL import Image
from torchvision import transforms

input_image = Image.open('~/Documents/pytorch/data/dog.jpg')
preprocess = transforms.Compose([
    transforms.Resize(256),
    transforms.CenterCrop(224),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])
input_tensor = preprocess(input_image)
input_batch = input_tensor.unsqueeze(0) # create a mini-batch as expected by the model
reload_lite_model = _load_for_lite_interpreter('~/Documents/pytorch/experiment/example_debug_map_with_tensorkey.ptl')

with torch.no_grad():
    output_lite = reload_lite_model(input_batch)
# Tensor of shape 1000, with confidence scores over Imagenet's 1000 classes
print(output_lite[0])
# The output has unnormalized scores. To get probabilities, you can run a softmax on it.
print(torch.nn.functional.softmax(output_lite[0], dim=0))

```
4. Compare the result with pytorch in master and pytorch built locally with this change, and see the same output.
5. The model size was 16.1 MB and becomes 12.9 with this change.



Differential Revision: [D25731596](https://our.internmc.facebook.com/intern/diff/D25731596)

[ghstack-poisoned]
## Summary
Jit will generate constant tensor value, and it locates in the constant folder after unzip model.ptl. Bytecode generated by lite interpreter also includes constant tensor, which are almost the same with the constant tensor value from jit. This pr reuses the constant tensor from jit. The implementation is:

1. In `export_module.cpp`, store all constant tensor value from jit in an `unordered_map constants_from_jit`, where the tensor value use tensor string as a hash. When writing bytecode (`writeByteCode` function), the map `constants_from_jit` will also be passed. If the constant in bytecode exist in `constants_from_jit`, it will be replaced by a tuple in this format `('tensor_jit_index', 4)`, where 4 is the index in tensor table from jit.
2. In `import.cpp`, it will also read from `constants` folder and loads the constant tensor. Then it scans the constants tensor from bytecode. If there exists tuple in this format `('tensor_jit_index', 4)` under `constants` fields, it will fetch the constant from the tensor table from jit and replace with it accordingly.

Previous `python -m torch.utils.show_pickle bytecode.pkl`
```
...
   ('constants',
    (0.1,
     1e-05,
     None,
     2,
     1,
     False,
     0,
     True,
     False,
     1,
     2,
     3,
     torch._utils._rebuild_tensor_v2(pers.obj(('storage', torch.FloatStorage, '0', 'cpu', 90944),),
       0,
       (1, 116, 28, 28),
       (90944, 784, 28, 1),
       False,
       collections.OrderedDict()),
     torch._utils._rebuild_tensor_v2(pers.obj(('storage', torch.FloatStorage, '1', 'cpu', 90944),),
       0,
       (1, 116, 28, 28),
       (90944, 784, 28, 1),
       False,
       collections.OrderedDict()),
     torch._utils._rebuild_tensor_v2(pers.obj(('storage', torch.FloatStorage, '2', 'cpu', 90944),),
       0,
       (1, 116, 28, 28),
       (90944, 784, 28, 1),
       False,
       collections.OrderedDict()),
     0.1,
     1e-05,
     None,
     2,
     1,
     False,
...
)
...
```
Now `python -m torch.utils.show_pickle bytecode.pkl`
```
     ('constants',
    (0.1,
     1e-05,
     None,
     2,
     1,
     False,
     0,
     True,
     False,
     1,
     2,
     3,
     ('tensor_jit_index', 4),
     ('tensor_jit_index', 3),
     ('tensor_jit_index', 2),
     0.1,
     1e-05,
     None,
     2,
     1,
     False,
     0,
     24,
     True,
     58,
     ('tensor_jit_index', 0),
     3,
     ('tensor_jit_index', 1),
     -1,
     ('tensor_jit_index', 12),
     ('tensor_jit_index', 11),
     ('tensor_jit_index', 10),
     ('tensor_jit_index', 9),
     ('tensor_jit_index', 8),
     ('tensor_jit_index', 7),
     ('tensor_jit_index', 6),
     0.1,
     1e-05,
     None,
     2,
     1,
     False,
)
...
```

## Test 

1. Build pytorch locally. `MACOSX_DEPLOYMENT_TARGET=10.9 CC=clang CXX=clang++ USE_CUDA=0 DEBUG=1 MAX_JOBS=16 python setup.py develop`
2. Run `python save_lite.py`
```
import torch

# ~/Documents/pytorch/data/dog.jpg
model = torch.hub.load('pytorch/vision:v0.6.0', 'shufflenet_v2_x1_0', pretrained=True)
model.eval()

# sample execution (requires torchvision)
from PIL import Image
from torchvision import transforms
import pathlib
import tempfile
import torch.utils.mobile_optimizer

input_image = Image.open('~/Documents/pytorch/data/dog.jpg')
preprocess = transforms.Compose([
    transforms.Resize(256),
    transforms.CenterCrop(224),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])
input_tensor = preprocess(input_image)
input_batch = input_tensor.unsqueeze(0) # create a mini-batch as expected by the model

# move the input and model to GPU for speed if available
if torch.cuda.is_available():
    input_batch = input_batch.to('cuda')
    model.to('cuda')

with torch.no_grad():
    output = model(input_batch)
# Tensor of shape 1000, with confidence scores over Imagenet's 1000 classes
print(output[0])
# The output has unnormalized scores. To get probabilities, you can run a softmax on it.
print(torch.nn.functional.softmax(output[0], dim=0))

traced = torch.jit.trace(model, input_batch)
sum(p.numel() * p.element_size() for p in traced.parameters())
tf = pathlib.Path('~/Documents/pytorch/data/data/example_debug_map_with_tensorkey.ptl')

torch.jit.save(traced, tf.name)
print(pathlib.Path(tf.name).stat().st_size)
traced._save_for_lite_interpreter(tf.name)
print(pathlib.Path(tf.name).stat().st_size)
print(tf.name)

```

3. Run `python test_lite.py`
```
import torch
from torch.jit.mobile import _load_for_lite_interpreter
# sample execution (requires torchvision)
from PIL import Image
from torchvision import transforms

input_image = Image.open('~/Documents/pytorch/data/dog.jpg')
preprocess = transforms.Compose([
    transforms.Resize(256),
    transforms.CenterCrop(224),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])
input_tensor = preprocess(input_image)
input_batch = input_tensor.unsqueeze(0) # create a mini-batch as expected by the model
reload_lite_model = _load_for_lite_interpreter('~/Documents/pytorch/experiment/example_debug_map_with_tensorkey.ptl')

with torch.no_grad():
    output_lite = reload_lite_model(input_batch)
# Tensor of shape 1000, with confidence scores over Imagenet's 1000 classes
print(output_lite[0])
# The output has unnormalized scores. To get probabilities, you can run a softmax on it.
print(torch.nn.functional.softmax(output_lite[0], dim=0))

```
4. Compare the result with pytorch in master and pytorch built locally with this change, and see the same output.
5. The model size was 16.1 MB and becomes 12.9 with this change.



Differential Revision: [D25731596](https://our.internmc.facebook.com/intern/diff/D25731596)

[ghstack-poisoned]
## Summary
Jit will generate constant tensor value, and it locates in the constant folder after unzip model.ptl. Bytecode generated by lite interpreter also includes constant tensor, which are almost the same with the constant tensor value from jit. This pr reuses the constant tensor from jit. The implementation is:

1. In `export_module.cpp`, store all constant tensor value from jit in an `unordered_map constants_from_jit`, where the tensor value use tensor string as a hash. When writing bytecode (`writeByteCode` function), the map `constants_from_jit` will also be passed. If the constant in bytecode exist in `constants_from_jit`, it will be replaced by a tuple in this format `('tensor_jit_index', 4)`, where 4 is the index in tensor table from jit.
2. In `import.cpp`, it will also read from `constants` folder and loads the constant tensor. Then it scans the constants tensor from bytecode. If there exists tuple in this format `('tensor_jit_index', 4)` under `constants` fields, it will fetch the constant from the tensor table from jit and replace with it accordingly.

Previous `python -m torch.utils.show_pickle bytecode.pkl`
```
...
   ('constants',
    (0.1,
     1e-05,
     None,
     2,
     1,
     False,
     0,
     True,
     False,
     1,
     2,
     3,
     torch._utils._rebuild_tensor_v2(pers.obj(('storage', torch.FloatStorage, '0', 'cpu', 90944),),
       0,
       (1, 116, 28, 28),
       (90944, 784, 28, 1),
       False,
       collections.OrderedDict()),
     torch._utils._rebuild_tensor_v2(pers.obj(('storage', torch.FloatStorage, '1', 'cpu', 90944),),
       0,
       (1, 116, 28, 28),
       (90944, 784, 28, 1),
       False,
       collections.OrderedDict()),
     torch._utils._rebuild_tensor_v2(pers.obj(('storage', torch.FloatStorage, '2', 'cpu', 90944),),
       0,
       (1, 116, 28, 28),
       (90944, 784, 28, 1),
       False,
       collections.OrderedDict()),
     0.1,
     1e-05,
     None,
     2,
     1,
     False,
...
)
...
```
Now `python -m torch.utils.show_pickle bytecode.pkl`
```
     ('constants',
    (0.1,
     1e-05,
     None,
     2,
     1,
     False,
     0,
     True,
     False,
     1,
     2,
     3,
     ('tensor_jit_index', 4),
     ('tensor_jit_index', 3),
     ('tensor_jit_index', 2),
     0.1,
     1e-05,
     None,
     2,
     1,
     False,
     0,
     24,
     True,
     58,
     ('tensor_jit_index', 0),
     3,
     ('tensor_jit_index', 1),
     -1,
     ('tensor_jit_index', 12),
     ('tensor_jit_index', 11),
     ('tensor_jit_index', 10),
     ('tensor_jit_index', 9),
     ('tensor_jit_index', 8),
     ('tensor_jit_index', 7),
     ('tensor_jit_index', 6),
     0.1,
     1e-05,
     None,
     2,
     1,
     False,
)
...
```

## Test 

1. Build pytorch locally. `MACOSX_DEPLOYMENT_TARGET=10.9 CC=clang CXX=clang++ USE_CUDA=0 DEBUG=1 MAX_JOBS=16 python setup.py develop`
2. Run `python save_lite.py`
```
import torch

# ~/Documents/pytorch/data/dog.jpg
model = torch.hub.load('pytorch/vision:v0.6.0', 'shufflenet_v2_x1_0', pretrained=True)
model.eval()

# sample execution (requires torchvision)
from PIL import Image
from torchvision import transforms
import pathlib
import tempfile
import torch.utils.mobile_optimizer

input_image = Image.open('~/Documents/pytorch/data/dog.jpg')
preprocess = transforms.Compose([
    transforms.Resize(256),
    transforms.CenterCrop(224),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])
input_tensor = preprocess(input_image)
input_batch = input_tensor.unsqueeze(0) # create a mini-batch as expected by the model

# move the input and model to GPU for speed if available
if torch.cuda.is_available():
    input_batch = input_batch.to('cuda')
    model.to('cuda')

with torch.no_grad():
    output = model(input_batch)
# Tensor of shape 1000, with confidence scores over Imagenet's 1000 classes
print(output[0])
# The output has unnormalized scores. To get probabilities, you can run a softmax on it.
print(torch.nn.functional.softmax(output[0], dim=0))

traced = torch.jit.trace(model, input_batch)
sum(p.numel() * p.element_size() for p in traced.parameters())
tf = pathlib.Path('~/Documents/pytorch/data/data/example_debug_map_with_tensorkey.ptl')

torch.jit.save(traced, tf.name)
print(pathlib.Path(tf.name).stat().st_size)
traced._save_for_lite_interpreter(tf.name)
print(pathlib.Path(tf.name).stat().st_size)
print(tf.name)

```

3. Run `python test_lite.py`
```
import torch
from torch.jit.mobile import _load_for_lite_interpreter
# sample execution (requires torchvision)
from PIL import Image
from torchvision import transforms

input_image = Image.open('~/Documents/pytorch/data/dog.jpg')
preprocess = transforms.Compose([
    transforms.Resize(256),
    transforms.CenterCrop(224),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])
input_tensor = preprocess(input_image)
input_batch = input_tensor.unsqueeze(0) # create a mini-batch as expected by the model
reload_lite_model = _load_for_lite_interpreter('~/Documents/pytorch/experiment/example_debug_map_with_tensorkey.ptl')

with torch.no_grad():
    output_lite = reload_lite_model(input_batch)
# Tensor of shape 1000, with confidence scores over Imagenet's 1000 classes
print(output_lite[0])
# The output has unnormalized scores. To get probabilities, you can run a softmax on it.
print(torch.nn.functional.softmax(output_lite[0], dim=0))

```
4. Compare the result with pytorch in master and pytorch built locally with this change, and see the same output.
5. The model size was 16.1 MB and becomes 12.9 with this change.



Differential Revision: [D25731596](https://our.internmc.facebook.com/intern/diff/D25731596)

[ghstack-poisoned]
## Summary
Jit will generate constant tensor value, and it locates in the constant folder after unzip model.ptl. Bytecode generated by lite interpreter also includes constant tensor, which are almost the same with the constant tensor value from jit. This pr reuses the constant tensor from jit. The implementation is:

1. In `export_module.cpp`, store all constant tensor value from jit in an `unordered_map constants_from_jit`, where the tensor value use tensor string as a hash. When writing bytecode (`writeByteCode` function), the map `constants_from_jit` will also be passed. If the constant in bytecode exist in `constants_from_jit`, it will be replaced by a tuple in this format `('tensor_jit_index', 4)`, where 4 is the index in tensor table from jit.
2. In `import.cpp`, it will also read from `constants` folder and loads the constant tensor. Then it scans the constants tensor from bytecode. If there exists tuple in this format `('tensor_jit_index', 4)` under `constants` fields, it will fetch the constant from the tensor table from jit and replace with it accordingly.

Previous `python -m torch.utils.show_pickle bytecode.pkl`
```
...
   ('constants',
    (0.1,
     1e-05,
     None,
     2,
     1,
     False,
     0,
     True,
     False,
     1,
     2,
     3,
     torch._utils._rebuild_tensor_v2(pers.obj(('storage', torch.FloatStorage, '0', 'cpu', 90944),),
       0,
       (1, 116, 28, 28),
       (90944, 784, 28, 1),
       False,
       collections.OrderedDict()),
     torch._utils._rebuild_tensor_v2(pers.obj(('storage', torch.FloatStorage, '1', 'cpu', 90944),),
       0,
       (1, 116, 28, 28),
       (90944, 784, 28, 1),
       False,
       collections.OrderedDict()),
     torch._utils._rebuild_tensor_v2(pers.obj(('storage', torch.FloatStorage, '2', 'cpu', 90944),),
       0,
       (1, 116, 28, 28),
       (90944, 784, 28, 1),
       False,
       collections.OrderedDict()),
     0.1,
     1e-05,
     None,
     2,
     1,
     False,
...
)
...
```
Now `python -m torch.utils.show_pickle bytecode.pkl`
```
     ('constants',
    (0.1,
     1e-05,
     None,
     2,
     1,
     False,
     0,
     True,
     False,
     1,
     2,
     3,
     ('tensor_jit_index', 4),
     ('tensor_jit_index', 3),
     ('tensor_jit_index', 2),
     0.1,
     1e-05,
     None,
     2,
     1,
     False,
     0,
     24,
     True,
     58,
     ('tensor_jit_index', 0),
     3,
     ('tensor_jit_index', 1),
     -1,
     ('tensor_jit_index', 12),
     ('tensor_jit_index', 11),
     ('tensor_jit_index', 10),
     ('tensor_jit_index', 9),
     ('tensor_jit_index', 8),
     ('tensor_jit_index', 7),
     ('tensor_jit_index', 6),
     0.1,
     1e-05,
     None,
     2,
     1,
     False,
)
...
```

## Test 

1. Build pytorch locally. `MACOSX_DEPLOYMENT_TARGET=10.9 CC=clang CXX=clang++ USE_CUDA=0 DEBUG=1 MAX_JOBS=16 python setup.py develop`
2. Run `python save_lite.py`
```
import torch

# ~/Documents/pytorch/data/dog.jpg
model = torch.hub.load('pytorch/vision:v0.6.0', 'shufflenet_v2_x1_0', pretrained=True)
model.eval()

# sample execution (requires torchvision)
from PIL import Image
from torchvision import transforms
import pathlib
import tempfile
import torch.utils.mobile_optimizer

input_image = Image.open('~/Documents/pytorch/data/dog.jpg')
preprocess = transforms.Compose([
    transforms.Resize(256),
    transforms.CenterCrop(224),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])
input_tensor = preprocess(input_image)
input_batch = input_tensor.unsqueeze(0) # create a mini-batch as expected by the model

# move the input and model to GPU for speed if available
if torch.cuda.is_available():
    input_batch = input_batch.to('cuda')
    model.to('cuda')

with torch.no_grad():
    output = model(input_batch)
# Tensor of shape 1000, with confidence scores over Imagenet's 1000 classes
print(output[0])
# The output has unnormalized scores. To get probabilities, you can run a softmax on it.
print(torch.nn.functional.softmax(output[0], dim=0))

traced = torch.jit.trace(model, input_batch)
sum(p.numel() * p.element_size() for p in traced.parameters())
tf = pathlib.Path('~/Documents/pytorch/data/data/example_debug_map_with_tensorkey.ptl')

torch.jit.save(traced, tf.name)
print(pathlib.Path(tf.name).stat().st_size)
traced._save_for_lite_interpreter(tf.name)
print(pathlib.Path(tf.name).stat().st_size)
print(tf.name)

```

3. Run `python test_lite.py`
```
import torch
from torch.jit.mobile import _load_for_lite_interpreter
# sample execution (requires torchvision)
from PIL import Image
from torchvision import transforms

input_image = Image.open('~/Documents/pytorch/data/dog.jpg')
preprocess = transforms.Compose([
    transforms.Resize(256),
    transforms.CenterCrop(224),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])
input_tensor = preprocess(input_image)
input_batch = input_tensor.unsqueeze(0) # create a mini-batch as expected by the model
reload_lite_model = _load_for_lite_interpreter('~/Documents/pytorch/experiment/example_debug_map_with_tensorkey.ptl')

with torch.no_grad():
    output_lite = reload_lite_model(input_batch)
# Tensor of shape 1000, with confidence scores over Imagenet's 1000 classes
print(output_lite[0])
# The output has unnormalized scores. To get probabilities, you can run a softmax on it.
print(torch.nn.functional.softmax(output_lite[0], dim=0))

```
4. Compare the result with pytorch in master and pytorch built locally with this change, and see the same output.
5. The model size was 16.1 MB and becomes 12.9 with this change.



Differential Revision: [D25731596](https://our.internmc.facebook.com/intern/diff/D25731596)

[ghstack-poisoned]
cccclai added a commit that referenced this pull request Jan 9, 2021
ghstack-source-id: 463e9d9
Pull Request resolved: #49916
@facebook-github-bot
Copy link
Copy Markdown
Contributor

@cccclai merged this pull request in d4c1684.

@facebook-github-bot facebook-github-bot deleted the gh/cccclai/13/head branch January 12, 2021 15:14
@facebook-github-bot
Copy link
Copy Markdown
Contributor

This pull request has been reverted by e05882d.

cccclai added a commit that referenced this pull request Jan 22, 2021
## Summary
This change is originally introduce in #49916, however, it ran into issue for one model and reverted in #50521. This pr re-enable reusing constant from jit with the change to fix. The issue is in the last else statement. It shouldn't be an else statement, but to catch all cases when any if statement fails. The correct code should remove the else, and add a continue when the current `const_item` meets all criterias and  push to `updated_constant_vals`. It's a bit tricky to debug this problem, since the code is running in different threads and sometimes breaks at unrelated place. 
Before:
```
    std::vector<IValue> updated_constant_vals;	
    for (const auto& const_item : consts_list) {	
      if (const_item.isTuple()) {	
        const auto& tensor_jit = const_item.toTuple()->elements();	
        if (tensor_jit.size() > 1) {	
          const auto& tensor_jit_index_key = tensor_jit[0];	
          const auto& tensor_jit_index = tensor_jit[1];	
          if (tensor_jit_index_key.isString() &&	
              tensor_jit_index_key.toString().get()->string() ==	
                  mobile::kTensorJitIndex) {	
            updated_constant_vals.push_back(	
                constant_vals_from_jit[tensor_jit_index.toInt()]);	
          }	
        }	
      } else {	
        updated_constant_vals.push_back(const_item);	
      }	
    }
```
Current:
```
    for (const auto& const_item : consts_list) {
      if (const_item.isTuple()) {
        const auto& tensor_jit = const_item.toTuple()->elements();
        if (tensor_jit.size() > 1) {
          const auto& tensor_jit_index_key = tensor_jit[0];
          const auto& tensor_jit_index = tensor_jit[1];
          if (tensor_jit_index_key.isString() &&
              tensor_jit_index_key.toString().get()->string() ==
                  mobile::kTensorJitIndex) {
            updated_constant_vals.push_back(
                constant_vals_from_jit[tensor_jit_index.toInt()]);
            continue;
          }
        }
      }
      updated_constant_vals.push_back(const_item);
    }
```
## Test plan 
In addition to run the test in #49916, 
1. Run PyTorchPlayGroundMac
![image](https://user-images.githubusercontent.com/16430979/105466340-ee478f80-5c48-11eb-9bd6-d13fb639901a.png)
2.  `buck test pp-macos`

### Copy the summary and test plan in #49916 here
## Summary
Jit will generate constant tensor value, and it locates in the constant folder after unzip model.ptl. Bytecode generated by lite interpreter also includes constant tensor, which are almost the same with the constant tensor value from jit. This pr reuses the constant tensor from jit. The implementation is:

1. In `export_module.cpp`, store all constant tensor value from jit in an `unordered_map constants_from_jit`, where the tensor value use tensor string as a hash. When writing bytecode (`writeByteCode` function), the map `constants_from_jit` will also be passed. If the constant in bytecode exist in `constants_from_jit`, it will be replaced by a tuple in this format `('tensor_jit_index', 4)`, where 4 is the index in tensor table from jit.
2. In `import.cpp`, it will also read from `constants` folder and loads the constant tensor. Then it scans the constants tensor from bytecode. If there exists tuple in this format `('tensor_jit_index', 4)` under `constants` fields, it will fetch the constant from the tensor table from jit and replace with it accordingly.

Previous `python -m torch.utils.show_pickle bytecode.pkl`
```
...
   ('constants',
    (0.1,
     1e-05,
     None,
     2,
     1,
     False,
     0,
     True,
     False,
     1,
     2,
     3,
     torch._utils._rebuild_tensor_v2(pers.obj(('storage', torch.FloatStorage, '0', 'cpu', 90944),),
       0,
       (1, 116, 28, 28),
       (90944, 784, 28, 1),
       False,
       collections.OrderedDict()),
     torch._utils._rebuild_tensor_v2(pers.obj(('storage', torch.FloatStorage, '1', 'cpu', 90944),),
       0,
       (1, 116, 28, 28),
       (90944, 784, 28, 1),
       False,
       collections.OrderedDict()),
     torch._utils._rebuild_tensor_v2(pers.obj(('storage', torch.FloatStorage, '2', 'cpu', 90944),),
       0,
       (1, 116, 28, 28),
       (90944, 784, 28, 1),
       False,
       collections.OrderedDict()),
     0.1,
     1e-05,
     None,
     2,
     1,
     False,
...
)
...
```
Now `python -m torch.utils.show_pickle bytecode.pkl`
```
     ('constants',
    (0.1,
     1e-05,
     None,
     2,
     1,
     False,
     0,
     True,
     False,
     1,
     2,
     3,
     ('tensor_jit_index', 4),
     ('tensor_jit_index', 3),
     ('tensor_jit_index', 2),
     0.1,
     1e-05,
     None,
     2,
     1,
     False,
     0,
     24,
     True,
     58,
     ('tensor_jit_index', 0),
     3,
     ('tensor_jit_index', 1),
     -1,
     ('tensor_jit_index', 12),
     ('tensor_jit_index', 11),
     ('tensor_jit_index', 10),
     ('tensor_jit_index', 9),
     ('tensor_jit_index', 8),
     ('tensor_jit_index', 7),
     ('tensor_jit_index', 6),
     0.1,
     1e-05,
     None,
     2,
     1,
     False,
)
...
```

## Test 

1. Build pytorch locally. `MACOSX_DEPLOYMENT_TARGET=10.9 CC=clang CXX=clang++ USE_CUDA=0 DEBUG=1 MAX_JOBS=16 python setup.py develop`
2. Run `python save_lite.py`
```
import torch

# ~/Documents/pytorch/data/dog.jpg
model = torch.hub.load('pytorch/vision:v0.6.0', 'shufflenet_v2_x1_0', pretrained=True)
model.eval()

# sample execution (requires torchvision)
from PIL import Image
from torchvision import transforms
import pathlib
import tempfile
import torch.utils.mobile_optimizer

input_image = Image.open('~/Documents/pytorch/data/dog.jpg')
preprocess = transforms.Compose([
    transforms.Resize(256),
    transforms.CenterCrop(224),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])
input_tensor = preprocess(input_image)
input_batch = input_tensor.unsqueeze(0) # create a mini-batch as expected by the model

# move the input and model to GPU for speed if available
if torch.cuda.is_available():
    input_batch = input_batch.to('cuda')
    model.to('cuda')

with torch.no_grad():
    output = model(input_batch)
# Tensor of shape 1000, with confidence scores over Imagenet's 1000 classes
print(output[0])
# The output has unnormalized scores. To get probabilities, you can run a softmax on it.
print(torch.nn.functional.softmax(output[0], dim=0))

traced = torch.jit.trace(model, input_batch)
sum(p.numel() * p.element_size() for p in traced.parameters())
tf = pathlib.Path('~/Documents/pytorch/data/data/example_debug_map_with_tensorkey.ptl')

torch.jit.save(traced, tf.name)
print(pathlib.Path(tf.name).stat().st_size)
traced._save_for_lite_interpreter(tf.name)
print(pathlib.Path(tf.name).stat().st_size)
print(tf.name)

```

3. Run `python test_lite.py`
```
import torch
from torch.jit.mobile import _load_for_lite_interpreter
# sample execution (requires torchvision)
from PIL import Image
from torchvision import transforms

input_image = Image.open('~/Documents/pytorch/data/dog.jpg')
preprocess = transforms.Compose([
    transforms.Resize(256),
    transforms.CenterCrop(224),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])
input_tensor = preprocess(input_image)
input_batch = input_tensor.unsqueeze(0) # create a mini-batch as expected by the model
reload_lite_model = _load_for_lite_interpreter('~/Documents/pytorch/experiment/example_debug_map_with_tensorkey.ptl')

with torch.no_grad():
    output_lite = reload_lite_model(input_batch)
# Tensor of shape 1000, with confidence scores over Imagenet's 1000 classes
print(output_lite[0])
# The output has unnormalized scores. To get probabilities, you can run a softmax on it.
print(torch.nn.functional.softmax(output_lite[0], dim=0))

```
4. Compare the result with pytorch in master and pytorch built locally with this change, and see the same output.
5. The model size was 16.1 MB and becomes 12.9 with this change.



Differential Revision: [D25982807](https://our.internmc.facebook.com/intern/diff/D25982807)

[ghstack-poisoned]
laurentdupin pushed a commit to laurentdupin/pytorch that referenced this pull request Apr 24, 2026
Summary: Pull Request resolved: pytorch#49916

Test Plan:
1. Build pytorch locally. `MACOSX_DEPLOYMENT_TARGET=10.9 CC=clang CXX=clang++ USE_CUDA=0 DEBUG=1 MAX_JOBS=16 python setup.py develop`
2. Run `python save_lite.py`
```
import torch

# ~/Documents/pytorch/data/dog.jpg
model = torch.hub.load('pytorch/vision:v0.6.0', 'shufflenet_v2_x1_0', pretrained=True)
model.eval()

# sample execution (requires torchvision)
from PIL import Image
from torchvision import transforms
import pathlib
import tempfile
import torch.utils.mobile_optimizer

input_image = Image.open('~/Documents/pytorch/data/dog.jpg')
preprocess = transforms.Compose([
    transforms.Resize(256),
    transforms.CenterCrop(224),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])
input_tensor = preprocess(input_image)
input_batch = input_tensor.unsqueeze(0) # create a mini-batch as expected by the model

# move the input and model to GPU for speed if available
if torch.cuda.is_available():
    input_batch = input_batch.to('cuda')
    model.to('cuda')

with torch.no_grad():
    output = model(input_batch)
# Tensor of shape 1000, with confidence scores over Imagenet's 1000 classes
print(output[0])
# The output has unnormalized scores. To get probabilities, you can run a softmax on it.
print(torch.nn.functional.softmax(output[0], dim=0))

traced = torch.jit.trace(model, input_batch)
sum(p.numel() * p.element_size() for p in traced.parameters())
tf = pathlib.Path('~/Documents/pytorch/data/data/example_debug_map_with_tensorkey.ptl')

torch.jit.save(traced, tf.name)
print(pathlib.Path(tf.name).stat().st_size)
traced._save_for_lite_interpreter(tf.name)
print(pathlib.Path(tf.name).stat().st_size)
print(tf.name)

```

3. Run `python test_lite.py`
```
import torch
from torch.jit.mobile import _load_for_lite_interpreter
# sample execution (requires torchvision)
from PIL import Image
from torchvision import transforms

input_image = Image.open('~/Documents/pytorch/data/dog.jpg')
preprocess = transforms.Compose([
    transforms.Resize(256),
    transforms.CenterCrop(224),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])
input_tensor = preprocess(input_image)
input_batch = input_tensor.unsqueeze(0) # create a mini-batch as expected by the model
reload_lite_model = _load_for_lite_interpreter('~/Documents/pytorch/experiment/example_debug_map_with_tensorkey.ptl')

with torch.no_grad():
    output_lite = reload_lite_model(input_batch)
# Tensor of shape 1000, with confidence scores over Imagenet's 1000 classes
print(output_lite[0])
# The output has unnormalized scores. To get probabilities, you can run a softmax on it.
print(torch.nn.functional.softmax(output_lite[0], dim=0))

```
4. Compare the result with pytorch in master and pytorch built locally with this change, and see the same output.
5. The model size was 16.1 MB and becomes 12.9 with this change.

Imported from OSS

Reviewed By: kimishpatel, iseeyuan

Differential Revision: D25731596

Pulled By: cccclai

fbshipit-source-id: 9731ec1e0c1d5dc76cfa374d2ad3d5bb10990cf0
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

cla signed Merged oncall: jit Add this issue/PR to JIT oncall triage queue Reverted

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants