Skip to content

Commit fa034e0

Browse files
authored
fix shufflenetv2 with trt (#645)
* fix shufflenetv2 and pspnet * fix ci * remove print
1 parent ae47e9d commit fa034e0

6 files changed

Lines changed: 22 additions & 23 deletions

File tree

mmdeploy/apis/calibration.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,8 @@
66
from mmcv.parallel import MMDataParallel
77

88
from mmdeploy.core import patch_model
9-
from mmdeploy.utils import cfg_apply_marks, load_config
9+
from mmdeploy.utils import (IR, cfg_apply_marks, get_backend, get_ir_config,
10+
load_config)
1011
from .core import PIPELINE_MANAGER, no_mp
1112
from .utils import create_calib_input_data as create_calib_input_data_impl
1213

@@ -61,7 +62,10 @@ def create_calib_input_data(calib_file: str,
6162
dataset = task_processor.build_dataset(dataset_cfg, dataset_type)
6263

6364
# patch model
64-
patched_model = patch_model(model, cfg=deploy_cfg)
65+
backend = get_backend(deploy_cfg)
66+
ir = IR.get(get_ir_config(deploy_cfg)['type'])
67+
patched_model = patch_model(
68+
model, cfg=deploy_cfg, backend=backend, ir=ir)
6569

6670
dataloader = task_processor.build_dataloader(
6771
dataset, 1, 1, dist=False, shuffle=False)

mmdeploy/apis/onnx/export.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77

88
from mmdeploy.apis.core import PIPELINE_MANAGER
99
from mmdeploy.core import RewriterContext, patch_model
10-
from mmdeploy.utils import Backend, get_root_logger
10+
from mmdeploy.utils import IR, Backend, get_ir_config, get_root_logger
1111
from .optimizer import * # noqa
1212
from .passes import optimize_onnx
1313

@@ -91,20 +91,21 @@ def _add_or_update(cfg: dict, key: str, val: Any):
9191
verbose=verbose,
9292
keep_initializers_as_inputs=keep_initializers_as_inputs)
9393
_add_or_update(deploy_cfg, 'ir_config', ir_config)
94-
94+
ir = IR.get(get_ir_config(deploy_cfg)['type'])
9595
if isinstance(backend, Backend):
9696
backend = backend.value
9797
backend_config = dict(type=backend)
9898
_add_or_update(deploy_cfg, 'backend_config', backend_config)
9999

100100
context_info['cfg'] = deploy_cfg
101+
context_info['ir'] = ir
101102
if 'backend' not in context_info:
102103
context_info['backend'] = backend
103104
if 'opset' not in context_info:
104105
context_info['opset'] = opset_version
105106

106107
# patch model
107-
patched_model = patch_model(model, cfg=deploy_cfg, backend=backend)
108+
patched_model = patch_model(model, cfg=deploy_cfg, backend=backend, ir=ir)
108109

109110
if 'onnx_custom_passes' not in context_info:
110111
onnx_custom_passes = optimize_onnx if optimize else None

mmdeploy/apis/torch_jit/trace.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
from packaging.version import parse as version_parse
77

88
from mmdeploy.core import RewriterContext, patch_model
9-
from mmdeploy.utils import IR, Backend, get_root_logger
9+
from mmdeploy.utils import IR, Backend, get_ir_config, get_root_logger
1010
from ..core import PIPELINE_MANAGER
1111

1212

@@ -87,7 +87,8 @@ def _add_or_update(cfg: dict, key: str, val: Any):
8787

8888
# patch model
8989
if isinstance(func, torch.nn.Module):
90-
func = patch_model(func, cfg=deploy_cfg, backend=backend)
90+
ir = IR.get(get_ir_config(deploy_cfg)['type'])
91+
func = patch_model(func, cfg=deploy_cfg, backend=backend, ir=ir)
9192

9293
with RewriterContext(**context_info), torch.no_grad():
9394
# for exporting models with weight that depends on inputs
Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,8 @@
11
# Copyright (c) OpenMMLab. All rights reserved.
2-
from .shufflenet_v2 import shufflenetv2_backbone__forward__ncnn
2+
from .shufflenet_v2 import shufflenetv2_backbone__forward__default
33
from .vision_transformer import visiontransformer__forward__ncnn
44

55
__all__ = [
6-
'shufflenetv2_backbone__forward__ncnn',
6+
'shufflenetv2_backbone__forward__default',
77
'visiontransformer__forward__ncnn',
88
]

mmdeploy/codebase/mmcls/models/backbones/shufflenet_v2.py

Lines changed: 5 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -2,23 +2,16 @@
22
import torch
33

44
from mmdeploy.core import FUNCTION_REWRITER
5-
from mmdeploy.utils import Backend
65

76

8-
# torch.chunk will export dynamic shape slice, which will lead integer input
9-
# on ncnn backend. So the model needs to rewrite.
107
@FUNCTION_REWRITER.register_rewriter(
11-
func_name='mmcls.models.backbones.shufflenet_v2.InvertedResidual.forward',
12-
backend=Backend.NCNN.value)
13-
@FUNCTION_REWRITER.register_rewriter(
14-
func_name='mmcls.models.backbones.shufflenet_v2.InvertedResidual.forward',
15-
backend=Backend.TORCHSCRIPT.value)
16-
def shufflenetv2_backbone__forward__ncnn(ctx, self, x):
17-
"""Rewrite `forward` of InvertedResidual used in shufflenet_v2 for ncnn
18-
backend.
8+
func_name='mmcls.models.backbones.shufflenet_v2.InvertedResidual.forward')
9+
def shufflenetv2_backbone__forward__default(ctx, self, x):
10+
"""Rewrite `forward` of InvertedResidual used in shufflenet_v2.
1911
2012
The chunk in original InvertedResidual.forward will convert to dynamic
21-
`Slice` operator in ONNX, which will raise error in ncnn.
13+
`Slice` operator in ONNX, which will raise error in ncnn, torchscript
14+
and tensorrt. Here we replace `chunk` with `split`.
2215
2316
Args:
2417
ctx (ContextCaller): The context with additional information.

mmdeploy/utils/config_utils.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,13 +14,13 @@ def load_config(*args) -> List[mmcv.Config]:
1414
args (str | Sequence[str]): The path to the config file(s).
1515
1616
Returns:
17-
List[mmcv.Config]: The content of config.
17+
List[mmcv.Config | dict]: The content of config.
1818
"""
1919

2020
def _load_config(cfg):
2121
if isinstance(cfg, str):
2222
cfg = mmcv.Config.fromfile(cfg)
23-
if not isinstance(cfg, (mmcv.Config, mmcv.ConfigDict)):
23+
if not isinstance(cfg, (mmcv.Config, mmcv.ConfigDict, dict)):
2424
raise TypeError('deploy_cfg must be a filename or Config object, '
2525
f'but got {type(cfg)}')
2626
return cfg

0 commit comments

Comments
 (0)