Skip to content

Conversation

@justinchuby
Copy link
Collaborator

@justinchuby justinchuby commented Jan 18, 2025

Use matmul when the input is not rank 2 to avoid decomp to addmm.

@justinchuby justinchuby added the module: torchlib Related to the torch/aten function lib in development label Jan 18, 2025
@justinchuby justinchuby enabled auto-merge (squash) January 18, 2025 01:02
@justinchuby
Copy link
Collaborator Author

justinchuby commented Jan 18, 2025

import torch


class TestModule(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.linear = torch.nn.Linear(10, 10)

    def forward(self, x):
        return self.linear(x)


model = TestModule()
ep = torch.onnx.export(model, (torch.randn(1, 10),), dynamo=True, verify=True)
print(ep)
ep = torch.onnx.export(model, (torch.randn(1, 12, 15, 10),), dynamo=True, verify=True)
print(ep)
[torch.onnx] Obtain model graph for `TestModule([...]` with `torch.export.export(..., strict=False)`...
[torch.onnx] Obtain model graph for `TestModule([...]` with `torch.export.export(..., strict=False)`... ✅
[torch.onnx] Run decomposition...
[torch.onnx] Run decomposition... ✅
[torch.onnx] Translate the graph into ONNX...
[torch.onnx] Translate the graph into ONNX... ✅
[torch.onnx] Check the ONNX model...
[torch.onnx] Check the ONNX model... ✅
[torch.onnx] Execute the model with ONNX Runtime...
[torch.onnx] Execute the model with ONNX Runtime... ✅
[torch.onnx] Verify output accuracy...
[torch.onnx] Verify output accuracy... ✅
ONNXProgram(
    model=
        <
            ir_version=10,
            opset_imports={'pkg.onnxscript.torch_lib.common': 1, '': 18},
            producer_name='pytorch',
            producer_version='2.7.0.dev20250115+cu124',
            domain=None,
            model_version=None,
        >
        graph(
            name=main_graph,
            inputs=(
                %"x"<FLOAT,[1,10]>
            ),
            outputs=(
                %"linear"<FLOAT,[1,10]>
            ),
            initializers=(
                %"linear.weight"<FLOAT,[10,10]>,
                %"linear.bias"<FLOAT,[10]>
            ),
        ) {
            0 |  # node_Gemm_0
                 %"linear"<FLOAT,[1,10]> ⬅️ ::Gemm(%"x", %"linear.weight", %"linear.bias") {beta=1.0, transB=True, alpha=1.0, transA=0}
            return %"linear"<FLOAT,[1,10]>
        }

        <
            opset_imports={'': 18},
        >
        def pkg.onnxscript.torch_lib.common::Rank(
            inputs=(
                %"input"<?,?>
            ),
            outputs=(
                %"return_val"<?,?>
            ),
        ) {
            0 |  # n0
                 %"tmp"<?,?> ⬅️ ::Shape(%"input")
            1 |  # n1
                 %"return_val"<?,?> ⬅️ ::Size(%"tmp")
            return %"return_val"<?,?>
        }

        <
            opset_imports={'': 18},
        >
        def pkg.onnxscript.torch_lib.common::IsScalar(
            inputs=(
                %"input"<?,?>
            ),
            outputs=(
                %"return_val"<?,?>
            ),
        ) {
            0 |  # n0
                 %"tmp"<?,?> ⬅️ ::Shape(%"input")
            1 |  # n1
                 %"tmp_0"<?,?> ⬅️ ::Size(%"tmp")
            2 |  # n2
                 %"tmp_1"<?,?> ⬅️ ::Constant() {value_int=0}
            3 |  # n3
                 %"return_val"<?,?> ⬅️ ::Equal(%"tmp_0", %"tmp_1")
            return %"return_val"<?,?>
        }
    ,
    exported_program=
        ExportedProgram:
            class GraphModule(torch.nn.Module):
                def forward(self, p_linear_weight: "f32[10, 10]", p_linear_bias: "f32[10]", x: "f32[1, 10]"):
                     # File: /home/justinchu/anaconda3/envs/onnx/lib/python3.13/site-packages/torch/nn/modules/linear.py:125 in forward, code: return F.linear(input, self.weight, self.bias)
                    linear: "f32[1, 10]" = torch.ops.aten.linear.default(x, p_linear_weight, p_linear_bias);  x = p_linear_weight = p_linear_bias = None
                    return (linear,)
            
        Graph signature: ExportGraphSignature(input_specs=[InputSpec(kind=<InputKind.PARAMETER: 2>, arg=TensorArgument(name='p_linear_weight'), target='linear.weight', persistent=None), InputSpec(kind=<InputKind.PARAMETER: 2>, arg=TensorArgument(name='p_linear_bias'), target='linear.bias', persistent=None), InputSpec(kind=<InputKind.USER_INPUT: 1>, arg=TensorArgument(name='x'), target=None, persistent=None)], output_specs=[OutputSpec(kind=<OutputKind.USER_OUTPUT: 1>, arg=TensorArgument(name='linear'), target=None)])
        Range constraints: {}

)

[torch.onnx] Obtain model graph for `TestModule([...]` with `torch.export.export(..., strict=False)`...
[torch.onnx] Obtain model graph for `TestModule([...]` with `torch.export.export(..., strict=False)`... ✅
[torch.onnx] Run decomposition...
[torch.onnx] Run decomposition... ✅
[torch.onnx] Translate the graph into ONNX...
[torch.onnx] Translate the graph into ONNX... ✅
[torch.onnx] Check the ONNX model...
[torch.onnx] Check the ONNX model... ✅
[torch.onnx] Execute the model with ONNX Runtime...
[torch.onnx] Execute the model with ONNX Runtime... ✅
[torch.onnx] Verify output accuracy...
[torch.onnx] Verify output accuracy... ✅
ONNXProgram(
    model=
        <
            ir_version=10,
            opset_imports={'pkg.onnxscript.torch_lib.common': 1, '': 18},
            producer_name='pytorch',
            producer_version='2.7.0.dev20250115+cu124',
            domain=None,
            model_version=None,
        >
        graph(
            name=main_graph,
            inputs=(
                %"x"<FLOAT,[1,12,15,10]>
            ),
            outputs=(
                %"linear"<FLOAT,[1,12,15,10]>
            ),
            initializers=(
                %"linear.weight"<FLOAT,[10,10]>,
                %"linear.bias"<FLOAT,[10]>
            ),
        ) {
            0 |  # node_Transpose_0
                 %"val_0"<?,?> ⬅️ ::Transpose(%"linear.weight") {perm=[1, 0]}
            1 |  # node_MatMul_1
                 %"val_1"<?,?> ⬅️ ::MatMul(%"x", %"val_0")
            2 |  # node_Add_2
                 %"linear"<FLOAT,[1,12,15,10]> ⬅️ ::Add(%"val_1", %"linear.bias")
            return %"linear"<FLOAT,[1,12,15,10]>
        }

        <
            opset_imports={'': 18},
        >
        def pkg.onnxscript.torch_lib.common::Rank(
            inputs=(
                %"input"<?,?>
            ),
            outputs=(
                %"return_val"<?,?>
            ),
        ) {
            0 |  # n0
                 %"tmp"<?,?> ⬅️ ::Shape(%"input")
            1 |  # n1
                 %"return_val"<?,?> ⬅️ ::Size(%"tmp")
            return %"return_val"<?,?>
        }

        <
            opset_imports={'': 18},
        >
        def pkg.onnxscript.torch_lib.common::IsScalar(
            inputs=(
                %"input"<?,?>
            ),
            outputs=(
                %"return_val"<?,?>
            ),
        ) {
            0 |  # n0
                 %"tmp"<?,?> ⬅️ ::Shape(%"input")
            1 |  # n1
                 %"tmp_0"<?,?> ⬅️ ::Size(%"tmp")
            2 |  # n2
                 %"tmp_1"<?,?> ⬅️ ::Constant() {value_int=0}
            3 |  # n3
                 %"return_val"<?,?> ⬅️ ::Equal(%"tmp_0", %"tmp_1")
            return %"return_val"<?,?>
        }
    ,
    exported_program=
        ExportedProgram:
            class GraphModule(torch.nn.Module):
                def forward(self, p_linear_weight: "f32[10, 10]", p_linear_bias: "f32[10]", x: "f32[1, 12, 15, 10]"):
                     # File: /home/justinchu/anaconda3/envs/onnx/lib/python3.13/site-packages/torch/nn/modules/linear.py:125 in forward, code: return F.linear(input, self.weight, self.bias)
                    linear: "f32[1, 12, 15, 10]" = torch.ops.aten.linear.default(x, p_linear_weight, p_linear_bias);  x = p_linear_weight = p_linear_bias = None
                    return (linear,)
            
        Graph signature: ExportGraphSignature(input_specs=[InputSpec(kind=<InputKind.PARAMETER: 2>, arg=TensorArgument(name='p_linear_weight'), target='linear.weight', persistent=None), InputSpec(kind=<InputKind.PARAMETER: 2>, arg=TensorArgument(name='p_linear_bias'), target='linear.bias', persistent=None), InputSpec(kind=<InputKind.USER_INPUT: 1>, arg=TensorArgument(name='x'), target=None, persistent=None)], output_specs=[OutputSpec(kind=<OutputKind.USER_OUTPUT: 1>, arg=TensorArgument(name='linear'), target=None)])
        Range constraints: {}

)

@codecov
Copy link

codecov bot commented Jan 18, 2025

❌ 51 Tests Failed:

Tests completed Failed Passed Skipped
11814 51 11763 2454
View the top 1 failed tests by shortest run time
onnxscript.backend.onnx_export_test.TestOnnxBackEnd::test_export2python_produces_correct_onnx_script_model_0212_test_cast_STRING_to_FLOAT
Stack Traces | 0.003s run time
onnxscript\backend\onnx_export_test.py:137: in extract_functions
    mod = importlib.import_module(import_name)
C:\hostedtoolcache\windows\Python\3.10.11\x64\lib\importlib\__init__.py:126: in import_module
    return _bootstrap._gcd_import(name[level:], package, level)
E   ModuleNotFoundError: No module named 'tests.onnx_backend_test_code.test_cast_STRING_to_FLOAT'

The above exception was the direct cause of the following exception:
.nox\test\lib\site-packages\parameterized\parameterized.py:620: in standalone_func
    return func(*(a + p.args), **p.kwargs, **kw)
onnxscript\backend\onnx_export_test.py:271: in test_export2python_produces_correct_onnx_script_model
    functions = extract_functions(backend_test.name, code, self.test_folder)
onnxscript\backend\onnx_export_test.py:139: in extract_functions
    raise AssertionError(
E   AssertionError: Unable to import 'tests.onnx_backend_test_code.test_cast_STRING_to_FLOAT' (e=No module named 'tests.onnx_backend_test_code.test_cast_STRING_to_FLOAT') (file: 'D:\\a\\onnxscript\\onnxscript\\tests\\onnx_backend_test_code\\test_cast_STRING_to_FLOAT.py', absolute path: 'D:\\a\\onnxscript\\onnxscript\\tests\\onnx_backend_test_code\\test_cast_STRING_to_FLOAT.py', current folder: D:\a\onnxscript\onnxscript
E   ---- CONTENT --
E   import numpy
E   from onnx import TensorProto
E   from onnx.helper import make_tensor
E   from onnxscript import script, external_tensor
E   from onnxscript.values import Opset
E   from onnxscript.onnx_types import FLOAT, STRING
E   from onnxscript.onnx_opset import opset21
E   
E   @script()
E   def bck_test_cast_STRING_to_FLOAT(input: STRING[3,4]) -> (FLOAT[3,4]):
E       output = opset21.Cast(input, to=1)
E       return output
View the full list of 2 ❄️ flaky tests
onnxscript.backend.onnx_export_test.TestOnnxBackEnd::test_export2python_produces_correct_onnx_script_model_0634_test_max_int32

Flake rate in main: 5.00% (Passed 38 times, Failed 2 times)

Stack Traces | 0.003s run time
onnxscript\backend\onnx_export_test.py:137: in extract_functions
    mod = importlib.import_module(import_name)
C:\hostedtoolcache\windows\Python\3.10.11\x64\lib\importlib\__init__.py:126: in import_module
    return _bootstrap._gcd_import(name[level:], package, level)
E   ModuleNotFoundError: No module named 'tests.onnx_backend_test_code.test_max_int32'

The above exception was the direct cause of the following exception:
.nox\test\lib\site-packages\parameterized\parameterized.py:620: in standalone_func
    return func(*(a + p.args), **p.kwargs, **kw)
onnxscript\backend\onnx_export_test.py:271: in test_export2python_produces_correct_onnx_script_model
    functions = extract_functions(backend_test.name, code, self.test_folder)
onnxscript\backend\onnx_export_test.py:139: in extract_functions
    raise AssertionError(
E   AssertionError: Unable to import 'tests.onnx_backend_test_code.test_max_int32' (e=No module named 'tests.onnx_backend_test_code.test_max_int32') (file: 'D:\\a\\onnxscript\\onnxscript\\tests\\onnx_backend_test_code\\test_max_int32.py', absolute path: 'D:\\a\\onnxscript\\onnxscript\\tests\\onnx_backend_test_code\\test_max_int32.py', current folder: D:\a\onnxscript\onnxscript
E   ---- CONTENT --
E   import numpy
E   from onnx import TensorProto
E   from onnx.helper import make_tensor
E   from onnxscript import script, external_tensor
E   from onnxscript.values import Opset
E   from onnxscript.onnx_types import INT32
E   from onnxscript.onnx_opset import opset13
E   
E   @script()
E   def bck_test_max_int32(data_0: INT32[3], data_1: INT32[3]) -> (INT32[3]):
E       result = opset13.Max(data_0, data_1)
E       return result
onnxscript.backend.onnx_export_test.TestOnnxBackEnd::test_export2python_produces_correct_onnx_script_model_1104_test_shape_start_1_end_2

Flake rate in main: 12.50% (Passed 14 times, Failed 2 times)

Stack Traces | 0.003s run time
onnxscript\backend\onnx_export_test.py:137: in extract_functions
    mod = importlib.import_module(import_name)
C:\hostedtoolcache\windows\Python\3.10.11\x64\lib\importlib\__init__.py:126: in import_module
    return _bootstrap._gcd_import(name[level:], package, level)
E   ModuleNotFoundError: No module named 'tests.onnx_backend_test_code.test_shape_start_1_end_2'

The above exception was the direct cause of the following exception:
.nox\test\lib\site-packages\parameterized\parameterized.py:620: in standalone_func
    return func(*(a + p.args), **p.kwargs, **kw)
onnxscript\backend\onnx_export_test.py:271: in test_export2python_produces_correct_onnx_script_model
    functions = extract_functions(backend_test.name, code, self.test_folder)
onnxscript\backend\onnx_export_test.py:139: in extract_functions
    raise AssertionError(
E   AssertionError: Unable to import 'tests.onnx_backend_test_code.test_shape_start_1_end_2' (e=No module named 'tests.onnx_backend_test_code.test_shape_start_1_end_2') (file: 'D:\\a\\onnxscript\\onnxscript\\tests\\onnx_backend_test_code\\test_shape_start_1_end_2.py', absolute path: 'D:\\a\\onnxscript\\onnxscript\\tests\\onnx_backend_test_code\\test_shape_start_1_end_2.py', current folder: D:\a\onnxscript\onnxscript
E   ---- CONTENT --
E   import numpy
E   from onnx import TensorProto
E   from onnx.helper import make_tensor
E   from onnxscript import script, external_tensor
E   from onnxscript.values import Opset
E   from onnxscript.onnx_types import FLOAT, INT64
E   from onnxscript.onnx_opset import opset21
E   
E   @script()
E   def bck_test_shape_start_1_end_2(x: FLOAT[3,4,5]) -> (INT64[1]):
E       y = opset21.Shape(x, end=2, start=1)
E       return y

To view more test analytics, go to the Test Analytics Dashboard
📢 Thoughts on this report? Let us know!

@titaiwangms
Copy link
Contributor

https://github.com/microsoft/onnxscript/pull/1821/files also disabled linear_bias though, do we want that back as well?

@justinchuby
Copy link
Collaborator Author

Dynamic axes works too

model = TestModule()
ep = torch.onnx.export(model, (torch.randn(2, 10),), dynamic_shapes=(
    {0: torch.export.Dim.DYNAMIC},
), dynamo=True, verify=True)
print(ep)
ep = torch.onnx.export(model, (torch.randn(2, 12, 15, 10),), dynamo=True, dynamic_shapes=(
    {0: torch.export.Dim.DYNAMIC},
), verify=True)
print(ep)

@justinchuby
Copy link
Collaborator Author

#1821 (files) also disabled linear_bias though, do we want that back as well?

I combined the implementation.

@justinchuby justinchuby merged commit e7d199e into main Jan 18, 2025
22 of 27 checks passed
@justinchuby justinchuby deleted the justinchu/better-linear branch January 18, 2025 01:19
kunal-vaishnavi added a commit to microsoft/onnxruntime that referenced this pull request Jan 31, 2025
### Description
This PR adds fusions for [Google's SigLIP
model](https://huggingface.co/google/siglip-base-patch16-224/) and
Microsoft's internal conformer-encoder model.

Here is an example of how to run the ORT transformer optimizer for the
SigLIP model.
```
$ git clone https://github.com/microsoft/onnxruntime
$ cd onnxruntime/onnxruntime/python/tools/transformers
$ python3 optimizer.py --input /path/to/model.onnx --output /path/to/model_opt.onnx --model_type clip --num_heads 16 --hidden_size 1152 --use_external_data_format --opt_level 0 --disable_shape_inference
```

Here is an example of how to run the ORT transformer optimizer for the
conformer-encoder model.
```
$ git clone https://github.com/microsoft/onnxruntime
$ cd onnxruntime/onnxruntime/python/tools/transformers
$ python3 optimizer.py --input /path/to/model.onnx --output /path/to/model_opt.onnx --model_type conformer --num_heads 16 --hidden_size 1024 --use_external_data_format --opt_level 0 --disable_shape_inference --convert_attribute
```

### Motivation and Context
This PR helps optimize multi-modal models that use SigLIP for the vision
encoder and conformer-encoder for the speech encoder.

This PR uses changes from the following PRs:
- pytorch/pytorch#144801
- microsoft/onnxscript#2018
- microsoft/onnxscript#2019
- microsoft/onnxscript#2020
- microsoft/onnxscript#2021
- microsoft/onnxscript#2022
- microsoft/onnxscript#2024
- microsoft/onnxscript#2025
- microsoft/onnxscript#2029
- microsoft/onnxscript#2033

### Introduction of ONNX Script

This PR introduces [ONNX
Script](https://github.com/microsoft/onnxscript) into the ORT
transformer optimizer as an optional step via the
`fold_transpose_initializers()` method of the `DynamoOnnxHelper` class.
sfatimar pushed a commit to intel/onnxruntime that referenced this pull request Feb 5, 2025
### Description
This PR adds fusions for [Google's SigLIP
model](https://huggingface.co/google/siglip-base-patch16-224/) and
Microsoft's internal conformer-encoder model.

Here is an example of how to run the ORT transformer optimizer for the
SigLIP model.
```
$ git clone https://github.com/microsoft/onnxruntime
$ cd onnxruntime/onnxruntime/python/tools/transformers
$ python3 optimizer.py --input /path/to/model.onnx --output /path/to/model_opt.onnx --model_type clip --num_heads 16 --hidden_size 1152 --use_external_data_format --opt_level 0 --disable_shape_inference
```

Here is an example of how to run the ORT transformer optimizer for the
conformer-encoder model.
```
$ git clone https://github.com/microsoft/onnxruntime
$ cd onnxruntime/onnxruntime/python/tools/transformers
$ python3 optimizer.py --input /path/to/model.onnx --output /path/to/model_opt.onnx --model_type conformer --num_heads 16 --hidden_size 1024 --use_external_data_format --opt_level 0 --disable_shape_inference --convert_attribute
```

### Motivation and Context
This PR helps optimize multi-modal models that use SigLIP for the vision
encoder and conformer-encoder for the speech encoder.

This PR uses changes from the following PRs:
- pytorch/pytorch#144801
- microsoft/onnxscript#2018
- microsoft/onnxscript#2019
- microsoft/onnxscript#2020
- microsoft/onnxscript#2021
- microsoft/onnxscript#2022
- microsoft/onnxscript#2024
- microsoft/onnxscript#2025
- microsoft/onnxscript#2029
- microsoft/onnxscript#2033

### Introduction of ONNX Script

This PR introduces [ONNX
Script](https://github.com/microsoft/onnxscript) into the ORT
transformer optimizer as an optional step via the
`fold_transpose_initializers()` method of the `DynamoOnnxHelper` class.
sfatimar pushed a commit to intel/onnxruntime that referenced this pull request Feb 5, 2025
### Description
This PR adds fusions for [Google's SigLIP
model](https://huggingface.co/google/siglip-base-patch16-224/) and
Microsoft's internal conformer-encoder model.

Here is an example of how to run the ORT transformer optimizer for the
SigLIP model.
```
$ git clone https://github.com/microsoft/onnxruntime
$ cd onnxruntime/onnxruntime/python/tools/transformers
$ python3 optimizer.py --input /path/to/model.onnx --output /path/to/model_opt.onnx --model_type clip --num_heads 16 --hidden_size 1152 --use_external_data_format --opt_level 0 --disable_shape_inference
```

Here is an example of how to run the ORT transformer optimizer for the
conformer-encoder model.
```
$ git clone https://github.com/microsoft/onnxruntime
$ cd onnxruntime/onnxruntime/python/tools/transformers
$ python3 optimizer.py --input /path/to/model.onnx --output /path/to/model_opt.onnx --model_type conformer --num_heads 16 --hidden_size 1024 --use_external_data_format --opt_level 0 --disable_shape_inference --convert_attribute
```

### Motivation and Context
This PR helps optimize multi-modal models that use SigLIP for the vision
encoder and conformer-encoder for the speech encoder.

This PR uses changes from the following PRs:
- pytorch/pytorch#144801
- microsoft/onnxscript#2018
- microsoft/onnxscript#2019
- microsoft/onnxscript#2020
- microsoft/onnxscript#2021
- microsoft/onnxscript#2022
- microsoft/onnxscript#2024
- microsoft/onnxscript#2025
- microsoft/onnxscript#2029
- microsoft/onnxscript#2033

### Introduction of ONNX Script

This PR introduces [ONNX
Script](https://github.com/microsoft/onnxscript) into the ORT
transformer optimizer as an optional step via the
`fold_transpose_initializers()` method of the `DynamoOnnxHelper` class.
ashrit-ms pushed a commit to microsoft/onnxruntime that referenced this pull request Feb 11, 2025
### Description
This PR adds fusions for [Google's SigLIP
model](https://huggingface.co/google/siglip-base-patch16-224/) and
Microsoft's internal conformer-encoder model.

Here is an example of how to run the ORT transformer optimizer for the
SigLIP model.
```
$ git clone https://github.com/microsoft/onnxruntime
$ cd onnxruntime/onnxruntime/python/tools/transformers
$ python3 optimizer.py --input /path/to/model.onnx --output /path/to/model_opt.onnx --model_type clip --num_heads 16 --hidden_size 1152 --use_external_data_format --opt_level 0 --disable_shape_inference
```

Here is an example of how to run the ORT transformer optimizer for the
conformer-encoder model.
```
$ git clone https://github.com/microsoft/onnxruntime
$ cd onnxruntime/onnxruntime/python/tools/transformers
$ python3 optimizer.py --input /path/to/model.onnx --output /path/to/model_opt.onnx --model_type conformer --num_heads 16 --hidden_size 1024 --use_external_data_format --opt_level 0 --disable_shape_inference --convert_attribute
```

### Motivation and Context
This PR helps optimize multi-modal models that use SigLIP for the vision
encoder and conformer-encoder for the speech encoder.

This PR uses changes from the following PRs:
- pytorch/pytorch#144801
- microsoft/onnxscript#2018
- microsoft/onnxscript#2019
- microsoft/onnxscript#2020
- microsoft/onnxscript#2021
- microsoft/onnxscript#2022
- microsoft/onnxscript#2024
- microsoft/onnxscript#2025
- microsoft/onnxscript#2029
- microsoft/onnxscript#2033

### Introduction of ONNX Script

This PR introduces [ONNX
Script](https://github.com/microsoft/onnxscript) into the ORT
transformer optimizer as an optional step via the
`fold_transpose_initializers()` method of the `DynamoOnnxHelper` class.
guschmue pushed a commit to microsoft/onnxruntime that referenced this pull request Mar 6, 2025
### Description
This PR adds fusions for [Google's SigLIP
model](https://huggingface.co/google/siglip-base-patch16-224/) and
Microsoft's internal conformer-encoder model.

Here is an example of how to run the ORT transformer optimizer for the
SigLIP model.
```
$ git clone https://github.com/microsoft/onnxruntime
$ cd onnxruntime/onnxruntime/python/tools/transformers
$ python3 optimizer.py --input /path/to/model.onnx --output /path/to/model_opt.onnx --model_type clip --num_heads 16 --hidden_size 1152 --use_external_data_format --opt_level 0 --disable_shape_inference
```

Here is an example of how to run the ORT transformer optimizer for the
conformer-encoder model.
```
$ git clone https://github.com/microsoft/onnxruntime
$ cd onnxruntime/onnxruntime/python/tools/transformers
$ python3 optimizer.py --input /path/to/model.onnx --output /path/to/model_opt.onnx --model_type conformer --num_heads 16 --hidden_size 1024 --use_external_data_format --opt_level 0 --disable_shape_inference --convert_attribute
```

### Motivation and Context
This PR helps optimize multi-modal models that use SigLIP for the vision
encoder and conformer-encoder for the speech encoder.

This PR uses changes from the following PRs:
- pytorch/pytorch#144801
- microsoft/onnxscript#2018
- microsoft/onnxscript#2019
- microsoft/onnxscript#2020
- microsoft/onnxscript#2021
- microsoft/onnxscript#2022
- microsoft/onnxscript#2024
- microsoft/onnxscript#2025
- microsoft/onnxscript#2029
- microsoft/onnxscript#2033

### Introduction of ONNX Script

This PR introduces [ONNX
Script](https://github.com/microsoft/onnxscript) into the ORT
transformer optimizer as an optional step via the
`fold_transpose_initializers()` method of the `DynamoOnnxHelper` class.
ashrit-ms pushed a commit to microsoft/onnxruntime that referenced this pull request Mar 17, 2025
### Description
This PR adds fusions for [Google's SigLIP
model](https://huggingface.co/google/siglip-base-patch16-224/) and
Microsoft's internal conformer-encoder model.

Here is an example of how to run the ORT transformer optimizer for the
SigLIP model.
```
$ git clone https://github.com/microsoft/onnxruntime
$ cd onnxruntime/onnxruntime/python/tools/transformers
$ python3 optimizer.py --input /path/to/model.onnx --output /path/to/model_opt.onnx --model_type clip --num_heads 16 --hidden_size 1152 --use_external_data_format --opt_level 0 --disable_shape_inference
```

Here is an example of how to run the ORT transformer optimizer for the
conformer-encoder model.
```
$ git clone https://github.com/microsoft/onnxruntime
$ cd onnxruntime/onnxruntime/python/tools/transformers
$ python3 optimizer.py --input /path/to/model.onnx --output /path/to/model_opt.onnx --model_type conformer --num_heads 16 --hidden_size 1024 --use_external_data_format --opt_level 0 --disable_shape_inference --convert_attribute
```

### Motivation and Context
This PR helps optimize multi-modal models that use SigLIP for the vision
encoder and conformer-encoder for the speech encoder.

This PR uses changes from the following PRs:
- pytorch/pytorch#144801
- microsoft/onnxscript#2018
- microsoft/onnxscript#2019
- microsoft/onnxscript#2020
- microsoft/onnxscript#2021
- microsoft/onnxscript#2022
- microsoft/onnxscript#2024
- microsoft/onnxscript#2025
- microsoft/onnxscript#2029
- microsoft/onnxscript#2033

### Introduction of ONNX Script

This PR introduces [ONNX
Script](https://github.com/microsoft/onnxscript) into the ORT
transformer optimizer as an optional step via the
`fold_transpose_initializers()` method of the `DynamoOnnxHelper` class.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

module: torchlib Related to the torch/aten function lib in development

Projects

Development

Successfully merging this pull request may close these issues.

3 participants