|
7 | 7 |
|
8 | 8 | from mmdeploy.apis.core import PIPELINE_MANAGER |
9 | 9 | from mmdeploy.core import RewriterContext, patch_model |
10 | | -from mmdeploy.utils import Backend, get_root_logger |
| 10 | +from mmdeploy.utils import IR, Backend, get_ir_config, get_root_logger |
11 | 11 | from .optimizer import * # noqa |
12 | 12 | from .passes import optimize_onnx |
13 | 13 |
|
@@ -91,20 +91,21 @@ def _add_or_update(cfg: dict, key: str, val: Any): |
91 | 91 | verbose=verbose, |
92 | 92 | keep_initializers_as_inputs=keep_initializers_as_inputs) |
93 | 93 | _add_or_update(deploy_cfg, 'ir_config', ir_config) |
94 | | - |
| 94 | + ir = IR.get(get_ir_config(deploy_cfg)['type']) |
95 | 95 | if isinstance(backend, Backend): |
96 | 96 | backend = backend.value |
97 | 97 | backend_config = dict(type=backend) |
98 | 98 | _add_or_update(deploy_cfg, 'backend_config', backend_config) |
99 | 99 |
|
100 | 100 | context_info['cfg'] = deploy_cfg |
| 101 | + context_info['ir'] = ir |
101 | 102 | if 'backend' not in context_info: |
102 | 103 | context_info['backend'] = backend |
103 | 104 | if 'opset' not in context_info: |
104 | 105 | context_info['opset'] = opset_version |
105 | 106 |
|
106 | 107 | # patch model |
107 | | - patched_model = patch_model(model, cfg=deploy_cfg, backend=backend) |
| 108 | + patched_model = patch_model(model, cfg=deploy_cfg, backend=backend, ir=ir) |
108 | 109 |
|
109 | 110 | if 'onnx_custom_passes' not in context_info: |
110 | 111 | onnx_custom_passes = optimize_onnx if optimize else None |
|
0 commit comments