@@ -40,16 +40,17 @@ def copypaste_collate_fn(batch):
4040 return copypaste (* utils .collate_fn (batch ))
4141
4242
43- def get_dataset (name , image_set , transform , data_path , use_v2 ):
44- paths = {"coco" : (data_path , get_coco , 91 ), "coco_kp" : (data_path , get_coco_kp , 2 )}
45- p , ds_fn , num_classes = paths [name ]
43+ def get_dataset (is_train , args ):
44+ image_set = "train" if is_train else "val"
45+ paths = {"coco" : (args .data_path , get_coco , 91 ), "coco_kp" : (args .data_path , get_coco_kp , 2 )}
46+ p , ds_fn , num_classes = paths [args .dataset ]
4647
47- ds = ds_fn (p , image_set = image_set , transforms = transform , use_v2 = use_v2 )
48+ ds = ds_fn (p , image_set = image_set , transforms = get_transform ( is_train , args ), use_v2 = args . use_v2 )
4849 return ds , num_classes
4950
5051
51- def get_transform (train , args ):
52- if train :
52+ def get_transform (is_train , args ):
53+ if is_train :
5354 return presets .DetectionPresetTrain (
5455 data_augmentation = args .data_augmentation , backend = args .backend , use_v2 = args .use_v2
5556 )
@@ -185,8 +186,8 @@ def main(args):
185186 # Data loading code
186187 print ("Loading data" )
187188
188- dataset , num_classes = get_dataset (args . dataset , "train" , get_transform ( True , args ), args . data_path , args . use_v2 )
189- dataset_test , _ = get_dataset (args . dataset , "val" , get_transform ( False , args ), args . data_path , args . use_v2 )
189+ dataset , num_classes = get_dataset (is_train = True , args = args )
190+ dataset_test , _ = get_dataset (is_train = False , args = args )
190191
191192 print ("Creating data loaders" )
192193 if args .distributed :
0 commit comments