Skip to content

Train your own improved network with mmdetection and deploy with mmdeploy, but it shows errors. How to implement mmedploy's support for custom network structure? #210

@Bin-ze

Description

@Bin-ze

Thanks for your bug report. We appreciate it a lot.

Checklist

  1. I have searched related issues but cannot get the expected help.
  2. I have read the FAQ documentation but cannot get the expected help.
  3. The bug has not been fixed in the latest version.

Describe the bug
The yolox network has been improved, both training and testing have no problems, but when deploying with mmdeploy, an error occurs。
config:
base = ['../base/schedules/schedule_1x.py', '../base/default_runtime.py']

img_scale = (512, 512)

model settings

model = dict(
type='YOLOX',
input_size=img_scale,
random_size_range=(15, 25),
random_size_interval=10,
backbone=dict(type='CSPDarknet', deepen_factor=0.33, widen_factor=0.5),
neck=dict(
type='YOLOXPAFPN',
in_channels=[128, 256, 512],
out_channels=128,
num_csp_blocks=1),

bbox_head=dict(
    type='YOLOXHead', num_classes=3, in_channels=128, feat_channels=128),
semantic_head=dict(
    type='FusedSemanticHead_v1',
    num_ins=3,
    fusion_level=0,
    num_convs=2,
    in_channels=256,
    conv_out_channels=128),    
train_cfg=dict(assigner=dict(type='SimOTAAssigner', center_radius=2.5)),
# In order to align the source code, the threshold of the val phase is
# 0.01, and the threshold of the test phase is 0.001.
test_cfg=dict(score_thr=0.01, nms=dict(type='nms', iou_threshold=0.65)))

dataset settings

dataset_type = 'VOCDataset'
data_root = 'data/VOCdevkit/'

