diff --git a/bindings/python/convert.py b/bindings/python/convert.py index f537df28..5813abd9 100644 --- a/bindings/python/convert.py +++ b/bindings/python/convert.py @@ -161,8 +161,11 @@ def check_final_model(model_id: str, folder: str): shutil.copy(config, os.path.join(folder, "config.json")) config = AutoConfig.from_pretrained(folder) - _, (pt_model, pt_infos) = infer_framework_load_model(model_id, config, output_loading_info=True) - _, (sf_model, sf_infos) = infer_framework_load_model(folder, config, output_loading_info=True) + import transformers + + class_ = getattr(transformers, config.architectures[0]) + (pt_model, pt_infos) = class_.from_pretrained(folder, output_loading_info=True) + (sf_model, sf_infos) = class_.from_pretrained(folder, output_loading_info=True) if pt_infos != sf_infos: error_string = create_diff(pt_infos, sf_infos) @@ -199,7 +202,19 @@ def check_final_model(model_id: str, folder: str): sf_model = sf_model.cuda() kwargs = {k: v.cuda() for k, v in kwargs.items()} - pt_logits = pt_model(**kwargs)[0] + try: + pt_logits = pt_model(**kwargs)[0] + except Exception as e: + try: + # Musicgen special exception. + decoder_input_ids = torch.ones((input_ids.shape[0] * pt_model.decoder.num_codebooks, 1), dtype=torch.long) + if torch.cuda.is_available(): + decoder_input_ids = decoder_input_ids.cuda() + + kwargs["decoder_input_ids"] = decoder_input_ids + pt_logits = pt_model(**kwargs)[0] + except Exception: + raise e sf_logits = sf_model(**kwargs)[0] torch.testing.assert_close(sf_logits, pt_logits)