Add SanitizeBoundingBoxes transform#7246
Conversation
pmeier
left a comment
There was a problem hiding this comment.
Thanks Nicolas. Left a few high level comments.
| params = dict(valid_indices=valid_indices, labels=labels) | ||
| flat_outputs = [ | ||
| # Even-though it may look like we're transforming all inputs, we don't: | ||
| # _transform() will only care about BoundingBoxes and the labels | ||
| self._transform(inpt, params) | ||
| for inpt in flat_inputs | ||
| ] |
There was a problem hiding this comment.
Not sure if we can do better without other changes, but this looks pretty weird. I mean, we have the bounding box and labels here. All we need to do is to put it at the right place in flat_inputs and we should be good to go without going the extra mile through self._transform.
There was a problem hiding this comment.
All we need to do is to put it at the right place in flat_inputs
yup... and I don't know how to do that easily :)
But if you can find a way to bypass _transforms(), I'm all ears
There was a problem hiding this comment.
I mean for boxes, we can change query_bounding_box to whatever we like. Meaning, we could return the index from there. For the labels the story is different. We can't pass the flat_inputs because we rely on the dict keys. Meaning, users would need to return a spec similar to what tree_flatten produces, but that is bad UX. 🤷
test/test_prototype_transforms.py
Outdated
| assert out["boxes"].shape[0] == out["labels"].shape[0] | ||
|
|
||
| # This works because we conveniently set labels to arange(num_boxes) | ||
| assert out["labels"].tolist() == valid_indices |
There was a problem hiding this comment.
I guess I can also manually check that all valid_boxes are there, and that there's no invalid ones... LMK
pmeier
left a comment
There was a problem hiding this comment.
Looks good so far. We should also test the other label heuristics though.
One thing that came to mind is that only allowing a str key might be a little restrictive. Even for our builtin pipelines we have (PIL.Image.Image, dict(labels=...)). Maybe we can relax that and accept a sequence which items are used for a repeated index? Meaning, if someone passes labels=(1, "labels"), we do
input = input[1]
input = input["labels"]| # TODO: Do we really need to check for out of bounds here? All | ||
| # transforms should be clamping anyway, so this should never happen? |
There was a problem hiding this comment.
I would keep this for now until we are sure about this, i.e. we have tests that guarantee this. Happy to remove if it turns out we don't need it.
I had in mind something similar where we would go as deep as needed in the inputs until we find a |
pmeier
left a comment
There was a problem hiding this comment.
LGTM if CI is green. Thanks Nicolas!
|
Hey @NicolasHug! You merged this PR, but no labels were added. The list of valid labels is available at https://github.com/pytorch/vision/blob/main/.github/process_commit.py |
|
Thanks a lot for the reviews and for the help with |
Summary: Co-authored-by: Philip Meier <github.pmeier@posteo.de> Reviewed By: vmoens Differential Revision: D44416597 fbshipit-source-id: a4fa3db7daaf5ca5a563935545df17ad36363703
cc @vfdev-5 @bjuncek @pmeier