Skip to content

Lahaidar/export faster rcnn#1401

Merged
fmassa merged 9 commits intopytorch:masterfrom
lara-hdr:lahaidar/export_faster_rcnn
Oct 18, 2019
Merged

Lahaidar/export faster rcnn#1401
fmassa merged 9 commits intopytorch:masterfrom
lara-hdr:lahaidar/export_faster_rcnn

Conversation

@lara-hdr
Copy link
Copy Markdown
Contributor

@lara-hdr lara-hdr commented Oct 1, 2019

In addition to #1329 and #1329, these are the last changes to be able to export faster rcnn to ONNX.
For now we will only be able to export with a batch of 1 (split in ONNX is static, but we are working on supporting dynamic cases that will allow us to export postprocess_detections in roi_heads with a batch_size > 1).

The test is disabled for now, until the 2 references PRs are merged and bilinear Resize is implemented in opset 11 to match PyTorch's interpolate.

Copy link
Copy Markdown
Member

@fmassa fmassa left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Changes are pretty minimal, that's pretty nice, thanks for the PR @lara-hdr !

I have a few comments and questions, let me know if I'm missing something

@fmassa
Copy link
Copy Markdown
Member

fmassa commented Oct 16, 2019

@lara-hdr can you rebase your PR on top of master? I've issued a fix to CI that should make things work.

Also, I've (temporarily) made torchvision CI depend on PyTorch 1.3, so maybe we might need to disable a few tests if the PyTorch version is not greater than 1.4 (like unbind)

@codecov-io
Copy link
Copy Markdown

Codecov Report

Merging #1401 into master will decrease coverage by 1.16%.
The diff coverage is 57.14%.

Impacted file tree graph

@@            Coverage Diff             @@
##           master    #1401      +/-   ##
==========================================
- Coverage   64.46%   63.29%   -1.17%     
==========================================
  Files          83       83              
  Lines        6421     6424       +3     
  Branches      982      880     -102     
==========================================
- Hits         4139     4066      -73     
- Misses       1992     2047      +55     
- Partials      290      311      +21
Impacted Files Coverage Δ
torchvision/models/detection/transform.py 66.92% <ø> (ø) ⬆️
torchvision/models/detection/roi_heads.py 55.3% <57.14%> (-0.47%) ⬇️
torchvision/ops/_register_onnx_ops.py 52.17% <0%> (-47.83%) ⬇️
torchvision/ops/poolers.py 80.72% <0%> (-15.67%) ⬇️
torchvision/models/detection/rpn.py 75.22% <0%> (-7.21%) ⬇️
torchvision/datasets/cifar.py 71.26% <0%> (-6.9%) ⬇️
torchvision/datasets/lsun.py 20.4% <0%> (-4.09%) ⬇️
torchvision/datasets/utils.py 59% <0%> (-3.11%) ⬇️
torchvision/datasets/folder.py 79.48% <0%> (-2.57%) ⬇️
torchvision/transforms/transforms.py 78.91% <0%> (-2.13%) ⬇️
... and 2 more

Continue to review full report at Codecov.

Legend - Click here to learn more
Δ = absolute <relative> (impact), ø = not affected, ? = missing data
Powered by Codecov. Last update 22c50fd...0c3667c. Read the comment docs.

@lara-hdr
Copy link
Copy Markdown
Contributor Author

@fmassa the roi_heads test also calls transform.postprocess() at the end, so it uses the unbind.
Should I keep this test for only roi_heads (and remove the call to transform.postprocess()) to have a test enabled for PyTorch 1.3?

Copy link
Copy Markdown
Member

@fmassa fmassa left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@lara-hdr we will be moving CI back to PyTorch nightly, so we can remove the tests being disabled in a follow-up PR.

Thanks a lot for the PR!

pred_scores = pred_scores.split(boxes_per_image, 0)
if len(boxes_per_image) == 1:
# TODO : remove this when ONNX support dynamic split sizes
pred_boxes = (pred_boxes,)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What is the limitation here exactly?

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think that ONNX exporter doesn't yet support a boxes_per_image which is dynamic, so it gets converted into a constant and thus doesn't work for different number of boxes.

This is something that will be fixed by the ONNX team I believe, and this is currently only a workaround solution (which I think shouldn't affect torchscript?)

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

exactly. split is not yet supported with dynamic sizes (so boxes_per_image will be exported as a constant which will for with any new input). But we are working on supporting this scenario in the exporter

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If the length isn't 1 isn't it still dynamic ? Not getting how this isn't still dynamic.

(small change to support in TS, mostly just curious)

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

it is, but we are currently only supporting a batch size of 1

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants