Skip to content

Add ONNX export for ViT#15658

Merged
lewtun merged 38 commits intomasterfrom
vision-onnx-export
Mar 9, 2022
Merged

Add ONNX export for ViT#15658
lewtun merged 38 commits intomasterfrom
vision-onnx-export

Conversation

@lewtun
Copy link
Member

@lewtun lewtun commented Feb 15, 2022

What does this PR do?

This PR enables the export of Vision Transformers (ViT) to ONNX with the following features:

  • default
  • image-classification

To enable this new modality, I had to significantly refactor the internals of the ONNX exporter because we need a way to pass the feature extractor instead of the tokenizer.

Thanks to a tip from @LysandreJik I replaced the positional tokenizer argument in various functions with a new preprocessor argument that can be a tokenizer or feature extractor (and possibly a processor in future). This should guarantee backwards compatibility for users who chose to use the Python API instead of the transformers.onnx CLI.

Usage

import requests
import numpy as np
from PIL import Image
from onnxruntime import InferenceSession
from transformers import AutoConfig, AutoFeatureExtractor, AutoModelForImageClassification

# Export ViT checkpoint with image classification head
model_ckpt = "google/vit-base-patch16-224"
!python -m transformers.onnx --model={model_ckpt} --feature=image-classification onnx/

# Download an image of two cute cats - naturally ;-)
url = "http://images.cocodataset.org/val2017/000000039769.jpg"
image = Image.open(requests.get(url, stream=True).raw)

# Instantiate config and feature extractor
config = AutoConfig.from_pretrained(model_ckpt)
feature_extractor = AutoFeatureExtractor.from_pretrained(model_ckpt)
inputs = feature_extractor(image, return_tensors="np")

# Create ONNX Runtime session
session = InferenceSession("onnx/model.onnx", providers=["CPUExecutionProvider"])
outputs = session.run(["logits"], dict(inputs))
predicted_class_idx = np.argmax(outputs[0])
# Returns Predicted class: Egyptian cat
print("Predicted class:", config.id2label[predicted_class_idx])

Here's two Colab notebooks comparing the inference gains with ORT vs vanilla PyTorch (~20-30% faster on CPU, ~5% faster on GPU):

Todo

  • Add deprecation warning if user passes tokenizer as keyword argument
  • Run an inference test to see if we get any speed-up over vanilla PyTorch (maybe)

@HuggingFaceDocBuilder
Copy link

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint.

@lewtun lewtun changed the title Add ONNX export for vision models Add ONNX export for ViT Feb 15, 2022
@parameterized.expand(_get_models_to_test(PYTORCH_EXPORT_MODELS))
@slow
@require_torch
@require_vision
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I added the vision requirement here to test the ViT checkpoint. Please let me know if this isn't a "good practice" because it mixes multiple modalities together

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think the vision modality is installed for ONNX tests, so you'd have to double check this actually ends up being tested.

model = model_class.from_config(config)
onnx_config = onnx_config_class_constructor(model.config)

# Check the modality of the inputs and instantiate the appropriate preprocessor
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If this becomes a piece of code we use often, maybe we can refactor this into a function?

images.append(Image.fromarray(data.astype("uint8")).convert("RGB"))
return images

def generate_dummy_inputs(
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This base method now has a mix of arguments for text and image modalities. I'm not 100% sure if we should split the modalities apart ...

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You split it now right? Just checking to make sure.

Copy link
Collaborator

@sgugger sgugger left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for adding this!

Regarding the the tokenizer optional kwarg, it's very good to keep it like this, but there should be a deprecation warning when it's actually used, and it shouldn't be documented.

@parameterized.expand(_get_models_to_test(PYTORCH_EXPORT_MODELS))
@slow
@require_torch
@require_vision
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think the vision modality is installed for ONNX tests, so you'd have to double check this actually ends up being tested.

@lewtun
Copy link
Member Author

lewtun commented Feb 17, 2022

While testing this branch on Colab, I discovered a weird bug when trying to run inference in ONNX Runtime with torch v1.10.2:

RuntimeException: [ONNXRuntimeError] : 6 : RUNTIME_EXCEPTION : Non-zero status code returned while running Reshape node. Name:'Reshape_42' Status Message: /Users/runner/work/1/s/onnxruntime/core/providers/cpu/tensor/reshape_helper.h:42 onnxruntime::ReshapeHelper::ReshapeHelper(const onnxruntime::TensorShape &, std::vector<int64_t> &, bool) gsl::narrow_cast<int64_t>(input_shape.Size()) == size was false. The input tensor cannot be reshaped to the requested shape. Input shape:{1,197,768}, requested shape:{2,197,12,64}

Curiously, there is no problem running inference with torch v1.9, so something seems to have changed in the torch ONNX exporter in the latest version. I'm currently investigating what the source of the problem is ...

from .utils import ParameterFormat, compute_effective_axis_dimension, compute_serialized_parameters_size


if TYPE_CHECKING:
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since I was already sorting out the relative imports, I also went ahead and fixed the import that are just used for type checking

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

❤️❤️ ❤️ ❤️

DEFAULT_FIXED_SEQUENCE = 8

_TASKS_TO_COMMON_OUTPUTS = {
default_fixed_batch = 2
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These class variables are now snake_case to prevent confusion / disaster with global constants

Copy link
Collaborator

@sgugger sgugger left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you so much for making this file more resilient and less prone to cyclical import errors :-)

Comment on lines +29 to +33
if is_torch_available():
from ..modeling_utils import PreTrainedModel

if is_tf_available():
from ..modeling_tf_utils import TFPreTrainedModel
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank your for this 😍 !


from ..feature_extraction_utils import FeatureExtractionMixin
from ..file_utils import TensorType, is_torch_available, is_vision_available
from ..tokenization_utils_base import PreTrainedTokenizerBase
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Last step since this file is imported at very low level, it would be great to import those (PreTrainedTokenizerBase and FeatureExtractionMixin) in TYPE_CHECKING (for type checks) and then only when we do the instance check dynamically

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sounds good!

.gitignore Outdated
Comment on lines +168 to +169
# Lewis
scratch/ No newline at end of file
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not sure if we want to add this in the general gitignore of Transformers?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oop! Will fix that!

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fixed :)

Copy link
Member

@michaelbenayoun michaelbenayoun left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Awesome work @lewtun !

images.append(Image.fromarray(data.astype("uint8")).convert("RGB"))
return images

def generate_dummy_inputs(
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You split it now right? Just checking to make sure.

def generate_dummy_inputs(
self,
tokenizer: PreTrainedTokenizer,
tokenizer: "PreTrainedTokenizerBase",
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Are you asking about the change to PreTrainedTokenizerBase or use of strings for the typing? Here's the reasons in both cases:

  • I chose PreTrainedTokenizerBase because it covers both slow and fast tokenizers. The alternative would have been something like Union[PreTrainedTokenizer, PreTrainedTokenizerFast], but that felt clunky
  • I used strings for the typing following @sgugger's suggestion to use the TYPE_CHECKING constant to fix the circular imports

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I was asking about the change of class, and it makes sense to me now, thanks for the explanation!

Copy link
Member

@LysandreJik LysandreJik left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks good! Thanks @lewtun for iterating and @sgugger for the great reviews!

return 1e-5

@property
def is_torch_support_available(self) -> bool:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For torch.fx we have a requirement on a specific torch version. If you have validated that it doesn't work with a specific torch version, I would see no problem in printing a warning mentioning exactly that. If it's going to fail, then raising an error is also fine.

@lewtun lewtun merged commit 50dd314 into master Mar 9, 2022
@lewtun lewtun deleted the vision-onnx-export branch March 9, 2022 16:37
@davanstrien
Copy link
Member

Super happy to see this merged! 🤗

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

7 participants