1- import os
2- import os .path
1+ import pathlib
32from typing import Callable , Optional , Any , Tuple
43
54from PIL import Image
65
7- from .utils import download_and_extract_archive , download_url
6+ from .utils import download_and_extract_archive , download_url , verify_str_arg
87from .vision import VisionDataset
98
109
1110class StanfordCars (VisionDataset ):
1211 """`Stanford Cars <https://ai.stanford.edu/~jkrause/cars/car_dataset.html>`_ Dataset
1312
14- .. warning::
13+ The Cars dataset contains 16,185 images of 196 classes of cars. The data is
14+ split into 8,144 training images and 8,041 testing images, where each class
15+ has been split roughly in a 50-50 split
16+
17+ .. note::
1518
1619 This class needs `scipy <https://docs.scipy.org/doc/>`_ to load target files from `.mat` format.
1720
1821 Args:
1922 root (string): Root directory of dataset
20- train (bool , optional):If True, creates dataset from training set, otherwise creates from test set
23+ split (string , optional): The dataset split, supports ``"train"`` (default) or ``"test"``.
2124 transform (callable, optional): A function/transform that takes in an PIL image
2225 and returns a transformed version. E.g, ``transforms.RandomCrop``
2326 target_transform (callable, optional): A function/transform that takes in the
@@ -26,30 +29,10 @@ class StanfordCars(VisionDataset):
2629 puts it in root directory. If dataset is already downloaded, it is not
2730 downloaded again."""
2831
29- urls = (
30- "https://ai.stanford.edu/~jkrause/car196/cars_test.tgz" ,
31- "https://ai.stanford.edu/~jkrause/car196/cars_train.tgz" ,
32- ) # test and train image urls
33-
34- md5s = (
35- "4ce7ebf6a94d07f1952d94dd34c4d501" ,
36- "065e5b463ae28d29e77c1b4b166cfe61" ,
37- ) # md5checksum for test and train data
38-
39- annot_urls = (
40- "https://ai.stanford.edu/~jkrause/car196/cars_test_annos_withlabels.mat" ,
41- "https://ai.stanford.edu/~jkrause/cars/car_devkit.tgz" ,
42- ) # annotations and labels for test and train
43-
44- annot_md5s = (
45- "b0a2b23655a3edd16d84508592a98d10" ,
46- "c3b158d763b6e2245038c8ad08e45376" ,
47- ) # md5 checksum for annotations
48-
4932 def __init__ (
5033 self ,
5134 root : str ,
52- train : bool = True ,
35+ split : str = "train" ,
5336 transform : Optional [Callable ] = None ,
5437 target_transform : Optional [Callable ] = None ,
5538 download : bool = False ,
@@ -62,7 +45,16 @@ def __init__(
6245
6346 super ().__init__ (root , transform = transform , target_transform = target_transform )
6447
65- self .train = train
48+ self ._split = verify_str_arg (split , "split" , ("train" , "test" ))
49+ self ._base_folder = pathlib .Path (root ) / "stanford_cars"
50+ devkit = self ._base_folder / "devkit"
51+
52+ if self ._split == "train" :
53+ self ._annotations_mat_path = devkit / "cars_train_annos.mat"
54+ self ._images_base_path = self ._base_folder / "cars_train"
55+ else :
56+ self ._annotations_mat_path = self ._base_folder / "cars_test_annos_withlabels.mat"
57+ self ._images_base_path = self ._base_folder / "cars_test"
6658
6759 if download :
6860 self .download ()
@@ -72,22 +64,13 @@ def __init__(
7264
7365 self ._samples = [
7466 (
75- os .path .join (self .root , f"cars_{ 'train' if self .train else 'test' } " , annotation ["fname" ]),
76- annotation ["class" ] - 1 ,
77- # Beware stanford cars target mapping starts from 1
67+ str (self ._images_base_path / annotation ["fname" ]),
68+ annotation ["class" ] - 1 , # Original target mapping starts from 1, hence -1
7869 )
79- for annotation in sio .loadmat (
80- os .path .join (
81- self .root ,
82- * ["devkit" , "cars_train_annos.mat" ] if self .train else ["cars_test_annos_withlabels.mat" ],
83- ),
84- squeeze_me = True ,
85- )["annotations" ]
70+ for annotation in sio .loadmat (self ._annotations_mat_path , squeeze_me = True )["annotations" ]
8671 ]
8772
88- self .classes = sio .loadmat (os .path .join (self .root , "devkit" , "cars_meta.mat" ), squeeze_me = True )[
89- "class_names"
90- ].tolist ()
73+ self .classes = sio .loadmat (str (devkit / "cars_meta.mat" ), squeeze_me = True )["class_names" ].tolist ()
9174 self .class_to_idx = {cls : i for i , cls in enumerate (self .classes )}
9275
9376 def __len__ (self ) -> int :
@@ -108,20 +91,31 @@ def download(self) -> None:
10891 if self ._check_exists ():
10992 return
11093
111- download_and_extract_archive (url = self .urls [self .train ], download_root = self .root , md5 = self .md5s [self .train ])
112- download_and_extract_archive (url = self .annot_urls [1 ], download_root = self .root , md5 = self .annot_md5s [1 ])
113- if not self .train :
94+ download_and_extract_archive (
95+ url = "https://ai.stanford.edu/~jkrause/cars/car_devkit.tgz" ,
96+ download_root = self ._base_folder ,
97+ md5 = "c3b158d763b6e2245038c8ad08e45376" ,
98+ )
99+ if self ._split == "train" :
100+ download_and_extract_archive (
101+ url = "https://ai.stanford.edu/~jkrause/car196/cars_train.tgz" ,
102+ download_root = self ._base_folder ,
103+ md5 = "065e5b463ae28d29e77c1b4b166cfe61" ,
104+ )
105+ else :
106+ download_and_extract_archive (
107+ url = "https://ai.stanford.edu/~jkrause/car196/cars_test.tgz" ,
108+ download_root = self ._base_folder ,
109+ md5 = "4ce7ebf6a94d07f1952d94dd34c4d501" ,
110+ )
114111 download_url (
115- url = self . annot_urls [ 0 ] ,
116- root = self .root ,
117- md5 = self . annot_md5s [ 0 ] ,
112+ url = "https://ai.stanford.edu/~jkrause/car196/cars_test_annos_withlabels.mat" ,
113+ root = self ._base_folder ,
114+ md5 = "b0a2b23655a3edd16d84508592a98d10" ,
118115 )
119116
120117 def _check_exists (self ) -> bool :
121- return (
122- os .path .exists (os .path .join (self .root , f"cars_{ 'train' if self .train else 'test' } " ))
123- and os .path .isdir (os .path .join (self .root , f"cars_{ 'train' if self .train else 'test' } " ))
124- and os .path .exists (os .path .join (self .root , "devkit" , "cars_meta.mat" ))
125- if self .train
126- else os .path .exists (os .path .join (self .root , "cars_test_annos_withlabels.mat" ))
127- )
118+ if not (self ._base_folder / "devkit" ).is_dir ():
119+ return False
120+
121+ return self ._annotations_mat_path .exists () and self ._images_base_path .is_dir ()
0 commit comments