train_pipeline = [
dict(type='Mosaic', img_scale=img_scale, pad_val=114.0),
dict(
type='RandomAffine',
scaling_ratio_range=(0.1, 2),
border=(-img_scale[0] // 2, -img_scale[1] // 2)),
dict(
type='MixUp',
img_scale=img_scale,
ratio_range=(0.8, 1.6),
pad_val=114.0),
dict(type='YOLOXHSVRandomAug'),
dict(type='RandomFlip', flip_ratio=0.5),
# According to the official implementation, multi-scale
# training is not considered here but in the
# 'mmdet/models/detectors/yolox.py'.
dict(type='Resize', img_scale=img_scale, keep_ratio=True),
dict(
type='Pad',
pad_to_square=True,
# If the image is three-channel, the pad value needs
# to be set separately for each channel.
pad_val=dict(img=(114.0, 114.0, 114.0))),
dict(type='FilterAnnotations', min_gt_bbox_wh=(1, 1), keep_empty=False),
dict(type='DefaultFormatBundle'),
dict(type='Collect', keys=['img', 'gt_bboxes', 'gt_labels'])
]

train_dataset = dict(
type='MultiImageMixDataset',
dataset=dict(
type=dataset_type,
ann_file=data_root + 'VOC2007/ImageSets/Main/trainval.txt',
img_prefix=data_root + 'VOC2007/',
pipeline=[
dict(type='LoadImageFromFile'),
dict(type='LoadAnnotations', with_bbox=True)
],
filter_empty_gt=False,
),
pipeline=train_pipeline)

test_pipeline = [
dict(type='LoadImageFromFile'),
dict(
type='MultiScaleFlipAug',
img_scale=img_scale,
flip=False,
transforms=[
dict(type='Resize', keep_ratio=True),
dict(type='RandomFlip'),
dict(
type='Pad',
pad_to_square=True,
pad_val=dict(img=(114.0, 114.0, 114.0))),
dict(type='DefaultFormatBundle'),
dict(type='Collect', keys=['img'])
])
]

data = dict(
samples_per_gpu=8,
workers_per_gpu=4,
persistent_workers=True,
train=train_dataset,
val=dict(
type=dataset_type,
ann_file=data_root + 'VOC2007/ImageSets/Main/trainval.txt',
img_prefix=data_root + 'VOC2007/',
pipeline=test_pipeline),
test=dict(
type=dataset_type,
ann_file=data_root + 'VOC2007/ImageSets/Main/val.txt',
img_prefix=data_root + 'VOC2007/',
pipeline=test_pipeline))

optimizer

default 8 gpu

optimizer = dict(
type='SGD',
lr=0.0005,
momentum=0.9,
weight_decay=5e-4,
nesterov=True,
paramwise_cfg=dict(norm_decay_mult=0., bias_decay_mult=0.))
optimizer_config = dict(grad_clip=None)

max_epochs = 100
num_last_epochs = 30
resume_from = None
interval = 1

learning policy

lr_config = dict(
delete=True,
policy='YOLOX',
warmup='exp',
by_epoch=False,
warmup_by_epoch=True,
warmup_ratio=1,
warmup_iters=5, # 5 epoch
num_last_epochs=num_last_epochs,
min_lr_ratio=0.05)

runner = dict(type='EpochBasedRunner', max_epochs=max_epochs)

custom_hooks = [
dict(
type='YOLOXModeSwitchHook',
num_last_epochs=num_last_epochs,
priority=48),
dict(
type='SyncNormHook',
num_last_epochs=num_last_epochs,
interval=interval,
priority=48),
dict(
type='ExpMomentumEMAHook',
resume_from=resume_from,
momentum=0.0001,
priority=49)
]
checkpoint_config = dict(interval=interval)
evaluation = dict(
save_best='auto',
# The evaluation interval is 'interval' when running epoch is
# less than ‘max_epochs - num_last_epochs’.
# The evaluation interval is 1 when running epoch is greater than
# or equal to ‘max_epochs - num_last_epochs’.
interval=interval,
dynamic_intervals=[(max_epochs - num_last_epochs, 1)],
metric='mAP')
log_config = dict(interval=50)

error as following :

/project/train/src_repo/mmdeploy/mmdeploy/core/optimizers/function_marker.py:158: TracerWarning: Converting a tensor to a Python integer might cause the trace to be incorrect. We can't record the data flow of Python values, so this value will be treated as a constant in the future. This means that the trace might not generalize to other inputs!
ys_shape = tuple(int(s) for s in ys.shape)
/project/train/src_repo/mmdeploy/mmdeploy/codebase/mmdet/core/post_processing/bbox_nms.py:260: TracerWarning: Converting a tensor to a Python integer might cause the trace to be incorrect. We can't record the data flow of Python values, so this value will be treated as a constant in the future. This means that the trace might not generalize to other inputs!
dets, labels = TRTBatchedNMSop.apply(boxes, scores, int(scores.shape[-1]),
/project/train/src_repo/mmdeploy/mmdeploy/mmcv/ops/nms.py:177: TracerWarning: Converting a tensor to a Python boolean might cause the trace to be incorrect. We can't record the data flow of Python values, so this value will be treated as a constant in the future. This means that the trace might not generalize to other inputs!
out_boxes = min(num_boxes, after_topk)
/project/train/src_repo/mmdeploy/mmdeploy/mmcv/ops/nms.py:181: TracerWarning: Converting a tensor to a Python integer might cause the trace to be incorrect. We can't record the data flow of Python values, so this value will be treated as a constant in the future. This means that the trace might not generalize to other inputs!
(batch_size, out_boxes)).to(scores.device)
/usr/local/lib/python3.6/dist-packages/torch/tensor.py:590: RuntimeWarning: Iterating over a tensor might cause the trace to be incorrect. Passing a tensor of different shape won't change the number of iterations executed (and might lead to errors or silently give incorrect results).
'incorrect results).', category=RuntimeWarning)
2022-03-04:16:46:15,root ERROR [utils.py:41] not enough values to unpack (expected 2, got 1)
Traceback (most recent call last):
File "/project/train/src_repo/mmdeploy/mmdeploy/utils/utils.py", line 36, in target_wrapper
result = target(*args, **kwargs)
File "/project/train/src_repo/mmdeploy/mmdeploy/apis/pytorch2onnx.py", line 96, in torch2onnx
output_file=output_file)
File "/project/train/src_repo/mmdeploy/mmdeploy/apis/pytorch2onnx.py", line 53, in torch2onnx_impl
strip_doc_string=onnx_cfg.get('strip_doc_string', True))
File "/usr/local/lib/python3.6/dist-packages/torch/onnx/init.py", line 276, in export
custom_opsets, enable_onnx_checker, use_external_data_format)
File "/usr/local/lib/python3.6/dist-packages/torch/onnx/utils.py", line 94, in export
use_external_data_format=use_external_data_format)
File "/usr/local/lib/python3.6/dist-packages/torch/onnx/utils.py", line 701, in _export
dynamic_axes=dynamic_axes)
File "/usr/local/lib/python3.6/dist-packages/torch/onnx/utils.py", line 459, in _model_to_graph
use_new_jit_passes)
File "/usr/local/lib/python3.6/dist-packages/torch/onnx/utils.py", line 420, in _create_jit_graph
graph, torch_out = _trace_and_get_graph_from_model(model, args)
File "/usr/local/lib/python3.6/dist-packages/torch/onnx/utils.py", line 380, in _trace_and_get_graph_from_model
torch.jit._get_trace_graph(model, args, strict=False, _force_outplace=False, _return_inputs_states=True)
File "/usr/local/lib/python3.6/dist-packages/torch/jit/_trace.py", line 1139, in _get_trace_graph
outs = ONNXTracedModule(f, strict, _force_outplace, return_inputs, _return_inputs_states)(*args, **kwargs)
File "/usr/local/lib/python3.6/dist-packages/torch/nn/modules/module.py", line 889, in _call_impl
result = self.forward(*input, **kwargs)
File "/usr/local/lib/python3.6/dist-packages/torch/jit/_trace.py", line 130, in forward
self._force_outplace,
File "/usr/local/lib/python3.6/dist-packages/torch/jit/_trace.py", line 116, in wrapper
outs.append(self.inner(*trace_inputs))
File "/usr/local/lib/python3.6/dist-packages/torch/nn/modules/module.py", line 887, in _call_impl
result = self._slow_forward(*input, **kwargs)
File "/usr/local/lib/python3.6/dist-packages/torch/nn/modules/module.py", line 860, in _slow_forward
result = self.forward(*input, **kwargs)
File "/project/train/src_repo/mmdeploy/mmdeploy/core/rewriters/rewriter_utils.py", line 177, in wrapper
return self.func(self, *args, **kwargs)
File "/project/train/src_repo/mmdeploy/mmdeploy/codebase/mmdet/models/detectors/base.py", line 69, in base_detector__forward
return __forward_impl(ctx, self, img, img_metas=img_metas, **kwargs)
File "/project/train/src_repo/mmdeploy/mmdeploy/core/optimizers/function_marker.py", line 247, in g
rets = f(*args, **kwargs)
File "/project/train/src_repo/mmdeploy/mmdeploy/codebase/mmdet/models/detectors/base.py", line 28, in __forward_impl
return self.simple_test(img, img_metas, **kwargs)
File "/project/train/src_repo/mmdetection/mmdet/models/detectors/yolox.py", line 152, in simple_test
for det_bboxes, det_labels in results_list
File "/project/train/src_repo/mmdetection/mmdet/models/detectors/yolox.py", line 152, in
for det_bboxes, det_labels in results_list
ValueError: not enough values to unpack (expected 2, got 1)
2022-03-04 16:46:16,123 - mmdeploy - ERROR - torch2onnx failed.

