Skip to content

Del ort_model._modules to foward its accessing to torch_model._modules#14563

Merged
guyang3532 merged 1 commit intomicrosoft:mainfrom
guyang3532:fix6
Mar 3, 2023
Merged

Del ort_model._modules to foward its accessing to torch_model._modules#14563
guyang3532 merged 1 commit intomicrosoft:mainfrom
guyang3532:fix6

Conversation

@guyang3532
Copy link
Contributor

@guyang3532 guyang3532 commented Feb 3, 2023

General Description

Missing '_modules' attribute in ORTModule will cause load_state_dict for wrapped_ortmodule fail.
The ut of 'test_load_state_dict_for_wrapped_ortmodule' has not catch this problem is because it didn't copy the state_dict
and the two state_dicts shared the same memory.

Motivation and Context

reference:#7847

@guyang3532 guyang3532 changed the title Forward access of ort_model._modules to torch_model._modules [draft]Forward access of ort_model._modules to torch_model._modules Feb 3, 2023
@guyang3532 guyang3532 changed the title [draft]Forward access of ort_model._modules to torch_model._modules [draft]set ort_model._modules to torch_model._modules Feb 3, 2023
@guyang3532
Copy link
Contributor Author

guyang3532 commented Feb 3, 2023

I think a better solution should be forwarding the access of ORTModule._modules to TorchModule._modules to keep consistent rather than just copying it. But I have not figured out a good implementation. Do you have any suggestion? @baijumeswani @pengwa

@baijumeswani baijumeswani added the training issues related to ONNX Runtime training; typically submitted using template label Feb 3, 2023
@baijumeswani
Copy link
Contributor

baijumeswani commented Feb 3, 2023

ORTModule.load_state_dict already forwards the call to the underlying torch model. Does that not work?

@guyang3532 guyang3532 force-pushed the fix6 branch 2 times, most recently from 74388e9 to 0d46c2a Compare February 7, 2023 09:29
@guyang3532
Copy link
Contributor Author

ORTModule.load_state_dict already forwards the call to the underlying torch model. Does that not work?

As you described in #7847, because load_state_dict does not recursively call load_state_dict on its children, but instead it defines its own function load (inside load_state_dict) which does this task.

@guyang3532 guyang3532 changed the title [draft]set ort_model._modules to torch_model._modules Del ort_model._modules to foward it to torch_model._modules Feb 7, 2023
@guyang3532 guyang3532 changed the title Del ort_model._modules to foward it to torch_model._modules Del ort_model._modules to foward its accessing to torch_model._modules Feb 7, 2023
baijumeswani
baijumeswani previously approved these changes Feb 7, 2023
baijumeswani
baijumeswani previously approved these changes Feb 11, 2023
@guyang3532 guyang3532 merged commit c49f250 into microsoft:main Mar 3, 2023
mszhanyi pushed a commit that referenced this pull request Mar 9, 2023
#14563)

Missing '_modules' attribute in ORTModule will cause load_state_dict for
wrapped_ortmodule fail.

reference:#7847
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

training issues related to ONNX Runtime training; typically submitted using template

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants