Skip to content

Commit

Permalink
Merge pull request #41 from erickpeirson/main
Browse files Browse the repository at this point in the history
Use custom cache dir for tokenizer download, too
  • Loading branch information
dleemiller authored Nov 7, 2024
2 parents 3e6d38d + 61d49e4 commit d8810b8
Show file tree
Hide file tree
Showing 2 changed files with 22 additions and 11 deletions.
6 changes: 4 additions & 2 deletions tests/test_wordllama.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,7 @@ def test_check_and_download_model(self, mock_exists, mock_download):
cache_dir=Path("/dummy/cache"),
)
self.assertEqual(
weights_file_path, Path("/dummy/cache/l2_supercat_256.safetensors")
weights_file_path, Path("/dummy/cache/weights/l2_supercat_256.safetensors")
)

@patch(
Expand Down Expand Up @@ -181,7 +181,9 @@ def test_load(
disable_download=False,
)
mock_check_tokenizer.assert_called_once_with(
config_name="l2_supercat", disable_download=False
config_name="l2_supercat",
cache_dir=Path('/dummy/cache'),
disable_download=False,
)
mock_load_tokenizer.assert_called_once()
mock_safe_open.assert_called_once_with(
Expand Down
27 changes: 18 additions & 9 deletions wordllama/wordllama.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,10 @@ def get_filename(config_name: str, dim: int, binary: bool = False) -> str:
return f"{config_name}_{dim}{suffix}.safetensors"

@staticmethod
def get_cache_dir(is_tokenizer_config: bool = False) -> Path:
def get_cache_dir(
is_tokenizer_config: bool = False,
base_cache_dir: Optional[Path] = None
) -> Path:
"""
Get the cache directory path for weights or tokenizer configuration.
Expand All @@ -64,7 +67,8 @@ def get_cache_dir(is_tokenizer_config: bool = False) -> Path:
Returns:
Path: The path to the cache directory.
"""
base_cache_dir = Path.home() / ".cache" / "wordllama"
if base_cache_dir is None:
base_cache_dir = Path.home() / ".cache" / "wordllama"
return base_cache_dir / ("tokenizers" if is_tokenizer_config else "weights")

@staticmethod
Expand All @@ -88,8 +92,7 @@ def download_file_from_hf(
Returns:
Path: The path to the cached file.
"""
if cache_dir is None:
cache_dir = WordLlama.get_cache_dir()
cache_dir = WordLlama.get_cache_dir(base_cache_dir=cache_dir)

cache_dir.mkdir(parents=True, exist_ok=True)
cached_file_path = cache_dir / filename
Expand Down Expand Up @@ -141,8 +144,7 @@ def check_and_download_model(
if weights_dir is None:
weights_dir = Path(__file__).parent / "weights"

if cache_dir is None:
cache_dir = cls.get_cache_dir()
cache_dir = WordLlama.get_cache_dir(base_cache_dir=cache_dir)

filename = cls.get_filename(config_name=config_name, dim=dim, binary=binary)
weights_file_path = weights_dir / filename
Expand Down Expand Up @@ -172,7 +174,7 @@ def check_and_download_model(
f"Weights file '{filename}' not found in cache directory '{cache_dir}'. Downloading..."
)
weights_file_path = cls.download_file_from_hf(
repo_id=model_uri.repo_id, filename=filename
repo_id=model_uri.repo_id, filename=filename, cache_dir=cache_dir
)

if not weights_file_path.exists():
Expand All @@ -184,21 +186,27 @@ def check_and_download_model(

@classmethod
def check_and_download_tokenizer(
cls, config_name: str, disable_download: bool = False
cls,
config_name: str,
cache_dir: Optional[Path] = None,
disable_download: bool = False,
) -> Path:
"""
Check if tokenizer configuration exists locally, if not, download it.
Args:
config_name (str): The name of the model configuration.
tokenizer_filename (str): The filename of the tokenizer configuration.
cache_dir (Path, optional): Directory where cached files are stored. Defaults to the appropriate cache directory.
disable_download (bool, optional): Disable downloading for models not in cache.
Returns:
Path: Path to the tokenizer configuration file.
"""
model_uri = getattr(cls, config_name)
cache_dir = cls.get_cache_dir(is_tokenizer_config=True)

cache_dir = cls.get_cache_dir(is_tokenizer_config=True, base_cache_dir=cache_dir)

tokenizer_file_path = cache_dir / model_uri.tokenizer_config

if not tokenizer_file_path.exists():
Expand Down Expand Up @@ -287,6 +295,7 @@ def load(
# Check and download tokenizer config if necessary
tokenizer_file_path = cls.check_and_download_tokenizer(
config_name=config,
cache_dir=cache_dir,
disable_download=disable_download,
)

Expand Down

0 comments on commit d8810b8

Please sign in to comment.