Support Exporting RPN to ONNX#1329
Conversation
fmassa
left a comment
There was a problem hiding this comment.
I made a few more comments, let me know what you think.
Also, there is a merge conflict it seems.
| @@ -298,8 +311,17 @@ def _get_top_n_idx(self, objectness, num_anchors_per_level): | |||
| r = [] | |||
| offset = 0 | |||
There was a problem hiding this comment.
If we make _get_top_n_idx a torchscript function (moving it outside of the body of the class), would it work with ONNX? The function can be entirely made torchscript ready (only need to add the annotations).
There was a problem hiding this comment.
You are suggesting to keep the old implementation of _get_top_n_idx but as a torchscript function, correct?
The use of the native python function min() will still be an issue here, and will not be supported by onnx.
This suggested implementation replaces the native python min() with torch.min which is supported by onnx.
There was a problem hiding this comment.
Yes, it's true that it might require some further changes to the way ONNX handles things.
Here is my thinking (maybe it doesn't make sense, please let me know):
AFAIK ONNX only supports tensors, and not numbers.
During conversion from TorchIR to ONNX IR, we perform an upcast from all numbers to tensors. We also replace prim::min by aten::min, so that the same code could be used by both ONNX and TorchScript?
Does this seem feasible?
Codecov Report
@@ Coverage Diff @@
## master #1329 +/- ##
=========================================
+ Coverage 63.99% 64% +<.01%
=========================================
Files 80 78 -2
Lines 6308 6200 -108
Branches 967 951 -16
=========================================
- Hits 4037 3968 -69
+ Misses 1990 1948 -42
- Partials 281 284 +3
Continue to review full report at Codecov.
|
fmassa
left a comment
There was a problem hiding this comment.
Sorry for the delay in replying.
I have a few more comments, let me know what you think
| @@ -298,8 +311,17 @@ def _get_top_n_idx(self, objectness, num_anchors_per_level): | |||
| r = [] | |||
| offset = 0 | |||
There was a problem hiding this comment.
Yes, it's true that it might require some further changes to the way ONNX handles things.
Here is my thinking (maybe it doesn't make sense, please let me know):
AFAIK ONNX only supports tensors, and not numbers.
During conversion from TorchIR to ONNX IR, we perform an upcast from all numbers to tensors. We also replace prim::min by aten::min, so that the same code could be used by both ONNX and TorchScript?
Does this seem feasible?
|
|
||
| # This is not in nn | ||
| class FrozenBatchNorm2d(torch.jit.ScriptModule): | ||
| class FrozenBatchNorm2d(torch.nn.Module): |
There was a problem hiding this comment.
There was a problem hiding this comment.
ScriptModule is sort of a deprecated api at this point; replacing it with torch.nn.Module and removing script_method is good.
ScriptModule does works with tracing; any function that has a script annotation gets compiled and inlined into the trace.
There was a problem hiding this comment.
I agree, but given that pytorch/pytorch#22328 is not going to be merged anymore, there is no way recommended way fusing single modules with torchscript anymore?
…idar/onnx_export_rpn
|
@lara-hdr I've tried to enable the tests on this PR in fmassa@d2a8d4f , but I get an error (using a PyTorch nightly from yesterday): Do you have an idea what the problem might be? |
fmassa
left a comment
There was a problem hiding this comment.
One last request of factoring out some part of the code into a function, and then making the tests running, and then we should be good to merge, thanks a lot!
|
@fmassa does the latest version look good? |
fmassa
left a comment
There was a problem hiding this comment.
LGTM, thanks a lot Lara!
I'm cutting the release branch of torchvision today, so I'll be merging the PR tomorrow once the 0.4.1 branch is cut.
|
Thanks a lot Lara! |
This PR modifies the code in RPN to be able to trace it and export it to ONNX.
@fmassa, this need #1325 to be merged in order to be complete (I used run_model() with lists and torchvision._is_tracing() that are added/modified in the referenced PR), but submitting it for comments and review in the meantime.