-
Notifications
You must be signed in to change notification settings - Fork 3.8k
Description
Describe the bug
During training of a huggingface transformers (version 4.18.0) GPT2 with onnxruntime training (model is wrapped in ORTModule) and fp16 mixed precision training enabled, I encounter a type error.
RuntimeError: /onnxruntime_src/orttraining/orttraining/python/orttraining_pybind_state.cc:752
onnxruntime::python::addObjectMethodsForTraining(pybind11::module&,
onnxruntime::python::ExecutionProviderRegistrationFn)::<lambda(onnxruntime::training::OrtModuleGraphBuilder*,
const pybind11::bytes&, const onnxruntime::training::OrtModuleGraphBuilderConfiguration&)>
[ONNXRuntimeError] : 1 : FAIL : Type Error: Type parameter (T) of Optype (Where) bound to different types
(tensor(float) and tensor(float16) in node (Where_199).
I do not encounter this issue when using transformers version 4.16.0.
This bug is similar to the one reported in issue pytorch/ort#104. That issue has been closed but it appears to me that the mixed precision training issue hasn't actually been resolved. The conversation on the issue changed to some other errors encountered while moving to onnxruntime-training 1.11.0. It was closed after the new error was resolved. However, I still encounter the same mixed precision training issue when using the recommended package versions.
Urgency
We are considering using huggingface optimum ORTTrainer to integrate ORT training into huggingface transformers.
Migration to optimum is blocked by this issue.
System information
- torch 1.11.0
- transformers 4.18.0
- deepspeed 0.6.1
- torch-ort 1.11.0
- onnx 1.10.2
- onnxruntime-training 1.11.0+cu113
- GPU: v100 (16GiB)
To Reproduce
The code and instructions for reproduction can be found in this Git repo directory .
Please contact me (jambaykinley@microsoft.com) for access to the repo.