Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Model config tests #627

Open
wants to merge 15 commits into
base: dev
Choose a base branch
from
25 changes: 25 additions & 0 deletions tests/unit/test_model_configurations.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
from functools import lru_cache

import pytest

from transformer_lens import loading
from transformer_lens.HookedTransformerConfig import HookedTransformerConfig


@lru_cache(maxsize=None)
def get_cached_config(model_name: str) -> HookedTransformerConfig:
"""Retrieve the configuration of a pretrained model.

Args:
model_name (str): Name of the pretrained model.

Returns:
HookedTransformerConfig: Configuration of the pretrained model.
"""
return loading.get_pretrained_model_config(model_name)


@pytest.mark.parametrize("model_name", loading.DEFAULT_MODEL_ALIASES)
def test_model_configurations(model_name: str):
"""Tests that all of the model configurations are in fact loaded (e.g. are not None)."""
assert get_cached_config(model_name) is not None, f"Configuration for {model_name} is None"
2 changes: 1 addition & 1 deletion transformer_lens/HookedTransformerConfig.py
Original file line number Diff line number Diff line change
Expand Up @@ -245,7 +245,7 @@ class HookedTransformerConfig:
tokenizer_prepends_bos: Optional[bool] = None
n_key_value_heads: Optional[int] = None
post_embedding_ln: bool = False
rotary_base: int = 10000
rotary_base: float = 10000.0
trust_remote_code: bool = False
rotary_adjacent_pairs: bool = False
load_in_4bit: bool = False
Expand Down
2 changes: 1 addition & 1 deletion transformer_lens/components/abstract_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -469,7 +469,7 @@ def calculate_sin_cos_rotary(
self,
rotary_dim: int,
n_ctx: int,
base: int = 10000,
base: float = 10000,
dtype: torch.dtype = torch.float32,
) -> Tuple[Float[torch.Tensor, "n_ctx rotary_dim"], Float[torch.Tensor, "n_ctx rotary_dim"]]:
"""
Expand Down
2 changes: 1 addition & 1 deletion transformer_lens/loading_from_pretrained.py
Original file line number Diff line number Diff line change
Expand Up @@ -777,7 +777,7 @@ def convert_hf_model_config(model_name: str, **kwargs):
"rotary_dim": 4096 // 32,
"final_rms": True,
"gated_mlp": True,
"rotary_base": 1000000,
"rotary_base": 1000000.0,
}
if "python" in official_model_name.lower():
# The vocab size of python version of CodeLlama-7b is 32000
Expand Down
Loading