Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Draft] Support Flash Attention #501

Draft
wants to merge 5 commits into
base: dev
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
41 changes: 41 additions & 0 deletions tests/unit/test_fast_attn.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
import pytest
import torch

from transformer_lens.HookedTransformer import HookedTransformer


class TestFastAttn:
prompt = """
This is a library for doing mechanistic interpretability of GPT-2 Style language models.
The goal of mechanistic interpretability is to take a trained model and reverse engineer
the algorithms the model learned during training from its weights.
"""

# fixtures
@pytest.fixture(scope="class", params=["gpt2", "facebook/opt-125m"])
def model_name(self, request):
return request.param

@pytest.fixture(scope="class")
def model(self, model_name):
return HookedTransformer.from_pretrained(model_name)

@pytest.mark.skipif(not torch.cuda.is_available(), reason="Requires a CUDA device")
def test_logits_and_cache(self, model):
model.to("cuda")
model.cfg.use_fast_attn = True
fast_logits, fast_cache = model.run_with_cache(self.prompt)
model.cfg.use_fast_attn = False
slow_logits, slow_cache = model.run_with_cache(self.prompt)

assert torch.allclose(
fast_logits, slow_logits, rtol=5e-1, atol=5e-1
), "Logits mismatch"

# Fast cache should be missing Attn Scores and Pattern Keys
assert len(fast_cache) < len(slow_cache)

for k, v in fast_cache.items():
assert torch.allclose(
v, slow_cache[k], rtol=5e-1, atol=5e-1
), f"Cache mismatch for {k}"
6 changes: 6 additions & 0 deletions transformer_lens/HookedTransformerConfig.py
Original file line number Diff line number Diff line change
Expand Up @@ -159,6 +159,11 @@ class HookedTransformerConfig:
must also be set. Set to None if not using MoE.
experts_per_token (int, *optional*): The number of experts to use for each pass in the MoE layer. If set,
num_experts must also be set. Set to None if not using MoE.
use_fast_attn (bool): Whether to use torch.nn.functional.scaled_dot_product_attention. This
implementation includes FlashAttention-2, as well as, two other alternative (potentially faster) attention
implmentations. PyTorch attempts to automatically select the most optimal implementation
based on inputs. Note, using these implementations will mean loss of some intermediate hooks
(ie. hook_attn_scores and hook_pattern). Defaults to False.
"""

n_layers: int
Expand Down Expand Up @@ -214,6 +219,7 @@ class HookedTransformerConfig:
load_in_4bit: bool = False
num_experts: Optional[int] = None
experts_per_token: Optional[int] = None
use_fast_attn: bool = False

def __post_init__(self):
if self.n_heads == -1:
Expand Down
Loading