Updated _load_pretrained_model_low_mem to check if keys are in the state_dict#16643
Conversation
|
I am wondering what is the correct place to add a test for this function |
|
The documentation is not available anymore as the PR was closed or merged. |
sgugger
left a comment
There was a problem hiding this comment.
Thanks for tackling this, the fix could be a tiny bit better I believe.
@stas00 It looks like the whole low_cpu_mem_usage is not tested at present? Maybe we can take care of tests in a separate PR for both a whole and a sharded checkpoint, so this can be merged fast for the RegNet PR?
src/transformers/modeling_utils.py
Outdated
| if isinstance(getattr(submodule, param_name), torch.nn.Parameter): | ||
| new_val = torch.nn.Parameter(new_val) | ||
| setattr(submodule, param_name, new_val) | ||
| if k in state_dict: |
There was a problem hiding this comment.
This test should go above on line 2165 with a continue if it's not True, to avoid looking for the param when we don't need it.
There was a problem hiding this comment.
Updated. the only difference to your comment is setattr(submodule, param_name, new_val) is after the check for the key
There was a problem hiding this comment.
There is nothing on line 2165, are you sure you pushed your update? The goal is to avoid spending any time in this block (starting at submodule, param_name = find_submodule_and_param_name(model, k)) when there is no need to.
There was a problem hiding this comment.
Apologies, updated. No need for ugly continue when you can do everything with a positive conditional flow
There was a problem hiding this comment.
Prefilter?
keys_to_load = [k for k in loaded_state_dict_keys if k in state_dict]
There was a problem hiding this comment.
it won't be the same if loaded_state_dict_keys doesn't include all state_dict keys. I'm pretty sure it is right now, but it may change. Note this warning:
transformers/src/transformers/modeling_utils.py
Line 2121 in 10131af
it was a quick hack to enable an urgent use so it needs to be completed to do a full support, in which case not all keys from state_dict might be loaded.
There was a problem hiding this comment.
I only suggested the comprehension way as another way to avoid too much conditional nesting.
continue is there for this exact reason and a functional programming tool
There was a problem hiding this comment.
You will have to put your continue inside an if statement. For me is the same, feel free to suggest the change that fits your coding style preference and I will happily change it. But, let's avoid unneeded nitpicking
There was a problem hiding this comment.
I suggested a simple alternative to deep conditional nesting here: #16643 (comment)
But I'm fine with the code the way it is now as well.
There was a problem hiding this comment.
Sure, what I meant is that prefiltering is the same as just iterating the loaded state_dict keys, that is the cleanest solution
|
Your plan works for me, Sylvain. I will work on the low mem test then today. |
stas00
left a comment
There was a problem hiding this comment.
LGTM, thank you for fixing this bug, @FrancescoSaverioZuppichini
What does this PR do?
This PR checks if any key is in the
state_dictbefore attempting to load it. If we have multiple checkpoints, not all keys are in every checkpoint.TODO