Skip to content

Commit

Permalink
rename resolved_archive_file to resolved_model_file
Browse files Browse the repository at this point in the history
  • Loading branch information
SunMarc committed Feb 14, 2025
1 parent abd3a91 commit d1c4a61
Showing 1 changed file with 17 additions and 17 deletions.
34 changes: 17 additions & 17 deletions src/diffusers/models/modeling_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1029,7 +1029,7 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P
keep_in_fp32_modules = []

is_sharded = False
resolved_archive_file = None
resolved_model_file = None

# Determine if we're loading from a directory of sharded checkpoints.
sharded_metadata = None
Expand Down Expand Up @@ -1064,7 +1064,7 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P

# load model
if from_flax:
resolved_archive_file = _get_model_file(
resolved_model_file = _get_model_file(
pretrained_model_name_or_path,
weights_name=FLAX_WEIGHTS_NAME,
cache_dir=cache_dir,
Expand All @@ -1082,11 +1082,11 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P
# Convert the weights
from .modeling_pytorch_flax_utils import load_flax_checkpoint_in_pytorch_model

model = load_flax_checkpoint_in_pytorch_model(model, resolved_archive_file)
model = load_flax_checkpoint_in_pytorch_model(model, resolved_model_file)
else:
# in the case it is sharded, we have already the index
if is_sharded:
resolved_archive_file, sharded_metadata = _get_checkpoint_shard_files(
resolved_model_file, sharded_metadata = _get_checkpoint_shard_files(
pretrained_model_name_or_path,
index_file,
cache_dir=cache_dir,
Expand All @@ -1100,7 +1100,7 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P
)
elif use_safetensors:
try:
resolved_archive_file = _get_model_file(
resolved_model_file = _get_model_file(
pretrained_model_name_or_path,
weights_name=_add_variant(SAFETENSORS_WEIGHTS_NAME, variant),
cache_dir=cache_dir,
Expand All @@ -1123,8 +1123,8 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P
"Defaulting to unsafe serialization. Pass `allow_pickle=False` to raise an error instead."
)

if resolved_archive_file is None and not is_sharded:
resolved_archive_file = _get_model_file(
if resolved_model_file is None and not is_sharded:
resolved_model_file = _get_model_file(
pretrained_model_name_or_path,
weights_name=_add_variant(WEIGHTS_NAME, variant),
cache_dir=cache_dir,
Expand All @@ -1139,8 +1139,8 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P
dduf_entries=dduf_entries,
)

if not isinstance(resolved_archive_file, list):
resolved_archive_file = [resolved_archive_file]
if not isinstance(resolved_model_file, list):
resolved_model_file = [resolved_model_file]

# set dtype to instantiate the model under:
# 1. If torch_dtype is not None, we use that dtype
Expand Down Expand Up @@ -1168,7 +1168,7 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P
if not is_sharded:
# Time to load the checkpoint
state_dict = load_state_dict(
resolved_archive_file[0], disable_mmap=disable_mmap, dduf_entries=dduf_entries
resolved_model_file[0], disable_mmap=disable_mmap, dduf_entries=dduf_entries
)
# We only fix it for non sharded checkpoints as we don't need it yet for sharded one.
model._fix_state_dict_keys_on_load(state_dict)
Expand Down Expand Up @@ -1200,7 +1200,7 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P
) = cls._load_pretrained_model(
model,
state_dict,
resolved_archive_file,
resolved_model_file,
pretrained_model_name_or_path,
loaded_keys,
ignore_mismatched_sizes=ignore_mismatched_sizes,
Expand Down Expand Up @@ -1361,7 +1361,7 @@ def _load_pretrained_model(
cls,
model,
state_dict: OrderedDict,
resolved_archive_file: List[str],
resolved_model_file: List[str],
pretrained_model_name_or_path: Union[str, os.PathLike],
loaded_keys: List[str],
ignore_mismatched_sizes: bool = False,
Expand Down Expand Up @@ -1415,13 +1415,13 @@ def _load_pretrained_model(

if state_dict is not None:
# load_state_dict will manage the case where we pass a dict instead of a file
# if state dict is not None, it means that we don't need to read the files from resolved_archive_file also
resolved_archive_file = [state_dict]
# if state dict is not None, it means that we don't need to read the files from resolved_model_file also
resolved_model_file = [state_dict]

if len(resolved_archive_file) > 1:
resolved_archive_file = logging.tqdm(resolved_archive_file, desc="Loading checkpoint shards")
if len(resolved_model_file) > 1:
resolved_model_file = logging.tqdm(resolved_model_file, desc="Loading checkpoint shards")

for shard_file in resolved_archive_file:
for shard_file in resolved_model_file:
state_dict = load_state_dict(shard_file, dduf_entries=dduf_entries)

def _find_mismatched_keys(
Expand Down

0 comments on commit d1c4a61

Please sign in to comment.