-
Notifications
You must be signed in to change notification settings - Fork 316
/
HookedTransformer.py
2434 lines (2142 loc) · 116 KB
/
HookedTransformer.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
"""Hooked Transformer.
The Hooked Transformer is the core part of TransformerLens.
In common PyTorch model implementations (e.g. ones from HuggingFace) it's fairly easy to extract
model weights, but much harder to extract activations. TransformerLens aims to simplify this task by
attaching hooks to every notable activation within the model. This enables the inspection and/or
alteration of activations in individual components like attention heads and MLP layers, facilitating
a deeper understanding of the internal workings of transformers like GPT-2.
"""
import logging
import os
from typing import Dict, List, NamedTuple, Optional, Tuple, Union, cast, overload
import einops
import numpy as np
import torch
import torch.nn as nn
import tqdm.auto as tqdm
from fancy_einsum import einsum
from jaxtyping import Float, Int
from packaging import version
from transformers import AutoModelForCausalLM, AutoTokenizer, PreTrainedTokenizerBase
from typing_extensions import Literal
import transformer_lens.loading_from_pretrained as loading
import transformer_lens.utils as utils
from transformer_lens.ActivationCache import ActivationCache
from transformer_lens.components import (
Embed,
LayerNorm,
LayerNormPre,
PosEmbed,
RMSNorm,
RMSNormPre,
TransformerBlock,
Unembed,
)
from transformer_lens.FactoredMatrix import FactoredMatrix
from transformer_lens.hook_points import HookedRootModule, HookPoint
from transformer_lens.HookedTransformerConfig import HookedTransformerConfig
from transformer_lens.loading_from_pretrained import NON_HF_HOSTED_MODEL_NAMES
# Note - activation cache is used with run_with_cache, past_key_value_caching is used for
# generation.
from transformer_lens.past_key_value_caching import HookedTransformerKeyValueCache
from transformer_lens.utilities import devices
from transformer_lens.utils import (
USE_DEFAULT_VALUE,
init_kaiming_normal_,
init_kaiming_uniform_,
init_xavier_normal_,
init_xavier_uniform_,
)
SingleLoss = Float[torch.Tensor, ""] # Type alias for a single element tensor
LossPerToken = Float[torch.Tensor, "batch pos-1"]
Loss = Union[SingleLoss, LossPerToken]
DTYPE_FROM_STRING = {
"float32": torch.float32,
"fp32": torch.float32,
"float16": torch.float16,
"fp16": torch.float16,
"bfloat16": torch.bfloat16,
"bf16": torch.bfloat16,
}
class Output(NamedTuple):
"""Output Named Tuple.
Named tuple object for if we want to output both logits and loss.
"""
logits: Float[torch.Tensor, "batch pos d_vocab"]
loss: Loss
class HookedTransformer(HookedRootModule):
"""Hooked Transformer.
Implements a full Transformer using the components :doc:`here <transformer_lens.components>`,
with a :class:`transformer_lens.hook_points.HookPoint` on every interesting activation.
TransformerLens comes loaded with >50 GPT-style models. Typically you initialise it with one of
these via :meth:`from_pretrained`, although it can also be instantiated with randomly
initialized weights via :meth:`__init__`.
Once you've initialized the model, a common next step is to test it can do the task you're
investigating. This can be done with :func:`transformer_lens.utils.test_prompt`.
"""
ln_final: nn.Module
def __init__(
self,
cfg: Union[HookedTransformerConfig, Dict],
tokenizer: Optional[PreTrainedTokenizerBase] = None,
move_to_device: bool = True,
default_padding_side: Literal["left", "right"] = "right",
):
"""Model initialization.
Note that if you want to load the model from pretrained weights, you should use
:meth:`from_pretrained` instead.
Args:
cfg: The config to use for the model.
tokenizer: The tokenizer to use for the model. If not provided, it is inferred from
`cfg.tokenizer_name` or initialized to `None`. If `None`, then the model cannot be
passed strings, and d_vocab must be explicitly set.
move_to_device: Whether to move the model to the device specified in cfg.
device. Must be true if `n_devices` in the config is greater than 1, since the
model's layers will be split across multiple devices.
default_padding_side: Which side to pad on.
"""
super().__init__()
if isinstance(cfg, str):
raise ValueError(
"Please pass in a config dictionary or HookedTransformerConfig object. If you want to load a "
"pretrained model, use HookedTransformer.from_pretrained() instead."
)
self.cfg = HookedTransformerConfig.unwrap(cfg)
if tokenizer is not None:
self.set_tokenizer(tokenizer, default_padding_side=default_padding_side)
elif self.cfg.tokenizer_name is not None:
# If we have a tokenizer name, we can load it from HuggingFace
if self.cfg.tokenizer_name in NON_HF_HOSTED_MODEL_NAMES:
logging.warning(
"%s tokenizer not loaded. Please load manually.",
self.cfg.tokenizer_name,
)
else:
# Hugging Face defaults to use_fast to True
use_fast = True
# Phi model's fast tokenizer does not support adding a BOS token, use_fast
# should be False
if "phi" in self.cfg.tokenizer_name.lower():
use_fast = False
huggingface_token = os.environ.get("HF_TOKEN", None)
self.set_tokenizer(
AutoTokenizer.from_pretrained(
self.cfg.tokenizer_name,
add_bos_token=True,
trust_remote_code=self.cfg.trust_remote_code,
use_fast=use_fast,
token=huggingface_token,
),
default_padding_side=default_padding_side,
)
else:
# If no tokenizer name is provided, we assume we're training on an algorithmic task and
# will pass in tokens directly. In this case, we don't need a tokenizer.
assert self.cfg.d_vocab != -1, "Must provide a tokenizer if d_vocab is not provided"
self.tokenizer = None
if default_padding_side != "right":
logging.warning(
"default_padding_side is explictly given but ignored because tokenizer is not set."
)
self.embed = Embed(self.cfg)
self.hook_embed = HookPoint() # [batch, pos, d_model]
if self.cfg.positional_embedding_type != "rotary":
self.pos_embed = PosEmbed(self.cfg)
self.hook_pos_embed = HookPoint() # [batch, pos, d__dictmodel]
if self.cfg.use_hook_tokens:
self.hook_tokens = HookPoint() # [batch, pos]
self.blocks = nn.ModuleList(
[TransformerBlock(self.cfg, block_index) for block_index in range(self.cfg.n_layers)]
)
if self.cfg.normalization_type == "RMS":
self.ln_final = RMSNorm(self.cfg)
elif self.cfg.normalization_type == "RMSPre":
self.ln_final = RMSNormPre(self.cfg)
elif self.cfg.normalization_type == "LN":
if self.cfg.final_rms:
self.ln_final = RMSNorm(self.cfg)
else:
self.ln_final = LayerNorm(self.cfg)
elif self.cfg.normalization_type == "LNPre":
# We've folded in LayerNorm weights, so just need the center + scale parts
if self.cfg.final_rms:
self.ln_final = RMSNormPre(self.cfg)
else:
self.ln_final = LayerNormPre(self.cfg)
elif self.cfg.normalization_type is None:
# If it's None, don't create either layer
pass
else:
logging.warning("Invalid normalization_type passed in %s", self.cfg.normalization_type)
self.unembed = Unembed(self.cfg)
if self.cfg.init_weights:
self.init_weights()
if move_to_device:
# We load the devices in a pipeline manner - the first device gets the embed and
# pos_embed layers and the first n_layers // n_devices blocks, the second gets the next
# n_layers // n_devices blocks ... the last gets the last n_layers // n_devices blocks,
# the final normalization layer (if it exists) and the unembed layer
self.move_model_modules_to_device()
# Helper variable to store a small (10K-20K) dataset of training data. Empty by default, can
# be loaded with load_sample_training_dataset
self.dataset = None
# Gives each module a parameter with its name (relative to this root module)
# Needed for HookPoints to work
self.setup()
def check_hooks_to_add(
self,
hook_point,
hook_point_name,
hook,
dir="fwd",
is_permanent=False,
prepend=False,
) -> None:
if hook_point_name.endswith("attn.hook_result"):
assert (
self.cfg.use_attn_result
), f"Cannot add hook {hook_point_name} if use_attn_result_hook is False"
if hook_point_name.endswith(("hook_q_input", "hook_k_input", "hook_v_input")):
assert (
self.cfg.use_split_qkv_input
), f"Cannot add hook {hook_point_name} if use_split_qkv_input is False"
if hook_point_name.endswith("mlp_in"):
assert (
self.cfg.use_hook_mlp_in
), f"Cannot add hook {hook_point_name} if use_hook_mlp_in is False"
if hook_point_name.endswith("attn_in"):
assert (
self.cfg.use_attn_in
), f"Cannot add hook {hook_point_name} if use_attn_in is False"
def input_to_embed(
self,
input: Union[str, List[str], Int[torch.Tensor, "batch pos"]],
prepend_bos: Optional[Union[bool, None]] = USE_DEFAULT_VALUE,
padding_side: Optional[Union[Literal["left", "right"], None]] = USE_DEFAULT_VALUE,
past_kv_cache: Optional[HookedTransformerKeyValueCache] = None,
) -> Tuple[
Float[torch.Tensor, "batch pos d_model"], # residual
Optional[Int[torch.Tensor, "batch pos"]], # tokens
Optional[Float[torch.Tensor, "batch pos d_model"]], # shortformer_pos_embed
Optional[torch.Tensor], # attention_mask [batch pos]
]:
"""Convert input to first residual stream.
Args:
input (Union[str, List[str], Int[torch.Tensor, "batch pos"]]): The input to the model.
prepend_bos (bool, optional): Overrides self.cfg.default_prepend_bos. Whether to prepend
the BOS token to the input (only applies when input is a string). Defaults to None,
implying usage of self.cfg.default_prepend_bos which is set to True unless specified
otherwise. Pass True or False to locally override the default.
padding_side ([Literal["left", "right"], optional): Overrides
self.tokenizer.padding_side. Specifies which side to pad when tokenizing
multiple strings of different lengths.
past_kv_cache (HookedTransformerKeyValueCache, optional): If passed, we're doing caching
and attention_mask will be stored in the cache.
"""
if isinstance(input, str) or isinstance(input, list):
# If text, convert to tokens (batch_size=1)
assert (
self.tokenizer is not None
), "Must provide a tokenizer if passing a string to the model"
# This is only intended to support passing in a single string
tokens = self.to_tokens(input, prepend_bos=prepend_bos, padding_side=padding_side)
else:
tokens = input
if len(tokens.shape) == 1:
# If tokens are a rank 1 tensor, add a dummy batch dimension to avoid things breaking.
tokens = tokens[None]
if tokens.device.type != self.cfg.device:
tokens = tokens.to(devices.get_device_for_block_index(0, self.cfg))
if (self.tokenizer and self.tokenizer.padding_side == "left") or past_kv_cache is not None:
# If the padding side is left or we are using caching, we need to compute the attention
# mask for the adjustment of absolute positional embeddings and attention masking so
# that pad tokens are not attended.
if prepend_bos is USE_DEFAULT_VALUE:
prepend_bos = self.cfg.default_prepend_bos
attention_mask = utils.get_attention_mask(self.tokenizer, tokens, prepend_bos)
if past_kv_cache is not None:
# past_kv_cache is not None, so we're doing caching.
# We need to extend the previous attention_mask.
# Update the past_kv_cache with the new attention_mask (unless it's frozen)
attention_mask = past_kv_cache.append_attention_mask(attention_mask)
else:
# We separate this case from for computational efficiency.
attention_mask = None
# If we're doing caching, then we reuse keys and values from previous runs, as that's the
# only way that past activations will affect the final logits. The cache contains those so
# we don't need to recompute them. This is useful for generating text. As we have absolute
# positional encodings, to implement this we have a `pos_offset` variable, defaulting to
# zero, which says to offset which positional encodings are used (cached keys and values
# were calculated with their own positional encodings).
if past_kv_cache is None:
pos_offset = 0
else:
batch_size, ctx_length = tokens.shape
(
cached_batch_size,
cache_ctx_length,
num_heads_in_cache,
d_head_in_cache,
) = past_kv_cache[0].past_keys.shape
assert cached_batch_size == batch_size
if self.cfg.n_key_value_heads is None:
assert num_heads_in_cache == self.cfg.n_heads
else:
assert num_heads_in_cache == self.cfg.n_key_value_heads
assert d_head_in_cache == self.cfg.d_head
pos_offset = cache_ctx_length
if self.cfg.use_hook_tokens:
tokens = self.hook_tokens(tokens)
embed = self.hook_embed(self.embed(tokens)) # [batch, pos, d_model]
if self.cfg.positional_embedding_type == "standard":
pos_embed = self.hook_pos_embed(
self.pos_embed(tokens, pos_offset, attention_mask)
) # [batch, pos, d_model]
residual = embed + pos_embed # [batch, pos, d_model]
shortformer_pos_embed = None
elif self.cfg.positional_embedding_type == "shortformer":
# If we're using shortformer style attention, we don't add the positional embedding to
# the residual stream. See HookedTransformerConfig for details
pos_embed = self.hook_pos_embed(
self.pos_embed(tokens, pos_offset, attention_mask)
) # [batch, pos, d_model]
residual = embed
shortformer_pos_embed = pos_embed
elif self.cfg.positional_embedding_type == "rotary":
# Rotary doesn't use positional embeddings, instead they're applied when dot producting
# keys and queries. See HookedTransformerConfig for details
residual = embed
shortformer_pos_embed = None
elif self.cfg.positional_embedding_type == "alibi":
# ALiBi does not add positional embeddings to word embeddings,instead it biases QK attention scores.
residual = embed
shortformer_pos_embed = None
else:
raise ValueError(
f"Invalid positional_embedding_type passed in {self.cfg.positional_embedding_type}"
)
return residual, tokens, shortformer_pos_embed, attention_mask
@overload
def forward(
self,
input,
return_type: Literal["logits"],
loss_per_token: bool = False,
prepend_bos: Optional[Union[bool, None]] = USE_DEFAULT_VALUE,
padding_side: Optional[Union[Literal["left", "right"], None]] = USE_DEFAULT_VALUE,
start_at_layer: Optional[int] = None,
tokens: Optional[Int[torch.Tensor, "batch pos"]] = None,
shortformer_pos_embed: Optional[Float[torch.Tensor, "batch pos d_model"]] = None,
attention_mask: Optional[torch.Tensor] = None, # [batch pos]
stop_at_layer: Optional[int] = None,
past_kv_cache: Optional[HookedTransformerKeyValueCache] = None,
) -> Loss:
...
@overload
def forward(
self,
input,
return_type: Literal["loss"],
loss_per_token: bool = False,
prepend_bos: Optional[Union[bool, None]] = USE_DEFAULT_VALUE,
padding_side: Optional[Union[Literal["left", "right"], None]] = USE_DEFAULT_VALUE,
start_at_layer: Optional[int] = None,
tokens: Optional[Int[torch.Tensor, "batch pos"]] = None,
shortformer_pos_embed: Optional[Float[torch.Tensor, "batch pos d_model"]] = None,
attention_mask: Optional[torch.Tensor] = None, # [batch pos]
stop_at_layer: Optional[int] = None,
past_kv_cache: Optional[HookedTransformerKeyValueCache] = None,
) -> Loss:
...
@overload
def forward(
self,
input,
return_type: Literal["both"],
loss_per_token: bool = False,
prepend_bos: Optional[Union[bool, None]] = USE_DEFAULT_VALUE,
padding_side: Optional[Union[Literal["left", "right"], None]] = USE_DEFAULT_VALUE,
start_at_layer: Optional[int] = None,
tokens: Optional[Int[torch.Tensor, "batch pos"]] = None,
shortformer_pos_embed: Optional[Float[torch.Tensor, "batch pos d_model"]] = None,
attention_mask: Optional[torch.Tensor] = None, # [batch pos]
stop_at_layer: Optional[int] = None,
past_kv_cache: Optional[HookedTransformerKeyValueCache] = None,
) -> Tuple[Float[torch.Tensor, "batch pos d_vocab"], Loss]:
...
@overload
def forward(
self,
input,
return_type: Literal[None],
loss_per_token: bool = False,
prepend_bos: Optional[Union[bool, None]] = USE_DEFAULT_VALUE,
padding_side: Optional[Union[Literal["left", "right"], None]] = USE_DEFAULT_VALUE,
start_at_layer: Optional[int] = None,
tokens: Optional[Int[torch.Tensor, "batch pos"]] = None,
shortformer_pos_embed: Optional[Float[torch.Tensor, "batch pos d_model"]] = None,
attention_mask: Optional[torch.Tensor] = None, # [batch pos]
stop_at_layer: Optional[int] = None,
past_kv_cache: Optional[HookedTransformerKeyValueCache] = None,
) -> None:
...
def forward(
self,
input: Union[
str,
List[str],
Int[torch.Tensor, "batch pos"],
Float[torch.Tensor, "batch pos d_model"],
],
return_type: Optional[str] = "logits",
loss_per_token: bool = False,
prepend_bos: Optional[Union[bool, None]] = USE_DEFAULT_VALUE,
padding_side: Optional[Literal["left", "right"]] = USE_DEFAULT_VALUE,
start_at_layer: Optional[int] = None,
tokens: Optional[Int[torch.Tensor, "batch pos"]] = None,
shortformer_pos_embed: Optional[Float[torch.Tensor, "batch pos d_model"]] = None,
attention_mask: Optional[torch.Tensor] = None, # [batch pos]
stop_at_layer: Optional[int] = None,
past_kv_cache: Optional[HookedTransformerKeyValueCache] = None,
) -> Union[
None,
Float[torch.Tensor, "batch pos d_vocab"],
Loss,
Tuple[Float[torch.Tensor, "batch pos d_vocab"], Loss],
]:
"""Forward Pass.
Input is either a batch of tokens ([batch, pos]) or a text string, a string is automatically
tokenized to a batch of a single element. The prepend_bos flag only applies when inputting a
text string.
Note that loss is the standard "predict the next token" cross-entropy loss for GPT-2 style
language models - if you want a custom loss function, the recommended behaviour is returning
the logits and then applying your custom loss function.
Args:
return_type Optional[str]: The type of output to return. Can be one of: None (return
nothing, don't calculate logits), 'logits' (return logits), 'loss' (return
cross-entropy loss), 'both' (return logits and loss).
loss_per_token bool: Whether to return the (next token prediction) loss per token (True)
or average (False). Average loss is a scalar (averaged over position *and* batch),
per-token loss is a tensor ([batch, position-1]) - position-1 because we're
predicting the next token, and there's no specified next token for the final token.
Defaults to False.
prepend_bos Optional[bool]: Overrides self.cfg.default_prepend_bos. Whether to prepend
the BOS token to the input (only applies when input is a string). Defaults to None,
implying usage of self.cfg.default_prepend_bos which is set to True unless specified
otherwise. (Even for models not explicitly trained with a prepended BOS token, heads
often use the first position as a resting position and accordingly lose information
from the first token, so this empirically seems to give better results.) Pass True
or False to locally override the default.
padding_side Optional[Literal["left", "right"]]: Overrides self.tokenizer.padding_side.
Specifies which side to pad on when tokenizing multiple strings of different
lengths.
start_at_layer Optional[int]: If not None, start the forward pass at the specified
layer. Requires input to be the residual stream before the specified layer with
shape [batch, pos, d_model]. Inclusive - ie, start_at_layer = 0 skips the embedding
then runs the rest of the model. Supports negative indexing. start_at_layer = -1
only runs the final block and the unembedding. Defaults to None (run the full
model).
tokens: Optional[Int[torch.Tensor, "batch pos"]]: Tokenized input. Only use if
start_at_layer is not None and return type is "loss" or "both".
shortformer_pos_embed: Optional[Float[torch.Tensor, "batch pos d_model"]]: Positional
embedding for shortformer models. Only use if start_at_layer is not None and
self.cfg.positional_embedding_type == "shortformer".
attention_mask: Optional[torch.Tensor]: The attention mask for padded tokens. Only use
if start_at_layer is not None and (self.tokenizer.padding_side == "left" or
past_kv_cache is not None).
stop_at_layer Optional[int]: If not None, stop the forward pass at the specified layer.
Exclusive - ie, stop_at_layer = 0 will only run the embedding layer, stop_at_layer =
1 will run the embedding layer and the first transformer block, etc. Supports
negative indexing. Useful for analysis of intermediate layers, eg finding neuron
activations in layer 3 of a 24 layer model. Defaults to None (run the full model).
If not None, we return the last residual stream computed.
past_kv_cache Optional[HookedTransformerKeyValueCache]: If not None, keys and values
will be stored for every attention head (unless the cache is frozen). If there are
keys and values already in the cache, these will be prepended to the keys and values
for the new input, so that the new tokens can pay attention to previous tokens. This
is useful for generating text, because we don't need to repeat computation for
tokens that have already been through the model. Also caches attention_mask so
previous tokens are masked correctly (unless frozen). Padding should be ignored in
all cases, so it's okay to eg. pass in left padded tokens twice in a row.
Warning: Don't accidentally prepend_bos to the second half of a prompt.
Defaults to None (don't use caching).
"""
with utils.LocallyOverridenDefaults(
self, prepend_bos=prepend_bos, padding_side=padding_side
):
if start_at_layer is None:
(
residual,
tokens,
shortformer_pos_embed,
attention_mask,
) = self.input_to_embed(
input,
prepend_bos=prepend_bos,
padding_side=padding_side,
past_kv_cache=past_kv_cache,
)
else:
assert type(input) == torch.Tensor
residual = input
if start_at_layer is None:
start_at_layer = 0
# If we explicitly want to start or stop at a layer, we only iterate through the blocks
# between those indices. Note that start_at_layer is inclusive and stop_at_layer is
# exclusive.
# Eg: start_at_layer==None + stop_at_layer==0 means to only run the embed.
# Eg: start_at_layer==3 + stop_at_layer==-1 means to run from layer 3 until the end of the PENULTIMATE layer
blocks_and_idxs = list(zip(range(self.cfg.n_layers), self.blocks))
for i, block in blocks_and_idxs[start_at_layer:stop_at_layer]: # type: ignore
# Note that each block includes skip connections, so we don't need
# residual + block(residual)
# If we're using multiple GPUs, we need to send the residual and shortformer_pos_embed to the correct GPU
residual = residual.to(devices.get_device_for_block_index(i, self.cfg))
if shortformer_pos_embed is not None:
shortformer_pos_embed = shortformer_pos_embed.to(
devices.get_device_for_block_index(i, self.cfg)
)
residual = block(
residual,
# Cache contains a list of HookedTransformerKeyValueCache objects, one for each
# block
past_kv_cache_entry=past_kv_cache[i] if past_kv_cache is not None else None,
shortformer_pos_embed=shortformer_pos_embed,
attention_mask=attention_mask,
) # [batch, pos, d_model]
if stop_at_layer is not None:
# When we stop at an early layer, we end here rather than doing further computation
return residual
if self.cfg.normalization_type is not None:
residual = self.ln_final(residual) # [batch, pos, d_model]
if return_type is None:
return None
else:
logits = self.unembed(residual) # [batch, pos, d_vocab]
if return_type == "logits":
return logits
else:
assert (
tokens is not None
), "tokens must be passed in if return_type is 'loss' or 'both'"
loss = self.loss_fn(logits, tokens, per_token=loss_per_token)
if return_type == "loss":
return loss
elif return_type == "both":
return Output(logits, loss)
else:
logging.warning(f"Invalid return_type passed in: {return_type}")
return None
def loss_fn(
self,
logits: Float[torch.Tensor, "batch pos d_vocab"],
tokens: Int[torch.Tensor, "batch pos"],
per_token: bool = False,
):
"""Wrapper around `utils.lm_cross_entropy_loss`.
Used in forward() with return_type=="loss" or "both".
"""
if tokens.device != logits.device:
tokens = tokens.to(logits.device)
return utils.lm_cross_entropy_loss(logits, tokens, per_token)
@overload
def run_with_cache(
self, *model_args, return_cache_object: Literal[True] = True, **kwargs
) -> Tuple[Output, ActivationCache]:
...
@overload
def run_with_cache(
self, *model_args, return_cache_object: Literal[False], **kwargs
) -> Tuple[Output, Dict[str, torch.Tensor]]:
...
def run_with_cache(
self, *model_args, return_cache_object=True, remove_batch_dim=False, **kwargs
) -> Tuple[
Union[
None,
Float[torch.Tensor, "batch pos d_vocab"],
Loss,
Tuple[Float[torch.Tensor, "batch pos d_vocab"], Loss],
],
Union[ActivationCache, Dict[str, torch.Tensor]],
]:
"""Wrapper around `run_with_cache` in HookedRootModule.
If return_cache_object is True, this will return an ActivationCache object, with a bunch of
useful HookedTransformer specific methods, otherwise it will return a dictionary of
activations as in HookedRootModule.
"""
out, cache_dict = super().run_with_cache(
*model_args, remove_batch_dim=remove_batch_dim, **kwargs
)
if return_cache_object:
cache = ActivationCache(cache_dict, self, has_batch_dim=not remove_batch_dim)
return out, cache
else:
return out, cache_dict
def set_tokenizer(
self,
tokenizer,
default_padding_side="right",
):
"""Set the tokenizer to use for this model.
Args:
tokenizer (PreTrainedTokenizer): a pretrained HuggingFace tokenizer.
default_padding_side (str): "right" or "left", which side to pad on.
"""
assert isinstance(
tokenizer, PreTrainedTokenizerBase
), f"{type(tokenizer)} is not a supported tokenizer, please use PreTrainedTokenizer or PreTrainedTokenizerFast"
assert default_padding_side in [
"right",
"left",
], f"padding_side must be 'right' or 'left', got {default_padding_side}"
# Use a tokenizer that is initialized with add_bos_token=True as the default tokenizer.
# Such a tokenizer should be set as the default tokenizer because the tokenization of some
# tokenizers like LlamaTokenizer are different when bos token is automatically/manually
# prepended, and add_bos_token cannot be dynamically controlled after initialization
# (https://github.com/huggingface/transformers/issues/25886).
tokenizer_with_bos = utils.get_tokenizer_with_bos(tokenizer)
self.tokenizer = tokenizer_with_bos
assert self.tokenizer is not None # keep mypy happy
self.tokenizer.padding_side = default_padding_side
# Some tokenizers doesn't automatically prepend the BOS token even when they are initialized
# with add_bos_token=True. Therefore, we need this information to dynamically control prepend_bos.
self.cfg.tokenizer_prepends_bos = len(self.tokenizer.encode("")) > 0
if self.tokenizer.eos_token is None:
self.tokenizer.eos_token = "<|endoftext|>"
if self.tokenizer.pad_token is None:
self.tokenizer.pad_token = self.tokenizer.eos_token
if self.tokenizer.bos_token is None:
self.tokenizer.bos_token = self.tokenizer.eos_token
# Infer vocab size from tokenizer
if self.cfg.d_vocab == -1:
self.cfg.d_vocab = max(self.tokenizer.vocab.values()) + 1
if self.cfg.d_vocab_out == -1:
self.cfg.d_vocab_out = self.cfg.d_vocab
def to_tokens(
self,
input: Union[str, List[str]],
prepend_bos: Optional[Union[bool, None]] = USE_DEFAULT_VALUE,
padding_side: Optional[Union[Literal["left", "right"], None]] = USE_DEFAULT_VALUE,
move_to_device: bool = True,
truncate: bool = True,
) -> Int[torch.Tensor, "batch pos"]:
"""Converts a string to a tensor of tokens.
If prepend_bos is True, prepends the BOS token to the input - this is recommended when
creating a sequence of tokens to be input to a model.
Gotcha: prepend_bos prepends a beginning of string token. This is a recommended default when
inputting a prompt to the model as the first token is often treated weirdly, but should only
be done at the START of the prompt. Make sure to turn it off if you're looking at the
tokenization of part of the prompt! (Note: some models eg GPT-2 were not trained with a BOS
token, others (OPT and my models) were)
Gotcha2: Tokenization of a string depends on whether there is a preceding space and whether
the first letter is capitalized. It's easy to shoot yourself in the foot here if you're not
careful!
Args:
input (Union[str, List[str]]): The input to tokenize.
prepend_bos (bool, optional): Overrides self.cfg.default_prepend_bos. Whether to prepend
the BOS token to the input (only applies when input is a string). Defaults to None,
implying usage of self.cfg.default_prepend_bos which is set to True unless specified
otherwise. Pass True or False to locally override the default.
padding_side (Union[Literal["left", "right"], None], optional): Overrides
self.tokenizer.padding_side. Specifies which side to pad when tokenizing
multiple strings of different lengths.
move_to_device (bool): Whether to move the output tensor of tokens to the device the
model lives on. Defaults to True truncate (bool): If the output tokens are too long,
whether to truncate the output tokens to the model's max context window. Does nothing
for shorter inputs. Defaults to True.
"""
with utils.LocallyOverridenDefaults(
self, prepend_bos=prepend_bos, padding_side=padding_side
):
assert self.tokenizer is not None, "Cannot use to_tokens without a tokenizer"
assert (
self.cfg.tokenizer_prepends_bos is not None
), "Set the tokenizer for the model by calling set_tokenizer"
if self.cfg.default_prepend_bos and not self.cfg.tokenizer_prepends_bos:
# We want to prepend bos but the tokenizer doesn't automatically do it, so we add it manually
input = utils.get_input_with_manually_prepended_bos(self.tokenizer, input)
tokens = self.tokenizer(
input,
return_tensors="pt",
padding=True,
truncation=truncate,
max_length=self.cfg.n_ctx if truncate else None,
)["input_ids"]
if not self.cfg.default_prepend_bos and self.cfg.tokenizer_prepends_bos:
# We don't want to prepend bos but the tokenizer does it automatically, so we remove it manually
tokens = utils.get_tokens_with_bos_removed(self.tokenizer, tokens)
if move_to_device:
tokens = tokens.to(self.cfg.device)
return tokens
def to_string(
self,
tokens: Union[
List[int],
Int[torch.Tensor, ""],
Int[torch.Tensor, "batch pos"],
Int[torch.Tensor, "pos"],
np.ndarray,
List[Int[torch.Tensor, "pos"]],
],
) -> Union[str, List[str]]:
"""Tokens to String(s).
Converts a tensor of tokens to a string (if rank 1) or a list of strings (if rank 2).
Accepts lists of tokens and numpy arrays as inputs too (and converts to tensors internally)
"""
assert self.tokenizer is not None, "Cannot use to_string without a tokenizer"
if not isinstance(tokens, torch.Tensor):
# We allow lists to be input
tokens = torch.tensor(tokens)
# I'm not sure what exactly clean_up_tokenization_spaces does, but if
# it's set, then tokenization is no longer invertible, and some tokens
# with a bunch of whitespace get collapsed together
if len(tokens.shape) == 2:
return self.tokenizer.batch_decode(tokens, clean_up_tokenization_spaces=False)
elif len(tokens.shape) <= 1:
return self.tokenizer.decode(tokens, clean_up_tokenization_spaces=False)
else:
raise ValueError(f"Invalid shape passed in: {tokens.shape}")
def to_str_tokens(
self,
input: Union[
str,
Int[torch.Tensor, "pos"],
Int[torch.Tensor, "1 pos"],
Int[np.ndarray, "pos"],
Int[np.ndarray, "1 pos"],
list,
],
prepend_bos: Optional[Union[bool, None]] = USE_DEFAULT_VALUE,
padding_side: Optional[Union[Literal["left", "right"], None]] = USE_DEFAULT_VALUE,
) -> Union[List[str], List[List[str]]]:
"""Map text, a list of text or tokens to a list of tokens as strings.
Gotcha: prepend_bos prepends a beginning of string token. This is a recommended default when
inputting a prompt to the model as the first token is often treated weirdly, but should only
be done at the START of the prompt. If prepend_bos=None is passed, it implies the usage of
self.cfg.default_prepend_bos which is set to True unless specified otherwise. Therefore,
make sure to locally turn it off by passing prepend_bos=False if you're looking at the
tokenization of part of the prompt! (Note: some models eg GPT-2 were not trained with a BOS
token, others (OPT and my models) were)
Gotcha2: Tokenization of a string depends on whether there is a preceding space and whether
the first letter is capitalized. It's easy to shoot yourself in the foot here if you're not
careful!
Gotcha3: If passing a string that exceeds the model's context length (model.cfg.n_ctx), it
will be truncated.
Args:
input (Union[str, list, torch.Tensor]): The input - either a string or a tensor of
tokens. If tokens, should be a tensor of shape [pos] or [1, pos].
prepend_bos (bool, optional): Overrides self.cfg.default_prepend_bos. Whether to prepend
the BOS token to the input (only applies when input is a string). Defaults to None,
implying usage of self.cfg.default_prepend_bos which is set to True unless specified
otherwise. Pass True or False to locally override the default.
padding_side (Union[Literal["left", "right"], None], optional): Overrides
self.tokenizer.padding_side. Specifies which side to pad when tokenizing multiple
strings of different lengths.
Returns:
str_tokens: List of individual tokens as strings
"""
with utils.LocallyOverridenDefaults(
self, prepend_bos=prepend_bos, padding_side=padding_side
):
assert self.tokenizer is not None # keep mypy happy
tokens: Union[np.ndarray, torch.Tensor]
if isinstance(input, list):
return list(
map(
lambda tokens: self.to_str_tokens(tokens, prepend_bos, padding_side),
input,
)
) # type: ignore
elif isinstance(input, str):
tokens = self.to_tokens(input, prepend_bos=prepend_bos, padding_side=padding_side)[
0
]
# Gemma tokenizer expects a batch dimension
if "gemma" in self.tokenizer.name_or_path and tokens.ndim == 1:
tokens = tokens.unsqueeze(1)
elif isinstance(input, torch.Tensor):
tokens = input
tokens = tokens.squeeze() # Get rid of a trivial batch dimension
if tokens.dim() == 0:
# Don't pass dimensionless tensor
tokens = tokens.unsqueeze(0)
assert (
tokens.dim() == 1
), f"Invalid tokens input to to_str_tokens, has shape: {tokens.shape}"
elif isinstance(input, np.ndarray):
tokens = input
tokens = tokens.squeeze() # Get rid of a trivial batch dimension
if tokens.ndim == 0:
# Don't pass dimensionless tensor
tokens = np.expand_dims(tokens, axis=0)
assert (
tokens.ndim == 1
), f"Invalid tokens input to to_str_tokens, has shape: {tokens.shape}"
else:
raise ValueError(f"Invalid input type to to_str_tokens: {type(input)}")
str_tokens = self.tokenizer.batch_decode(tokens, clean_up_tokenization_spaces=False)
return str_tokens
def to_single_token(self, string):
"""Map a string that makes up a single token to the id for that token.
Raises an error for strings that are not a single token! If uncertain use to_tokens.
"""
# We use the to_tokens method, do not append a BOS token
token = self.to_tokens(string, prepend_bos=False).squeeze()
# If token shape is non-empty, raise error
assert not token.shape, f"Input string: {string} is not a single token!"
return token.item()
def to_single_str_token(self, int_token: int) -> str:
# Gives the single token corresponding to an int in string form
assert isinstance(int_token, int)
token = self.to_str_tokens(torch.tensor([int_token]))
assert len(token) == 1
return cast(str, token[0])
def get_token_position(
self,
single_token: Union[str, int],
input: Union[str, Union[Float[torch.Tensor, "pos"], Float[torch.Tensor, "1 pos"]]],
mode="first",
prepend_bos: Optional[Union[bool, None]] = USE_DEFAULT_VALUE,
padding_side: Optional[Union[Literal["left", "right"], None]] = USE_DEFAULT_VALUE,
):
"""Get the position of a single_token in a string or sequence of tokens.
Raises an error if the token is not present.
Gotcha: If you're inputting a string, it'll automatically be tokenized. Be careful about the
setting for prepend_bos! When a string is input to the model, a BOS (beginning of sequence)
token is prepended by default when the string is tokenized because
self.cfg.default_prepend_bos is set to True unless specified otherwise. But this should only
be done at the START of the input, not when inputting part of the prompt. If you're getting
weird off-by-one errors, check carefully for what the setting should be!
Args:
single_token (Union[str, int]): The token to search for. Can
be a token index, or a string (but the string must correspond to a single token).
input (Union[str, torch.Tensor]): The sequence to
search in. Can be a string or a rank 1 tensor of tokens or a rank 2 tensor of tokens
with a dummy batch dimension.
mode (str, optional): If there are multiple matches, which match to return. Supports
"first" or "last". Defaults to "first".
prepend_bos (bool, optional): Overrides self.cfg.default_prepend_bos. Whether to prepend
the BOS token to the input (only applies when input is a string). Defaults to None,
implying usage of self.cfg.default_prepend_bos which is set to True unless specified
otherwise. Pass True or False to locally override the default.
padding_side (Union[Literal["left", "right"], None], optional): Overrides
self.tokenizer.padding_side. Specifies which side to pad when tokenizing multiple
strings of different lengths.
"""
if isinstance(input, str):
# If the input is a string, convert to tensor
tokens = self.to_tokens(input, prepend_bos=prepend_bos, padding_side=padding_side)
else:
tokens = input
if len(tokens.shape) == 2:
# If the tokens have shape [1, seq_len], flatten to [seq_len]
assert (
tokens.shape[0] == 1
), f"If tokens are rank two, they must have shape [1, seq_len], not {tokens.shape}"
tokens = tokens[0]
if isinstance(single_token, str):
# If the single token is a string, convert to an integer
single_token = self.to_single_token(single_token)
elif isinstance(single_token, torch.Tensor):
single_token = single_token.item()
indices = torch.arange(len(tokens), device=tokens.device)[tokens == single_token]
assert len(indices) > 0, "The token does not occur in the prompt"
if mode == "first":
return indices[0].item()
elif mode == "last":
return indices[-1].item()
else:
raise ValueError(f"mode must be 'first' or 'last', not {mode}")
def tokens_to_residual_directions(
self,
tokens: Union[
str,
int,
Int[torch.Tensor, ""],
Int[torch.Tensor, "pos"],
Int[torch.Tensor, "batch pos"],
],
) -> Union[
Float[torch.Tensor, "d_model"],
Float[torch.Tensor, "pos d_model"],
Float[torch.Tensor, "batch pos d_model"],
]:
"""Map tokens to a tensor with the unembedding vector for those tokens.
I.e. the vector in the residual stream that we dot with to the get the logit for that token.
WARNING: If you use this without folding in LayerNorm, the results will be misleading and
may be incorrect, as the LN weights change the unembed map. This is done automatically with
the fold_ln flag on from_pretrained
WARNING 2: LayerNorm scaling will scale up or down the effective direction in the residual
stream for each output token on any given input token position.
ActivationCache.apply_ln_to_stack will apply the appropriate scaling to these directions.
Args:
tokens (Union[str, int, torch.Tensor]): The token(s). If a single token, can be a single
element tensor, an integer, or string. If string, will be mapped to a single token
using to_single_token, and an error raised if it's multiple tokens. The method also
works for a batch of input tokens.
Returns:
residual_direction torch.Tensor: The unembedding vector for the token(s), a stack of
[d_model] tensor.
"""
if isinstance(tokens, torch.Tensor) and tokens.numel() > 1:
# If the tokens are a tensor, and have more than one element, assume they are a batch of
# tokens.
residual_directions = self.W_U[:, tokens]
residual_directions = einops.rearrange(
residual_directions, "d_model ... -> ... d_model"
)
return residual_directions
else:
# Otherwise there is a single token
if isinstance(tokens, str):
token = self.to_single_token(tokens)
elif isinstance(tokens, int):
token = tokens
elif isinstance(tokens, torch.Tensor) and tokens.numel() == 1:
token = tokens.item()