diff --git a/src/diffusers/models/modeling_utils.py b/src/diffusers/models/modeling_utils.py index 30fce977797e..2694a1ad8367 100644 --- a/src/diffusers/models/modeling_utils.py +++ b/src/diffusers/models/modeling_utils.py @@ -1029,7 +1029,7 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P keep_in_fp32_modules = [] is_sharded = False - resolved_archive_file = None + resolved_model_file = None # Determine if we're loading from a directory of sharded checkpoints. sharded_metadata = None @@ -1064,7 +1064,7 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P # load model if from_flax: - resolved_archive_file = _get_model_file( + resolved_model_file = _get_model_file( pretrained_model_name_or_path, weights_name=FLAX_WEIGHTS_NAME, cache_dir=cache_dir, @@ -1082,11 +1082,11 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P # Convert the weights from .modeling_pytorch_flax_utils import load_flax_checkpoint_in_pytorch_model - model = load_flax_checkpoint_in_pytorch_model(model, resolved_archive_file) + model = load_flax_checkpoint_in_pytorch_model(model, resolved_model_file) else: # in the case it is sharded, we have already the index if is_sharded: - resolved_archive_file, sharded_metadata = _get_checkpoint_shard_files( + resolved_model_file, sharded_metadata = _get_checkpoint_shard_files( pretrained_model_name_or_path, index_file, cache_dir=cache_dir, @@ -1100,7 +1100,7 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P ) elif use_safetensors: try: - resolved_archive_file = _get_model_file( + resolved_model_file = _get_model_file( pretrained_model_name_or_path, weights_name=_add_variant(SAFETENSORS_WEIGHTS_NAME, variant), cache_dir=cache_dir, @@ -1123,8 +1123,8 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P "Defaulting to unsafe serialization. Pass `allow_pickle=False` to raise an error instead." ) - if resolved_archive_file is None and not is_sharded: - resolved_archive_file = _get_model_file( + if resolved_model_file is None and not is_sharded: + resolved_model_file = _get_model_file( pretrained_model_name_or_path, weights_name=_add_variant(WEIGHTS_NAME, variant), cache_dir=cache_dir, @@ -1139,8 +1139,8 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P dduf_entries=dduf_entries, ) - if not isinstance(resolved_archive_file, list): - resolved_archive_file = [resolved_archive_file] + if not isinstance(resolved_model_file, list): + resolved_model_file = [resolved_model_file] # set dtype to instantiate the model under: # 1. If torch_dtype is not None, we use that dtype @@ -1168,7 +1168,7 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P if not is_sharded: # Time to load the checkpoint state_dict = load_state_dict( - resolved_archive_file[0], disable_mmap=disable_mmap, dduf_entries=dduf_entries + resolved_model_file[0], disable_mmap=disable_mmap, dduf_entries=dduf_entries ) # We only fix it for non sharded checkpoints as we don't need it yet for sharded one. model._fix_state_dict_keys_on_load(state_dict) @@ -1200,7 +1200,7 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P ) = cls._load_pretrained_model( model, state_dict, - resolved_archive_file, + resolved_model_file, pretrained_model_name_or_path, loaded_keys, ignore_mismatched_sizes=ignore_mismatched_sizes, @@ -1361,7 +1361,7 @@ def _load_pretrained_model( cls, model, state_dict: OrderedDict, - resolved_archive_file: List[str], + resolved_model_file: List[str], pretrained_model_name_or_path: Union[str, os.PathLike], loaded_keys: List[str], ignore_mismatched_sizes: bool = False, @@ -1415,13 +1415,13 @@ def _load_pretrained_model( if state_dict is not None: # load_state_dict will manage the case where we pass a dict instead of a file - # if state dict is not None, it means that we don't need to read the files from resolved_archive_file also - resolved_archive_file = [state_dict] + # if state dict is not None, it means that we don't need to read the files from resolved_model_file also + resolved_model_file = [state_dict] - if len(resolved_archive_file) > 1: - resolved_archive_file = logging.tqdm(resolved_archive_file, desc="Loading checkpoint shards") + if len(resolved_model_file) > 1: + resolved_model_file = logging.tqdm(resolved_model_file, desc="Loading checkpoint shards") - for shard_file in resolved_archive_file: + for shard_file in resolved_model_file: state_dict = load_state_dict(shard_file, dduf_entries=dduf_entries) def _find_mismatched_keys(