Skip to content

Fix training artifacts for 2GB+ models and MSELoss#22414

Merged
baijumeswani merged 1 commit intomicrosoft:mainfrom
jkbeavers:22411-training-2gb-model-offline-artifacts-fail-with-mse-loss
Oct 15, 2024
Merged

Fix training artifacts for 2GB+ models and MSELoss#22414
baijumeswani merged 1 commit intomicrosoft:mainfrom
jkbeavers:22411-training-2gb-model-offline-artifacts-fail-with-mse-loss

Conversation

@jkbeavers
Copy link
Copy Markdown
Contributor

Description

generate_artifacts fails when creating training artifacts for a model using external data and MSELoss.

The use of a global base model when creating new training Blocks and onnx.save destroying any external data means any loss block (e.g. MSELoss) that builds more than one sub-Block will fail validation due to missing external data and raise an exception.

Fix

Saving using a deep copy of the global model circumvents this at the cost of holding 2x the model size in memory.

Other Implementations

An alternative approach using less memory would load the on-disk external data before it is deleted in Block::__del__ and insert the appropriate fields into the global ModelProto.
This seems a bit brittle due to the coupling to the specific way external data is destructively accessed in onnx.save. If there exists a non-modifying save in the onnx repo it would be ideal to use that in Block::__call__ instead.

Motivation and Context

Fixes generate_artifacts bug reported in #22411

The use of a global base model when creating new training `Blocks`
and `onnx.save` destroying any external data meant any loss block
(e.g. `MSELoss`) that builds more than one sub-`Block` will fail
validation due to missing external data.

Saving using a deep copy of the global model circumvents this.

Fixes microsoft#22411
Copy link
Copy Markdown

@byt3n33dl3 byt3n33dl3 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

blocks kinda (@microsoft-github-policy-service agree company="Microsoft")

@jkbeavers
Copy link
Copy Markdown
Contributor Author

@microsoft-github-policy-service agree company="RWS"

@snnn
Copy link
Copy Markdown
Contributor

snnn commented Oct 15, 2024

/azp run Big Models, Linux Android Emulator QNN CI Pipeline, Linux CPU CI Pipeline, Linux CPU Minimal Build E2E CI Pipeline, Linux GPU CI Pipeline, Linux GPU TensorRT CI Pipeline

@snnn
Copy link
Copy Markdown
Contributor

snnn commented Oct 15, 2024

/azp run Linux OpenVINO CI Pipeline, Linux QNN CI Pipeline, MacOS CI Pipeline, ONNX Runtime Web CI Pipeline, Windows ARM64 QNN CI Pipeline,

@snnn
Copy link
Copy Markdown
Contributor

snnn commented Oct 15, 2024

/azp run Windows CPU CI Pipeline, Windows GPU CUDA CI Pipeline, Windows GPU DML CI Pipeline, Windows GPU Doc Gen CI Pipeline, Windows GPU TensorRT CI Pipeline, Windows x64 QNN CI Pipeline, onnxruntime-binary-size-checks-ci-pipeline, orttraining-linux-ci-pipeline, orttraining-linux-gpu-ci-pipeline

@azure-pipelines
Copy link
Copy Markdown

Azure Pipelines successfully started running 6 pipeline(s).

@azure-pipelines
Copy link
Copy Markdown

Azure Pipelines successfully started running 5 pipeline(s).

@azure-pipelines
Copy link
Copy Markdown

Azure Pipelines successfully started running 9 pipeline(s).

@WilliamTambellini
Copy link
Copy Markdown
Contributor

+1

@snnn snnn added the training issues related to ONNX Runtime training; typically submitted using template label Oct 15, 2024
@baijumeswani baijumeswani merged commit a5e85a9 into microsoft:main Oct 15, 2024
@WilliamTambellini
Copy link
Copy Markdown
Contributor

Tks @snnn and @baijumeswani
Any way for you to do a patch release asap ?

@baijumeswani
Copy link
Copy Markdown
Contributor

I think this will be included in the upcoming 1.20 release.

@WilliamTambellini
Copy link
Copy Markdown
Contributor

tks @baijumeswani
@snnn could you confirm it ll be in the next version end of the month?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

training issues related to ONNX Runtime training; typically submitted using template

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants