Skip to content

Commit bfd5354

Browse files
LysandreJiksgugger
andcommitted
Add to ONNX docs (#13048)
* Add to ONNX docs * Add MBART example * Update docs/source/serialization.rst Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>
1 parent 226763a commit bfd5354

1 file changed

Lines changed: 30 additions & 0 deletions

File tree

docs/source/serialization.rst

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -99,6 +99,30 @@ It will be exported under ``onnx/bert-base-cased``. You should see similar logs:
9999
-[✓] all values close (atol: 0.0001)
100100
All good, model saved at: onnx/bert-base-cased/model.onnx
101101
102+
This export can now be used in the ONNX inference runtime:
103+
104+
.. code-block::
105+
106+
import onnxruntime as ort
107+
108+
from transformers import BertTokenizerFast
109+
tokenizer = BertTokenizerFast.from_pretrained("bert-base-cased")
110+
111+
ort_session = ort.InferenceSession("onnx/bert-base-cased/model.onnx")
112+
113+
inputs = tokenizer("Using BERT in ONNX!", return_tensors="np")
114+
outputs = ort_session.run(["last_hidden_state", "pooler_output"], dict(inputs))
115+
116+
The outputs used (:obj:`["last_hidden_state", "pooler_output"]`) can be obtained by taking a look at the ONNX
117+
configuration of each model. For example, for BERT:
118+
119+
.. code-block::
120+
121+
from transformers.models.bert import BertOnnxConfig, BertConfig
122+
123+
config = BertConfig()
124+
onnx_config = BertOnnxConfig(config)
125+
output_keys = list(onnx_config.outputs.keys())
102126
103127
Implementing a custom configuration for an unsupported architecture
104128
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
@@ -142,6 +166,12 @@ An important fact to notice is the use of `OrderedDict` in both inputs and outpu
142166
as inputs are matched against their relative position within the `PreTrainedModel.forward()` prototype and outputs are
143167
match against there position in the returned `BaseModelOutputX` instance.
144168

169+
An example of such an addition is visible here, for the MBart model: `Making MBART ONNX-convertible
170+
<https://github.com/huggingface/transformers/pull/13049/commits/d097adcebd89a520f04352eb215a85916934204f>`__
171+
172+
If you would like to contribute your addition to the library, we recommend you implement tests. An example of such
173+
tests is visible here: `Adding tests to the MBART ONNX conversion
174+
<https://github.com/huggingface/transformers/pull/13049/commits/5d642f65abf45ceeb72bd855ca7bfe2506a58e6a>`__
145175

146176
Graph conversion
147177
-----------------------------------------------------------------------------------------------------------------------

0 commit comments

Comments
 (0)