Skip to content

Type Error when training Hugging Face Transformers GPT2 with fp16 enabled #11279

@jambayk

Description

@jambayk

Describe the bug
During training of a huggingface transformers (version 4.18.0) GPT2 with onnxruntime training (model is wrapped in ORTModule) and fp16 mixed precision training enabled, I encounter a type error.

RuntimeError: /onnxruntime_src/orttraining/orttraining/python/orttraining_pybind_state.cc:752 
onnxruntime::python::addObjectMethodsForTraining(pybind11::module&, 
onnxruntime::python::ExecutionProviderRegistrationFn)::<lambda(onnxruntime::training::OrtModuleGraphBuilder*, 
const pybind11::bytes&, const onnxruntime::training::OrtModuleGraphBuilderConfiguration&)> 
[ONNXRuntimeError] : 1 : FAIL : Type Error: Type parameter (T) of Optype (Where) bound to different types 
(tensor(float) and tensor(float16) in node (Where_199). 

I do not encounter this issue when using transformers version 4.16.0.

This bug is similar to the one reported in issue pytorch/ort#104. That issue has been closed but it appears to me that the mixed precision training issue hasn't actually been resolved. The conversation on the issue changed to some other errors encountered while moving to onnxruntime-training 1.11.0. It was closed after the new error was resolved. However, I still encounter the same mixed precision training issue when using the recommended package versions.

Urgency
We are considering using huggingface optimum ORTTrainer to integrate ORT training into huggingface transformers.
Migration to optimum is blocked by this issue.

System information

  • torch 1.11.0
  • transformers 4.18.0
  • deepspeed 0.6.1
  • torch-ort 1.11.0
  • onnx 1.10.2
  • onnxruntime-training 1.11.0+cu113
  • GPU: v100 (16GiB)

To Reproduce
The code and instructions for reproduction can be found in this Git repo directory .

Please contact me (jambaykinley@microsoft.com) for access to the repo.

Metadata

Metadata

Assignees

No one assigned

    Labels

    trainingissues related to ONNX Runtime training; typically submitted using template

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions