Skip to content

ONNXRuntimeError after enabled fp16 mixed precision training #104

@JingyaHuang

Description

@JingyaHuang

Hi folks,

I tested fp16 mixed precision training with ORTModule wrapped GPT2 model on a fine-tuning task. However, after enabling fp16, I encountered the following error:

Error Message

Traceback (most recent call last):
  File "test_onnxruntime_train.py", line 115, in test_ort_trainer
    train_result = trainer.train()
  File "/workspace/optimum/onnxruntime/trainer.py", line 498, in train
    tr_loss_step = self.training_step(model, inputs)
  File "/usr/local/lib/python3.6/dist-packages/transformers/trainer.py", line 1984, in training_step
    loss = self.compute_loss(model, inputs)
  File "/usr/local/lib/python3.6/dist-packages/transformers/trainer.py", line 2016, in compute_loss
    outputs = model(**inputs)
  File "/usr/local/lib/python3.6/dist-packages/torch/nn/modules/module.py", line 1051, in _call_impl
    return forward_call(*input, **kwargs)
  File "/usr/local/lib/python3.6/dist-packages/onnxruntime/training/ortmodule/ortmodule.py", line 81, in _forward
    return self._torch_module.forward(*inputs, **kwargs)
  File "/usr/local/lib/python3.6/dist-packages/onnxruntime/training/ortmodule/_torch_module_ort.py", line 32, in _forward
    return self._execution_manager(self.is_training()).forward(*inputs, **kwargs)
  File "/usr/local/lib/python3.6/dist-packages/onnxruntime/training/ortmodule/_training_manager.py", line 265, in forward
    override_policy=_FallbackPolicy.FALLBACK_FORCE_TORCH_FORWARD)
  File "/usr/local/lib/python3.6/dist-packages/onnxruntime/training/ortmodule/_fallback.py", line 194, in handle_exception
    raise exception
  File "/usr/local/lib/python3.6/dist-packages/onnxruntime/training/ortmodule/_training_manager.py", line 85, in forward
    self._initialize_graph_builder(training=True)
  File "/usr/local/lib/python3.6/dist-packages/onnxruntime/training/ortmodule/_graph_execution_manager.py", line 420, in _initialize_graph_builder
    self._onnx_models.exported_model.SerializeToString(), grad_builder_config)
RuntimeError: /onnxruntime_src/orttraining/orttraining/python/orttraining_pybind_state.cc:707 onnxruntime::python::addObjectMethodsForTraining(pybind11::module&, onnxruntime::python::ExecutionProviderRegistrationFn)::<lambda(onnxruntime::training::OrtModuleGraphBuilder*, const pybind11::bytes&, const onnxruntime::training::OrtModuleGraphBuilderConfiguration&)> [ONNXRuntimeError] : 1 : FAIL : Type Error: Type parameter (T) of Optype (Where) bound to different types (tensor(float) and tensor(float16) in node (Where_183).

It seems that the exported ONNX graph is broken due to incompatible input types. I am wondering where comes the problem. Do any insight on that?


System information

Docker image built with the Dockerfile-cu11 in onnxruntime-training-examples.

  • OS: Ubuntu 18.04
  • CUDA/cuDNN version: 11/8
  • onnxruntime-training: 1.9.0+cu111
  • torch: 1.9.0+cu111
  • torch-ort: 1.9.0
  • Python version:3.6
  • GPU: A100

Additional Information

  • I actually have a work version under the environment: torch 1.8.1+torch-ort 1.9.0+onnxruntime-training1.11.0.dev20220113001+cu102, so I wonder if the error comes from the fact that what's in the Dockerfile are outdated. However, I can't find how to install onnxruntime-training1.11.0.dev20220113001+cu102 anymore.
  • Here is the onnx graph exported with DebugOptions, not sure if that could help
    image

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions