What would you like to happen?
The current implementation of RunInference provides model handlers for PyTorch and Sklearn models. These handlers assume that the method to call for inference is fixed:
- Pytorch: Do a forward pass by calling the
__call__ method -> output = torch_model(input)
- Sklearn: call the model's
predict method -> output = sklearn_model.predict(input)
However in some cases we want to provide a custom method for RunInference to call.
Two examples:
-
A number of pretrained models loaded with the Huggingface transformers library recommend using the generate() method. From the Huggingface docs on the T5 mode:
At inference time, it is recommended to use generate(). This method takes care of encoding the input and feeding the encoded hidden states via cross-attention layers to the decoder and auto-regressively generates the decoder output.
tokenizer = T5Tokenizer.from_pretrained("t5-small")
model = T5ForConditionalGeneration.from_pretrained("t5-small")
input_ids = tokenizer("translate English to German: The house is wonderful.", return_tensors="pt").input_ids
outputs = model.generate(input_ids)
print(tokenizer.decode(outputs[0], skip_special_tokens=True))
Das Haus ist wunderbar.
-
Using OpenAI's CLIP model which is implemented as a torch model we might not want to execute the normal forward pass to encode both images and text image_embedding, text_embedding = clip_model(image, text) but instead only compute the image embeddings image_embedding = clip_model.encode_image(image).
Solution: Allowing the user to specify the inference_fn when creating a ModelHandler would enable this usage.
Issue Priority
Priority: 2
Issue Component
Component: sdk-py-core
What would you like to happen?
The current implementation of RunInference provides model handlers for PyTorch and Sklearn models. These handlers assume that the method to call for inference is fixed:
__call__method ->output = torch_model(input)predictmethod ->output = sklearn_model.predict(input)However in some cases we want to provide a custom method for RunInference to call.
Two examples:
A number of pretrained models loaded with the Huggingface transformers library recommend using the
generate()method. From the Huggingface docs on the T5 mode:Using OpenAI's CLIP model which is implemented as a torch model we might not want to execute the normal forward pass to encode both images and text
image_embedding, text_embedding = clip_model(image, text)but instead only compute the image embeddingsimage_embedding = clip_model.encode_image(image).Solution: Allowing the user to specify the
inference_fnwhen creating a ModelHandler would enable this usage.Issue Priority
Priority: 2
Issue Component
Component: sdk-py-core