Skip to content

Commit 0cf8cab

Browse files
raimbekovmzy1gitZhitao Yu
authored
Add direct XYWH-CXCYWH conversion for better performance (#9326)
Co-authored-by: zy1git <zycoding1@gmail.com> Co-authored-by: Zhitao Yu <zhitao@fb.com>
1 parent 4967c64 commit 0cf8cab

2 files changed

Lines changed: 81 additions & 3 deletions

File tree

test/test_transforms_v2.py

Lines changed: 36 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,14 @@
5353
from torchvision.transforms.v2 import functional as F
5454
from torchvision.transforms.v2._utils import check_type, is_pure_tensor
5555
from torchvision.transforms.v2.functional._geometry import _get_perspective_coeffs, _parallelogram_to_bounding_boxes
56+
from torchvision.transforms.v2.functional._meta import (
57+
_cxcywh_to_xywh,
58+
_cxcywh_to_xyxy,
59+
_xywh_to_cxcywh,
60+
_xywh_to_xyxy,
61+
_xyxy_to_cxcywh,
62+
_xyxy_to_xywh,
63+
)
5664
from torchvision.transforms.v2.functional._utils import _get_kernel, _import_cvcuda, _register_kernel_internal
5765

5866

@@ -4314,8 +4322,14 @@ def test_kernel_noop(self, format, inplace):
43144322
assert output._version == input_version
43154323

43164324
@pytest.mark.parametrize(("old_format", "new_format"), old_new_formats)
4317-
def test_kernel_inplace(self, old_format, new_format):
4318-
input = make_bounding_boxes(format=old_format).as_subclass(torch.Tensor)
4325+
@pytest.mark.parametrize("dtype", [torch.float32, torch.int64])
4326+
def test_kernel_inplace(self, old_format, new_format, dtype):
4327+
if not dtype.is_floating_point and (
4328+
tv_tensors.is_rotated_bounding_format(old_format) or tv_tensors.is_rotated_bounding_format(new_format)
4329+
):
4330+
pytest.xfail("Rotated bounding boxes should be floating point tensors")
4331+
4332+
input = make_bounding_boxes(format=old_format, dtype=dtype).as_subclass(torch.Tensor)
43194333
input_version = input._version
43204334

43214335
output_out_of_place = F.convert_bounding_box_format(input, old_format=old_format, new_format=new_format)
@@ -4412,6 +4426,26 @@ def test_errors(self):
44124426
input_tv_tensor, old_format=input_tv_tensor.format, new_format=input_tv_tensor.format
44134427
)
44144428

4429+
@pytest.mark.parametrize(
4430+
"old_format",
4431+
[tv_tensors.BoundingBoxFormat.XYWH, tv_tensors.BoundingBoxFormat.CXCYWH],
4432+
)
4433+
@pytest.mark.parametrize("dtype", [torch.float32, torch.float64, torch.int32, torch.int64])
4434+
@pytest.mark.parametrize("device", cpu_and_cuda())
4435+
def test_xywh_cxcywh_direct_conversion_parity(self, old_format, dtype, device):
4436+
4437+
bounding_boxes = make_bounding_boxes(format=old_format, dtype=dtype, device=device)
4438+
input_tensor = bounding_boxes.as_subclass(torch.Tensor).clone()
4439+
4440+
if old_format == tv_tensors.BoundingBoxFormat.XYWH:
4441+
actual = _xywh_to_cxcywh(input_tensor.clone(), inplace=False)
4442+
expected = _xyxy_to_cxcywh(_xywh_to_xyxy(input_tensor.clone(), inplace=False), inplace=False)
4443+
else:
4444+
actual = _cxcywh_to_xywh(input_tensor.clone(), inplace=False)
4445+
expected = _xyxy_to_xywh(_cxcywh_to_xyxy(input_tensor.clone(), inplace=False), inplace=False)
4446+
4447+
torch.testing.assert_close(actual, expected)
4448+
44154449
def test_cxcywh_to_xyxy_odd_dimensions(self):
44164450
# Non-regression test for https://github.com/pytorch/vision/issues/8887
44174451
# Integer bounding boxes with odd width/height produced incorrect results

torchvision/transforms/v2/functional/_meta.py

Lines changed: 45 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -176,6 +176,46 @@ def _xyxy_to_xywh(xyxy: torch.Tensor, inplace: bool) -> torch.Tensor:
176176
return xywh
177177

178178

179+
def _xywh_to_cxcywh(xywh: torch.Tensor, inplace: bool) -> torch.Tensor:
180+
if not inplace:
181+
xywh = xywh.clone()
182+
183+
# cx = x + width / 2, cy = y + height / 2, width and height stay the same
184+
xywh[..., :2].add_(xywh[..., 2:].div(2, rounding_mode=None if xywh.is_floating_point() else "floor"))
185+
186+
return xywh
187+
188+
189+
def _cxcywh_to_xywh(cxcywh: torch.Tensor, inplace: bool) -> torch.Tensor:
190+
# For integer tensors, use float arithmetic to match the behavior of the
191+
# two-step conversion CXCYWH -> XYXY -> XYWH (where _cxcywh_to_xyxy uses
192+
# float arithmetic, see PR #9322).
193+
original = cxcywh
194+
dtype = cxcywh.dtype
195+
need_cast = not cxcywh.is_floating_point()
196+
197+
if need_cast:
198+
cxcywh = cxcywh.float()
199+
elif not inplace:
200+
cxcywh = cxcywh.clone()
201+
202+
half_wh = cxcywh[..., 2:] / 2
203+
# x = cx - w/2, y = cy - h/2
204+
cxcywh[..., :2].sub_(half_wh)
205+
206+
if need_cast:
207+
# For integer types, truncation of x1/y1 and x2/y2 (= x1 + w, y1 + h) can change
208+
# the effective width/height. Recompute w/h to match the two-step path.
209+
x2y2 = (cxcywh[..., :2] + cxcywh[..., 2:]).to(dtype)
210+
cxcywh = cxcywh.to(dtype)
211+
cxcywh[..., 2:] = x2y2 - cxcywh[..., :2]
212+
if inplace:
213+
original[:] = cxcywh
214+
return original
215+
216+
return cxcywh
217+
218+
179219
def _cxcywh_to_xyxy(cxcywh: torch.Tensor, inplace: bool) -> torch.Tensor:
180220
# For integer tensors, use float arithmetic to match the behavior of
181221
# `torchvision.ops._box_convert._box_cxcywh_to_xyxy`.
@@ -316,7 +356,11 @@ def _convert_bounding_box_format(
316356
if tv_tensors.is_rotated_bounding_format(old_format) ^ tv_tensors.is_rotated_bounding_format(new_format):
317357
raise ValueError("Cannot convert between rotated and unrotated bounding boxes.")
318358

319-
# TODO: Add _xywh_to_cxcywh and _cxcywh_to_xywh to improve performance
359+
if old_format == BoundingBoxFormat.XYWH and new_format == BoundingBoxFormat.CXCYWH:
360+
return _xywh_to_cxcywh(bounding_boxes, inplace)
361+
if old_format == BoundingBoxFormat.CXCYWH and new_format == BoundingBoxFormat.XYWH:
362+
return _cxcywh_to_xywh(bounding_boxes, inplace)
363+
320364
if old_format == BoundingBoxFormat.XYWH:
321365
bounding_boxes = _xywh_to_xyxy(bounding_boxes, inplace)
322366
elif old_format == BoundingBoxFormat.CXCYWH:

0 commit comments

Comments
 (0)