Reproduction

  1. What command or script did you run?

export TENSORRT_DIR=/project/train/src_repo/TensorRT-8.2.3.0
export LD_LIBRARY_PATH=$TENSORRT_DIR/lib:$TENSORRT_DIR
export CUDNN_DIR=/project/train/src_repo/cudnn-linux-x86_64-8.3.2.44_cuda10.2-archive
export LD_LIBRARY_PATH=${CUDNN_DIR}/lib:${LD_LIBRARY_PATH}
#yolox
PATH_TO_MMDET=/project/train/src_repo/mmdetection/
python tools/deploy.py configs/mmdet/detection/detection_tensorrt-fp16_dynamic-320x320-1344x1344.py $PATH_TO_MMDET/configs/_my_test/yolox.py /project/train/models/yolox/epoch_5.pth $PATH_TO_MMDET/data/VOCdevkit/VOC2007/JPEGImages/white_glove_identification_kitchen_behind_none_train_p_day_20210818_4003.jpg --work-dir /project/train/models/work_dir_yolox --show --device cuda:0 --dump-info

  1. Did you make any modifications on the code or config? Did you understand what you have modified?
    I added a custom header to the network,The output of the modified head and the predicted feature map are weighted with attention.

Environment

2022-03-04 16:44:18,418 - mmdeploy - INFO -

