Skip to content

Support Exporting RPN to ONNX#1329

Merged
fmassa merged 11 commits intopytorch:masterfrom
lara-hdr:lahaidar/onnx_export_rpn
Oct 15, 2019
Merged

Support Exporting RPN to ONNX#1329
fmassa merged 11 commits intopytorch:masterfrom
lara-hdr:lahaidar/onnx_export_rpn

Conversation

@lara-hdr
Copy link
Copy Markdown
Contributor

This PR modifies the code in RPN to be able to trace it and export it to ONNX.

@fmassa, this need #1325 to be merged in order to be complete (I used run_model() with lists and torchvision._is_tracing() that are added/modified in the referenced PR), but submitting it for comments and review in the meantime.

Copy link
Copy Markdown
Member

@fmassa fmassa left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I made a few more comments, let me know what you think.

Also, there is a merge conflict it seems.

@@ -298,8 +311,17 @@ def _get_top_n_idx(self, objectness, num_anchors_per_level):
r = []
offset = 0
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If we make _get_top_n_idx a torchscript function (moving it outside of the body of the class), would it work with ONNX? The function can be entirely made torchscript ready (only need to add the annotations).

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You are suggesting to keep the old implementation of _get_top_n_idx but as a torchscript function, correct?

The use of the native python function min() will still be an issue here, and will not be supported by onnx.
This suggested implementation replaces the native python min() with torch.min which is supported by onnx.

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, it's true that it might require some further changes to the way ONNX handles things.

