From acdf28faae59b6dffa2bc1e649abb094d6edcec4 Mon Sep 17 00:00:00 2001 From: David Xue Date: Wed, 1 May 2024 14:50:55 -0400 Subject: [PATCH 1/3] add safetensors to model not found error for default use_safetensors=None case --- src/transformers/modeling_utils.py | 8 ++++---- tests/test_modeling_utils.py | 16 ++++++++++++++++ 2 files changed, 20 insertions(+), 4 deletions(-) diff --git a/src/transformers/modeling_utils.py b/src/transformers/modeling_utils.py index 1ed8040f88c5..6f44ec9c3818 100644 --- a/src/transformers/modeling_utils.py +++ b/src/transformers/modeling_utils.py @@ -3263,8 +3263,8 @@ def from_pretrained( ) else: raise EnvironmentError( - f"Error no file named {_add_variant(WEIGHTS_NAME, variant)}, {TF2_WEIGHTS_NAME}," - f" {TF_WEIGHTS_NAME + '.index'} or {FLAX_WEIGHTS_NAME} found in directory" + f"Error no file named {_add_variant(WEIGHTS_NAME, variant)}, {_add_variant(SAFE_WEIGHTS_NAME, variant)}," + f" {TF2_WEIGHTS_NAME}, {TF_WEIGHTS_NAME + '.index'} or {FLAX_WEIGHTS_NAME} found in directory" f" {pretrained_model_name_or_path}." ) elif os.path.isfile(os.path.join(subfolder, pretrained_model_name_or_path)): @@ -3410,8 +3410,8 @@ def from_pretrained( else: raise EnvironmentError( f"{pretrained_model_name_or_path} does not appear to have a file named" - f" {_add_variant(WEIGHTS_NAME, variant)}, {TF2_WEIGHTS_NAME}, {TF_WEIGHTS_NAME} or" - f" {FLAX_WEIGHTS_NAME}." + f" {_add_variant(WEIGHTS_NAME, variant)}, {_add_variant(SAFE_WEIGHTS_NAME, variant)}," + f" {TF2_WEIGHTS_NAME}, {TF_WEIGHTS_NAME} or {FLAX_WEIGHTS_NAME}." ) except EnvironmentError: # Raise any environment error raise by `cached_file`. It will have a helpful error message adapted diff --git a/tests/test_modeling_utils.py b/tests/test_modeling_utils.py index 16d8e9e1293d..8696f437b76f 100755 --- a/tests/test_modeling_utils.py +++ b/tests/test_modeling_utils.py @@ -1000,6 +1000,22 @@ def test_use_safetensors(self): all_downloaded_files = glob.glob(os.path.join(tmp_dir, "*", "snapshots", "*", "*", "*")) self.assertTrue(any(f.endswith("safetensors") for f in all_downloaded_files)) self.assertFalse(any(f.endswith("bin") for f in all_downloaded_files)) + + # test no model file found when use_safetensors=None (default when safetensors package available) + with self.assertRaises(OSError) as missing_model_file_error: + BertModel.from_pretrained("hf-internal-testing/config-no-model") + + self.assertTrue("does not appear to have a file named pytorch_model.bin, model.safetensors," in str(missing_model_file_error.exception)) + + with self.assertRaises(OSError) as missing_model_file_error: + with tempfile.TemporaryDirectory() as tmp_dir: + with open(os.path.join(tmp_dir, "config.json"), "w") as f: + f.write("{}") + f.close() + BertModel.from_pretrained(tmp_dir) + + self.assertTrue("Error no file named pytorch_model.bin, model.safetensors", str(missing_model_file_error.exception)) + @require_safetensors def test_safetensors_save_and_load(self): From c03da57711b2581ea6b02ff850bf99ae0e7575ea Mon Sep 17 00:00:00 2001 From: David Xue Date: Wed, 1 May 2024 16:02:16 -0400 Subject: [PATCH 2/3] format code w/ ruff --- tests/test_modeling_utils.py | 16 ++++++++++------ 1 file changed, 10 insertions(+), 6 deletions(-) diff --git a/tests/test_modeling_utils.py b/tests/test_modeling_utils.py index 8696f437b76f..8e02b62c5c71 100755 --- a/tests/test_modeling_utils.py +++ b/tests/test_modeling_utils.py @@ -1000,13 +1000,16 @@ def test_use_safetensors(self): all_downloaded_files = glob.glob(os.path.join(tmp_dir, "*", "snapshots", "*", "*", "*")) self.assertTrue(any(f.endswith("safetensors") for f in all_downloaded_files)) self.assertFalse(any(f.endswith("bin") for f in all_downloaded_files)) - + # test no model file found when use_safetensors=None (default when safetensors package available) with self.assertRaises(OSError) as missing_model_file_error: BertModel.from_pretrained("hf-internal-testing/config-no-model") - - self.assertTrue("does not appear to have a file named pytorch_model.bin, model.safetensors," in str(missing_model_file_error.exception)) - + + self.assertTrue( + "does not appear to have a file named pytorch_model.bin, model.safetensors," + in str(missing_model_file_error.exception) + ) + with self.assertRaises(OSError) as missing_model_file_error: with tempfile.TemporaryDirectory() as tmp_dir: with open(os.path.join(tmp_dir, "config.json"), "w") as f: @@ -1014,8 +1017,9 @@ def test_use_safetensors(self): f.close() BertModel.from_pretrained(tmp_dir) - self.assertTrue("Error no file named pytorch_model.bin, model.safetensors", str(missing_model_file_error.exception)) - + self.assertTrue( + "Error no file named pytorch_model.bin, model.safetensors", str(missing_model_file_error.exception) + ) @require_safetensors def test_safetensors_save_and_load(self): From 81ad2cb92808a1546eb02045cfb411f6d4d662c1 Mon Sep 17 00:00:00 2001 From: David Xue Date: Wed, 1 May 2024 16:43:08 -0400 Subject: [PATCH 3/3] fix assert true typo --- tests/test_modeling_utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/test_modeling_utils.py b/tests/test_modeling_utils.py index 8e02b62c5c71..f98e1a2a2391 100755 --- a/tests/test_modeling_utils.py +++ b/tests/test_modeling_utils.py @@ -1018,7 +1018,7 @@ def test_use_safetensors(self): BertModel.from_pretrained(tmp_dir) self.assertTrue( - "Error no file named pytorch_model.bin, model.safetensors", str(missing_model_file_error.exception) + "Error no file named pytorch_model.bin, model.safetensors" in str(missing_model_file_error.exception) ) @require_safetensors