Skip to content

Pytorch function triu rewriting for ONNX is not correct when diagonal parameter is used #791

@Antoine-Prieur

Description

@Antoine-Prieur

Describe the bug

The rewriting of the pytorch function triu for ONNX is not correct when a model uses this function with the parameter diagonal with a value different than 0, it doesn't throw any error but gives the wrong output.

Reproduction

Just try to convert a model which uses triu with $\text{diagonal} \ne 0$, I used the model SATRN, which uses a triu function in his decoder NRTRDecoder, and compare the output of the model before and after with RewriterContext(**context_info) (which rewrites the functions) in the file mmdeploy/apis/onnx/export.py.

Bug fix

We just have to change the rewriting of the triu fonction in mmdeploy/pytorch/functions/triu.py to support diagonal:

# Copyright (c) OpenMMLab. All rights reserved.
import torch

from mmdeploy.core import FUNCTION_REWRITER


@FUNCTION_REWRITER.register_rewriter(func_name='torch.triu')
def triu(ctx,
         input: torch.Tensor,
         diagonal: int = 0,
         *args,
         **kwargs) -> torch.Tensor:
     """Rewrite `triu` for exporting model to ONNX."""
     assert len(input.shape) >= 2
     height, width = input.shape[-2:]
     arange0 = torch.arange(width, device=input.device).unsqueeze(0)
     arange1 = torch.arange(height, device=input.device).unsqueeze(-1)
     mask = arange0 >= torch.add(arange1, diagonal)
     return input * mask

Metadata

Metadata

Assignees

Labels

No labels
No labels

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions