From 218e59463eba7753ee67f4fde947012043411423 Mon Sep 17 00:00:00 2001 From: cmathw Date: Tue, 30 Jan 2024 21:18:54 +0000 Subject: [PATCH 1/4] add flash attn mvp --- tests/unit/test_fast_attn.py | 37 +++++++++ transformer_lens/HookedTransformerConfig.py | 6 ++ transformer_lens/components.py | 89 ++++++++++++++------- 3 files changed, 102 insertions(+), 30 deletions(-) create mode 100644 tests/unit/test_fast_attn.py diff --git a/tests/unit/test_fast_attn.py b/tests/unit/test_fast_attn.py new file mode 100644 index 000000000..d4c310ca9 --- /dev/null +++ b/tests/unit/test_fast_attn.py @@ -0,0 +1,37 @@ +import pytest +import torch + +from transformer_lens.HookedTransformer import HookedTransformer + + +class TestFastAttn: + prompt = """ + This is a library for doing mechanistic interpretability of GPT-2 Style language models. + The goal of mechanistic interpretability is to take a trained model and reverse engineer + the algorithms the model learned during training from its weights. + """ + + # fixtures + @pytest.fixture(scope="class", params=["gpt2", "facebook/opt-125m"]) + def model_name(self, request): + return request.param + + @pytest.fixture(scope="class") + def model(self, model_name): + return HookedTransformer.from_pretrained(model_name) + + # tests + def test_logits_and_cache(self, model_name): + model = HookedTransformer.from_pretrained(model_name) + model.cfg.use_fast_attn = True + fast_logits, fast_cache = model.run_with_cache(self.prompt) + model.cfg.use_fast_attn = False + slow_logits, slow_cache = model.run_with_cache(self.prompt) + + assert torch.allclose(fast_logits, slow_logits, rtol=5e-1, atol=5e-1), "Logits mismatch" + + # Fast cache should be missing Attn Scores and Pattern Keys + assert len(fast_cache) < len(slow_cache) + + for k, v in fast_cache.items(): + assert torch.allclose(v, slow_cache[k], rtol=5e-1, atol=5e-1), f"Cache mismatch for {k}" diff --git a/transformer_lens/HookedTransformerConfig.py b/transformer_lens/HookedTransformerConfig.py index 501f6e881..c5b7fcb72 100644 --- a/transformer_lens/HookedTransformerConfig.py +++ b/transformer_lens/HookedTransformerConfig.py @@ -151,6 +151,11 @@ class HookedTransformerConfig: Only for models that use Grouped Query Attention. post_embedding_ln (bool): Whether to apply layer normalization after embedding the tokens. Defaults to False. + use_fast_attn (bool): Whether to use torch.nn.functional.scaled_dot_product_attention. This + implementation includes FlashAttention-2, as well as, two other alternative (potentially faster) attention + implmentations. PyTorch attempts to automatically select the most optimal implementation + based on inputs. Note, using these implementations will mean loss of some intermediate hooks + (ie. hook_attn_scores and hook_pattern). Defaults to False. """ n_layers: int @@ -203,6 +208,7 @@ class HookedTransformerConfig: rotary_base: int = 10000 trust_remote_code: bool = False rotary_adjacent_pairs: bool = False + use_fast_attn: bool = False def __post_init__(self): if self.n_heads == -1: diff --git a/transformer_lens/components.py b/transformer_lens/components.py index 942ec2819..939f4e9a1 100644 --- a/transformer_lens/components.py +++ b/transformer_lens/components.py @@ -553,39 +553,48 @@ def forward( q = q.to(torch.float32) k = k.to(torch.float32) - attn_scores = self.calculate_attention_scores( - q, k - ) # [batch, head_index, query_pos, key_pos] - - if self.cfg.positional_embedding_type == "alibi": - query_ctx = attn_scores.size(-2) - # The key context length is the number of positions in the past - this includes all positions in the cache - key_ctx = attn_scores.size(-1) - - # only recompute when necessary to increase efficiency. - if self.alibi is None or key_ctx > self.alibi.size(-1): - self.alibi = Attention.create_alibi_bias( - self.cfg.n_heads, key_ctx, self.cfg.device - ) + if self.cfg.use_fast_attn: + z = self.calculate_z_with_sdpa(q, k, v) # [batch, pos, head_index, d_head] + + else: + attn_scores = self.calculate_attention_scores( + q, k + ) # [batch, head_index, query_pos, key_pos] - attn_scores += self.alibi[ - :, :query_ctx, :key_ctx - ] # [batch, head_index, query_pos, key_pos] + if self.cfg.positional_embedding_type == "alibi": + query_ctx = attn_scores.size(-2) + # The key context length is the number of positions in the past - this includes all positions in the cache + key_ctx = attn_scores.size(-1) - if self.cfg.attention_dir == "causal": - # If causal attention, we mask it to only attend backwards. If bidirectional, we don't mask. - attn_scores = self.apply_causal_mask( - attn_scores, kv_cache_pos_offset, attention_mask + # only recompute when necessary to increase efficiency. + if self.alibi is None or key_ctx > self.alibi.size(-1): + self.alibi = Attention.create_alibi_bias( + self.cfg.n_heads, key_ctx, self.cfg.device + ) + + attn_scores += self.alibi[ + :, :query_ctx, :key_ctx + ] # [batch, head_index, query_pos, key_pos] + + if self.cfg.attention_dir == "causal": + # If causal attention, we mask it to only attend backwards. If bidirectional, we don't mask. + attn_scores = self.apply_causal_mask( + attn_scores, kv_cache_pos_offset, attention_mask + ) # [batch, head_index, query_pos, key_pos] + if additive_attention_mask is not None: + attn_scores += additive_attention_mask + + attn_scores = self.hook_attn_scores(attn_scores) + pattern = F.softmax(attn_scores, dim=-1) + pattern = torch.where( + torch.isnan(pattern), torch.zeros_like(pattern), pattern + ) + pattern = self.hook_pattern( + pattern ) # [batch, head_index, query_pos, key_pos] - if additive_attention_mask is not None: - attn_scores += additive_attention_mask - - attn_scores = self.hook_attn_scores(attn_scores) - pattern = F.softmax(attn_scores, dim=-1) - pattern = torch.where(torch.isnan(pattern), torch.zeros_like(pattern), pattern) - pattern = self.hook_pattern(pattern) # [batch, head_index, query_pos, key_pos] - pattern = pattern.to(self.cfg.dtype) - z = self.calculate_z_scores(v, pattern) # [batch, pos, head_index, d_head] + pattern = pattern.to(self.cfg.dtype) + z = self.calculate_z_scores(v, pattern) # [batch, pos, head_index, d_head] + if not self.cfg.use_attn_result: out = ( ( @@ -689,6 +698,26 @@ def calculate_attention_scores( ) return attn_scores + def calculate_z_with_sdpa( + self, + q: Float[torch.Tensor, "batch query_pos head_index d_head"], + k: Float[torch.Tensor, "batch key_pos head_index d_head"], + v: Float[torch.Tensor, "batch key_pos head_index d_head"], + ) -> Float[torch.Tensor, "batch query_pos head_index d_head"]: + # PyTorch's scaled_dot_product_attention requires Q, K, V to be float16 and shape [batch ... pos d_head] + convert_to_sdpa_format = lambda tensor: einops.rearrange( + tensor, "batch pos head_index d_head -> batch head_index pos d_head" + ).to(torch.float16) + + convert_from_sdpa_format = lambda tensor: einops.rearrange( + tensor, "batch head_index pos d_head -> batch pos head_index d_head" + ).to(q.dtype) + + query, key, value = map(convert_to_sdpa_format, [q, k, v]) + z = F.scaled_dot_product_attention(query, key, value, is_causal=True) + z = self.hook_z(convert_from_sdpa_format(z)) + return z + def calculate_z_scores( self, v: Float[torch.Tensor, "batch key_pos head_index d_head"], From 187180e4ad836d71ec7a2f1ec0582218183a9507 Mon Sep 17 00:00:00 2001 From: cmathw Date: Tue, 30 Jan 2024 21:48:46 +0000 Subject: [PATCH 2/4] formatting --- tests/unit/test_fast_attn.py | 14 +++++++++----- 1 file changed, 9 insertions(+), 5 deletions(-) diff --git a/tests/unit/test_fast_attn.py b/tests/unit/test_fast_attn.py index d4c310ca9..69b492405 100644 --- a/tests/unit/test_fast_attn.py +++ b/tests/unit/test_fast_attn.py @@ -27,11 +27,15 @@ def test_logits_and_cache(self, model_name): fast_logits, fast_cache = model.run_with_cache(self.prompt) model.cfg.use_fast_attn = False slow_logits, slow_cache = model.run_with_cache(self.prompt) - - assert torch.allclose(fast_logits, slow_logits, rtol=5e-1, atol=5e-1), "Logits mismatch" - + + assert torch.allclose( + fast_logits, slow_logits, rtol=5e-1, atol=5e-1 + ), "Logits mismatch" + # Fast cache should be missing Attn Scores and Pattern Keys assert len(fast_cache) < len(slow_cache) - + for k, v in fast_cache.items(): - assert torch.allclose(v, slow_cache[k], rtol=5e-1, atol=5e-1), f"Cache mismatch for {k}" + assert torch.allclose( + v, slow_cache[k], rtol=5e-1, atol=5e-1 + ), f"Cache mismatch for {k}" From ab586a410ff47f5bb0c1fc9fe8ac7cab2c843185 Mon Sep 17 00:00:00 2001 From: cmathw Date: Tue, 30 Jan 2024 22:01:49 +0000 Subject: [PATCH 3/4] mark fast_attn tests skip if no gpu --- tests/unit/test_fast_attn.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/unit/test_fast_attn.py b/tests/unit/test_fast_attn.py index 69b492405..5b4f0f642 100644 --- a/tests/unit/test_fast_attn.py +++ b/tests/unit/test_fast_attn.py @@ -20,9 +20,9 @@ def model_name(self, request): def model(self, model_name): return HookedTransformer.from_pretrained(model_name) - # tests + @pytest.mark.skipif(not torch.cuda.is_available(), reason="Requires a CUDA device") def test_logits_and_cache(self, model_name): - model = HookedTransformer.from_pretrained(model_name) + model = HookedTransformer.from_pretrained(model_name).to("cuda") model.cfg.use_fast_attn = True fast_logits, fast_cache = model.run_with_cache(self.prompt) model.cfg.use_fast_attn = False From 49357ae62243b7ee81e4f6623fded2e31e623e7b Mon Sep 17 00:00:00 2001 From: cmathw Date: Tue, 30 Jan 2024 22:11:32 +0000 Subject: [PATCH 4/4] fix model fixture definition --- tests/unit/test_fast_attn.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/unit/test_fast_attn.py b/tests/unit/test_fast_attn.py index 5b4f0f642..ae831fd61 100644 --- a/tests/unit/test_fast_attn.py +++ b/tests/unit/test_fast_attn.py @@ -21,8 +21,8 @@ def model(self, model_name): return HookedTransformer.from_pretrained(model_name) @pytest.mark.skipif(not torch.cuda.is_available(), reason="Requires a CUDA device") - def test_logits_and_cache(self, model_name): - model = HookedTransformer.from_pretrained(model_name).to("cuda") + def test_logits_and_cache(self, model): + model.to("cuda") model.cfg.use_fast_attn = True fast_logits, fast_cache = model.run_with_cache(self.prompt) model.cfg.use_fast_attn = False