From 551962700ad48e532587a8b1ad32ca820ddf2356 Mon Sep 17 00:00:00 2001 From: Lee Miller Date: Sun, 10 Nov 2024 13:53:08 -0700 Subject: [PATCH 1/3] reworking for simpler caching logic --- tests/test_wordllama.py | 507 +++++++++++++++++++++++++++++++--------- wordllama/wordllama.py | 378 +++++++++++++++--------------- 2 files changed, 584 insertions(+), 301 deletions(-) diff --git a/tests/test_wordllama.py b/tests/test_wordllama.py index 8329e2d..985185d 100644 --- a/tests/test_wordllama.py +++ b/tests/test_wordllama.py @@ -1,5 +1,5 @@ import unittest -from unittest.mock import patch, MagicMock, mock_open, create_autospec +from unittest.mock import patch, MagicMock, mock_open, call from pathlib import Path import numpy as np from tokenizers import Tokenizer @@ -15,18 +15,19 @@ class TestWordLlama(unittest.TestCase): - def setUp(self): self.config_name = "l2_supercat" self.dim = 256 self.binary = False self.trunc_dim = None - tokenizer_inference_config = create_autospec(TokenizerInferenceConfig) + # Mock TokenizerInferenceConfig + tokenizer_inference_config = MagicMock(spec=TokenizerInferenceConfig) tokenizer_inference_config.use_local_config = True tokenizer_inference_config.config_filename = "tokenizer_config.json" - model_config = create_autospec(WordLlamaModel) + # Mock WordLlamaModel + model_config = MagicMock(spec=WordLlamaModel) model_config.n_vocab = 32000 model_config.dim = 4096 model_config.n_layers = 12 @@ -34,7 +35,8 @@ def setUp(self): model_config.hf_model_id = "dummy-model-id" model_config.pad_token = 0 - tokenizer_config = create_autospec(TokenizerConfig) + # Mock TokenizerConfig + tokenizer_config = MagicMock(spec=TokenizerConfig) tokenizer_config.inference = tokenizer_inference_config tokenizer_config.return_tensors = True tokenizer_config.return_attention_mask = True @@ -43,14 +45,17 @@ def setUp(self): tokenizer_config.truncation = True tokenizer_config.add_special_tokens = True - training_config = create_autospec(TrainingConfig) + # Mock TrainingConfig + training_config = MagicMock(spec=TrainingConfig) training_config.learning_rate = 0.001 training_config.batch_size = 32 training_config.epochs = 10 - matryoshka_config = create_autospec(MatryoshkaConfig) + # Mock MatryoshkaConfig + matryoshka_config = MagicMock(spec=MatryoshkaConfig) matryoshka_config.dims = [64, 128, 256, 512, 1024] + # Assemble WordLlamaConfig self.config = WordLlamaConfig( model=model_config, tokenizer=tokenizer_config, @@ -58,140 +63,416 @@ def setUp(self): matryoshka=matryoshka_config, ) - @patch("wordllama.wordllama.Path.open", new_callable=mock_open) - @patch("wordllama.wordllama.Path.exists", autospec=True) - @patch("wordllama.wordllama.Path.mkdir", autospec=True) @patch("wordllama.wordllama.requests.get", autospec=True) - def test_download_file_from_hf( - self, mock_get, mock_mkdir, mock_exists, mock_file_open + @patch("wordllama.wordllama.Path.mkdir", autospec=True) + @patch("wordllama.wordllama.Path.exists", autospec=True) + @patch("wordllama.wordllama.Path.open", new_callable=mock_open) + def test_resolve_file_downloads_if_not_found( + self, mock_file_open, mock_exists, mock_mkdir, mock_get ): - mock_exists.return_value = False + """ + Test that resolve_file downloads the file from Hugging Face + when it does not exist in project root or cache. + """ + # Setup mocks + # First, project_root_path.exists() returns False + # Then, cache_path.exists() returns False + # Therefore, it should attempt to download + mock_exists.side_effect = [False, False] + + # Mock the GET request mock_response = MagicMock() mock_response.iter_content.return_value = [b"chunk1", b"chunk2"] mock_response.raise_for_status = MagicMock() mock_get.return_value = mock_response - WordLlama.download_file_from_hf( - repo_id="dummy-repo", - filename="dummy-file", + # Call resolve_file for weights + weights_path = WordLlama.resolve_file( + config_name=self.config_name, + dim=self.dim, + binary=self.binary, + file_type="weights", cache_dir=Path("/dummy/cache"), - force_download=True, - token="dummy-token", + disable_download=False, ) + + # Assert that the file was attempted to be downloaded mock_get.assert_called_once_with( - "https://huggingface.co/dummy-repo/resolve/main/dummy-file", - headers={"Authorization": "Bearer dummy-token"}, + "https://huggingface.co/dleemiller/word-llama-l2-supercat/resolve/main/l2_supercat_256.safetensors", stream=True, ) - mock_file_open.assert_called_once_with("wb") - mock_file_open().write.assert_any_call(b"chunk1") - mock_file_open().write.assert_any_call(b"chunk2") - @patch( - "wordllama.wordllama.WordLlama.download_file_from_hf", - return_value=Path("/dummy/cache/l2_supercat_256.safetensors"), - ) - @patch("wordllama.wordllama.Path.exists", side_effect=[False, True, True]) - def test_check_and_download_model(self, mock_exists, mock_download): - weights_file_path = WordLlama.check_and_download_model( - config_name=self.config_name, - dim=self.dim, - binary=self.binary, - weights_dir=Path("/dummy/weights"), - cache_dir=Path("/dummy/cache"), + # Assert that mkdir was called to create the cache directory + mock_mkdir.assert_called_once_with( + Path("/dummy/cache/weights"), parents=True, exist_ok=True ) + + # Assert that the file was written to cache_path + mock_file_open.assert_called_once_with("wb") + handle = mock_file_open() + handle.write.assert_has_calls([call(b"chunk1"), call(b"chunk2")]) + + # Assert the returned path is correct self.assertEqual( - weights_file_path, Path("/dummy/cache/weights/l2_supercat_256.safetensors") + weights_path, Path("/dummy/cache/weights/l2_supercat_256.safetensors") ) - @patch( - "wordllama.wordllama.WordLlama.download_file_from_hf", - return_value=Path("/dummy/cache/tokenizers/tokenizer_config.json"), - ) - @patch("wordllama.wordllama.Path.exists", side_effect=[False, True, True]) - def test_check_and_download_tokenizer(self, mock_exists, mock_download): - tokenizer_file_path = WordLlama.check_and_download_tokenizer( - config_name=self.config_name + @patch.object(WordLlama, "resolve_file", autospec=True) + def test_load_with_default_cache_dir(self, mock_resolve_file): + """ + Test that load uses the default cache directory when cache_dir is not provided. + """ + # Setup mock for resolve_file + default_cache_dir = WordLlama.DEFAULT_CACHE_DIR + weights_path = default_cache_dir / "weights" / "l2_supercat_256.safetensors" + tokenizer_path = ( + default_cache_dir / "tokenizers" / "l2_supercat_tokenizer_config.json" ) - self.assertEqual( - tokenizer_file_path, Path("/dummy/cache/tokenizers/tokenizer_config.json") + mock_resolve_file.side_effect = [weights_path, tokenizer_path] + + # Mock tokenizer and weights loading + with patch( + "wordllama.wordllama.WordLlama.load_tokenizer", + return_value=MagicMock(spec=Tokenizer), + ) as mock_load_tokenizer, patch( + "wordllama.wordllama.safe_open", autospec=True + ) as mock_safe_open: + # Mock the tensor returned by safe_open + mock_tensor = MagicMock() + mock_tensor.__getitem__.return_value = np.random.rand(256, 4096) + mock_safe_open.return_value.__enter__.return_value.get_tensor.return_value = ( + mock_tensor + ) + + # Call load without specifying cache_dir + model = WordLlama.load( + config=self.config, + binary=self.binary, + dim=self.dim, + trunc_dim=self.trunc_dim, + cache_dir=default_cache_dir, + ) + + # Assert resolve_file was called twice: once for weights, once for tokenizer + expected_calls = [ + call( + config_name=self.config, + dim=self.dim, + binary=self.binary, + file_type="weights", + cache_dir=default_cache_dir, + disable_download=False, + ), + call( + config_name=self.config, + dim=self.dim, + binary=False, + file_type="tokenizer", + cache_dir=default_cache_dir, + disable_download=False, + ), + ] + mock_resolve_file.assert_has_calls(expected_calls, any_order=False) + self.assertEqual(mock_resolve_file.call_count, 2) + + # Assert load_tokenizer was called with correct path + mock_load_tokenizer.assert_called_once_with( + tokenizer_path, + self.config, + ) + + # Assert safe_open was called with the weights path + mock_safe_open.assert_called_once_with( + weights_path, + framework="np", + device="cpu", + ) + + # Assert the returned model is an instance of WordLlamaInference + self.assertIsInstance(model, WordLlamaInference) + + @patch.object(WordLlama, "resolve_file", autospec=True) + def test_load_with_custom_cache_dir(self, mock_resolve_file): + """ + Test that load correctly handles various custom cache_dir inputs. + """ + # Define different cache_dir inputs + cache_dirs = { + "tilde": "~/tmp_cache", + "relative": "tmp", + "relative_dot": "./tmp", + "absolute": "/tmp/cache_dir", + } + + # Expected resolved paths + expected_resolved_dirs = { + "tilde": Path("~/tmp_cache").expanduser(), + "relative": Path("tmp").resolve(), + "relative_dot": Path("./tmp").resolve(), + "absolute": Path("/tmp/cache_dir"), + } + + for key, cache_dir_input in cache_dirs.items(): + with self.subTest(cache_dir=key): + # Reset mocks + mock_resolve_file.reset_mock() + + # Setup mock for resolve_file + weights_path = ( + expected_resolved_dirs[key] + / "weights" + / "l2_supercat_256.safetensors" + ) + tokenizer_path = ( + expected_resolved_dirs[key] + / "tokenizers" + / "l2_supercat_tokenizer_config.json" + ) + mock_resolve_file.side_effect = [weights_path, tokenizer_path] + + # Mock tokenizer and weights loading + with patch( + "wordllama.wordllama.WordLlama.load_tokenizer", + return_value=MagicMock(spec=Tokenizer), + ) as mock_load_tokenizer, patch( + "wordllama.wordllama.safe_open", autospec=True + ) as mock_safe_open: + # Mock the tensor returned by safe_open + mock_tensor = MagicMock() + mock_tensor.__getitem__.return_value = np.random.rand(256, 4096) + mock_safe_open.return_value.__enter__.return_value.get_tensor.return_value = ( + mock_tensor + ) + + # Call load with custom cache_dir + model = WordLlama.load( + config=self.config_name, + cache_dir=cache_dir_input, + binary=self.binary, + dim=self.dim, + trunc_dim=self.trunc_dim, + ) + + # Assert resolve_file was called twice with the correct cache_dir + expected_calls = [ + call( + # WordLlama, + config_name=self.config_name, + dim=self.dim, + binary=self.binary, + file_type="weights", + cache_dir=expected_resolved_dirs[key], + disable_download=False, + ), + call( + # WordLlama, + config_name=self.config_name, + dim=self.dim, + binary=False, + file_type="tokenizer", + cache_dir=expected_resolved_dirs[key], + disable_download=False, + ), + ] + mock_resolve_file.assert_has_calls(expected_calls, any_order=False) + self.assertEqual(mock_resolve_file.call_count, 2) + + # Assert load_tokenizer was called with correct path + mock_load_tokenizer.assert_called_once_with( + tokenizer_path, + self.config, + ) + + # Assert safe_open was called with the weights path + mock_safe_open.assert_called_once_with( + weights_path, + framework="np", + device="cpu", + ) + + # Assert the returned model is an instance of WordLlamaInference + self.assertIsInstance(model, WordLlamaInference) + + @patch.object(WordLlama, "resolve_file", autospec=True) + def test_load_with_disable_download(self, mock_resolve_file): + """ + Test that load raises FileNotFoundError when files are missing and downloads are disabled. + """ + # Setup mocks to simulate files not existing and downloads disabled + mock_resolve_file.side_effect = FileNotFoundError("File not found") + + # Call load with disable_download=True and expect FileNotFoundError + with self.assertRaises(FileNotFoundError): + WordLlama.load( + config=self.config, + cache_dir=Path("/dummy/cache"), + binary=self.binary, + dim=self.dim, + trunc_dim=self.trunc_dim, + disable_download=True, + ) + + # Assert resolve_file was called twice: once for weights, once for tokenizer + expected_calls = [ + call( + config_name=self.config, + dim=self.dim, + binary=self.binary, + file_type="weights", + cache_dir=Path("/dummy/cache"), + disable_download=True, + ) + ] + mock_resolve_file.assert_has_calls(expected_calls, any_order=False) + self.assertEqual(mock_resolve_file.call_count, 1) + + @patch.object(WordLlama, "resolve_file", autospec=True) + def test_load_with_truncated_dimension(self, mock_resolve_file): + """ + Test that load correctly handles trunc_dim parameter. + """ + # Setup mock for resolve_file + weights_path = Path("/dummy/cache/weights/l2_supercat_256.safetensors") + tokenizer_path = Path( + "/dummy/cache/tokenizers/l2_supercat_tokenizer_config.json" ) + mock_resolve_file.side_effect = [weights_path, tokenizer_path] + + # Mock tokenizer and weights loading + with patch( + "wordllama.wordllama.WordLlama.load_tokenizer", + return_value=MagicMock(spec=Tokenizer), + ) as mock_load_tokenizer, patch( + "wordllama.wordllama.safe_open", autospec=True + ) as mock_safe_open: + # Mock the tensor returned by safe_open + mock_tensor = MagicMock() + mock_tensor.__getitem__.return_value = np.random.rand(256, 4096) + mock_safe_open.return_value.__enter__.return_value.get_tensor.return_value = ( + mock_tensor + ) + + # Call load with trunc_dim + model = WordLlama.load( + config=self.config, + cache_dir=Path("/dummy/cache"), + binary=self.binary, + dim=self.dim, + trunc_dim=128, + ) + + # Assert resolve_file was called twice + expected_calls = [ + call( + config_name=self.config, + dim=self.dim, + binary=self.binary, + file_type="weights", + cache_dir=Path("/dummy/cache"), + disable_download=False, + ), + call( + config_name=self.config, + dim=self.dim, + binary=False, + file_type="tokenizer", + cache_dir=Path("/dummy/cache"), + disable_download=False, + ), + ] + mock_resolve_file.assert_has_calls(expected_calls, any_order=False) + self.assertEqual(mock_resolve_file.call_count, 2) + + # Assert load_tokenizer was called with correct path + mock_load_tokenizer.assert_called_once_with( + tokenizer_path, + self.config, + ) + + # Assert safe_open was called with the weights path + mock_safe_open.assert_called_once_with( + weights_path, + framework="np", + device="cpu", + ) + + # Assert the returned model is an instance of WordLlamaInference + self.assertIsInstance(model, WordLlamaInference) + + # Assert that the embedding was truncated + mock_tensor.__getitem__.assert_called_with((slice(None), slice(None, 128))) @patch( "wordllama.wordllama.Tokenizer.from_pretrained", return_value=MagicMock(spec=Tokenizer), ) - @patch( - "wordllama.wordllama.tokenizer_from_file", - return_value=MagicMock(spec=Tokenizer), - ) - @patch("wordllama.wordllama.Path.exists", return_value=True) - def test_load_tokenizer( - self, mock_exists, mock_tokenizer_from_file, mock_from_pretrained - ): - tokenizer = WordLlama.load_tokenizer( - Path("/dummy/cache/tokenizers/tokenizer_config.json"), self.config - ) - mock_tokenizer_from_file.assert_called_once_with( - Path("/dummy/cache/tokenizers/tokenizer_config.json") + @patch.object(WordLlama, "resolve_file", autospec=True) + def test_load_tokenizer_fallback(self, mock_resolve_file, mock_from_pretrained): + """ + Test that load_tokenizer falls back to Hugging Face if local config is not found. + """ + # Setup mocks + # First call for weights, second call for tokenizer + weights_path = Path("/dummy/cache/weights/l2_supercat_256.safetensors") + tokenizer_path = Path( + "/dummy/cache/tokenizers/l2_supercat_tokenizer_config.json" ) - mock_from_pretrained.assert_not_called() - self.assertIsInstance(tokenizer, Tokenizer) + mock_resolve_file.side_effect = [weights_path, tokenizer_path] - @patch( - "wordllama.wordllama.WordLlama.check_and_download_model", - return_value=Path("/dummy/cache/l2_supercat_256.safetensors"), - ) - @patch( - "wordllama.wordllama.WordLlama.check_and_download_tokenizer", - return_value=Path("/dummy/cache/tokenizers/l2_supercat_tokenizer_config.json"), - ) - @patch( - "wordllama.wordllama.WordLlama.load_tokenizer", - return_value=MagicMock(spec=Tokenizer), - ) - @patch("wordllama.wordllama.safe_open", autospec=True) - def test_load( - self, - mock_safe_open, - mock_load_tokenizer, - mock_check_tokenizer, - mock_check_model, - ): - mock_safe_open.return_value.__enter__.return_value.get_tensor.return_value = ( - np.random.rand(256, 4096) - ) + # Simulate tokenizer config does not exist by patching Path.exists + with patch( + "wordllama.wordllama.Path.exists", side_effect=[False, False] + ), patch("wordllama.wordllama.safe_open", autospec=True) as mock_safe_open: + # Mock the tensor returned by safe_open + mock_tensor = MagicMock() + mock_tensor.__getitem__.return_value = np.random.rand(256, 4096) + mock_safe_open.return_value.__enter__.return_value.get_tensor.return_value = ( + mock_tensor + ) - model = WordLlama.load( - config=self.config_name, - weights_dir=Path("/dummy/weights"), - cache_dir=Path("/dummy/cache"), - binary=self.binary, - dim=self.dim, - trunc_dim=self.trunc_dim, - ) + # Call load + model = WordLlama.load( + config=self.config, + cache_dir=Path("/dummy/cache"), + binary=self.binary, + dim=self.dim, + trunc_dim=self.trunc_dim, + ) - self.assertIsInstance(model, WordLlamaInference) - mock_check_model.assert_called_once_with( - config_name="l2_supercat", - dim=256, - binary=False, - weights_dir=Path("/dummy/weights"), - cache_dir=Path("/dummy/cache"), - disable_download=False, - ) - mock_check_tokenizer.assert_called_once_with( - config_name="l2_supercat", - cache_dir=Path('/dummy/cache'), - disable_download=False, - ) - mock_load_tokenizer.assert_called_once() - mock_safe_open.assert_called_once_with( - Path("/dummy/cache/l2_supercat_256.safetensors"), - framework="np", - device="cpu", - ) - self.assertIsInstance(model, WordLlamaInference) + # Assert resolve_file was called twice: weights and tokenizer + expected_calls = [ + call( + config_name=self.config, + dim=self.dim, + binary=self.binary, + file_type="weights", + cache_dir=Path("/dummy/cache"), + disable_download=False, + ), + call( + config_name=self.config, + dim=self.dim, + binary=False, + file_type="tokenizer", + cache_dir=Path("/dummy/cache"), + disable_download=False, + ), + ] + mock_resolve_file.assert_has_calls(expected_calls, any_order=False) + self.assertEqual(mock_resolve_file.call_count, 2) + + # Assert Tokenizer.from_pretrained was called since local config was not found + mock_from_pretrained.assert_called_once_with("dummy-model-id") + + # Assert safe_open was called with the weights path + mock_safe_open.assert_called_once_with( + weights_path, + framework="np", + device="cpu", + ) + + # Assert the returned model is an instance of WordLlamaInference + self.assertIsInstance(model, WordLlamaInference) if __name__ == "__main__": diff --git a/wordllama/wordllama.py b/wordllama/wordllama.py index 2a941e1..0d3f2de 100644 --- a/wordllama/wordllama.py +++ b/wordllama/wordllama.py @@ -9,7 +9,7 @@ from .inference import WordLlamaInference from .config import Config, WordLlamaConfig -from .tokenizers import tokenizer_from_file + logger = logging.getLogger(__name__) @@ -23,6 +23,11 @@ class ModelURI: class WordLlama: + """ + The WordLlama class is responsible for managing model weights and tokenizer configurations. + It handles the resolution of file paths, caching, and downloading from Hugging Face repositories. + """ + l2_supercat = ModelURI( repo_id="dleemiller/word-llama-l2-supercat", available_dims=[64, 128, 256, 512, 1024], @@ -37,6 +42,8 @@ class WordLlama: tokenizer_config="l3_supercat_tokenizer_config.json", ) + DEFAULT_CACHE_DIR = Path.home() / ".cache" / "wordllama" + @staticmethod def get_filename(config_name: str, dim: int, binary: bool = False) -> str: """ @@ -45,268 +52,263 @@ def get_filename(config_name: str, dim: int, binary: bool = False) -> str: Args: config_name (str): The name of the configuration. dim (int): The dimensionality of the model. - binary (bool): Whether the file is binary. + binary (bool): Indicates whether the file is binary. Returns: - str: The generated filename. + str: The constructed filename. """ - suffix = "" if not binary else "_binary" + suffix = "_binary" if binary else "" return f"{config_name}_{dim}{suffix}.safetensors" @staticmethod - def get_cache_dir( - is_tokenizer_config: bool = False, - base_cache_dir: Optional[Path] = None - ) -> Path: + def get_tokenizer_filename(config_name: str) -> str: """ - Get the cache directory path for weights or tokenizer configuration. + Retrieve the tokenizer configuration filename based on the configuration name. Args: - is_tokenizer_config (bool, optional): If True, return the tokenizer cache directory. + config_name (str): The name of the configuration. Returns: - Path: The path to the cache directory. + str: The tokenizer configuration filename. """ - if base_cache_dir is None: - base_cache_dir = Path.home() / ".cache" / "wordllama" - return base_cache_dir / ("tokenizers" if is_tokenizer_config else "weights") + model_uri = getattr(WordLlama, config_name) + return model_uri.tokenizer_config - @staticmethod - def download_file_from_hf( - repo_id: str, - filename: str, + @classmethod + def get_file_path( + cls, + config_name: str, + dim: int, + binary: bool, + file_type: str, # 'weights' or 'tokenizer' cache_dir: Optional[Path] = None, - force_download: bool = False, - token: Optional[str] = None, ) -> Path: """ - Download a file from a Hugging Face model repository and cache it locally. + Determine the directory path for weights or tokenizer files. Args: - repo_id (str): The repository ID on Hugging Face (e.g., 'user/repo'). - filename (str): The name of the file to download. - cache_dir (Path, optional): The directory to cache the downloaded file. Defaults to the appropriate cache directory. - force_download (bool, optional): If True, force download the file even if it exists in the cache. - token (str, optional): The Hugging Face token for accessing private repositories. + config_name (str): The configuration name. + dim (int): The dimensionality of the model. + binary (bool): Indicates whether the weights file is binary. + file_type (str): Specifies the type of file ('weights' or 'tokenizer'). + cache_dir (Path, optional): Custom cache directory. Defaults to None. Returns: - Path: The path to the cached file. + Path: The resolved directory path for the specified file type. """ - cache_dir = WordLlama.get_cache_dir(base_cache_dir=cache_dir) - - cache_dir.mkdir(parents=True, exist_ok=True) - cached_file_path = cache_dir / filename - - if not force_download and cached_file_path.exists(): - logger.debug(f"File {filename} exists in cache. Using cached version.") - return cached_file_path - - url = f"https://huggingface.co/{repo_id}/resolve/main/{filename}" - headers = {"Authorization": f"Bearer {token}"} if token else {} - - logger.info(f"Downloading {filename} from {url}") - - response = requests.get(url, headers=headers, stream=True) - response.raise_for_status() - - with cached_file_path.open("wb") as f: - for chunk in response.iter_content(chunk_size=8192): - f.write(chunk) - - logger.debug(f"File {filename} downloaded and cached at {cached_file_path}") - - return cached_file_path + cache_dir = cache_dir or cls.DEFAULT_CACHE_DIR + sub_dir = "tokenizers" if file_type == "tokenizer" else "weights" + return cache_dir / sub_dir @classmethod - def check_and_download_model( + def resolve_file( cls, config_name: str, dim: int, - binary: bool = False, - weights_dir: Optional[Path] = None, + binary: bool, + file_type: str, # 'weights' or 'tokenizer' cache_dir: Optional[Path] = None, disable_download: bool = False, ) -> Path: """ - Check if model weights exist locally, if not, download them. + Resolve the file path by checking the project root and cache directories. + If the file is not found, download it from Hugging Face to the cache directory. Args: - config_name (str): The name of the model configuration. - dim (int): The dimensionality of the model. - binary (bool, optional): Whether the file is binary. Defaults to False. - weights_dir (Path, optional): Directory where weight files are stored. If None, defaults to 'weights' directory in the current module directory. - cache_dir (Path, optional): Directory where cached files are stored. Defaults to the appropriate cache directory. - disable_download (bool, optional): Disable downloads for models not in cache. + config_name (str): The name of the configuration. + dim (int): The dimensionality of the model (irrelevant for tokenizers). + binary (bool): Indicates whether the weights file is binary (irrelevant for tokenizers). + file_type (str): Specifies the type of file ('weights' or 'tokenizer'). + cache_dir (Path, optional): Custom cache directory. Defaults to None. + disable_download (bool): If True, prevents downloading files not found locally. Returns: - Path: Path to the model weights file. - """ - if weights_dir is None: - weights_dir = Path(__file__).parent / "weights" + Path: The resolved file path. - cache_dir = WordLlama.get_cache_dir(base_cache_dir=cache_dir) + Raises: + FileNotFoundError: If the file is not found locally and downloads are disabled. + ValueError: If an invalid file_type is provided. + """ + if file_type == "weights": + filename = cls.get_filename(config_name, dim, binary) + elif file_type == "tokenizer": + filename = cls.get_tokenizer_filename(config_name) + else: + raise ValueError("file_type must be either 'weights' or 'tokenizer'.") - filename = cls.get_filename(config_name=config_name, dim=dim, binary=binary) - weights_file_path = weights_dir / filename + project_root_path = Path(__file__).parent / "wordllama" / file_type / filename + cache_path = ( + cls.get_file_path(config_name, dim, binary, file_type, cache_dir) / filename + ) - if not weights_file_path.exists(): - logger.debug( - f"Weights file '{filename}' not found in '{weights_dir}'. Checking cache directory..." - ) - weights_file_path = cache_dir / filename - if not weights_file_path.exists(): - if disable_download: - raise FileNotFoundError( - f"Weights file '{filename}' not found and downloads are disabled." - ) - - model_uri = getattr(cls, config_name) - if binary: - assert ( - dim in model_uri.binary_dims - ), f"Dimension must be one of {model_uri.binary_dims}" - else: - assert ( - dim in model_uri.available_dims - ), f"Dimension must be one of {model_uri.available_dims}" + # Check in project root directory + if project_root_path.exists(): + logger.debug(f"Found {file_type} file in project root: {project_root_path}") + return project_root_path - logger.debug( - f"Weights file '{filename}' not found in cache directory '{cache_dir}'. Downloading..." - ) - weights_file_path = cls.download_file_from_hf( - repo_id=model_uri.repo_id, filename=filename, cache_dir=cache_dir - ) + # Check in cache directory + if cache_path.exists(): + logger.debug(f"Found {file_type} file in cache: {cache_path}") + return cache_path - if not weights_file_path.exists(): + if disable_download: raise FileNotFoundError( - f"Weights file '{weights_file_path}' not found in directory '{weights_dir}' or cache '{cache_dir}'." + f"{file_type.capitalize()} file '{filename}' not found in project root or cache, and downloads are disabled." ) - return weights_file_path - - @classmethod - def check_and_download_tokenizer( - cls, - config_name: str, - cache_dir: Optional[Path] = None, - disable_download: bool = False, - ) -> Path: - """ - Check if tokenizer configuration exists locally, if not, download it. - - Args: - config_name (str): The name of the model configuration. - tokenizer_filename (str): The filename of the tokenizer configuration. - cache_dir (Path, optional): Directory where cached files are stored. Defaults to the appropriate cache directory. - disable_download (bool, optional): Disable downloading for models not in cache. - - Returns: - Path: Path to the tokenizer configuration file. - """ + # Download from Hugging Face model_uri = getattr(cls, config_name) + repo_id = model_uri.repo_id + download_dir = cache_path.parent + download_dir.mkdir(parents=True, exist_ok=True) - cache_dir = cls.get_cache_dir(is_tokenizer_config=True, base_cache_dir=cache_dir) - - tokenizer_file_path = cache_dir / model_uri.tokenizer_config - - if not tokenizer_file_path.exists(): - if disable_download: - raise FileNotFoundError( - f"Weights file '{tokenizer_file_path}' not found and downloads are disabled." - ) - - logger.debug( - f"Tokenizer config '{model_uri.tokenizer_config}' not found in cache directory '{cache_dir}'. Downloading..." - ) + url = f"https://huggingface.co/{repo_id}/resolve/main/{filename}" + logger.info( + f"Downloading {file_type} file '{filename}' from Hugging Face repository '{repo_id}'." + ) - tokenizer_file_path = cls.download_file_from_hf( - repo_id=model_uri.repo_id, - filename=model_uri.tokenizer_config, - cache_dir=cache_dir, + try: + response = requests.get(url, stream=True) + response.raise_for_status() + except requests.RequestException as e: + logger.error( + f"Failed to download {file_type} file '{filename}' from '{url}': {e}" ) - - if not tokenizer_file_path.exists(): raise FileNotFoundError( - f"Tokenizer config file '{tokenizer_file_path}' not found in cache '{cache_dir}'." - ) + f"Failed to download {file_type} file '{filename}' from '{url}'." + ) from e + + with cache_path.open("wb") as f: + for chunk in response.iter_content(chunk_size=8192): + if chunk: # Filter out keep-alive chunks + f.write(chunk) - return tokenizer_file_path + logger.debug(f"Downloaded {file_type} file and cached at {cache_path}") + return cache_path @classmethod def load( cls, config: Union[str, WordLlamaConfig] = "l2_supercat", - weights_dir: Optional[Path] = None, - cache_dir: Optional[Path] = None, + cache_dir: Optional[Union[Path, str]] = None, binary: bool = False, dim: int = 256, trunc_dim: Optional[int] = None, disable_download: bool = False, ) -> WordLlamaInference: """ - Load the WordLlama model. + Load the WordLlama model by resolving and loading the necessary weights and tokenizer files. + + The order of operations for loading files is as follows: + - **Weights:** + 1. Check in project root (`project_root / wordllama / weights / ...`). + 2. If not found, check in `cache_dir / weights / ...`. + 3. If still not found, download from Hugging Face to `cache_dir / weights / ...`. + - **Tokenizer:** + 1. Check in project root (`project_root / wordllama / tokenizers / ...`). + 2. If not found, check in `cache_dir / tokenizers / ...`. + 3. If still not found, download from Hugging Face to `cache_dir / tokenizers / ...`. Args: - config (Union[str, WordLlamaConfig], optional): The configuration object or the name of the configuration to load. Defaults to "l2_supercat". - weights_dir (Optional[Path], optional): Directory where weight files are stored. If None, defaults to 'weights' directory in the current module directory. Defaults to None. - cache_dir (Optional[Path], optional): Directory where cached files are stored. Defaults to ~/.cache/wordllama/weights. Defaults to None. - binary (bool, optional): Whether to load the binary version of the weights. Defaults to False. - dim (int, optional): The dimensionality of the model to load. Options: [64, 128, 256, 512, 1024]. Defaults to 256. - trunc_dim (Optional[int], optional): The dimension to truncate the model to. Must be less than or equal to 'dim'. Defaults to None. - disable_download(bool, optional): Turn off downloading models from huggingface when local model is not cached. + config (Union[str, WordLlamaConfig], optional): + The configuration name or an instance of WordLlamaConfig to load. + Defaults to "l2_supercat". + cache_dir (Optional[Path], optional): + The directory to use for caching files. + If None, defaults to `~/.cache/wordllama`. + Can be set to a custom path as needed. + binary (bool, optional): + Indicates whether to load the binary version of the weights. + Defaults to False. + dim (int, optional): + The dimensionality of the model to load. + Must be one of the available dimensions specified in the configuration. + Defaults to 256. + trunc_dim (Optional[int], optional): + The dimension to truncate the model to. + Must be less than or equal to 'dim' and one of the available dimensions. + Defaults to None. + disable_download (bool, optional): + If True, prevents downloading files from Hugging Face if they are not found locally. + Defaults to False. Returns: - WordLlamaInference: The loaded WordLlama model. + WordLlamaInference: An instance of WordLlamaInference containing the loaded model. Raises: - ValueError: If the configuration is not found or dimensions are invalid. - FileNotFoundError: If the weights file is not found in either the weights directory or cache directory. + ValueError: + - If the provided configuration is invalid or not found. + - If the specified dimensions are invalid. + FileNotFoundError: + - If the required files are not found locally and downloads are disabled. + - If downloading fails due to network issues or invalid URLs. """ + # Resolve configuration if isinstance(config, str): - config_obj = Config._configurations.get(config, None) + config_obj = Config._configurations.get(config) if config_obj is None: raise ValueError(f"Configuration '{config}' not found.") + config_name = config elif isinstance(config, WordLlamaConfig): config_obj = config + # config_name = getattr(config_obj, "name", None) + config_name = config + if config_name is None: + raise ValueError( + "WordLlamaConfig instance must have a 'name' attribute." + ) else: raise ValueError( "Invalid configuration type provided. It must be either a string or an instance of WordLlamaConfig." ) - assert ( - dim in config_obj.matryoshka.dims - ), f"Model dimension must be one of matryoshka dims in config: {config_obj.matryoshka.dims}" + # Validate dimensions + if dim not in config_obj.matryoshka.dims: + raise ValueError( + f"Model dimension must be one of {config_obj.matryoshka.dims}" + ) if trunc_dim is not None: - assert ( - trunc_dim <= dim - ), f"Cannot truncate to dimension lower than model dimension ({trunc_dim} > {dim})" - assert trunc_dim in config_obj.matryoshka.dims - - # Check and download model weights - weights_file_path = cls.check_and_download_model( - config_name=config, + if trunc_dim > dim: + raise ValueError( + f"Cannot truncate to a higher dimension ({trunc_dim} > {dim})" + ) + if trunc_dim not in config_obj.matryoshka.dims: + raise ValueError( + f"Truncated dimension must be one of {config_obj.matryoshka.dims}" + ) + + if cache_dir and isinstance(cache_dir, str): + cache_dir = Path(cache_dir).expanduser() # Expand ~ to the home dir + cache_dir = cache_dir.resolve(strict=False) # Resolve to absolute path + + # Resolve and load weights + weights_file_path = cls.resolve_file( + config_name=config_name, dim=dim, binary=binary, - weights_dir=weights_dir, + file_type="weights", cache_dir=cache_dir, disable_download=disable_download, ) - # Check and download tokenizer config if necessary - tokenizer_file_path = cls.check_and_download_tokenizer( - config_name=config, + # Resolve and load tokenizer + tokenizer_file_path = cls.resolve_file( + config_name=config_name, + dim=dim, + binary=False, + file_type="tokenizer", cache_dir=cache_dir, disable_download=disable_download, ) - # Load the tokenizer + # Load tokenizer tokenizer = cls.load_tokenizer(tokenizer_file_path, config_obj) - # Load the model weights + # Load model weights with safe_open(weights_file_path, framework="np", device="cpu") as f: embedding = f.get_tensor("embedding.weight") - if trunc_dim: # truncate dimension - embedding = embedding[:, 0:trunc_dim] + if trunc_dim: + embedding = embedding[:, :trunc_dim] logger.debug(f"Loading weights from: {weights_file_path}") return WordLlamaInference(embedding, config_obj, tokenizer, binary=binary) @@ -314,30 +316,30 @@ def load( @staticmethod def load_tokenizer(tokenizer_file_path: Path, config: WordLlamaConfig) -> Tokenizer: """ - Load the tokenizer from a local file or from the Hugging Face repository. - First, it checks the default path, then the cache directory. + Load the tokenizer from a local configuration file or fallback to Hugging Face. + + The method first attempts to load the tokenizer using the local configuration if specified. + If the local configuration is not found or not used, it falls back to loading the tokenizer + from the Hugging Face repository. Args: tokenizer_file_path (Path): The path to the tokenizer configuration file. - config (WordLlamaConfig): The configuration for the WordLlama model. + config (WordLlamaConfig): The configuration object containing tokenizer settings. Returns: - Tokenizer: An instance of the Tokenizer class. + Tokenizer: An instance of the Tokenizer class initialized with the appropriate configuration. """ - if ( - config.tokenizer.inference is not None - and config.tokenizer.inference.use_local_config - ): - # Check in the default path + if config.tokenizer.inference and config.tokenizer.inference.use_local_config: if tokenizer_file_path.exists(): logger.debug( - f"Loading tokenizer from default path: {tokenizer_file_path}" + f"Loading tokenizer from local config: {tokenizer_file_path}" + ) + return Tokenizer.from_file(str(tokenizer_file_path)) + else: + warnings.warn( + f"Local tokenizer config not found at {tokenizer_file_path}. Falling back to Hugging Face." ) - return tokenizer_from_file(tokenizer_file_path) - - warnings.warn( - f"Tokenizer config file not found in both default and cache paths. Falling back to Hugging Face model: {config.model.hf_model_id}" - ) - # Load from Hugging Face if local config is not used or not found + # Fallback to Hugging Face + logger.debug(f"Loading tokenizer from Hugging Face: {config.model.hf_model_id}") return Tokenizer.from_pretrained(config.model.hf_model_id) From d3074f84165092388180a6ede2568d8c7995ef18 Mon Sep 17 00:00:00 2001 From: Lee Miller Date: Sun, 10 Nov 2024 13:57:34 -0700 Subject: [PATCH 2/3] fixing test --- tests/test_wordllama.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/tests/test_wordllama.py b/tests/test_wordllama.py index 985185d..c76fb78 100644 --- a/tests/test_wordllama.py +++ b/tests/test_wordllama.py @@ -246,7 +246,7 @@ def test_load_with_custom_cache_dir(self, mock_resolve_file): # Call load with custom cache_dir model = WordLlama.load( - config=self.config_name, + config=self.config, cache_dir=cache_dir_input, binary=self.binary, dim=self.dim, @@ -257,7 +257,7 @@ def test_load_with_custom_cache_dir(self, mock_resolve_file): expected_calls = [ call( # WordLlama, - config_name=self.config_name, + config_name=self.config, dim=self.dim, binary=self.binary, file_type="weights", @@ -266,7 +266,7 @@ def test_load_with_custom_cache_dir(self, mock_resolve_file): ), call( # WordLlama, - config_name=self.config_name, + config_name=self.config, dim=self.dim, binary=False, file_type="tokenizer", From 97d314d6e43110017b6b24c2a5192e184c2a3fbf Mon Sep 17 00:00:00 2001 From: Lee Miller Date: Sun, 10 Nov 2024 14:31:16 -0700 Subject: [PATCH 3/3] cleanup and linting --- setup.py | 1 - tests/test_inference.py | 1 + tests/test_vector_similarity.py | 1 - tests/test_wordllama.py | 35 +++++++++++++++++---------------- wordllama/config/__init__.py | 2 ++ wordllama/wordllama.py | 26 +++++++----------------- 6 files changed, 28 insertions(+), 38 deletions(-) diff --git a/setup.py b/setup.py index 161beee..63c9be6 100644 --- a/setup.py +++ b/setup.py @@ -1,7 +1,6 @@ from setuptools import setup, Extension from Cython.Build import cythonize import numpy as np -import platform numpy_include = np.get_include() diff --git a/tests/test_inference.py b/tests/test_inference.py index 15a654f..5f3cc3e 100644 --- a/tests/test_inference.py +++ b/tests/test_inference.py @@ -70,6 +70,7 @@ def mock_encode_batch(texts, *args, **kwargs): matryoshka_config = MatryoshkaConfig(dims=[1024, 512, 256, 128, 64]) self.config = WordLlamaConfig( + config_name="test", model=model_config, tokenizer=tokenizer_config, training=training_config, diff --git a/tests/test_vector_similarity.py b/tests/test_vector_similarity.py index 50d36c9..e066d96 100644 --- a/tests/test_vector_similarity.py +++ b/tests/test_vector_similarity.py @@ -1,5 +1,4 @@ import unittest -from unittest.mock import patch, MagicMock import numpy as np from wordllama.algorithms import vector_similarity, binarize_and_packbits diff --git a/tests/test_wordllama.py b/tests/test_wordllama.py index c76fb78..e8f040d 100644 --- a/tests/test_wordllama.py +++ b/tests/test_wordllama.py @@ -57,6 +57,7 @@ def setUp(self): # Assemble WordLlamaConfig self.config = WordLlamaConfig( + config_name="test", model=model_config, tokenizer=tokenizer_config, training=training_config, @@ -156,20 +157,20 @@ def test_load_with_default_cache_dir(self, mock_resolve_file): # Assert resolve_file was called twice: once for weights, once for tokenizer expected_calls = [ call( - config_name=self.config, + config_name="test", dim=self.dim, binary=self.binary, file_type="weights", cache_dir=default_cache_dir, - disable_download=False, + disable_download=True, ), call( - config_name=self.config, + config_name="test", dim=self.dim, binary=False, file_type="tokenizer", cache_dir=default_cache_dir, - disable_download=False, + disable_download=True, ), ] mock_resolve_file.assert_has_calls(expected_calls, any_order=False) @@ -257,21 +258,21 @@ def test_load_with_custom_cache_dir(self, mock_resolve_file): expected_calls = [ call( # WordLlama, - config_name=self.config, + config_name="test", dim=self.dim, binary=self.binary, file_type="weights", cache_dir=expected_resolved_dirs[key], - disable_download=False, + disable_download=True, ), call( # WordLlama, - config_name=self.config, + config_name="test", dim=self.dim, binary=False, file_type="tokenizer", cache_dir=expected_resolved_dirs[key], - disable_download=False, + disable_download=True, ), ] mock_resolve_file.assert_has_calls(expected_calls, any_order=False) @@ -315,7 +316,7 @@ def test_load_with_disable_download(self, mock_resolve_file): # Assert resolve_file was called twice: once for weights, once for tokenizer expected_calls = [ call( - config_name=self.config, + config_name="test", dim=self.dim, binary=self.binary, file_type="weights", @@ -364,20 +365,20 @@ def test_load_with_truncated_dimension(self, mock_resolve_file): # Assert resolve_file was called twice expected_calls = [ call( - config_name=self.config, + config_name="test", dim=self.dim, binary=self.binary, file_type="weights", cache_dir=Path("/dummy/cache"), - disable_download=False, + disable_download=True, ), call( - config_name=self.config, + config_name="test", dim=self.dim, binary=False, file_type="tokenizer", cache_dir=Path("/dummy/cache"), - disable_download=False, + disable_download=True, ), ] mock_resolve_file.assert_has_calls(expected_calls, any_order=False) @@ -442,20 +443,20 @@ def test_load_tokenizer_fallback(self, mock_resolve_file, mock_from_pretrained): # Assert resolve_file was called twice: weights and tokenizer expected_calls = [ call( - config_name=self.config, + config_name="test", dim=self.dim, binary=self.binary, file_type="weights", cache_dir=Path("/dummy/cache"), - disable_download=False, + disable_download=True, ), call( - config_name=self.config, + config_name="test", dim=self.dim, binary=False, file_type="tokenizer", cache_dir=Path("/dummy/cache"), - disable_download=False, + disable_download=True, ), ] mock_resolve_file.assert_has_calls(expected_calls, any_order=False) diff --git a/wordllama/config/__init__.py b/wordllama/config/__init__.py index b04801d..4d332ae 100644 --- a/wordllama/config/__init__.py +++ b/wordllama/config/__init__.py @@ -47,6 +47,7 @@ class WordLlamaModel(BaseModel): class WordLlamaConfig(BaseModel): + config_name: str model: WordLlamaModel tokenizer: TokenizerConfig training: TrainingConfig @@ -71,6 +72,7 @@ def load_configurations() -> Dict[str, WordLlamaConfig]: for config_file in config_dir.glob("*.toml"): config_data = toml.load(config_file) config_name = config_file.stem # Filename without extension + config_data["config_name"] = config_name configs[config_name] = WordLlamaConfig(**config_data) return configs diff --git a/wordllama/wordllama.py b/wordllama/wordllama.py index 0d3f2de..0cf956c 100644 --- a/wordllama/wordllama.py +++ b/wordllama/wordllama.py @@ -77,9 +77,6 @@ def get_tokenizer_filename(config_name: str) -> str: @classmethod def get_file_path( cls, - config_name: str, - dim: int, - binary: bool, file_type: str, # 'weights' or 'tokenizer' cache_dir: Optional[Path] = None, ) -> Path: @@ -87,9 +84,6 @@ def get_file_path( Determine the directory path for weights or tokenizer files. Args: - config_name (str): The configuration name. - dim (int): The dimensionality of the model. - binary (bool): Indicates whether the weights file is binary. file_type (str): Specifies the type of file ('weights' or 'tokenizer'). cache_dir (Path, optional): Custom cache directory. Defaults to None. @@ -137,9 +131,7 @@ def resolve_file( raise ValueError("file_type must be either 'weights' or 'tokenizer'.") project_root_path = Path(__file__).parent / "wordllama" / file_type / filename - cache_path = ( - cls.get_file_path(config_name, dim, binary, file_type, cache_dir) / filename - ) + cache_path = cls.get_file_path(file_type, cache_dir) / filename # Check in project root directory if project_root_path.exists(): @@ -199,19 +191,18 @@ def load( """ Load the WordLlama model by resolving and loading the necessary weights and tokenizer files. - The order of operations for loading files is as follows: - - **Weights:** + Weights: 1. Check in project root (`project_root / wordllama / weights / ...`). 2. If not found, check in `cache_dir / weights / ...`. 3. If still not found, download from Hugging Face to `cache_dir / weights / ...`. - - **Tokenizer:** + Tokenizer: 1. Check in project root (`project_root / wordllama / tokenizers / ...`). 2. If not found, check in `cache_dir / tokenizers / ...`. 3. If still not found, download from Hugging Face to `cache_dir / tokenizers / ...`. Args: config (Union[str, WordLlamaConfig], optional): - The configuration name or an instance of WordLlamaConfig to load. + The configuration name or a custom instance of WordLlamaConfig to load. Defaults to "l2_supercat". cache_dir (Optional[Path], optional): The directory to use for caching files. @@ -251,12 +242,9 @@ def load( config_name = config elif isinstance(config, WordLlamaConfig): config_obj = config - # config_name = getattr(config_obj, "name", None) - config_name = config - if config_name is None: - raise ValueError( - "WordLlamaConfig instance must have a 'name' attribute." - ) + config_name = getattr(config, "config_name") + disable_download = True # disable for custom config + logger.debug("Downloads are disabled for custom configs.") else: raise ValueError( "Invalid configuration type provided. It must be either a string or an instance of WordLlamaConfig."