@@ -1451,6 +1451,41 @@ def test_box_area_jit(self):
14511451 torch .testing .assert_close (scripted_area , expected )
14521452
14531453
1454+ class TestBoxAreaCenter :
1455+ def area_check (self , box , expected , atol = 1e-4 ):
1456+ out = ops .box_area_center (box )
1457+ torch .testing .assert_close (out , expected , rtol = 0.0 , check_dtype = False , atol = atol )
1458+
1459+ @pytest .mark .parametrize ("dtype" , [torch .int8 , torch .int16 , torch .int32 , torch .int64 ])
1460+ def test_int_boxes (self , dtype ):
1461+ box_tensor = ops .box_convert (torch .tensor ([[0 , 0 , 100 , 100 ], [0 , 0 , 0 , 0 ]], dtype = dtype ),
1462+ in_fmt = "xyxy" , out_fmt = "cxcywh" )
1463+ expected = torch .tensor ([10000 , 0 ], dtype = torch .int32 )
1464+ self .area_check (box_tensor , expected )
1465+
1466+ @pytest .mark .parametrize ("dtype" , [torch .float32 , torch .float64 ])
1467+ def test_float_boxes (self , dtype ):
1468+ box_tensor = ops .box_convert (torch .tensor (FLOAT_BOXES , dtype = dtype ), in_fmt = "xyxy" , out_fmt = "cxcywh" )
1469+ expected = torch .tensor ([604723.0806 , 600965.4666 , 592761.0085 ], dtype = dtype )
1470+ self .area_check (box_tensor , expected )
1471+
1472+ def test_float16_box (self ):
1473+ box_tensor = ops .box_convert (torch .tensor (
1474+ [[2.825 , 1.8625 , 3.90 , 4.85 ], [2.825 , 4.875 , 19.20 , 5.10 ], [2.925 , 1.80 , 8.90 , 4.90 ]], dtype = torch .float16
1475+ ), in_fmt = "xyxy" , out_fmt = "cxcywh" )
1476+
1477+ expected = torch .tensor ([3.2170 , 3.7108 , 18.5071 ], dtype = torch .float16 )
1478+ self .area_check (box_tensor , expected , atol = 0.01 )
1479+
1480+ def test_box_area_jit (self ):
1481+ box_tensor = ops .box_convert (torch .tensor ([[0 , 0 , 100 , 100 ], [0 , 0 , 0 , 0 ]], dtype = torch .float ),
1482+ in_fmt = "xyxy" , out_fmt = "cxcywh" )
1483+ expected = ops .box_area_center (box_tensor )
1484+ scripted_fn = torch .jit .script (ops .box_area_center )
1485+ scripted_area = scripted_fn (box_tensor )
1486+ torch .testing .assert_close (scripted_area , expected )
1487+
1488+
14541489INT_BOXES = [[0 , 0 , 100 , 100 ], [0 , 0 , 50 , 50 ], [200 , 200 , 300 , 300 ], [0 , 0 , 25 , 25 ]]
14551490INT_BOXES2 = [[0 , 0 , 100 , 100 ], [0 , 0 , 50 , 50 ], [200 , 200 , 300 , 300 ]]
14561491FLOAT_BOXES = [
@@ -1459,6 +1494,14 @@ def test_box_area_jit(self):
14591494 [279.2440 , 197.9812 , 1189.4746 , 849.2019 ],
14601495]
14611496
1497+ INT_BOXES_CXCYWH = [[50 , 50 , 100 , 100 ], [25 , 25 , 50 , 50 ], [250 , 250 , 100 , 100 ], [10 , 10 , 20 , 20 ]]
1498+ INT_BOXES2_CXCYWH = [[50 , 50 , 100 , 100 ], [25 , 25 , 50 , 50 ], [250 , 250 , 100 , 100 ]]
1499+ FLOAT_BOXES_CXCYWH = [
1500+ [739.4324 , 518.5154 , 908.1572 , 665.8793 ],
1501+ [738.8228 , 519.9021 , 907.3512 , 662.3295 ],
1502+ [734.3593 , 523.5916 , 910.2306 , 651.2207 ]
1503+ ]
1504+
14621505
14631506def gen_box (size , dtype = torch .float ):
14641507 xy1 = torch .rand ((size , 2 ), dtype = dtype )
@@ -1525,6 +1568,65 @@ def test_iou_cartesian(self):
15251568 self ._run_cartesian_test (ops .box_iou )
15261569
15271570
1571+ class TestIouCenterBase :
1572+ @staticmethod
1573+ def _run_test (target_fn : Callable , actual_box1 , actual_box2 , dtypes , atol , expected ):
1574+ for dtype in dtypes :
1575+ actual_box1 = torch .tensor (actual_box1 , dtype = dtype )
1576+ actual_box2 = torch .tensor (actual_box2 , dtype = dtype )
1577+ expected_box = torch .tensor (expected )
1578+ out = target_fn (actual_box1 , actual_box2 )
1579+ torch .testing .assert_close (out , expected_box , rtol = 0.0 , check_dtype = False , atol = atol )
1580+
1581+ @staticmethod
1582+ def _run_jit_test (target_fn : Callable , actual_box : List ):
1583+ box_tensor = torch .tensor (actual_box , dtype = torch .float )
1584+ expected = target_fn (box_tensor , box_tensor )
1585+ scripted_fn = torch .jit .script (target_fn )
1586+ scripted_out = scripted_fn (box_tensor , box_tensor )
1587+ torch .testing .assert_close (scripted_out , expected )
1588+
1589+ @staticmethod
1590+ def _cartesian_product (boxes1 , boxes2 , target_fn : Callable ):
1591+ N = boxes1 .size (0 )
1592+ M = boxes2 .size (0 )
1593+ result = torch .zeros ((N , M ))
1594+ for i in range (N ):
1595+ for j in range (M ):
1596+ result [i , j ] = target_fn (boxes1 [i ].unsqueeze (0 ), boxes2 [j ].unsqueeze (0 ))
1597+ return result
1598+
1599+ @staticmethod
1600+ def _run_cartesian_test (target_fn : Callable ):
1601+ boxes1 = ops .box_convert (gen_box (5 ), in_fmt = "xyxy" , out_fmt = "cxcywh" )
1602+ boxes2 = ops .box_convert (gen_box (7 ), in_fmt = "xyxy" , out_fmt = "cxcywh" )
1603+ a = TestIouCenterBase ._cartesian_product (boxes1 , boxes2 , target_fn )
1604+ b = target_fn (boxes1 , boxes2 )
1605+ torch .testing .assert_close (a , b )
1606+
1607+
1608+ class TestBoxIouCenter (TestIouBase ):
1609+ int_expected = [[1.0 , 0.25 , 0.0 ], [0.25 , 1.0 , 0.0 ], [0.0 , 0.0 , 1.0 ], [0.04 , 0.16 , 0.0 ]]
1610+ float_expected = [[1.0 , 0.9933 , 0.9673 ], [0.9933 , 1.0 , 0.9737 ], [0.9673 , 0.9737 , 1.0 ]]
1611+
1612+ @pytest .mark .parametrize (
1613+ "actual_box1, actual_box2, dtypes, atol, expected" ,
1614+ [
1615+ pytest .param (INT_BOXES_CXCYWH , INT_BOXES2_CXCYWH , [torch .int16 , torch .int32 , torch .int64 ], 1e-4 , int_expected ),
1616+ pytest .param (FLOAT_BOXES_CXCYWH , FLOAT_BOXES_CXCYWH , [torch .float16 ], 0.002 , float_expected ),
1617+ pytest .param (FLOAT_BOXES_CXCYWH , FLOAT_BOXES_CXCYWH , [torch .float32 , torch .float64 ], 1e-3 , float_expected ),
1618+ ],
1619+ )
1620+ def test_iou (self , actual_box1 , actual_box2 , dtypes , atol , expected ):
1621+ self ._run_test (ops .box_iou_center , actual_box1 , actual_box2 , dtypes , atol , expected )
1622+
1623+ def test_iou_jit (self ):
1624+ self ._run_jit_test (ops .box_iou_center , INT_BOXES_CXCYWH )
1625+
1626+ def test_iou_cartesian (self ):
1627+ self ._run_cartesian_test (ops .box_iou_center )
1628+
1629+
15281630class TestGeneralizedBoxIou (TestIouBase ):
15291631 int_expected = [[1.0 , 0.25 , - 0.7778 ], [0.25 , 1.0 , - 0.8611 ], [- 0.7778 , - 0.8611 , 1.0 ], [0.0625 , 0.25 , - 0.8819 ]]
15301632 float_expected = [[1.0 , 0.9933 , 0.9673 ], [0.9933 , 1.0 , 0.9737 ], [0.9673 , 0.9737 , 1.0 ]]
0 commit comments