Skip to content

TensorRT deploy error after upgrading version #460

@austinmw

Description

@austinmw

Hi, I attempted to upgrade the GPU Dockerfile to use TensorRT 21.08 in order to make it compatible with my Triton inference container version. Before upgrading I was able to successfully export my yolox model to an end2end.engine file, but after upgrading I face the following error:

[2022-05-10 18:38:06.082] [mmdeploy] [info] [model.cpp:95] Register 'DirectoryModel'
2022-05-10 18:38:06,147 - mmdeploy - INFO - torch2onnx start.
[2022-05-10 18:38:07.800] [mmdeploy] [info] [model.cpp:95] Register 'DirectoryModel'
2022-05-10:18:38:08,matplotlib.font_manager INFO [font_manager.py:1443] generated new fontManager
load checkpoint from local path: /volume_share/checkpoint.pth
The model and loaded state dict do not match exactly

unexpected key in source state_dict: ema_backbone_stem_conv_conv_weight, ema_backbone_stem_conv_bn_weight,
...SHORTENED BY ME TO REDUCE LOGS

2022-05-10 18:38:13,432 - mmdeploy - WARNING - DeprecationWarning: get_onnx_config will be deprecated in the future.
2022-05-10:18:38:13,mmdeploy WARNING [utils.py:91] DeprecationWarning: get_onnx_config will be deprecated in the future.
/root/workspace/mmdeploy/mmdeploy/core/optimizers/function_marker.py:158: TracerWarning: Converting a tensor to a Python integer might cause the trace to be incorrect. We can't record the data flow of Python values, so this value will be treated as a constant in the future. This means that the trace might not generalize to other inputs!
ys_shape = tuple(int(s) for s in ys.shape)
/opt/conda/lib/python3.8/site-packages/torch/functional.py:568: UserWarning: torch.meshgrid: in an upcoming release, it will be required to pass the indexing argument. (Triggered internally at /opt/conda/conda-bld/pytorch_1646755903507/work/aten/src/ATen/native/TensorShape.cpp:2228.)
return _VF.meshgrid(tensors, **kwargs) # type: ignore[attr-defined]
/root/workspace/mmdeploy/mmdeploy/codebase/mmdet/core/post_processing/bbox_nms.py:260: TracerWarning: Converting a tensor to a Python integer might cause the trace to be incorrect. We can't record the data flow of Python values, so this value will be treated as a constant in the future. This means that the trace might not generalize to other inputs!
dets, labels = TRTBatchedNMSop.apply(boxes, scores, int(scores.shape[-1]),
/root/workspace/mmdeploy/mmdeploy/mmcv/ops/nms.py:177: TracerWarning: Converting a tensor to a Python boolean might cause the trace to be incorrect. We can't record the data flow of Python values, so this value will be treated as a constant in the future. This means that the trace might not generalize to other inputs!
out_boxes = min(num_boxes, after_topk)
WARNING: The shape inference of mmdeploy::TRTBatchedNMS type is missing, so it may result in wrong shape inference for the exported graph. Please consider adding it in symbolic function.
WARNING: The shape inference of mmdeploy::TRTBatchedNMS type is missing, so it may result in wrong shape inference for the exported graph. Please consider adding it in symbolic function.
WARNING: The shape inference of mmdeploy::TRTBatchedNMS type is missing, so it may result in wrong shape inference for the exported graph. Please consider adding it in symbolic function.
WARNING: The shape inference of mmdeploy::TRTBatchedNMS type is missing, so it may result in wrong shape inference for the exported graph. Please consider adding it in symbolic function.
WARNING: The shape inference of mmdeploy::TRTBatchedNMS type is missing, so it may result in wrong shape inference for the exported graph. Please consider adding it in symbolic function.
WARNING: The shape inference of mmdeploy::TRTBatchedNMS type is missing, so it may result in wrong shape inference for the exported graph. Please consider adding it in symbolic function.
2022-05-10 18:38:24,320 - mmdeploy - INFO - torch2onnx success.
2022-05-10 18:38:24,323 - mmdeploy - INFO - onnx2tensorrt of /volume_share/end2end.onnx start.
[2022-05-10 18:38:25.940] [mmdeploy] [info] [model.cpp:95] Register 'DirectoryModel'
2022-05-10 18:38:25,978 - mmdeploy - INFO - Successfully loaded tensorrt plugins from /root/workspace/mmdeploy/build/lib/libmmdeploy_tensorrt_ops.so
2022-05-10:18:38:25,mmdeploy INFO [init_plugins.py:32] Successfully loaded tensorrt plugins from /root/workspace/mmdeploy/build/lib/libmmdeploy_tensorrt_ops.so
[TensorRT] INFO: [MemUsageChange] Init CUDA: CPU +249, GPU +0, now: CPU 318, GPU 481 (MiB)
[TensorRT] WARNING: onnx2trt_utils.cpp:362: Your ONNX model has been generated with INT64 weights, while TensorRT does not natively support INT64. Attempting to cast down to INT32.
[TensorRT] WARNING: onnx2trt_utils.cpp:390: One or more weights outside the range of INT32 was clamped
[TensorRT] WARNING: onnx2trt_utils.cpp:390: One or more weights outside the range of INT32 was clamped
[TensorRT] WARNING: onnx2trt_utils.cpp:390: One or more weights outside the range of INT32 was clamped
[TensorRT] WARNING: onnx2trt_utils.cpp:390: One or more weights outside the range of INT32 was clamped
[TensorRT] WARNING: onnx2trt_utils.cpp:390: One or more weights outside the range of INT32 was clamped
[TensorRT] WARNING: onnx2trt_utils.cpp:390: One or more weights outside the range of INT32 was clamped
[TensorRT] WARNING: onnx2trt_utils.cpp:390: One or more weights outside the range of INT32 was clamped
[TensorRT] WARNING: onnx2trt_utils.cpp:390: One or more weights outside the range of INT32 was clamped
[TensorRT] WARNING: onnx2trt_utils.cpp:390: One or more weights outside the range of INT32 was clamped
[TensorRT] WARNING: onnx2trt_utils.cpp:390: One or more weights outside the range of INT32 was clamped
[TensorRT] WARNING: onnx2trt_utils.cpp:390: One or more weights outside the range of INT32 was clamped
[TensorRT] INFO: No importer registered for op: TRTBatchedNMS. Attempting to import as plugin.
[TensorRT] INFO: Searching for plugin: TRTBatchedNMS, plugin_version: 1, plugin_namespace:
[TensorRT] INFO: Successfully created plugin: TRTBatchedNMS
[TensorRT] INFO: [MemUsageSnapshot] Builder begin: CPU 2119 MiB, GPU 1215 MiB
[TensorRT] INFO: [MemUsageChange] Init cuBLAS/cuBLASLt: CPU +347, GPU +160, now: CPU 2474, GPU 1375 (MiB)
[TensorRT] INFO: [MemUsageChange] Init cuDNN: CPU +273, GPU +168, now: CPU 2747, GPU 1543 (MiB)
[TensorRT] WARNING: Detected invalid timing cache, setup a local cache instead
[TensorRT] INFO: [MemUsageChange] Init cuBLAS/cuBLASLt: CPU +0, GPU +0, now: CPU 2747, GPU 1525 (MiB)
[TensorRT] ERROR: 4: [shapeCompiler.cpp::evaluateShapeChecks::822] Error Code 4: Internal Error (kOPT values for profile 0 violate shape constraints: condition '==' violated. Concat_351: dimensions not compatible for concatenation)
2022-05-10:18:38:31,root ERROR [utils.py:43] Failed to create TensorRT engine
Traceback (most recent call last):
File "/root/workspace/mmdeploy/mmdeploy/utils/utils.py", line 38, in target_wrapper
result = target(*args, **kwargs)
File "/root/workspace/mmdeploy/mmdeploy/backend/tensorrt/onnx2tensorrt.py", line 75, in onnx2tensorrt
engine = create_trt_engine(
File "/root/workspace/mmdeploy/mmdeploy/backend/tensorrt/utils.py", line 116, in create_trt_engine
assert engine is not None, 'Failed to create TensorRT engine'
AssertionError: Failed to create TensorRT engine
2022-05-10 18:38:32,391 - mmdeploy - ERROR - onnx2tensorrt of /volume_share/end2end.onnx failed.
checkpoint.pth end2end.onnx

Any help in solving this error Error Code 4: Internal Error (kOPT values for profile 0 violate shape constraints: condition '==' violated. Concat_351: dimensions not compatible for concatenation) would be greatly appreciated!


Here's a copy of my updated Dockerfile for reference:

# Use this version to match triton inference container
FROM nvcr.io/nvidia/tensorrt:21.08-py3
    
ARG CUDA=11.3
ARG PYTHON_VERSION=3.8
ARG TORCH_VERSION=1.11.0
ARG TORCHVISION_VERSION=0.12.0
ARG ONNXRUNTIME_VERSION=1.11.1
ARG MMCV_VERSION=1.5.0
ARG PPLCV_VERSION=0.6.3
ENV FORCE_CUDA="1"



ENV DEBIAN_FRONTEND=noninteractive

### change the system source for installing libs
ARG USE_SRC_INSIDE=false
RUN if [ ${USE_SRC_INSIDE} == true ] ; \
    then \
        sed -i s/archive.ubuntu.com/mirrors.aliyun.com/g /etc/apt/sources.list ; \
        sed -i s/security.ubuntu.com/mirrors.aliyun.com/g /etc/apt/sources.list ; \
        echo "Use aliyun source for installing libs" ; \
    else \
        echo "Keep the download source unchanged" ; \
    fi

### update apt and install libs
RUN apt-get update &&\
    apt-get install -y vim libsm6 libxext6 libxrender-dev libgl1-mesa-glx git wget libssl-dev libopencv-dev libspdlog-dev --no-install-recommends &&\
    rm -rf /var/lib/apt/lists/*

RUN curl -fsSL -v -o ~/miniconda.sh -O  https://repo.anaconda.com/miniconda/Miniconda3-latest-Linux-x86_64.sh  && \
    chmod +x ~/miniconda.sh && \
    ~/miniconda.sh -b -p /opt/conda && \
    rm ~/miniconda.sh && \
    /opt/conda/bin/conda install -y python=${PYTHON_VERSION} conda-build pyyaml numpy ipython cython typing typing_extensions mkl mkl-include ninja && \
    /opt/conda/bin/conda clean -ya

### pytorch
RUN /opt/conda/bin/conda install pytorch==${TORCH_VERSION} torchvision==${TORCHVISION_VERSION} cudatoolkit=${CUDA} -c pytorch
ENV PATH /opt/conda/bin:$PATH

### install mmcv-full
RUN /opt/conda/bin/pip install mmcv-full==${MMCV_VERSION} -f https://download.openmmlab.com/mmcv/dist/cu${CUDA//./}/torch${TORCH_VERSION}/index.html

WORKDIR /root/workspace
### get onnxruntime
RUN wget https://github.com/microsoft/onnxruntime/releases/download/v${ONNXRUNTIME_VERSION}/onnxruntime-linux-x64-${ONNXRUNTIME_VERSION}.tgz \
    && tar -zxvf onnxruntime-linux-x64-${ONNXRUNTIME_VERSION}.tgz &&\
    pip install onnxruntime-gpu==${ONNXRUNTIME_VERSION}

### cp trt from pip to conda
RUN cp -r /usr/local/lib/python${PYTHON_VERSION}/dist-packages/tensorrt* /opt/conda/lib/python${PYTHON_VERSION}/site-packages/

### install mmdeploy
ENV ONNXRUNTIME_DIR=/root/workspace/onnxruntime-linux-x64-${ONNXRUNTIME_VERSION}
ENV TENSORRT_DIR=/workspace/tensorrt
ARG VERSION
RUN git clone https://github.com/open-mmlab/mmdeploy &&\
    cd mmdeploy &&\
    if [ -z ${VERSION} ] ; then echo "No MMDeploy version passed in, building on master" ; else git checkout tags/v${VERSION} -b tag_v${VERSION} ; fi &&\
    git submodule update --init --recursive &&\
    mkdir -p build &&\
    cd build &&\
    cmake -DMMDEPLOY_TARGET_BACKENDS="ort;trt" .. &&\
    make -j$(nproc) &&\
    cd .. &&\
    pip install -e .

### build sdk
RUN git clone https://github.com/openppl-public/ppl.cv.git &&\
    cd ppl.cv &&\
    git checkout tags/v${PPLCV_VERSION} -b v${PPLCV_VERSION} &&\
    ./build.sh cuda

ENV BACKUP_LD_LIBRARY_PATH=$LD_LIBRARY_PATH
ENV LD_LIBRARY_PATH=/usr/local/cuda-11.4/compat/lib.real/:$LD_LIBRARY_PATH

RUN cd /root/workspace/mmdeploy &&\
    rm -rf build/CM* build/cmake-install.cmake build/Makefile build/csrc &&\
    mkdir -p build && cd build &&\
    cmake .. \
        -DMMDEPLOY_BUILD_SDK=ON \
        -DCMAKE_CXX_COMPILER=g++ \
        -Dpplcv_DIR=/root/workspace/ppl.cv/cuda-build/install/lib/cmake/ppl \
        -DTENSORRT_DIR=${TENSORRT_DIR} \
        -DONNXRUNTIME_DIR=${ONNXRUNTIME_DIR} \
        -DMMDEPLOY_BUILD_SDK_PYTHON_API=ON \
        -DMMDEPLOY_TARGET_DEVICES="cuda;cpu" \
        -DMMDEPLOY_TARGET_BACKENDS="ort;trt" \
        -DMMDEPLOY_CODEBASES=all &&\
    make -j$(nproc) && make install &&\
    cd install/example  && mkdir -p build && cd build &&\
    cmake -DMMDeploy_DIR=/root/workspace/mmdeploy/build/install/lib/cmake/MMDeploy .. &&\
    make -j$(nproc) && export SPDLOG_LEVEL=warn &&\
    if [ -z ${VERSION} ] ; then echo "Built MMDeploy master for GPU devices successfully!" ; else echo "Built MMDeploy version v${VERSION} for GPU devices successfully!" ; fi

ENV LD_LIBRARY_PATH="/root/workspace/mmdeploy/build/lib:${BACKUP_LD_LIBRARY_PATH}"

Metadata

Metadata

Assignees

Labels

No labels
No labels

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions