Skip to content

Commit 6ab188f

Browse files
committed
resolve comment
1 parent c2e1439 commit 6ab188f

1 file changed

Lines changed: 6 additions & 6 deletions

File tree

mmdeploy/codebase/mmcls/deploy/classification_model.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -156,20 +156,20 @@ def get_classes_from_config(model_cfg: Union[str, mmcv.Config]):
156156
data_cfg = model_cfg.data
157157

158158
def _get_class_names(dataset_type: str):
159-
dataset = data_cfg[dataset_type]
160-
if not dataset:
159+
dataset = data_cfg.get(dataset_type, None)
160+
if (not dataset) or (dataset.type not in module_dict):
161161
return None
162+
162163
module = module_dict[dataset.type]
163164
if module.CLASSES is not None:
164165
return module.CLASSES
165166
return module.get_classes(dataset.get('classes', None))
166167

167168
class_names = None
168169
for dataset_type in ['val', 'test', 'train']:
169-
if dataset_type in data_cfg:
170-
class_names = _get_class_names(dataset_type)
171-
if class_names is not None:
172-
break
170+
class_names = _get_class_names(dataset_type)
171+
if class_names is not None:
172+
break
173173

174174
if class_names is None:
175175
logger = get_root_logger()

0 commit comments

Comments
 (0)