diff --git a/ccs/__init__.py b/ccs/__init__.py index ce69da0..f9007c4 100644 --- a/ccs/__init__.py +++ b/ccs/__init__.py @@ -1,5 +1,7 @@ from .extraction import Extract, extract_hiddens from .training import EigenFitter, EigenFitterConfig +from .training.train import Elicit +from .evaluation import Eval from .truncated_eigh import truncated_eigh __all__ = [ @@ -7,5 +9,7 @@ "EigenFitterConfig", "extract_hiddens", "Extract", + "Elicit", + "Eval", "truncated_eigh", ] diff --git a/ccs/extraction/extraction.py b/ccs/extraction/extraction.py index f6e0b4b..741d102 100644 --- a/ccs/extraction/extraction.py +++ b/ccs/extraction/extraction.py @@ -137,8 +137,9 @@ def __post_init__(self, layer_stride: int): config = assert_type( PretrainedConfig, AutoConfig.from_pretrained(self.model) ) - layer_range = range(1, config.num_hidden_layers, layer_stride) - self.layers = tuple(layer_range) + # Note that we always include 0 which is the embedding layer + layer_range = range(1, config.num_hidden_layers + 1, layer_stride) + self.layers = (0,) + tuple(layer_range) def explode(self) -> list["Extract"]: """Explode this config into a list of configs, one for each layer.""" @@ -198,7 +199,7 @@ def extract_hiddens( seed=cfg.seed, ) - layer_indices = cfg.layers or tuple(range(1, model.config.num_hidden_layers)) + layer_indices = cfg.layers or tuple(range(model.config.num_hidden_layers + 1)) global_max_examples = cfg.max_examples[0 if split_type == "train" else 1] @@ -263,14 +264,17 @@ def extract_hiddens( if is_enc_dec: answer = labels = assert_type(Tensor, encoding.labels) else: - encoding2 = tokenizer( - choice["answer"], - # Don't include [CLS] and [SEP] in the answer - add_special_tokens=False, - return_tensors="pt", - ).to(device) - - answer = assert_type(Tensor, encoding2.input_ids) + a_id = tokenizer.encode(" " + choice["answer"], add_special_tokens=False) + + # the Llama tokenizer splits off leading spaces + if tokenizer.decode(a_id[0]).strip() == "": + a_id_without_space = tokenizer.encode( + choice, add_special_tokens=False + ) + assert a_id_without_space == a_id[1:] + a_id = a_id_without_space + + answer = torch.tensor([a_id], device=device) labels = ( # -100 is the mask token torch.cat([torch.full_like(ids, -100), answer], dim=-1) @@ -293,13 +297,13 @@ def extract_hiddens( # Compute the log probability of the answer tokens if available if has_lm_preds: - logprob = -assert_type(Tensor, outputs.loss) + logprob = -assert_type(Tensor, outputs.loss).to(torch.float32) # Convert logprob to logodds to be consistent with reporters # Because we went through logprobs, logodds corresponding to # probs near 1 will be somewhat imprecise # log(p/(1-p)) = log(p) - log(1-p) = logp - log(1 - exp(logp)) lm_log_odds[i, j] = logprob - torch.log1p(-logprob.exp()) - + hiddens = ( outputs.get("decoder_hidden_states") or outputs["hidden_states"] ) @@ -339,7 +343,7 @@ def extract_hiddens( **hidden_dict, ) if has_lm_preds: - out_record["lm_log_odds"] = lm_log_odds.log_softmax(dim=-1) + out_record["lm_log_odds"] = lm_log_odds assert out_record["variant_ids"] == sorted(out_record["variant_ids"]) num_yielded += 1 @@ -377,7 +381,7 @@ def hidden_features(cfg: Extract) -> tuple[DatasetInfo, Features]: if num_dropped: print(f"Dropping {num_dropped} non-multiple choice templates") - layer_indices = cfg.layers or tuple(range(1, model_cfg.num_hidden_layers)) + layer_indices = cfg.layers or tuple(range(model_cfg.num_hidden_layers + 1)) layer_cols = { f"hidden_{layer}": Array3D( dtype="int16",