Conversation
| return cached_filenames, sharded_metadata | ||
|
|
||
|
|
||
| def load_state_dict(checkpoint_file: Union[str, os.PathLike]): |
There was a problem hiding this comment.
This is isolating the bit of code where we load the state dict for the first time in current from_pretrained as I need to call this several times. There is no change in that code.
| ) | ||
|
|
||
|
|
||
| def _load_state_dict_into_model(model_to_load, state_dict, start_prefix): |
There was a problem hiding this comment.
This is isolating the bit of code that loads a state dict into a model in the current PreTrainedModel._load_state_dict_into_model on master. It takes two bits of that function: the replacement of gamma/beta at the beginning, and the actual load in the middle.
| # This variable will flag if we're loading a sharded checkpoint. In this case the archive file is just the | ||
| # index of the files. | ||
| is_sharded = False | ||
| sharded_metadata = None |
There was a problem hiding this comment.
We are in from_pretrained and the main diff now, this is the part that needs the most attention :-)
| elif os.path.isfile(os.path.join(pretrained_model_name_or_path, WEIGHTS_INDEX_NAME)): | ||
| # Load from a sharded PyTorch checkpoint | ||
| archive_file = os.path.join(pretrained_model_name_or_path, WEIGHTS_INDEX_NAME) | ||
| is_sharded = True |
There was a problem hiding this comment.
This deals with a sharded checkpoint when pretrained_model_name_or_path is a local folder.
There was a problem hiding this comment.
I think I would prefer the logic following this line of code to be a bit different because a new try: ... method in a except EntryNotFoundError makes the logic a bit harder to follow in my opinion.
I like here that we overwrite the archive_file to the WEIGHTS_INDEX_NAME and set the is_sharded flag to True. Couldn't we do this for the other use case as well, i.e. when the pretrained_model_name_or_path is a link on the Hub by calling some kind of is_remote_url_sharded function?
What I mean is to do the following. Change the lines 1664- from:
else:
filename = WEIGHTS_NAMEto something like
elif is_remote_url_sharded(pretrained_model_name_or_path, filename=WEIGHTS_INDEX_NAME, revision=revision, mirror=mirror):
filenname = WEIGHTS_INDEX_NAME
is_sharded = True
else:
filename = WEIGHTS_NAMEAnd then we can maybe reduce the following complexity a bit. It does introduce one additional call to the Hub though at every from_pretrained(...) so that's maybe why the try...except is preferred?
I do think it would make the code easier to read though
| pretrained_model_name_or_path, | ||
| filename=WEIGHTS_INDEX_NAME, | ||
| revision=revision, | ||
| mirror=mirror, |
There was a problem hiding this comment.
This deals with a sharded checkpoint when pretrained_model_name_or_path is a model ID on the Hub.
There was a problem hiding this comment.
I like this approach in the sense that WEIGHTS_NAME will trump any index, and sharding is the fallback (The source of truth is clearly established, pytorch_model.bin>index> shards).
It also follows preexisting code nicely.
Maybe in a follow up disentangling the various try..except ( and the associated network calls) could be done later by getting in a single call to the hub, both SHAs and existing or not filenames on the hub (if it's possible). Definitely does not belong in this PR I think.
| def _fix_key(key): | ||
| if "beta" in key: | ||
| return key.replace("beta", "bias") | ||
| if "gamma" in key: | ||
| return key.replace("gamma", "weight") | ||
| return key | ||
|
|
||
| loaded_keys = [_fix_key(key) for key in loaded_keys] |
There was a problem hiding this comment.
This is doing on the keys what the code moved above is doing directly on the state dict. Necessary to have missing_keys and unexpected_keys be correct.
| # Mistmatched keys contains tuples key/shape1/shape2 of weights in the checkpoint that have a shape not | ||
| # matching the weights in the model. | ||
| mismatched_keys = [] | ||
| if ignore_mismatched_sizes: | ||
| for checkpoint_key in loaded_keys: | ||
| model_key = checkpoint_key | ||
| if remove_prefix_from_model: | ||
| # The model key starts with `prefix` but `checkpoint_key` doesn't so we add it. | ||
| model_key = f"{prefix}.{checkpoint_key}" | ||
| elif add_prefix_to_model: | ||
| # The model key doesn't start with `prefix` but `checkpoint_key` does so we remove it. | ||
| model_key = ".".join(checkpoint_key.split(".")[1:]) | ||
|
|
||
| if ( | ||
| model_key in model_state_dict | ||
| and state_dict[checkpoint_key].shape != model_state_dict[model_key].shape | ||
| ): | ||
| mismatched_keys.append( | ||
| (checkpoint_key, state_dict[checkpoint_key].shape, model_state_dict[model_key].shape) | ||
| ) | ||
| del state_dict[checkpoint_key] |
There was a problem hiding this comment.
This part is moved below without changes, as it needs the state_dict, which we don't have yet here when the checkpoint is sharded.
| # copy state_dict so _load_from_state_dict can modify it | ||
| metadata = getattr(state_dict, "_metadata", None) | ||
| state_dict = state_dict.copy() | ||
| if metadata is not None: | ||
| state_dict._metadata = metadata | ||
|
|
||
| error_msgs = [] | ||
|
|
||
| # PyTorch's `_load_from_state_dict` does not copy parameters in a module's descendants | ||
| # so we need to apply the function recursively. | ||
| def load(module: nn.Module, prefix=""): | ||
| local_metadata = {} if metadata is None else metadata.get(prefix[:-1], {}) | ||
| args = (state_dict, prefix, local_metadata, True, [], [], error_msgs) | ||
| if is_deepspeed_zero3_enabled(): | ||
| import deepspeed | ||
|
|
||
| # because zero3 puts placeholders in model params, this context | ||
| # manager gathers (unpartitions) the params of the current layer, then loads from | ||
| # the state dict and then re-partitions them again | ||
| with deepspeed.zero.GatheredParameters(list(module.parameters(recurse=False)), modifier_rank=0): | ||
| if torch.distributed.get_rank() == 0: | ||
| module._load_from_state_dict(*args) | ||
| else: | ||
| module._load_from_state_dict(*args) | ||
|
|
||
| for name, child in module._modules.items(): | ||
| if child is not None: | ||
| load(child, prefix + name + ".") | ||
|
|
There was a problem hiding this comment.
Moved without changes to _load_state_dict_into_model.
| mismatched_keys = [] | ||
| if ignore_mismatched_sizes: | ||
| for checkpoint_key in loaded_keys: | ||
| model_key = checkpoint_key | ||
| if remove_prefix_from_model: | ||
| # The model key starts with `prefix` but `checkpoint_key` doesn't so we add it. | ||
| model_key = f"{prefix}.{checkpoint_key}" | ||
| elif add_prefix_to_model: | ||
| # The model key doesn't start with `prefix` but `checkpoint_key` does so we remove it. | ||
| model_key = ".".join(checkpoint_key.split(".")[1:]) | ||
|
|
||
| if ( | ||
| model_key in model_state_dict | ||
| and state_dict[checkpoint_key].shape != model_state_dict[model_key].shape | ||
| ): | ||
| mismatched_keys.append( | ||
| (checkpoint_key, state_dict[checkpoint_key].shape, model_state_dict[model_key].shape) | ||
| ) | ||
| del state_dict[checkpoint_key] |
There was a problem hiding this comment.
Moved from above, as commented.
| del state_dict[checkpoint_key] | ||
|
|
||
| error_msgs = _load_state_dict_into_model(model_to_load, state_dict, start_prefix) | ||
| else: |
There was a problem hiding this comment.
Same code, but in a for loop for each shard.
|
The documentation is not available anymore as the PR was closed or merged. |
There was a problem hiding this comment.
Great work, @sgugger!!! Thank you for working on this
I only did a quick pass and want to play with the code later today and will give more feedback afterwards - added a few nits meanwhile.
So one of the concerns raised by @Narsil is that a user may upload updates that are out of sync. #13548 (comment)
Should we store in the index file the sha256 values of the shards, since this is what we get from LFS, e.g. https://huggingface.co/distilgpt2/blob/main/pytorch_model.bin
SHA256: ecbb4e22dd2b9dcc43b2622e1b87ebb9361fb31e496b98ea01a38785c1dbaa01
so that once the index file has been downloaded we could validate the checkpoint integrity. We don't need to run sha256 which would be super slow - we just need to check that the index file's sha256 records match the lfs sha256 records as in the example above.
This of course would require running sha256 when creating the sharded checkpoint, which would be slow. But it won't happen for small models.
This of course can be done as a next iteration and shouldn't impede with this PR as it is now.
src/transformers/modeling_utils.py
Outdated
| Example: | ||
|
|
||
| ```py | ||
| >>> dtype_size(torch.float32) 4""" |
There was a problem hiding this comment.
input and output got wrapped together
There was a problem hiding this comment.
Will fix, thanks for flagging!
src/transformers/modeling_utils.py
Outdated
| The sub-checkpoints are determined by iterating through the `state_dict` in the order of its keys, so there is no | ||
| optimization made to make each sub-checkpoint as close as possible to the maximum size passed. |
There was a problem hiding this comment.
I'm having a hard time parsing the " is no optimization made"
Do you mean to say that instead of splitting an 11GB checkpoint into 5.5+5.5 we do 10+1 split?
There was a problem hiding this comment.
I mean that if at a given key we're at, say 2.5GB, and the next weight is 5GB, we won't go look for some other weight that is less than 2.5GB, just stop the current shard there and go to the next.
So we may end up with 1 + 10 split if there is one wieght that is 10GB on its own, yes.
There was a problem hiding this comment.
got it. Will it then fail if if there will be at least one weight bigger than 10GB?
There was a problem hiding this comment.
might be easier to understand immediately what is meant if we show an example with numbers:
if the limit is 10GB and we have weights of sizes
[6GB, 6GB, 2GB, 6GB, 2GB, 2GB]they will get sharded as[6GB], [6+2GB], [6+2+2GB]and not[6+2+2GB], [6+2GB], [6GB]
There was a problem hiding this comment.
It doesn't fail, there will just be one shard that is larger than the max_size in this case. We might add intra-weight sharding in the future, but for now this should suffice.
Also note that a user can always shard their model themselves (thom is doing that for big science for instance). save_pretrained just provides one default.
Thanks for the example, will add it to the doc!
src/transformers/modeling_utils.py
Outdated
|
|
||
| <Tip warning={true}> | ||
|
|
||
| If one of the model's weight is bigger that `max_size`, it will end up in its own sub-checkpoint which will have a |
There was a problem hiding this comment.
| If one of the model's weight is bigger that `max_size`, it will end up in its own sub-checkpoint which will have a | |
| If one of the model's weight is bigger than `max_size`, it will end up in its own sub-checkpoint which will have a |
There was a problem hiding this comment.
but it can't be bigger than 30GB since then the hub will fail.
Co-authored-by: Stas Bekman <stas00@users.noreply.github.com>
…mers into checkpoint_sharding
src/transformers/modeling_utils.py
Outdated
| raise ValueError("`size` is not in a valid format. Use an integer followed by the unit, e.g., '5GB'.") | ||
|
|
||
|
|
||
| def dtype_size(dtype): |
There was a problem hiding this comment.
(nit) since both bytes and bits is pretty common maybe worth to call it dtype_byte_size?
| shard_file = WEIGHTS_NAME.replace(".bin", f"-{idx+1:05d}-of-{len(shards):05d}.bin") | ||
| save_function(shard, os.path.join(save_directory, shard_file)) | ||
| for key in shard.keys(): | ||
| weight_map[key] = shard_file |
There was a problem hiding this comment.
nice the keys are the actually PyTorch parameter names here no? That's cool - then the user knows which layer names are in which shard
There was a problem hiding this comment.
Yes, those are the full parameter names as in the state dict :-)
| except EntryNotFoundError: | ||
| raise EnvironmentError( | ||
| f"{pretrained_model_name_or_path} does not appear to have a file named {shard_filename} which is " | ||
| "required according to the checkpoint index." |
There was a problem hiding this comment.
| "required according to the checkpoint index." | |
| f"required according to the checkpoint index: {os.path.join(pretrained_model_name_or_path, WEIGHTS_INDEX_NAME}." |
There was a problem hiding this comment.
(nit) this should be the correct path to the weight index no?
There was a problem hiding this comment.
In this code path pretrained_model_name_or_path is a model ID, so that won't work.
| mismatched_keys = [] | ||
| if ignore_mismatched_sizes: | ||
| for checkpoint_key in loaded_keys: | ||
| model_key = checkpoint_key |
There was a problem hiding this comment.
(nit) maybe worth to refactor this into a new function:
model_key = checkpoint_key
if remove_prefix_from_model:
# The model key starts with `prefix` but `checkpoint_key` doesn't so we add it.
model_key = f"{prefix}.{checkpoint_key}"
elif add_prefix_to_model:
# The model key doesn't start with `prefix` but `checkpoint_key` does so we remove it.
model_key = ".".join(checkpoint_key.split(".")[1:])
if (
model_key in model_state_dict
and state_dict[checkpoint_key].shape != model_state_dict[model_key].shape
):
mismatched_keys.append(
(checkpoint_key, state_dict[checkpoint_key].shape, model_state_dict[model_key].shape)
)
del state_dict[checkpoint_key]->
if not match_shape(checkpoint_key, state_dict):
mismatched_keys.append(
(checkpoint_key, state_dict[checkpoint_key].shape, model_state_dict[model_key].shape)
)
del state_dict[checkpoint_key]There was a problem hiding this comment.
Yes to refactoring, but I'd leave it for the followup PR, I have more refactoring to do :-)
| # Finally, check the model can be reloaded | ||
| new_model = BertModel.from_pretrained(tmp_dir) | ||
| for p1, p2 in zip(model.parameters(), new_model.parameters()): | ||
| self.assertTrue(torch.allclose(p1, p2)) |
There was a problem hiding this comment.
This test is very nice as it is very high level.
However, I feel it deserves to be split into multiple smaller tests (We can keep the model loaded only once).
with tempfile.TemporaryDirectory() as tmp_dir:
model.save_pretrained(tmp_dir, max_shard_size="50kB")
self.assertEqual(os.listdir(tmp_dir), ["pytorch_model-00001-of-00002.bin", "pytorch_model-00002-of-00002.bin", "pytorch_model.bin.index.json"])
self.assertEqual(os.path.getsize("pytorch_model-00001-of-00002.bin", 12_230_400)
self.assertEqual(os.path.getsize("pytorch_model-00001-of-00002.bin", 62_230_400)
# this file is larger than expected 50Kb, so it must be single weight
state_dict = torch.load(shard_file)
self.assertEqual(len(state_dict), 1)
self.assertEqual(json.loads("pytorch_model.bin.index.json"), {"metadata": {"total_size": 74_000_000, "weights_map": {....} )
# Maybe weights map is this is too big and doing the calculated check is more acceptable here, for brevity only
#This is great as is.
new_model = BertModel.from_pretrained(tmp_dir)
for p1, p2 in zip(model.parameters(), new_model.parameters()):
self.assertTrue(torch.allclose(p1, p2))
# Another test or separated logic
with tempfile.TemporaryDirectory() as tmp_dir:
model.save_pretrained(tmp_dir, max_shard_size="100kB")
self.assertEqual(os.listdir(tmp_dir), ["pytorch_model.bin"])
# No indexing.
self.assertEqual(os.path.getsize("pytorch_model.bin", 74_000_00)Just as a tendency, I feel like tests with static values (hardcoded) as expected values, tend to provide more value as they tend to fail more often (calculated values move both with the code and the test, sometimes both contain the bug, masking the error because the test is still green)
They are also more readable (for instance in my example, I know that the no sharding behavior is sound for a 72k model file, and that 62 + 12 = 74 so it looks consistent). We also know for a fact that shards > 50k are tested against (in your test code I am pretty sure it does, but we can only know for sure by putting a debugger and checking the actual values)
This is more a rule of thumb, but I think it could apply quite well here.
There was a problem hiding this comment.
The problem with hard-coded values is that tests suddenly fail on master and on every contributor's PR when the model/dataset was changed. That's why the test is more convoluted, but also more resilient to changes.
I agree it would be more readable your way, but I've been burned too many times to consider it. Will add something more explicit in the additional tests required by Patrick.
There was a problem hiding this comment.
The problem with hard-coded values is that tests suddenly fail on master and on every contributor's PR when the model/dataset was changed. That's why the test is more convoluted, but also more resilient to changes.
I understand, I consider it's actually a good thing, and modifying anything in hf-internal-testing should be considered a modification of tests (and should be treated with as much care by everyone), adding new models there should be free enough that adding should be preferred over modifying in my very personal opinion.
Now, you are the one getting burned, and so I definitely understand the argument that you don't want to live it.
The current test is good as it is.
Maybe when we have more granular permissions on the hub we could freeze (to prevent this issue) some models for this kind of test ?
tests/test_modeling_common.py
Outdated
| for p1, p2 in zip(model.parameters(), new_model.parameters()): | ||
| self.assertTrue(torch.equal(p1, p2)) | ||
|
|
||
| @require_torch |
There was a problem hiding this comment.
(nice to have)
Could it make sense to also test the function save_and_shard_checkpoint separately?
patrickvonplaten
left a comment
There was a problem hiding this comment.
Super clean implementation! For me the code is very readable and understandable which is great given the complexity of the feature.
Two comments:
- Think the code could be made a bit easier to read by doing a quick check whether the HF Hub path is a sharded checkpoint before jumping into the try...except cascade. This would however add an additional call to the Hub for every
from_pretrained(...)call, so I totally understand if it's not worth it here! Think it would make the code cleaner, but not sure how expensive an additional call to the Hub is here - Could be cool to maybe test the
save_and_shard_checkpointwith it's edge cases, e.g. themax_sizeall fit in one shard, shards always stay belowmax_sizeexcept if one weight is bigger, ...
LysandreJik
left a comment
There was a problem hiding this comment.
Very nice PR, thanks for working on it @sgugger. I've highlighted a few edge cases which I think should be handled before merging.
- There seems to be some naming which is
max_shard_size, and some other which ismax_size. I would prefer everything to bemax_shard_sizefor coherence. - kibibytes vs kilobytes (see comment below)?
- We can now have two separate models that live in the same folder, and each could potentially be loaded/saved at the same time. This isn't ideal as we could theoretically save two models consecutively, the latter being larger than the former, but when loading the recently-saved model, you would end up loading the old model instead. We could take the following action
- Manually remove the previous checkpoint before saving the current one if the number of shard is different (:smiley:)
- Raise an error when saving a model in a folder that has a different amount of files (:confused:)
- Print a warning that the resulting repo might be incoherent (:confused:)
- I think the
push_to_hubmethod could greatly benefit from that having themax_shard_sizeparameter and passing it tosave_pretrained. (Currently works through.save_pretrained('xxx', max_shard_size='yyy').- This will also imply a few additional changes to remove the previous checkpoint from the repository, as since not all files may be overwritten, the additional remaining files should be removed manually (implemented something similar here huggingface/datasets#3098)
src/transformers/modeling_utils.py
Outdated
| if size.upper().endswith("GB"): | ||
| return int(size[:-2]) * (2**30) |
There was a problem hiding this comment.
Looking forward to seeing TB added here as well 😃
src/transformers/modeling_utils.py
Outdated
| def save_and_shard_checkpoint( | ||
| save_directory: Union[str, os.PathLike], | ||
| state_dict: Dict[str, torch.Tensor], | ||
| max_size: Union[int, str] = "5GB", |
There was a problem hiding this comment.
Would prefer this to be max_shard_size (especially given the that the docstring is so)
src/transformers/modeling_utils.py
Outdated
| if size.upper().endswith("KB"): | ||
| return int(size[:-2]) * (2**10) |
There was a problem hiding this comment.
I think here it's a bit misleading to have the calculations happen in binary vs in decimal. This means that when specifying 1KB, right now it will save files that are 1024 bytes (kibibytes).
I find it a bit misleading as since this is the system that's used across the file, putting for example 100MB will save files that are ~= 104MB when read by the filesystem (which is decimal by default).
I would either:
- Clarify that here we really mean binary, by specifying that we are working with
MiBvalues - Update so that we are calculating with
1e3,1e6, etc. - Handle both
MBandMiBand shard accordingly
There was a problem hiding this comment.
Will handle both, it's easy enough.
src/transformers/modeling_utils.py
Outdated
| else: | ||
| raise ValueError from e |
There was a problem hiding this comment.
(nit) This is the error that will get raised if a shard is missing. I'd favor a more explicit error that the model was probably not saved correctly, or simply that a shard is missing. Right now one would get:
FileNotFoundError: [Errno 2] No such file or directory: 'here/pytorch_model-00005-of-00005.bin'
Narsil
left a comment
There was a problem hiding this comment.
Looks good to me.
I tried to share some comments to maybe improve it a little further.
src/transformers/modeling_utils.py
Outdated
| total_size = 0 | ||
|
|
||
| for key, weight in state_dict.items(): | ||
| weight_size = weight.numel() * dtype_size(weight.dtype) |
There was a problem hiding this comment.
comment: This function is actually slightly wrong as tensors take a little more than this, as they need storage for byteorder + shape info + stride ( + it's a float when using bool).
sys.getsizeof(tensor.storage())Should work more precisely (and seems to work on torch==1.6.0 ( https://pytorch.org/docs/1.6.0/tensors.html?highlight=storage#torch.Tensor.storage)
Not sure how this behaves on non-full views (tensor slices, where the buffer could contain holes)
Not worth modifying current code, just sharing
src/transformers/modeling_utils.py
Outdated
| Directory to which to save. Will be created if it doesn't exist. | ||
| state_dict (dictionary of `torch.Tensor`): | ||
| The state dictionary of the model to save. | ||
| max_shard_size (`int` or `str`, *optional*, defaults to `"10GB"`): |
There was a problem hiding this comment.
The doc says 10 while the default argument says 5
There was a problem hiding this comment.
Thanks for flagging!
src/transformers/modeling_utils.py
Outdated
| metadata = {"total_size": total_size} | ||
| index = {"metadata": metadata, "weight_map": weight_map} | ||
| with open(save_index_file, "w", encoding="utf-8") as f: | ||
| content = json.dumps(index, indent=2, sort_keys=True) + "\n" |
There was a problem hiding this comment.
Any particular reasons to not use
json.dump(index, f, indent=2, sort_keys=True)? And why the extra newline ?
Current code is perfectly fine.
There was a problem hiding this comment.
Just using the same as in configuration/feature extractor json dumps we have in the rest of the library. I'm assuming the last line is added to play nicely with some editors, but not 100% sure. Will leave it as is for consistency
There was a problem hiding this comment.
Makes sense, just wanted to make sure I wasn't missing something.
I remember you taught me that encoding="utf-8" was super important for Windows, which I have since in mind when I open files :)
src/transformers/modeling_utils.py
Outdated
|
|
||
| logger.info(f"Model weights saved in {output_model_file}") | ||
| # Actually save the `state_dict`, with sharding if the model is too big. | ||
| save_and_shard_checkpoint(save_directory, state_dict, max_size=max_shard_size, save_function=save_function) |
There was a problem hiding this comment.
Design wise, I think the previous function separated better the logic, feeding lambdas to functions makes code harder to debug in my experience.
Is something like:
# Maybe split the state dict across shards, if we do shard, then index is returned and needs to be saved too
shards, index, index_filename = shard_checkpoint(state_dict, max_size=max_shard_size)
if index is not None:
save_index(index, os.path.join(save_directory, index_filename)) # Or the code directly it's small enough
# If no sharding happens `shards` will be a dict with a single entry.
for shard_file, shard_weights in shards.items():
save_function(shard_weights, os.path.join(save_directory, shard_file))viable ?
Would make testing shard_checkpoints slightly easier.
In general I feel like if we treat shards being splitted by default being {"pytorch_model.bin": state_dict} it could make the code slightly easier to follow (not everywhere though, to make as little change as possible).
| pretrained_model_name_or_path, | ||
| filename=WEIGHTS_INDEX_NAME, | ||
| revision=revision, | ||
| mirror=mirror, |
There was a problem hiding this comment.
I like this approach in the sense that WEIGHTS_NAME will trump any index, and sharding is the fallback (The source of truth is clearly established, pytorch_model.bin>index> shards).
It also follows preexisting code nicely.
Maybe in a follow up disentangling the various try..except ( and the associated network calls) could be done later by getting in a single call to the hub, both SHAs and existing or not filenames on the hub (if it's possible). Definitely does not belong in this PR I think.
| # Finally, check the model can be reloaded | ||
| new_model = BertModel.from_pretrained(tmp_dir) | ||
| for p1, p2 in zip(model.parameters(), new_model.parameters()): | ||
| self.assertTrue(torch.allclose(p1, p2)) |
There was a problem hiding this comment.
This test is very nice as it is very high level.
However, I feel it deserves to be split into multiple smaller tests (We can keep the model loaded only once).
with tempfile.TemporaryDirectory() as tmp_dir:
model.save_pretrained(tmp_dir, max_shard_size="50kB")
self.assertEqual(os.listdir(tmp_dir), ["pytorch_model-00001-of-00002.bin", "pytorch_model-00002-of-00002.bin", "pytorch_model.bin.index.json"])
self.assertEqual(os.path.getsize("pytorch_model-00001-of-00002.bin", 12_230_400)
self.assertEqual(os.path.getsize("pytorch_model-00001-of-00002.bin", 62_230_400)
# this file is larger than expected 50Kb, so it must be single weight
state_dict = torch.load(shard_file)
self.assertEqual(len(state_dict), 1)
self.assertEqual(json.loads("pytorch_model.bin.index.json"), {"metadata": {"total_size": 74_000_000, "weights_map": {....} )
# Maybe weights map is this is too big and doing the calculated check is more acceptable here, for brevity only
#This is great as is.
new_model = BertModel.from_pretrained(tmp_dir)
for p1, p2 in zip(model.parameters(), new_model.parameters()):
self.assertTrue(torch.allclose(p1, p2))
# Another test or separated logic
with tempfile.TemporaryDirectory() as tmp_dir:
model.save_pretrained(tmp_dir, max_shard_size="100kB")
self.assertEqual(os.listdir(tmp_dir), ["pytorch_model.bin"])
# No indexing.
self.assertEqual(os.path.getsize("pytorch_model.bin", 74_000_00)Just as a tendency, I feel like tests with static values (hardcoded) as expected values, tend to provide more value as they tend to fail more often (calculated values move both with the code and the test, sometimes both contain the bug, masking the error because the test is still green)
They are also more readable (for instance in my example, I know that the no sharding behavior is sound for a 72k model file, and that 62 + 12 = 74 so it looks consistent). We also know for a fact that shards > 50k are tested against (in your test code I am pretty sure it does, but we can only know for sure by putting a debugger and checking the actual values)
This is more a rule of thumb, but I think it could apply quite well here.
LysandreJik
left a comment
There was a problem hiding this comment.
This looks good to me. Thanks for doing the heavy lifting, @sgugger!
| if isinstance(size, int): | ||
| return size | ||
| if size.upper().endswith("GIB"): | ||
| return int(size[:-3]) * (2**30) | ||
| if size.upper().endswith("MIB"): | ||
| return int(size[:-3]) * (2**20) | ||
| if size.upper().endswith("KIB"): | ||
| return int(size[:-3]) * (2**10) | ||
| if size.upper().endswith("GB"): | ||
| return int(size[:-2]) * (10**9) | ||
| if size.upper().endswith("MB"): | ||
| return int(size[:-2]) * (10**6) | ||
| if size.upper().endswith("KB"): | ||
| return int(size[:-2]) * (10**3) | ||
| raise ValueError("`size` is not in a valid format. Use an integer followed by the unit, e.g., '5GB'.") |
| # Clean the folder from a previous save | ||
| for filename in os.listdir(save_directory): | ||
| full_filename = os.path.join(save_directory, filename) | ||
| if filename.startswith(WEIGHTS_NAME[:-4]) and os.path.isfile(full_filename): | ||
| os.remove(full_filename) |
tests/test_modeling_common.py
Outdated
| with tempfile.TemporaryDirectory() as tmp_dir: | ||
| with tempfile.TemporaryDirectory() as tmp_dir: | ||
| # We use the same folder for various sizes to make sure a new save erases the old checkpoint. | ||
| for max_size in ["50kB", "100kB", "200kB"]: |
There was a problem hiding this comment.
Would be down to try with the kibibytes too (50KiB for example)
| max_shard_size (`int` or `str`, *optional*, defaults to `"10GB"`): | ||
| The maximum size for a checkpoint before being sharded. Checkpoints shard will then be each of size | ||
| lower than this size. If expressed as a string, needs to be digits followed by a unit (like `"5MB"`). | ||
|
|
||
| <Tip warning={true}> | ||
|
|
||
| If a single weight of the model is bigger than `max_shard_size`, it will be in its own checkpoint shard | ||
| which will be bigger than `max_shard_size`. | ||
|
|
||
| </Tip> |
|
@sgugger is it possibile to apply sharding to a current pretrained model that is bigger than 10gb (let's say t0pp like), using load_pretrained-> save_pretrained To save a sharded version, and the again load_pretrained That should handle therefore load the sharded checkpoint? Thanks |
|
Yes, that's exactly what you should do! |
|
Hahaha amazing! Where is it? |
|
@sgugger we could possibly upload the sharded checkpoint in addition to the current checkpoint in the same repo branch no? I thought that's what we wanted to do to preserve backward compat while still upgrading those big models |
|
If you put them in the same branch, |
|
I see. should we consider changing this behavior? we've seen that git-branches are not super practical for those large models (cc @osanseviero ) |
|
(yep, i know this is going to require to call something like /api/models/xxx at the start of |
|
or maybe we can say that we'll just push sharded models from now on (for example the bigscience models will be sharded-only) I think that's actually fine |
|
or we can default to the shared model for the new version of |
|
yep that was what i was suggesting but implementation wise is a bit more complex (and affects all model repos not just the sharded ones) |
| weight_map = {} | ||
| shards = {} | ||
| for idx, shard in enumerate(sharded_state_dicts): | ||
| shard_file = WEIGHTS_NAME.replace(".bin", f"-{idx+1:05d}-of-{len(sharded_state_dicts):05d}.bin") |
There was a problem hiding this comment.
FYI in datasets we start at 0, not 1. Same as TF, dask and apache beam
What does this PR do?
This PR introduces the ability to create and load sharded checkpoints. It introduces a new argument in
save_pretrainedthat controls the maximum size of a checkpoint before being auto-sharded into smaller parts (which defaults to 10GB after internal discussion, which should be good with the Hub and environment with low RAM like Colab).When the model total size is less than this maximum size, it's saved exactly like before. When the model size is bigger, while traversing the state dict, each time a new weight tips the size above that threshold, a new shard is created. Therefore each shard is usually of size less than the max size, but if an individual weight has a size bigger than this threshold, it will spawn a shard containing only itself that will be of a bigger size.
On the
from_pretrainedside, a bit of refactoring was necessary to make the API deal with several state dict files. The main part is isolating the code that loads a state dict into a model in a separate function, so I can call it for each shard. I'm leaving comments on the PR to facilitate the review and I will follow up with another PR that refactorsfrom_pretrainedeven more for cleaning but with no change of actual code.cc @julien-c @thomwolf @stas00 @Narsil who interacted in the RFC.
Linked issue: #13548