diff --git a/tests/unit/test_fast_attn.py b/tests/unit/test_fast_attn.py new file mode 100644 index 000000000..ae831fd61 --- /dev/null +++ b/tests/unit/test_fast_attn.py @@ -0,0 +1,41 @@ +import pytest +import torch + +from transformer_lens.HookedTransformer import HookedTransformer + + +class TestFastAttn: + prompt = """ + This is a library for doing mechanistic interpretability of GPT-2 Style language models. + The goal of mechanistic interpretability is to take a trained model and reverse engineer + the algorithms the model learned during training from its weights. + """ + + # fixtures + @pytest.fixture(scope="class", params=["gpt2", "facebook/opt-125m"]) + def model_name(self, request): + return request.param + + @pytest.fixture(scope="class") + def model(self, model_name): + return HookedTransformer.from_pretrained(model_name) + + @pytest.mark.skipif(not torch.cuda.is_available(), reason="Requires a CUDA device") + def test_logits_and_cache(self, model): + model.to("cuda") + model.cfg.use_fast_attn = True + fast_logits, fast_cache = model.run_with_cache(self.prompt) + model.cfg.use_fast_attn = False + slow_logits, slow_cache = model.run_with_cache(self.prompt) + + assert torch.allclose( + fast_logits, slow_logits, rtol=5e-1, atol=5e-1 + ), "Logits mismatch" + + # Fast cache should be missing Attn Scores and Pattern Keys + assert len(fast_cache) < len(slow_cache) + + for k, v in fast_cache.items(): + assert torch.allclose( + v, slow_cache[k], rtol=5e-1, atol=5e-1 + ), f"Cache mismatch for {k}" diff --git a/transformer_lens/HookedTransformerConfig.py b/transformer_lens/HookedTransformerConfig.py index 1e1e595ed..6fc2f1e5d 100644 --- a/transformer_lens/HookedTransformerConfig.py +++ b/transformer_lens/HookedTransformerConfig.py @@ -159,6 +159,11 @@ class HookedTransformerConfig: must also be set. Set to None if not using MoE. experts_per_token (int, *optional*): The number of experts to use for each pass in the MoE layer. If set, num_experts must also be set. Set to None if not using MoE. + use_fast_attn (bool): Whether to use torch.nn.functional.scaled_dot_product_attention. This + implementation includes FlashAttention-2, as well as, two other alternative (potentially faster) attention + implmentations. PyTorch attempts to automatically select the most optimal implementation + based on inputs. Note, using these implementations will mean loss of some intermediate hooks + (ie. hook_attn_scores and hook_pattern). Defaults to False. """ n_layers: int @@ -214,6 +219,7 @@ class HookedTransformerConfig: load_in_4bit: bool = False num_experts: Optional[int] = None experts_per_token: Optional[int] = None + use_fast_attn: bool = False def __post_init__(self): if self.n_heads == -1: