-
Notifications
You must be signed in to change notification settings - Fork 706
Closed
Description
Describe the bug
The rewriting of the pytorch function triu for ONNX is not correct when a model uses this function with the parameter diagonal with a value different than 0, it doesn't throw any error but gives the wrong output.
Reproduction
Just try to convert a model which uses triu with triu function in his decoder NRTRDecoder, and compare the output of the model before and after with RewriterContext(**context_info) (which rewrites the functions) in the file mmdeploy/apis/onnx/export.py.
Bug fix
We just have to change the rewriting of the triu fonction in mmdeploy/pytorch/functions/triu.py to support diagonal:
# Copyright (c) OpenMMLab. All rights reserved.
import torch
from mmdeploy.core import FUNCTION_REWRITER
@FUNCTION_REWRITER.register_rewriter(func_name='torch.triu')
def triu(ctx,
input: torch.Tensor,
diagonal: int = 0,
*args,
**kwargs) -> torch.Tensor:
"""Rewrite `triu` for exporting model to ONNX."""
assert len(input.shape) >= 2
height, width = input.shape[-2:]
arange0 = torch.arange(width, device=input.device).unsqueeze(0)
arange1 = torch.arange(height, device=input.device).unsqueeze(-1)
mask = arange0 >= torch.add(arange1, diagonal)
return input * mask
Reactions are currently unavailable
Metadata
Metadata
Assignees
Labels
No labels