|
53 | 53 | from torchvision.transforms.v2 import functional as F |
54 | 54 | from torchvision.transforms.v2._utils import check_type, is_pure_tensor |
55 | 55 | from torchvision.transforms.v2.functional._geometry import _get_perspective_coeffs, _parallelogram_to_bounding_boxes |
| 56 | +from torchvision.transforms.v2.functional._meta import ( |
| 57 | + _cxcywh_to_xywh, |
| 58 | + _cxcywh_to_xyxy, |
| 59 | + _xywh_to_cxcywh, |
| 60 | + _xywh_to_xyxy, |
| 61 | + _xyxy_to_cxcywh, |
| 62 | + _xyxy_to_xywh, |
| 63 | +) |
56 | 64 | from torchvision.transforms.v2.functional._utils import _get_kernel, _import_cvcuda, _register_kernel_internal |
57 | 65 |
|
58 | 66 |
|
@@ -4314,8 +4322,14 @@ def test_kernel_noop(self, format, inplace): |
4314 | 4322 | assert output._version == input_version |
4315 | 4323 |
|
4316 | 4324 | @pytest.mark.parametrize(("old_format", "new_format"), old_new_formats) |
4317 | | - def test_kernel_inplace(self, old_format, new_format): |
4318 | | - input = make_bounding_boxes(format=old_format).as_subclass(torch.Tensor) |
| 4325 | + @pytest.mark.parametrize("dtype", [torch.float32, torch.int64]) |
| 4326 | + def test_kernel_inplace(self, old_format, new_format, dtype): |
| 4327 | + if not dtype.is_floating_point and ( |
| 4328 | + tv_tensors.is_rotated_bounding_format(old_format) or tv_tensors.is_rotated_bounding_format(new_format) |
| 4329 | + ): |
| 4330 | + pytest.xfail("Rotated bounding boxes should be floating point tensors") |
| 4331 | + |
| 4332 | + input = make_bounding_boxes(format=old_format, dtype=dtype).as_subclass(torch.Tensor) |
4319 | 4333 | input_version = input._version |
4320 | 4334 |
|
4321 | 4335 | output_out_of_place = F.convert_bounding_box_format(input, old_format=old_format, new_format=new_format) |
@@ -4412,6 +4426,26 @@ def test_errors(self): |
4412 | 4426 | input_tv_tensor, old_format=input_tv_tensor.format, new_format=input_tv_tensor.format |
4413 | 4427 | ) |
4414 | 4428 |
|
| 4429 | + @pytest.mark.parametrize( |
| 4430 | + "old_format", |
| 4431 | + [tv_tensors.BoundingBoxFormat.XYWH, tv_tensors.BoundingBoxFormat.CXCYWH], |
| 4432 | + ) |
| 4433 | + @pytest.mark.parametrize("dtype", [torch.float32, torch.float64, torch.int32, torch.int64]) |
| 4434 | + @pytest.mark.parametrize("device", cpu_and_cuda()) |
| 4435 | + def test_xywh_cxcywh_direct_conversion_parity(self, old_format, dtype, device): |
| 4436 | + |
| 4437 | + bounding_boxes = make_bounding_boxes(format=old_format, dtype=dtype, device=device) |
| 4438 | + input_tensor = bounding_boxes.as_subclass(torch.Tensor).clone() |
| 4439 | + |
| 4440 | + if old_format == tv_tensors.BoundingBoxFormat.XYWH: |
| 4441 | + actual = _xywh_to_cxcywh(input_tensor.clone(), inplace=False) |
| 4442 | + expected = _xyxy_to_cxcywh(_xywh_to_xyxy(input_tensor.clone(), inplace=False), inplace=False) |
| 4443 | + else: |
| 4444 | + actual = _cxcywh_to_xywh(input_tensor.clone(), inplace=False) |
| 4445 | + expected = _xyxy_to_xywh(_cxcywh_to_xyxy(input_tensor.clone(), inplace=False), inplace=False) |
| 4446 | + |
| 4447 | + torch.testing.assert_close(actual, expected) |
| 4448 | + |
4415 | 4449 | def test_cxcywh_to_xyxy_odd_dimensions(self): |
4416 | 4450 | # Non-regression test for https://github.com/pytorch/vision/issues/8887 |
4417 | 4451 | # Integer bounding boxes with odd width/height produced incorrect results |
|
0 commit comments