2022-03-04 16:44:18,418 - mmdeploy - INFO - Environmental information
2022-03-04 16:44:20,573 - mmdeploy - INFO - sys.platform: linux
2022-03-04 16:44:20,573 - mmdeploy - INFO - Python: 3.6.9 (default, Dec 8 2021, 21:08:43) [GCC 8.4.0]
2022-03-04 16:44:20,574 - mmdeploy - INFO - CUDA available: True
2022-03-04 16:44:20,574 - mmdeploy - INFO - GPU 0: GeForce RTX 2080 Ti
2022-03-04 16:44:20,574 - mmdeploy - INFO - CUDA_HOME: /usr/local/cuda
2022-03-04 16:44:20,574 - mmdeploy - INFO - NVCC: Cuda compilation tools, release 10.1, V10.1.243
2022-03-04 16:44:20,574 - mmdeploy - INFO - GCC: gcc (Ubuntu 7.5.0-3ubuntu1~18.04) 7.5.0
2022-03-04 16:44:20,574 - mmdeploy - INFO - PyTorch: 1.8.1+cu102
2022-03-04 16:44:20,574 - mmdeploy - INFO - PyTorch compiling details: PyTorch built with:

  • GCC 7.3
  • C++ Version: 201402
  • Intel(R) Math Kernel Library Version 2020.0.0 Product Build 20191122 for Intel(R) 64 architecture applications
  • Intel(R) MKL-DNN v1.7.0 (Git Hash 7aed236906b1f7a05c0917e5257a1af05e9ff683)
  • OpenMP 201511 (a.k.a. OpenMP 4.5)
  • NNPACK is enabled
  • CPU capability usage: AVX2
  • CUDA Runtime 10.2
  • NVCC architecture flags: -gencode;arch=compute_37,code=sm_37;-gencode;arch=compute_50,code=sm_50;-gencode;arch=compute_60,code=sm_60;-gencode;arch=compute_70,code=sm_70
  • CuDNN 7.6.5
  • Magma 2.5.2
  • Build settings: BLAS_INFO=mkl, BUILD_TYPE=Release, CUDA_VERSION=10.2, CUDNN_VERSION=7.6.5, CXX_COMPILER=/opt/rh/devtoolset-7/root/usr/bin/c++, CXX_FLAGS= -Wno-deprecated -fvisibility-inlines-hidden -DUSE_PTHREADPOOL -fopenmp -DNDEBUG -DUSE_KINETO -DUSE_FBGEMM -DUSE_QNNPACK -DUSE_PYTORCH_QNNPACK -DUSE_XNNPACK -O2 -fPIC -Wno-narrowing -Wall -Wextra -Werror=return-type -Wno-missing-field-initializers -Wno-type-limits -Wno-array-bounds -Wno-unknown-pragmas -Wno-sign-compare -Wno-unused-parameter -Wno-unused-variable -Wno-unused-function -Wno-unused-result -Wno-unused-local-typedefs -Wno-strict-overflow -Wno-strict-aliasing -Wno-error=deprecated-declarations -Wno-stringop-overflow -Wno-psabi -Wno-error=pedantic -Wno-error=redundant-decls -Wno-error=old-style-cast -fdiagnostics-color=always -faligned-new -Wno-unused-but-set-variable -Wno-maybe-uninitialized -fno-math-errno -fno-trapping-math -Werror=format -Wno-stringop-overflow, LAPACK_INFO=mkl, PERF_WITH_AVX=1, PERF_WITH_AVX2=1, PERF_WITH_AVX512=1, TORCH_VERSION=1.8.1, USE_CUDA=ON, USE_CUDNN=ON, USE_EXCEPTION_PTR=1, USE_GFLAGS=OFF, USE_GLOG=OFF, USE_MKL=ON, USE_MKLDNN=ON, USE_MPI=OFF, USE_NCCL=ON, USE_NNPACK=ON, USE_OPENMP=ON,

2022-03-04 16:44:20,574 - mmdeploy - INFO - TorchVision: 0.9.1+cu102
2022-03-04 16:44:20,574 - mmdeploy - INFO - OpenCV: 4.5.2
2022-03-04 16:44:20,574 - mmdeploy - INFO - MMCV: 1.4.5
2022-03-04 16:44:20,574 - mmdeploy - INFO - MMCV Compiler: GCC 7.3
2022-03-04 16:44:20,574 - mmdeploy - INFO - MMCV CUDA Compiler: 10.2
2022-03-04 16:44:20,574 - mmdeploy - INFO - MMDeployment: 0.2.0+925bd02
2022-03-04 16:44:20,574 - mmdeploy - INFO -

2022-03-04 16:44:20,574 - mmdeploy - INFO - Backend information
Error traceback

If applicable, paste the error trackback here.

A placeholder for trackback.

Bug fix

If you have already identified the reason, you can provide the information here. If you are willing to create a PR to fix it, please also leave a comment here and that would be much appreciated!

Metadata

Metadata

Labels

No labels
No labels

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions