[On-Device Training] Expose Parameters through the Training API#17364
[On-Device Training] Expose Parameters through the Training API#17364baijumeswani merged 14 commits intomainfrom
Conversation
orttraining/orttraining/python/training/api/checkpoint_state.py
Outdated
Show resolved
Hide resolved
pengwa
left a comment
There was a problem hiding this comment.
Take a quick look, I have a few comments.
csharp/test/Microsoft.ML.OnnxRuntime.Tests.Common/TrainingTest.cs
Outdated
Show resolved
Hide resolved
Recommendation: instead of true, false you can do and pass it as an argument. This would improve readability a lot. In reply to: 1709295692 In reply to: 1709295692 In reply to: 1709295692 Refers to: csharp/src/Microsoft.ML.OnnxRuntime/Training/TrainingSession.shared.cs:365 in df21a2e. [](commit_id = df21a2e, deletion_comment = False) |
Also, I am suggesting to introduce a OrtValue based API some time in the near future. In reply to: 1709295692 Refers to: csharp/src/Microsoft.ML.OnnxRuntime/Training/TrainingSession.shared.cs:365 in df21a2e. [](commit_id = df21a2e, deletion_comment = False) |
Suggestion: float type is blittable, so you can simply copy it from native memory in the unsafe block and no array is needed, bc it introduces garbage. Marshal.Copy is slow. Same is for Int Something like this: Or one can do unsafe block: float val;
unsafe
{
val = \*(float\*)propertyValue.ToPointer();
}
---
In reply to: [1709298903](https://github.com/microsoft/onnxruntime/pull/17364#issuecomment-1709298903) [](http://example.com/codeflow?ancestors=1709298903)
---
In reply to: [1709298903](https://github.com/microsoft/onnxruntime/pull/17364#issuecomment-1709298903) [](http://example.com/codeflow?ancestors=1709298903)
---
In reply to: [1709298903](https://github.com/microsoft/onnxruntime/pull/17364#issuecomment-1709298903) [](http://example.com/codeflow?ancestors=1709298903)
---
Refers to: csharp/src/Microsoft.ML.OnnxRuntime/Training/CheckpointState.shared.cs:183 in df21a2e. [](commit_id = df21a2e20e52436297950511058c3626203d2e92, deletion_comment = False) |
See inference code for a function that converts and copies the string to native memory at the same time, and avoids intermediate array allocation. In reply to: 1709299587 In reply to: 1709299587 In reply to: 1709299587 Refers to: csharp/src/Microsoft.ML.OnnxRuntime/Training/CheckpointState.shared.cs:152 in df21a2e. [](commit_id = df21a2e, deletion_comment = False) |
…baijumeswani/update-checkpoint-params
nit: Pin() is very expensive. In reply to: 1724556619 Refers to: csharp/src/Microsoft.ML.OnnxRuntime/Training/CheckpointState.shared.cs:49 in 58ff40b. [](commit_id = 58ff40b, deletion_comment = False) |
csharp/src/Microsoft.ML.OnnxRuntime/Training/CheckpointState.shared.cs
Outdated
Show resolved
Hide resolved
csharp/src/Microsoft.ML.OnnxRuntime/Training/CheckpointState.shared.cs
Outdated
Show resolved
Hide resolved
orttraining/orttraining/training_api/include/onnxruntime_training_c_api.h
Show resolved
Hide resolved
csharp/src/Microsoft.ML.OnnxRuntime/Training/CheckpointState.shared.cs
Outdated
Show resolved
Hide resolved
csharp/src/Microsoft.ML.OnnxRuntime/Training/CheckpointState.shared.cs
Outdated
Show resolved
Hide resolved
csharp/src/Microsoft.ML.OnnxRuntime/Training/TrainingSession.shared.cs
Outdated
Show resolved
Hide resolved
csharp/src/Microsoft.ML.OnnxRuntime/Training/TrainingSession.shared.cs
Outdated
Show resolved
Hide resolved
csharp/src/Microsoft.ML.OnnxRuntime/Training/TrainingSession.shared.cs
Outdated
Show resolved
Hide resolved
csharp/src/Microsoft.ML.OnnxRuntime/Training/TrainingSession.shared.cs
Outdated
Show resolved
Hide resolved
…baijumeswani/update-checkpoint-params
csharp/src/Microsoft.ML.OnnxRuntime/Training/CheckpointState.shared.cs
Outdated
Show resolved
Hide resolved
…baijumeswani/update-checkpoint-params
|
Thank you for the review @yuslepukhin @AdamLouly @pengwa :) |
Cherry-pick PRs: #18026 #17912 #17901 “2 lines added whitespace errors when cherry-picking" #17293 #17364 #17505 #17885 This PR contains all the cherry-picks for the patch release except: 1. The PRs marked with sdxl_llama 2. #17772 which has a merge conflict. --------- Co-authored-by: Chi Lo <Chi.Lo@microsoft.com> Co-authored-by: Chi Lo <54722500+chilo-ms@users.noreply.github.com> Co-authored-by: Scott McKay <Scott.McKay@microsoft.com> Co-authored-by: Baiju Meswani <bmeswani@microsoft.com> Co-authored-by: Kaz Nishimura <kazssym@linuxfront.com> Co-authored-by: Scott McKay <skottmckay@gmail.com>
This pull request exposes the checkpoint parameters to users in C, C++, C# and Python.
Users will be able to query the current value of the parameters and update the parameters after this pull request.