Skip to content

[Training] >2GB Model Offline Artifacts fail with MSE Loss #22411

@jkbeavers

Description

@jkbeavers

Describe the issue

Error Description

Onnx runtime fails to create training artifacts when using a model with external data and any loss (custom, mse, bce with logits, l1) utilizing more than one Block.
This fails with a confusing message about the first tensor stored in the proto not being found in temp.onnx.data. i.e.
Data of TensorProto ( tensor name: ...) should be stored in temp.onnx.data, but it doesn't exist or is not accessible.

Bug Description

The recent patch to support >2Gb models missed testing the use of any of the other loss block besides crossentropy, which happens to not run into this issue.

The bug stems from three parts:

  • onnxruntimes's base training Block uses a global copy of the original ModelProto when creating new nodes for the training graph. This is shared between subsequently created Blocks
  • Block saves the model proto after building the new node and adding to the graph in __call__. This is used to then check the model's validity
  • onnx.save destructively modifies external data information in a ModelProto when calling set_external_data

When a second Block is created as part of the loss block, the tensors in the global ModelProto no longer hold external data and no external data file at temp.onnx.data is created; when checking the model validity, no tensor data can be found. For MSE this happens during the Sub block's validity check (if the "target" InputLike block doesn't already exist).

To reproduce

From a local build directory:

  1. Edit test_generate_artifacts_external_data_separate_files in orttraining_test_ort_apis_onnxblock.py by changing CrossEntropyLoss to MSELoss
  2. Run pytest orttraining_test_ort_apis_onnxblock.py -k test_generate_artifacts_external_data_separate_files

See error:
onnx.onnx_cpp2py_export.checker.ValidationError: Data of TensorProto ( tensor name: fc1.weight) should be stored in onnxruntime/build/Linux/RelWithDebInfo/temp.onnx.data, but it doesn't exist or is not accessible.

Urgency

This blocks the generation of ort training artifacts for models >2GB with many custom loss functions as well as all of the built-in ones besides CrossEntropyLoss.

ONNX Runtime Installation

Released Package

ONNX Runtime Version or Commit ID

1.19.2

PyTorch Version

2.4.1

Execution Provider

Default CPU

Execution Provider Library Version

No response

Metadata

Metadata

Assignees

No one assigned

    Labels

    trainingissues related to ONNX Runtime training; typically submitted using template

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions