Skip to content

[blockChooser.cpp::getRegionBlockSize::690] Error Code 2: Internal Error (Assertion memSize >= 0 failed. ) #2045

@AllentDan

Description

@AllentDan

Description

Encounter the error as follows:

[blockChooser.cpp::getRegionBlockSize::690] Error Code 2: Internal Error (Assertion memSize >= 0 failed. )

Environment

TensorRT Version: 8+
NVIDIA GPU: 1660
NVIDIA Driver Version: 470
CUDA Version: 11.3
CUDNN Version: compatible with cuda 11.3
Operating System: linux x86
Python Version (if applicable): 3.8
PyTorch Version (if applicable): 1.10

Steps To Reproduce

import torch
import onnx
import tensorrt as trt

onnx_model = 'model.onnx'


class NaiveModel(torch.nn.Module):

    def __init__(self):
        super().__init__()

    def forward(self, x):
        B, C, H, W = x.shape

        pad_w = W % 7
        pad_h = H % 7
        x_t = torch.zeros((B, C, H + pad_h, W + pad_w), device=x.device)
        x_t[:, :, :H, :W] = x
        return x_t


device = torch.device('cuda:0')

# generate ONNX model
torch.onnx.export(
    NaiveModel(),
    torch.randn(1, 3, 224, 224),
    onnx_model,
    input_names=['input'],
    output_names=['output'],
    dynamic_axes=dict(
        input=dict({
            0: 'batch',
            2: 'height',
            3: 'width'
        }),
        output=dict({0: 'batch'})),
    opset_version=11)
onnx_model = onnx.load(onnx_model)

# load_tensorrt_plugin()
# create builder and network
logger = trt.Logger(trt.Logger.ERROR)
builder = trt.Builder(logger)
EXPLICIT_BATCH = 1 << (int)(trt.NetworkDefinitionCreationFlag.EXPLICIT_BATCH)
network = builder.create_network(EXPLICIT_BATCH)

# parse onnx
parser = trt.OnnxParser(network, logger)

if not parser.parse(onnx_model.SerializeToString()):
    error_msgs = ''
    for error in range(parser.num_errors):
        error_msgs += f'{parser.get_error(error)}\n'
    raise RuntimeError(f'Failed to parse onnx, {error_msgs}')

config = builder.create_builder_config()
config.max_workspace_size = 1 << 30
profile = builder.create_optimization_profile()

profile.set_shape('input', [1, 3, 112, 112], [1, 3, 224, 224],
                  [1, 3, 512, 512])
config.add_optimization_profile(profile)
# create engine
with torch.cuda.device(device):
    engine = builder.build_engine(network, config)

with open('model.engine', mode='wb') as f:
    f.write(bytearray(engine.serialize()))
    print("generating file done!")

Metadata

Metadata

Assignees

Labels

Feature RequestRequest for new functionalitytriagedIssue has been triaged by maintainers

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions