Skip to content

Commit

Permalink
Merge pull request PaddlePaddle#282 from JunnYu/fix_ppdiffusers_from_…
Browse files Browse the repository at this point in the history
…pretrained

[PPDiffusers] update from_pretrained patch
  • Loading branch information
LokeZhou authored Nov 7, 2023
2 parents 484bfaa + a73a5ab commit bd22cf8
Showing 1 changed file with 7 additions and 21 deletions.
28 changes: 7 additions & 21 deletions ppdiffusers/ppdiffusers/patches/ppnlp_patch_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1122,17 +1122,7 @@ def _find_mismatched_keys(
raw_save_pretrained = PretrainedModel.save_pretrained

@classmethod
def from_pretrained(
cls,
pretrained_model_name_or_path,
*args,
from_hf_hub=False,
subfolder=None,
paddle_dtype=None,
from_diffusers=None,
variant=None,
**kwargs
):
def from_pretrained(cls, pretrained_model_name_or_path, *args, **kwargs):
try:
if cls.constructed_from_pretrained_config() and (
hasattr(cls, "smart_convert") or hasattr(cls, "register_load_torch_hook")
Expand All @@ -1141,26 +1131,22 @@ def from_pretrained(
cls,
pretrained_model_name_or_path,
*args,
from_hf_hub=from_hf_hub,
subfolder=subfolder,
paddle_dtype=paddle_dtype,
from_diffusers=from_diffusers,
variant=variant,
**kwargs,
)
except Exception:
pass

dtype = kwargs.pop("dtype", paddle_dtype)
# pop `from_diffusers`
kwargs.pop("from_diffusers", None)
# pop `paddle_dtype`
dtype = kwargs.pop("dtype", kwargs.pop("paddle_dtype", None))
if isinstance(dtype, paddle.dtype):
dtype = str(dtype).replace("paddle.", "")
if dtype is not None:
kwargs["dtype"] = dtype
return raw_from_pretrained(
cls,
pretrained_model_name_or_path,
*args,
from_hf_hub=from_hf_hub,
subfolder=subfolder,
dtype=dtype,
**kwargs,
)

Expand Down

0 comments on commit bd22cf8

Please sign in to comment.