1616
1717import inspect
1818import itertools
19+ import json
1920import os
2021import re
2122from collections import OrderedDict
2526
2627import safetensors
2728import torch
28- from huggingface_hub import create_repo
29+ from huggingface_hub import create_repo , split_torch_state_dict_into_shards
2930from huggingface_hub .utils import validate_hf_hub_args
3031from torch import Tensor , nn
3132
3233from .. import __version__
3334from ..utils import (
3435 CONFIG_NAME ,
3536 FLAX_WEIGHTS_NAME ,
37+ SAFE_WEIGHTS_INDEX_NAME ,
3638 SAFETENSORS_WEIGHTS_NAME ,
39+ WEIGHTS_INDEX_NAME ,
3740 WEIGHTS_NAME ,
3841 _add_variant ,
42+ _get_checkpoint_shard_files ,
3943 _get_model_file ,
4044 deprecate ,
4145 is_accelerate_available ,
4953)
5054from .model_loading_utils import (
5155 _determine_device_map ,
56+ _fetch_index_file ,
5257 _load_state_dict_into_model ,
5358 load_model_dict_into_meta ,
5459 load_state_dict ,
5762
5863logger = logging .get_logger (__name__ )
5964
65+ _REGEX_SHARD = re .compile (r"(.*?)-\d{5}-of-\d{5}" )
66+
6067
6168if is_torch_version (">=" , "1.9.0" ):
6269 _LOW_CPU_MEM_USAGE_DEFAULT = True
@@ -263,6 +270,7 @@ def save_pretrained(
263270 save_function : Optional [Callable ] = None ,
264271 safe_serialization : bool = True ,
265272 variant : Optional [str ] = None ,
273+ max_shard_size : Union [int , str ] = "5GB" ,
266274 push_to_hub : bool = False ,
267275 ** kwargs ,
268276 ):
@@ -285,6 +293,10 @@ def save_pretrained(
285293 Whether to save the model using `safetensors` or the traditional PyTorch way with `pickle`.
286294 variant (`str`, *optional*):
287295 If specified, weights are saved in the format `pytorch_model.<variant>.bin`.
296+ max_shard_size (`int` or `str`, defaults to `"5GB"`):
297+ The maximum size for a checkpoint before being sharded. Checkpoints shard will then be each of size
298+ lower than this size. If expressed as a string, needs to be digits followed by a unit (like `"5GB"`).
299+ If expressed as an integer, the unit is bytes.
288300 push_to_hub (`bool`, *optional*, defaults to `False`):
289301 Whether or not to push your model to the Hugging Face Hub after saving it. You can specify the
290302 repository you want to push to with `repo_id` (will default to the name of `save_directory` in your
@@ -296,6 +308,14 @@ def save_pretrained(
296308 logger .error (f"Provided path ({ save_directory } ) should be a directory, not a file" )
297309 return
298310
311+ weights_name = SAFETENSORS_WEIGHTS_NAME if safe_serialization else WEIGHTS_NAME
312+ weights_name = _add_variant (weights_name , variant )
313+ weight_name_split = weights_name .split ("." )
314+ if len (weight_name_split ) in [2 , 3 ]:
315+ weights_name_pattern = weight_name_split [0 ] + "{suffix}." + "." .join (weight_name_split [1 :])
316+ else :
317+ raise ValueError (f"Invalid { weights_name } provided." )
318+
299319 os .makedirs (save_directory , exist_ok = True )
300320
301321 if push_to_hub :
@@ -317,18 +337,58 @@ def save_pretrained(
317337 # Save the model
318338 state_dict = model_to_save .state_dict ()
319339
320- weights_name = SAFETENSORS_WEIGHTS_NAME if safe_serialization else WEIGHTS_NAME
321- weights_name = _add_variant (weights_name , variant )
322-
323340 # Save the model
324- if safe_serialization :
325- safetensors .torch .save_file (
326- state_dict , Path (save_directory , weights_name ).as_posix (), metadata = {"format" : "pt" }
341+ state_dict_split = split_torch_state_dict_into_shards (
342+ state_dict , max_shard_size = max_shard_size , filename_pattern = weights_name_pattern
343+ )
344+
345+ # Clean the folder from a previous save
346+ if is_main_process :
347+ for filename in os .listdir (save_directory ):
348+ if filename in state_dict_split .filename_to_tensors .keys ():
349+ continue
350+ full_filename = os .path .join (save_directory , filename )
351+ if not os .path .isfile (full_filename ):
352+ continue
353+ weights_without_ext = weights_name_pattern .replace (".bin" , "" ).replace (".safetensors" , "" )
354+ weights_without_ext = weights_without_ext .replace ("{suffix}" , "" )
355+ filename_without_ext = filename .replace (".bin" , "" ).replace (".safetensors" , "" )
356+ # make sure that file to be deleted matches format of sharded file, e.g. pytorch_model-00001-of-00005
357+ if (
358+ filename .startswith (weights_without_ext )
359+ and _REGEX_SHARD .fullmatch (filename_without_ext ) is not None
360+ ):
361+ os .remove (full_filename )
362+
363+ for filename , tensors in state_dict_split .filename_to_tensors .items ():
364+ shard = {tensor : state_dict [tensor ] for tensor in tensors }
365+ filepath = os .path .join (save_directory , filename )
366+ if safe_serialization :
367+ # At some point we will need to deal better with save_function (used for TPU and other distributed
368+ # joyfulness), but for now this enough.
369+ safetensors .torch .save_file (shard , filepath , metadata = {"format" : "pt" })
370+ else :
371+ torch .save (shard , filepath )
372+
373+ if state_dict_split .is_sharded :
374+ index = {
375+ "metadata" : state_dict_split .metadata ,
376+ "weight_map" : state_dict_split .tensor_to_filename ,
377+ }
378+ save_index_file = SAFE_WEIGHTS_INDEX_NAME if safe_serialization else WEIGHTS_INDEX_NAME
379+ save_index_file = os .path .join (save_directory , _add_variant (save_index_file , variant ))
380+ # Save the index as well
381+ with open (save_index_file , "w" , encoding = "utf-8" ) as f :
382+ content = json .dumps (index , indent = 2 , sort_keys = True ) + "\n "
383+ f .write (content )
384+ logger .info (
385+ f"The model is bigger than the maximum size per checkpoint ({ max_shard_size } ) and is going to be "
386+ f"split in { len (state_dict_split .filename_to_tensors )} checkpoint shards. You can find where each parameters has been saved in the "
387+ f"index located at { save_index_file } ."
327388 )
328389 else :
329- torch .save (state_dict , Path (save_directory , weights_name ).as_posix ())
330-
331- logger .info (f"Model weights saved in { Path (save_directory , weights_name ).as_posix ()} " )
390+ path_to_weights = os .path .join (save_directory , weights_name )
391+ logger .info (f"Model weights saved in { path_to_weights } " )
332392
333393 if push_to_hub :
334394 # Create a new empty model card and eventually tag it
@@ -566,6 +626,32 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P
566626 ** kwargs ,
567627 )
568628
629+ # Determine if we're loading from a directory of sharded checkpoints.
630+ is_sharded = False
631+ index_file = None
632+ is_local = os .path .isdir (pretrained_model_name_or_path )
633+ index_file = _fetch_index_file (
634+ is_local = is_local ,
635+ pretrained_model_name_or_path = pretrained_model_name_or_path ,
636+ subfolder = subfolder or "" ,
637+ use_safetensors = use_safetensors ,
638+ cache_dir = cache_dir ,
639+ variant = variant ,
640+ force_download = force_download ,
641+ resume_download = resume_download ,
642+ proxies = proxies ,
643+ local_files_only = local_files_only ,
644+ token = token ,
645+ revision = revision ,
646+ user_agent = user_agent ,
647+ commit_hash = commit_hash ,
648+ )
649+ if index_file is not None and index_file .is_file ():
650+ is_sharded = True
651+
652+ if is_sharded and from_flax :
653+ raise ValueError ("Loading of sharded checkpoints is not supported when `from_flax=True`." )
654+
569655 # load model
570656 model_file = None
571657 if from_flax :
@@ -590,7 +676,21 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P
590676
591677 model = load_flax_checkpoint_in_pytorch_model (model , model_file )
592678 else :
593- if use_safetensors :
679+ if is_sharded :
680+ sharded_ckpt_cached_folder , sharded_metadata = _get_checkpoint_shard_files (
681+ pretrained_model_name_or_path ,
682+ index_file ,
683+ cache_dir = cache_dir ,
684+ proxies = proxies ,
685+ resume_download = resume_download ,
686+ local_files_only = local_files_only ,
687+ token = token ,
688+ user_agent = user_agent ,
689+ revision = revision ,
690+ subfolder = subfolder or "" ,
691+ )
692+
693+ elif use_safetensors and not is_sharded :
594694 try :
595695 model_file = _get_model_file (
596696 pretrained_model_name_or_path ,
@@ -606,11 +706,16 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P
606706 user_agent = user_agent ,
607707 commit_hash = commit_hash ,
608708 )
709+
609710 except IOError as e :
711+ logger .error (f"An error occurred while trying to fetch { pretrained_model_name_or_path } : { e } " )
610712 if not allow_pickle :
611- raise e
612- pass
613- if model_file is None :
713+ raise
714+ logger .warning (
715+ "Defaulting to unsafe serialization. Pass `allow_pickle=False` to raise an error instead."
716+ )
717+
718+ if model_file is None and not is_sharded :
614719 model_file = _get_model_file (
615720 pretrained_model_name_or_path ,
616721 weights_name = _add_variant (WEIGHTS_NAME , variant ),
@@ -632,7 +737,7 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P
632737 model = cls .from_config (config , ** unused_kwargs )
633738
634739 # if device_map is None, load the state dict and move the params from meta device to the cpu
635- if device_map is None :
740+ if device_map is None and not is_sharded :
636741 param_device = "cpu"
637742 state_dict = load_state_dict (model_file , variant = variant )
638743 model ._convert_deprecated_attention_blocks (state_dict )
@@ -670,7 +775,7 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P
670775 try :
671776 accelerate .load_checkpoint_and_dispatch (
672777 model ,
673- model_file ,
778+ model_file if not is_sharded else sharded_ckpt_cached_folder ,
674779 device_map ,
675780 max_memory = max_memory ,
676781 offload_folder = offload_folder ,
0 commit comments