-
Notifications
You must be signed in to change notification settings - Fork 54
Closed
Description
Hi folks,
I tested fp16 mixed precision training with ORTModule wrapped GPT2 model on a fine-tuning task. However, after enabling fp16, I encountered the following error:
Error Message
Traceback (most recent call last):
File "test_onnxruntime_train.py", line 115, in test_ort_trainer
train_result = trainer.train()
File "/workspace/optimum/onnxruntime/trainer.py", line 498, in train
tr_loss_step = self.training_step(model, inputs)
File "/usr/local/lib/python3.6/dist-packages/transformers/trainer.py", line 1984, in training_step
loss = self.compute_loss(model, inputs)
File "/usr/local/lib/python3.6/dist-packages/transformers/trainer.py", line 2016, in compute_loss
outputs = model(**inputs)
File "/usr/local/lib/python3.6/dist-packages/torch/nn/modules/module.py", line 1051, in _call_impl
return forward_call(*input, **kwargs)
File "/usr/local/lib/python3.6/dist-packages/onnxruntime/training/ortmodule/ortmodule.py", line 81, in _forward
return self._torch_module.forward(*inputs, **kwargs)
File "/usr/local/lib/python3.6/dist-packages/onnxruntime/training/ortmodule/_torch_module_ort.py", line 32, in _forward
return self._execution_manager(self.is_training()).forward(*inputs, **kwargs)
File "/usr/local/lib/python3.6/dist-packages/onnxruntime/training/ortmodule/_training_manager.py", line 265, in forward
override_policy=_FallbackPolicy.FALLBACK_FORCE_TORCH_FORWARD)
File "/usr/local/lib/python3.6/dist-packages/onnxruntime/training/ortmodule/_fallback.py", line 194, in handle_exception
raise exception
File "/usr/local/lib/python3.6/dist-packages/onnxruntime/training/ortmodule/_training_manager.py", line 85, in forward
self._initialize_graph_builder(training=True)
File "/usr/local/lib/python3.6/dist-packages/onnxruntime/training/ortmodule/_graph_execution_manager.py", line 420, in _initialize_graph_builder
self._onnx_models.exported_model.SerializeToString(), grad_builder_config)
RuntimeError: /onnxruntime_src/orttraining/orttraining/python/orttraining_pybind_state.cc:707 onnxruntime::python::addObjectMethodsForTraining(pybind11::module&, onnxruntime::python::ExecutionProviderRegistrationFn)::<lambda(onnxruntime::training::OrtModuleGraphBuilder*, const pybind11::bytes&, const onnxruntime::training::OrtModuleGraphBuilderConfiguration&)> [ONNXRuntimeError] : 1 : FAIL : Type Error: Type parameter (T) of Optype (Where) bound to different types (tensor(float) and tensor(float16) in node (Where_183).
It seems that the exported ONNX graph is broken due to incompatible input types. I am wondering where comes the problem. Do any insight on that?
System information
Docker image built with the Dockerfile-cu11 in onnxruntime-training-examples.
- OS: Ubuntu 18.04
- CUDA/cuDNN version: 11/8
- onnxruntime-training: 1.9.0+cu111
- torch: 1.9.0+cu111
- torch-ort: 1.9.0
- Python version:3.6
- GPU: A100
Additional Information
- I actually have a work version under the environment: torch 1.8.1+torch-ort 1.9.0+onnxruntime-training1.11.0.dev20220113001+cu102, so I wonder if the error comes from the fact that what's in the Dockerfile are outdated. However, I can't find how to install onnxruntime-training1.11.0.dev20220113001+cu102 anymore.
- Here is the onnx graph exported with
DebugOptions, not sure if that could help

Reactions are currently unavailable
Metadata
Metadata
Assignees
Labels
No labels