Skip to content

Checkpoint sharding#16343

Merged
sgugger merged 21 commits intomainfrom
checkpoint_sharding
Mar 25, 2022
Merged

Checkpoint sharding#16343
sgugger merged 21 commits intomainfrom
checkpoint_sharding

Conversation

@sgugger
Copy link
Collaborator

@sgugger sgugger commented Mar 22, 2022

What does this PR do?

This PR introduces the ability to create and load sharded checkpoints. It introduces a new argument in save_pretrained that controls the maximum size of a checkpoint before being auto-sharded into smaller parts (which defaults to 10GB after internal discussion, which should be good with the Hub and environment with low RAM like Colab).

When the model total size is less than this maximum size, it's saved exactly like before. When the model size is bigger, while traversing the state dict, each time a new weight tips the size above that threshold, a new shard is created. Therefore each shard is usually of size less than the max size, but if an individual weight has a size bigger than this threshold, it will spawn a shard containing only itself that will be of a bigger size.

On the from_pretrained side, a bit of refactoring was necessary to make the API deal with several state dict files. The main part is isolating the code that loads a state dict into a model in a separate function, so I can call it for each shard. I'm leaving comments on the PR to facilitate the review and I will follow up with another PR that refactors from_pretrained even more for cleaning but with no change of actual code.

cc @julien-c @thomwolf @stas00 @Narsil who interacted in the RFC.

Linked issue: #13548

return cached_filenames, sharded_metadata


def load_state_dict(checkpoint_file: Union[str, os.PathLike]):
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is isolating the bit of code where we load the state dict for the first time in current from_pretrained as I need to call this several times. There is no change in that code.

)


def _load_state_dict_into_model(model_to_load, state_dict, start_prefix):
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is isolating the bit of code that loads a state dict into a model in the current PreTrainedModel._load_state_dict_into_model on master. It takes two bits of that function: the replacement of gamma/beta at the beginning, and the actual load in the middle.

# This variable will flag if we're loading a sharded checkpoint. In this case the archive file is just the
# index of the files.
is_sharded = False
sharded_metadata = None
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We are in from_pretrained and the main diff now, this is the part that needs the most attention :-)

elif os.path.isfile(os.path.join(pretrained_model_name_or_path, WEIGHTS_INDEX_NAME)):
# Load from a sharded PyTorch checkpoint
archive_file = os.path.join(pretrained_model_name_or_path, WEIGHTS_INDEX_NAME)
is_sharded = True
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This deals with a sharded checkpoint when pretrained_model_name_or_path is a local folder.

Copy link
Contributor

@patrickvonplaten patrickvonplaten Mar 23, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think I would prefer the logic following this line of code to be a bit different because a new try: ... method in a except EntryNotFoundError makes the logic a bit harder to follow in my opinion.

I like here that we overwrite the archive_file to the WEIGHTS_INDEX_NAME and set the is_sharded flag to True. Couldn't we do this for the other use case as well, i.e. when the pretrained_model_name_or_path is a link on the Hub by calling some kind of is_remote_url_sharded function?

What I mean is to do the following. Change the lines 1664- from:

                else:
                    filename = WEIGHTS_NAME

to something like

                elif is_remote_url_sharded(pretrained_model_name_or_path, filename=WEIGHTS_INDEX_NAME, revision=revision, mirror=mirror):
                    filenname = WEIGHTS_INDEX_NAME
                    is_sharded = True
               else:
                    filename = WEIGHTS_NAME

And then we can maybe reduce the following complexity a bit. It does introduce one additional call to the Hub though at every from_pretrained(...) so that's maybe why the try...except is preferred?

I do think it would make the code easier to read though

pretrained_model_name_or_path,
filename=WEIGHTS_INDEX_NAME,
revision=revision,
mirror=mirror,
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This deals with a sharded checkpoint when pretrained_model_name_or_path is a model ID on the Hub.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I like this approach in the sense that WEIGHTS_NAME will trump any index, and sharding is the fallback (The source of truth is clearly established, pytorch_model.bin>index> shards).
It also follows preexisting code nicely.

Maybe in a follow up disentangling the various try..except ( and the associated network calls) could be done later by getting in a single call to the hub, both SHAs and existing or not filenames on the hub (if it's possible). Definitely does not belong in this PR I think.

Comment on lines +1921 to +1928
def _fix_key(key):
if "beta" in key:
return key.replace("beta", "bias")
if "gamma" in key:
return key.replace("gamma", "weight")
return key

loaded_keys = [_fix_key(key) for key in loaded_keys]
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is doing on the keys what the code moved above is doing directly on the state dict. Necessary to have missing_keys and unexpected_keys be correct.

Comment on lines -1605 to -1631
# Mistmatched keys contains tuples key/shape1/shape2 of weights in the checkpoint that have a shape not
# matching the weights in the model.
mismatched_keys = []
if ignore_mismatched_sizes:
for checkpoint_key in loaded_keys:
model_key = checkpoint_key
if remove_prefix_from_model:
# The model key starts with `prefix` but `checkpoint_key` doesn't so we add it.
model_key = f"{prefix}.{checkpoint_key}"
elif add_prefix_to_model:
# The model key doesn't start with `prefix` but `checkpoint_key` does so we remove it.
model_key = ".".join(checkpoint_key.split(".")[1:])

if (
model_key in model_state_dict
and state_dict[checkpoint_key].shape != model_state_dict[model_key].shape
):
mismatched_keys.append(
(checkpoint_key, state_dict[checkpoint_key].shape, model_state_dict[model_key].shape)
)
del state_dict[checkpoint_key]
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This part is moved below without changes, as it needs the state_dict, which we don't have yet here when the checkpoint is sharded.

Comment on lines -1645 to -1679
# copy state_dict so _load_from_state_dict can modify it
metadata = getattr(state_dict, "_metadata", None)
state_dict = state_dict.copy()
if metadata is not None:
state_dict._metadata = metadata

error_msgs = []

# PyTorch's `_load_from_state_dict` does not copy parameters in a module's descendants
# so we need to apply the function recursively.
def load(module: nn.Module, prefix=""):
local_metadata = {} if metadata is None else metadata.get(prefix[:-1], {})
args = (state_dict, prefix, local_metadata, True, [], [], error_msgs)
if is_deepspeed_zero3_enabled():
import deepspeed

# because zero3 puts placeholders in model params, this context
# manager gathers (unpartitions) the params of the current layer, then loads from
# the state dict and then re-partitions them again
with deepspeed.zero.GatheredParameters(list(module.parameters(recurse=False)), modifier_rank=0):
if torch.distributed.get_rank() == 0:
module._load_from_state_dict(*args)
else:
module._load_from_state_dict(*args)

for name, child in module._modules.items():
if child is not None:
load(child, prefix + name + ".")

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Moved without changes to _load_state_dict_into_model.

Comment on lines +1984 to +2002
mismatched_keys = []
if ignore_mismatched_sizes:
for checkpoint_key in loaded_keys:
model_key = checkpoint_key
if remove_prefix_from_model:
# The model key starts with `prefix` but `checkpoint_key` doesn't so we add it.
model_key = f"{prefix}.{checkpoint_key}"
elif add_prefix_to_model:
# The model key doesn't start with `prefix` but `checkpoint_key` does so we remove it.
model_key = ".".join(checkpoint_key.split(".")[1:])

if (
model_key in model_state_dict
and state_dict[checkpoint_key].shape != model_state_dict[model_key].shape
):
mismatched_keys.append(
(checkpoint_key, state_dict[checkpoint_key].shape, model_state_dict[model_key].shape)
)
del state_dict[checkpoint_key]
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Moved from above, as commented.

del state_dict[checkpoint_key]

error_msgs = _load_state_dict_into_model(model_to_load, state_dict, start_prefix)
else:
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same code, but in a for loop for each shard.

@sgugger sgugger requested review from LysandreJik and patrickvonplaten and removed request for patrickvonplaten March 22, 2022 19:13
@HuggingFaceDocBuilderDev
Copy link

HuggingFaceDocBuilderDev commented Mar 22, 2022

The documentation is not available anymore as the PR was closed or merged.

Copy link
Contributor

@stas00 stas00 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Great work, @sgugger!!! Thank you for working on this

I only did a quick pass and want to play with the code later today and will give more feedback afterwards - added a few nits meanwhile.

So one of the concerns raised by @Narsil is that a user may upload updates that are out of sync. #13548 (comment)

Should we store in the index file the sha256 values of the shards, since this is what we get from LFS, e.g. https://huggingface.co/distilgpt2/blob/main/pytorch_model.bin

SHA256: ecbb4e22dd2b9dcc43b2622e1b87ebb9361fb31e496b98ea01a38785c1dbaa01

so that once the index file has been downloaded we could validate the checkpoint integrity. We don't need to run sha256 which would be super slow - we just need to check that the index file's sha256 records match the lfs sha256 records as in the example above.

This of course would require running sha256 when creating the sharded checkpoint, which would be slow. But it won't happen for small models.

This of course can be done as a next iteration and shouldn't impede with this PR as it is now.

Example:

```py
>>> dtype_size(torch.float32) 4"""
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

input and output got wrapped together

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Will fix, thanks for flagging!

Comment on lines +200 to +201
The sub-checkpoints are determined by iterating through the `state_dict` in the order of its keys, so there is no
optimization made to make each sub-checkpoint as close as possible to the maximum size passed.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm having a hard time parsing the " is no optimization made"

Do you mean to say that instead of splitting an 11GB checkpoint into 5.5+5.5 we do 10+1 split?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I mean that if at a given key we're at, say 2.5GB, and the next weight is 5GB, we won't go look for some other weight that is less than 2.5GB, just stop the current shard there and go to the next.

So we may end up with 1 + 10 split if there is one wieght that is 10GB on its own, yes.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

got it. Will it then fail if if there will be at least one weight bigger than 10GB?

Copy link
Contributor

@stas00 stas00 Mar 22, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

might be easier to understand immediately what is meant if we show an example with numbers:

if the limit is 10GB and we have weights of sizes [6GB, 6GB, 2GB, 6GB, 2GB, 2GB] they will get sharded as [6GB], [6+2GB], [6+2+2GB] and not [6+2+2GB], [6+2GB], [6GB]

Copy link
Collaborator Author

@sgugger sgugger Mar 22, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It doesn't fail, there will just be one shard that is larger than the max_size in this case. We might add intra-weight sharding in the future, but for now this should suffice.

Also note that a user can always shard their model themselves (thom is doing that for big science for instance). save_pretrained just provides one default.

Thanks for the example, will add it to the doc!


<Tip warning={true}>

If one of the model's weight is bigger that `max_size`, it will end up in its own sub-checkpoint which will have a
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
If one of the model's weight is bigger that `max_size`, it will end up in its own sub-checkpoint which will have a
If one of the model's weight is bigger than `max_size`, it will end up in its own sub-checkpoint which will have a

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

but it can't be bigger than 30GB since then the hub will fail.

sgugger and others added 4 commits March 22, 2022 16:53
raise ValueError("`size` is not in a valid format. Use an integer followed by the unit, e.g., '5GB'.")


def dtype_size(dtype):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

(nit) since both bytes and bits is pretty common maybe worth to call it dtype_byte_size?

shard_file = WEIGHTS_NAME.replace(".bin", f"-{idx+1:05d}-of-{len(shards):05d}.bin")
save_function(shard, os.path.join(save_directory, shard_file))
for key in shard.keys():
weight_map[key] = shard_file
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nice the keys are the actually PyTorch parameter names here no? That's cool - then the user knows which layer names are in which shard

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, those are the full parameter names as in the state dict :-)

except EntryNotFoundError:
raise EnvironmentError(
f"{pretrained_model_name_or_path} does not appear to have a file named {shard_filename} which is "
"required according to the checkpoint index."
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
"required according to the checkpoint index."
f"required according to the checkpoint index: {os.path.join(pretrained_model_name_or_path, WEIGHTS_INDEX_NAME}."

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

(nit) this should be the correct path to the weight index no?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In this code path pretrained_model_name_or_path is a model ID, so that won't work.

mismatched_keys = []
if ignore_mismatched_sizes:
for checkpoint_key in loaded_keys:
model_key = checkpoint_key
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

(nit) maybe worth to refactor this into a new function:

                        model_key = checkpoint_key
                        if remove_prefix_from_model:
                            # The model key starts with `prefix` but `checkpoint_key` doesn't so we add it.
                            model_key = f"{prefix}.{checkpoint_key}"
                        elif add_prefix_to_model:
                            # The model key doesn't start with `prefix` but `checkpoint_key` does so we remove it.
                            model_key = ".".join(checkpoint_key.split(".")[1:])

                        if (
                            model_key in model_state_dict
                            and state_dict[checkpoint_key].shape != model_state_dict[model_key].shape
                        ):
                            mismatched_keys.append(
                                (checkpoint_key, state_dict[checkpoint_key].shape, model_state_dict[model_key].shape)
                            )
                            del state_dict[checkpoint_key]

->

if not match_shape(checkpoint_key, state_dict):
                        mismatched_keys.append(
                            (checkpoint_key, state_dict[checkpoint_key].shape, model_state_dict[model_key].shape)
                        )
                        del state_dict[checkpoint_key]

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes to refactoring, but I'd leave it for the followup PR, I have more refactoring to do :-)

# Finally, check the model can be reloaded
new_model = BertModel.from_pretrained(tmp_dir)
for p1, p2 in zip(model.parameters(), new_model.parameters()):
self.assertTrue(torch.allclose(p1, p2))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Super cool test!

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This test is very nice as it is very high level.
However, I feel it deserves to be split into multiple smaller tests (We can keep the model loaded only once).

 with tempfile.TemporaryDirectory() as tmp_dir:
     model.save_pretrained(tmp_dir, max_shard_size="50kB")
     self.assertEqual(os.listdir(tmp_dir), ["pytorch_model-00001-of-00002.bin", "pytorch_model-00002-of-00002.bin", "pytorch_model.bin.index.json"])

     self.assertEqual(os.path.getsize("pytorch_model-00001-of-00002.bin", 12_230_400)
     self.assertEqual(os.path.getsize("pytorch_model-00001-of-00002.bin", 62_230_400)
     # this file is larger than expected 50Kb, so it must be single weight
     state_dict = torch.load(shard_file)
     self.assertEqual(len(state_dict), 1)
     
     
     self.assertEqual(json.loads("pytorch_model.bin.index.json"), {"metadata": {"total_size": 74_000_000, "weights_map": {....} )
     # Maybe weights map is this is too big and doing the calculated check is more acceptable here, for brevity only
     
       #This is great as is.
       new_model = BertModel.from_pretrained(tmp_dir)
       for p1, p2 in zip(model.parameters(), new_model.parameters()):
           self.assertTrue(torch.allclose(p1, p2))
           
           
           
# Another test or separated logic
with tempfile.TemporaryDirectory() as tmp_dir:
     model.save_pretrained(tmp_dir, max_shard_size="100kB")
     self.assertEqual(os.listdir(tmp_dir), ["pytorch_model.bin"])
     # No indexing.
     self.assertEqual(os.path.getsize("pytorch_model.bin", 74_000_00)

Just as a tendency, I feel like tests with static values (hardcoded) as expected values, tend to provide more value as they tend to fail more often (calculated values move both with the code and the test, sometimes both contain the bug, masking the error because the test is still green)

They are also more readable (for instance in my example, I know that the no sharding behavior is sound for a 72k model file, and that 62 + 12 = 74 so it looks consistent). We also know for a fact that shards > 50k are tested against (in your test code I am pretty sure it does, but we can only know for sure by putting a debugger and checking the actual values)

This is more a rule of thumb, but I think it could apply quite well here.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The problem with hard-coded values is that tests suddenly fail on master and on every contributor's PR when the model/dataset was changed. That's why the test is more convoluted, but also more resilient to changes.

I agree it would be more readable your way, but I've been burned too many times to consider it. Will add something more explicit in the additional tests required by Patrick.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The problem with hard-coded values is that tests suddenly fail on master and on every contributor's PR when the model/dataset was changed. That's why the test is more convoluted, but also more resilient to changes.

I understand, I consider it's actually a good thing, and modifying anything in hf-internal-testing should be considered a modification of tests (and should be treated with as much care by everyone), adding new models there should be free enough that adding should be preferred over modifying in my very personal opinion.

Now, you are the one getting burned, and so I definitely understand the argument that you don't want to live it.
The current test is good as it is.
Maybe when we have more granular permissions on the hub we could freeze (to prevent this issue) some models for this kind of test ?

for p1, p2 in zip(model.parameters(), new_model.parameters()):
self.assertTrue(torch.equal(p1, p2))

@require_torch
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

(nice to have)

Could it make sense to also test the function save_and_shard_checkpoint separately?

Copy link
Contributor

@patrickvonplaten patrickvonplaten left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Super clean implementation! For me the code is very readable and understandable which is great given the complexity of the feature.

Two comments:

  • Think the code could be made a bit easier to read by doing a quick check whether the HF Hub path is a sharded checkpoint before jumping into the try...except cascade. This would however add an additional call to the Hub for every from_pretrained(...) call, so I totally understand if it's not worth it here! Think it would make the code cleaner, but not sure how expensive an additional call to the Hub is here
  • Could be cool to maybe test the save_and_shard_checkpoint with it's edge cases, e.g. the max_size all fit in one shard, shards always stay below max_size except if one weight is bigger, ...

Copy link
Member

@LysandreJik LysandreJik left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Very nice PR, thanks for working on it @sgugger. I've highlighted a few edge cases which I think should be handled before merging.

  • There seems to be some naming which is max_shard_size, and some other which is max_size. I would prefer everything to be max_shard_size for coherence.
  • kibibytes vs kilobytes (see comment below)?
  • We can now have two separate models that live in the same folder, and each could potentially be loaded/saved at the same time. This isn't ideal as we could theoretically save two models consecutively, the latter being larger than the former, but when loading the recently-saved model, you would end up loading the old model instead. We could take the following action
    • Manually remove the previous checkpoint before saving the current one if the number of shard is different (:smiley:)
    • Raise an error when saving a model in a folder that has a different amount of files (:confused:)
    • Print a warning that the resulting repo might be incoherent (:confused:)
  • I think the push_to_hub method could greatly benefit from that having the max_shard_size parameter and passing it to save_pretrained. (Currently works through .save_pretrained('xxx', max_shard_size='yyy').
    • This will also imply a few additional changes to remove the previous checkpoint from the repository, as since not all files may be overwritten, the additional remaining files should be removed manually (implemented something similar here huggingface/datasets#3098)

Comment on lines +169 to +170
if size.upper().endswith("GB"):
return int(size[:-2]) * (2**30)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looking forward to seeing TB added here as well 😃

def save_and_shard_checkpoint(
save_directory: Union[str, os.PathLike],
state_dict: Dict[str, torch.Tensor],
max_size: Union[int, str] = "5GB",
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would prefer this to be max_shard_size (especially given the that the docstring is so)

Comment on lines +173 to +174
if size.upper().endswith("KB"):
return int(size[:-2]) * (2**10)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think here it's a bit misleading to have the calculations happen in binary vs in decimal. This means that when specifying 1KB, right now it will save files that are 1024 bytes (kibibytes).

I find it a bit misleading as since this is the system that's used across the file, putting for example 100MB will save files that are ~= 104MB when read by the filesystem (which is decimal by default).

I would either:

  • Clarify that here we really mean binary, by specifying that we are working with MiB values
  • Update so that we are calculating with 1e3, 1e6, etc.
  • Handle both MB and MiB and shard accordingly

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Will handle both, it's easy enough.

Comment on lines +391 to +392
else:
raise ValueError from e
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

(nit) This is the error that will get raised if a shard is missing. I'd favor a more explicit error that the model was probably not saved correctly, or simply that a shard is missing. Right now one would get:

FileNotFoundError: [Errno 2] No such file or directory: 'here/pytorch_model-00005-of-00005.bin'

Copy link
Contributor

@Narsil Narsil left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks good to me.

I tried to share some comments to maybe improve it a little further.

total_size = 0

for key, weight in state_dict.items():
weight_size = weight.numel() * dtype_size(weight.dtype)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

comment: This function is actually slightly wrong as tensors take a little more than this, as they need storage for byteorder + shape info + stride ( + it's a float when using bool).

sys.getsizeof(tensor.storage())

Should work more precisely (and seems to work on torch==1.6.0 ( https://pytorch.org/docs/1.6.0/tensors.html?highlight=storage#torch.Tensor.storage)
Not sure how this behaves on non-full views (tensor slices, where the buffer could contain holes)

Not worth modifying current code, just sharing

Directory to which to save. Will be created if it doesn't exist.
state_dict (dictionary of `torch.Tensor`):
The state dictionary of the model to save.
max_shard_size (`int` or `str`, *optional*, defaults to `"10GB"`):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The doc says 10 while the default argument says 5

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for flagging!

metadata = {"total_size": total_size}
index = {"metadata": metadata, "weight_map": weight_map}
with open(save_index_file, "w", encoding="utf-8") as f:
content = json.dumps(index, indent=2, sort_keys=True) + "\n"
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Any particular reasons to not use

json.dump(index, f, indent=2, sort_keys=True)

? And why the extra newline ?

Current code is perfectly fine.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just using the same as in configuration/feature extractor json dumps we have in the rest of the library. I'm assuming the last line is added to play nicely with some editors, but not 100% sure. Will leave it as is for consistency

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Makes sense, just wanted to make sure I wasn't missing something.
I remember you taught me that encoding="utf-8" was super important for Windows, which I have since in mind when I open files :)


logger.info(f"Model weights saved in {output_model_file}")
# Actually save the `state_dict`, with sharding if the model is too big.
save_and_shard_checkpoint(save_directory, state_dict, max_size=max_shard_size, save_function=save_function)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Design wise, I think the previous function separated better the logic, feeding lambdas to functions makes code harder to debug in my experience.

Is something like:

# Maybe split the state dict across shards, if we do shard, then index is returned and needs to be saved too
shards, index, index_filename = shard_checkpoint(state_dict, max_size=max_shard_size)
if index is not None:
    save_index(index, os.path.join(save_directory, index_filename)) # Or the code directly it's small enough
# If no sharding happens `shards` will be a dict with a single entry.
for shard_file, shard_weights in shards.items():
    save_function(shard_weights, os.path.join(save_directory, shard_file))

viable ?
Would make testing shard_checkpoints slightly easier.

In general I feel like if we treat shards being splitted by default being {"pytorch_model.bin": state_dict} it could make the code slightly easier to follow (not everywhere though, to make as little change as possible).

pretrained_model_name_or_path,
filename=WEIGHTS_INDEX_NAME,
revision=revision,
mirror=mirror,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I like this approach in the sense that WEIGHTS_NAME will trump any index, and sharding is the fallback (The source of truth is clearly established, pytorch_model.bin>index> shards).
It also follows preexisting code nicely.

Maybe in a follow up disentangling the various try..except ( and the associated network calls) could be done later by getting in a single call to the hub, both SHAs and existing or not filenames on the hub (if it's possible). Definitely does not belong in this PR I think.

# Finally, check the model can be reloaded
new_model = BertModel.from_pretrained(tmp_dir)
for p1, p2 in zip(model.parameters(), new_model.parameters()):
self.assertTrue(torch.allclose(p1, p2))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This test is very nice as it is very high level.
However, I feel it deserves to be split into multiple smaller tests (We can keep the model loaded only once).

 with tempfile.TemporaryDirectory() as tmp_dir:
     model.save_pretrained(tmp_dir, max_shard_size="50kB")
     self.assertEqual(os.listdir(tmp_dir), ["pytorch_model-00001-of-00002.bin", "pytorch_model-00002-of-00002.bin", "pytorch_model.bin.index.json"])

     self.assertEqual(os.path.getsize("pytorch_model-00001-of-00002.bin", 12_230_400)
     self.assertEqual(os.path.getsize("pytorch_model-00001-of-00002.bin", 62_230_400)
     # this file is larger than expected 50Kb, so it must be single weight
     state_dict = torch.load(shard_file)
     self.assertEqual(len(state_dict), 1)
     
     
     self.assertEqual(json.loads("pytorch_model.bin.index.json"), {"metadata": {"total_size": 74_000_000, "weights_map": {....} )
     # Maybe weights map is this is too big and doing the calculated check is more acceptable here, for brevity only
     
       #This is great as is.
       new_model = BertModel.from_pretrained(tmp_dir)
       for p1, p2 in zip(model.parameters(), new_model.parameters()):
           self.assertTrue(torch.allclose(p1, p2))
           
           
           
# Another test or separated logic
with tempfile.TemporaryDirectory() as tmp_dir:
     model.save_pretrained(tmp_dir, max_shard_size="100kB")
     self.assertEqual(os.listdir(tmp_dir), ["pytorch_model.bin"])
     # No indexing.
     self.assertEqual(os.path.getsize("pytorch_model.bin", 74_000_00)

Just as a tendency, I feel like tests with static values (hardcoded) as expected values, tend to provide more value as they tend to fail more often (calculated values move both with the code and the test, sometimes both contain the bug, masking the error because the test is still green)

They are also more readable (for instance in my example, I know that the no sharding behavior is sound for a 72k model file, and that 62 + 12 = 74 so it looks consistent). We also know for a fact that shards > 50k are tested against (in your test code I am pretty sure it does, but we can only know for sure by putting a debugger and checking the actual values)

This is more a rule of thumb, but I think it could apply quite well here.

Copy link
Member

@LysandreJik LysandreJik left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This looks good to me. Thanks for doing the heavy lifting, @sgugger!

Comment on lines +169 to +183
if isinstance(size, int):
return size
if size.upper().endswith("GIB"):
return int(size[:-3]) * (2**30)
if size.upper().endswith("MIB"):
return int(size[:-3]) * (2**20)
if size.upper().endswith("KIB"):
return int(size[:-3]) * (2**10)
if size.upper().endswith("GB"):
return int(size[:-2]) * (10**9)
if size.upper().endswith("MB"):
return int(size[:-2]) * (10**6)
if size.upper().endswith("KB"):
return int(size[:-2]) * (10**3)
raise ValueError("`size` is not in a valid format. Use an integer followed by the unit, e.g., '5GB'.")
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Love it!

Comment on lines +1364 to +1368
# Clean the folder from a previous save
for filename in os.listdir(save_directory):
full_filename = os.path.join(save_directory, filename)
if filename.startswith(WEIGHTS_NAME[:-4]) and os.path.isfile(full_filename):
os.remove(full_filename)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Cool, very nice

with tempfile.TemporaryDirectory() as tmp_dir:
with tempfile.TemporaryDirectory() as tmp_dir:
# We use the same folder for various sizes to make sure a new save erases the old checkpoint.
for max_size in ["50kB", "100kB", "200kB"]:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would be down to try with the kibibytes too (50KiB for example)

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sure!

Comment on lines +2235 to +2244
max_shard_size (`int` or `str`, *optional*, defaults to `"10GB"`):
The maximum size for a checkpoint before being sharded. Checkpoints shard will then be each of size
lower than this size. If expressed as a string, needs to be digits followed by a unit (like `"5MB"`).

<Tip warning={true}>

If a single weight of the model is bigger than `max_shard_size`, it will be in its own checkpoint shard
which will be bigger than `max_shard_size`.

</Tip>
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Love that!

@sgugger sgugger merged commit b473617 into main Mar 25, 2022
@sgugger sgugger deleted the checkpoint_sharding branch March 25, 2022 15:59
@loretoparisi
Copy link
Contributor

@sgugger is it possibile to apply sharding to a current pretrained model that is bigger than 10gb (let's say t0pp like), using

load_pretrained-> save_pretrained

To save a sharded version, and the again

load_pretrained

That should handle therefore load the sharded checkpoint?

Thanks

@sgugger
Copy link
Collaborator Author

sgugger commented Apr 7, 2022

Yes, that's exactly what you should do!
We'll also create a new branch of the t0pp checkpoint with a sharded checkpoint (we can't do it on the main branch or it would break compatibility with older versions of Transformers).

@loretoparisi
Copy link
Contributor

Hahaha amazing! Where is it?

@julien-c
Copy link
Member

julien-c commented Apr 8, 2022

@sgugger we could possibly upload the sharded checkpoint in addition to the current checkpoint in the same repo branch no?

I thought that's what we wanted to do to preserve backward compat while still upgrading those big models

@sgugger
Copy link
Collaborator Author

sgugger commented Apr 8, 2022

If you put them in the same branch, from_pretrained will only download the full checkpoint as it's the first in order of priority.

@julien-c
Copy link
Member

julien-c commented Apr 8, 2022

I see. should we consider changing this behavior? we've seen that git-branches are not super practical for those large models (cc @osanseviero )

@julien-c
Copy link
Member

julien-c commented Apr 8, 2022

(yep, i know this is going to require to call something like /api/models/xxx at the start of from_pretrained... :) )

@julien-c
Copy link
Member

julien-c commented Apr 8, 2022

or maybe we can say that we'll just push sharded models from now on (for example the bigscience models will be sharded-only) I think that's actually fine

@thomwolf
Copy link
Member

thomwolf commented Apr 8, 2022

or we can default to the shared model for the new version of transformers?

@julien-c
Copy link
Member

julien-c commented Apr 8, 2022

yep that was what i was suggesting but implementation wise is a bit more complex (and affects all model repos not just the sharded ones)

weight_map = {}
shards = {}
for idx, shard in enumerate(sharded_state_dicts):
shard_file = WEIGHTS_NAME.replace(".bin", f"-{idx+1:05d}-of-{len(sharded_state_dicts):05d}.bin")
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

FYI in datasets we start at 0, not 1. Same as TF, dask and apache beam

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

10 participants