Here is my thinking (maybe it doesn't make sense, please let me know):

AFAIK ONNX only supports tensors, and not numbers.
During conversion from TorchIR to ONNX IR, we perform an upcast from all numbers to tensors. We also replace prim::min by aten::min, so that the same code could be used by both ONNX and TorchScript?

Does this seem feasible?

@codecov-io
Copy link
Copy Markdown

codecov-io commented Sep 25, 2019

Codecov Report

Merging #1329 into master will increase coverage by <.01%.
The diff coverage is 56.52%.

Impacted file tree graph

@@            Coverage Diff            @@
##           master   #1329      +/-   ##
=========================================
+ Coverage   63.99%     64%   +<.01%     
=========================================
  Files          80      78       -2     
  Lines        6308    6200     -108     
  Branches      967     951      -16     
=========================================
- Hits         4037    3968      -69     
+ Misses       1990    1948      -42     
- Partials      281     284       +3
Impacted Files Coverage Δ
torchvision/models/detection/_utils.py 42.42% <100%> (+0.43%) ⬆️
torchvision/ops/misc.py 49.23% <100%> (+10.76%) ⬆️
torchvision/models/detection/rpn.py 75.79% <33.33%> (-3.53%) ⬇️
torchvision/ops/poolers.py 94.02% <0%> (-2.36%) ⬇️
torchvision/datasets/kinetics.py 34.78% <0%> (-1.22%) ⬇️
torchvision/transforms/transforms.py 80.46% <0%> (-0.59%) ⬇️
torchvision/datasets/ucf101.py 25% <0%> (-0.54%) ⬇️
torchvision/datasets/hmdb51.py 27.65% <0%> (-0.35%) ⬇️
torchvision/transforms/functional.py 70.77% <0%> (-0.29%) ⬇️
torchvision/models/video/resnet.py 79.52% <0%> (-0.16%) ⬇️
... and 5 more

Continue to review full report at Codecov.

Legend - Click here to learn more
Δ = absolute <relative> (impact), ø = not affected, ? = missing data
Powered by Codecov. Last update 76702a0...ce2b804. Read the comment docs.

Copy link
Copy Markdown
Member

@fmassa fmassa left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry for the delay in replying.

I have a few more comments, let me know what you think

@@ -298,8 +311,17 @@ def _get_top_n_idx(self, objectness, num_anchors_per_level):
r = []
offset = 0
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, it's true that it might require some further changes to the way ONNX handles things.

Here is my thinking (maybe it doesn't make sense, please let me know):

AFAIK ONNX only supports tensors, and not numbers.
During conversion from TorchIR to ONNX IR, we perform an upcast from all numbers to tensors. We also replace prim::min by aten::min, so that the same code could be used by both ONNX and TorchScript?

Does this seem feasible?


# This is not in nn
class FrozenBatchNorm2d(torch.jit.ScriptModule):
class FrozenBatchNorm2d(torch.nn.Module):
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@lara-hdr This is fine (and I proposed this initially).

@eellison is this expected that scriptmodules do not work with tracing? You mentioned that we had a @torch.jit.script_method annotation in forward, but that is I believe the way to go if we have a ScriptModule?

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ScriptModule is sort of a deprecated api at this point; replacing it with torch.nn.Module and removing script_method is good.

ScriptModule does works with tracing; any function that has a script annotation gets compiled and inlined into the trace.

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I agree, but given that pytorch/pytorch#22328 is not going to be merged anymore, there is no way recommended way fusing single modules with torchscript anymore?

@fmassa
Copy link
Copy Markdown
Member

fmassa commented Oct 4, 2019

@lara-hdr I've tried to enable the tests on this PR in fmassa@d2a8d4f , but I get an error (using a PyTorch nightly from yesterday):

======================================================================
ERROR: test_rpn (__main__.ONNXExporterTester)
----------------------------------------------------------------------
Traceback (most recent call last):
  File "test/test_onnx.py", line 165, in test_rpn
    self.run_model(model, [(images, features), (test_images, test_features)])
  File "test/test_onnx.py", line 30, in run_model
    torch.onnx.export(model, inputs_list[0], onnx_io, do_constant_folding=True, opset_version=10)
  File "/Users/fmassa/anaconda3/lib/python3.6/site-packages/torch/onnx/__init__.py", line 143, in export
    strip_doc_string, dynamic_axes, keep_initializers_as_inputs)
  File "/Users/fmassa/anaconda3/lib/python3.6/site-packages/torch/onnx/utils.py", line 66, in export
    dynamic_axes=dynamic_axes, keep_initializers_as_inputs=keep_initializers_as_inputs)
  File "/Users/fmassa/anaconda3/lib/python3.6/site-packages/torch/onnx/utils.py", line 382, in _export
    fixed_batch_size=fixed_batch_size)
  File "/Users/fmassa/anaconda3/lib/python3.6/site-packages/torch/onnx/utils.py", line 249, in _model_to_graph
    graph, torch_out = _trace_and_get_graph_from_model(model, args, training)
  File "/Users/fmassa/anaconda3/lib/python3.6/site-packages/torch/onnx/utils.py", line 206, in _trace_and_get_graph_from_model
    trace, torch_out, inputs_states = torch.jit.get_trace_graph(model, args, _force_outplace=True, _return_inputs_states=True)
  File "/Users/fmassa/anaconda3/lib/python3.6/site-packages/torch/jit/__init__.py", line 275, in get_trace_graph
    return LegacyTracedModule(f, _force_outplace, return_inputs, _return_inputs_states)(*args, **kwargs)
  File "/Users/fmassa/anaconda3/lib/python3.6/site-packages/torch/nn/modules/module.py", line 545, in __call__
    result = self.forward(*input, **kwargs)
  File "/Users/fmassa/anaconda3/lib/python3.6/site-packages/torch/jit/__init__.py", line 329, in forward
    in_vars, in_desc = _flatten(args)
RuntimeError: Only tuples, lists and Variables supported as JIT inputs/outputs. Dictionaries and strings are also accepted but their usage is not recommended. But got unsupported type int

Do you have an idea what the problem might be?

Copy link
Copy Markdown
Member

@fmassa fmassa left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

One last request of factoring out some part of the code into a function, and then making the tests running, and then we should be good to merge, thanks a lot!

@lara-hdr
Copy link
Copy Markdown
Contributor Author

lara-hdr commented Oct 8, 2019

@fmassa does the latest version look good?
thanks :)

Copy link
Copy Markdown
Member

@fmassa fmassa left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, thanks a lot Lara!

I'm cutting the release branch of torchvision today, so I'll be merging the PR tomorrow once the 0.4.1 branch is cut.

@lara-hdr
Copy link
Copy Markdown
Contributor Author

@fmassa once this is merged I'll update #1401.
Thanks!

@fmassa fmassa merged commit 1d6145d into pytorch:master Oct 15, 2019
@fmassa
Copy link
Copy Markdown
Member

fmassa commented Oct 15, 2019

Thanks a lot Lara!

@bhack bhack mentioned this pull request Jul 2, 2020
6 tasks
@pmeier pmeier mentioned this pull request Nov 3, 2022
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants