Skip to content

Commit

Permalink
Merge with EleutherELK
Browse files Browse the repository at this point in the history
  • Loading branch information
AlexTMallen committed Nov 10, 2023
1 parent c9a8e19 commit 611c93c
Show file tree
Hide file tree
Showing 3 changed files with 14 additions and 9 deletions.
11 changes: 2 additions & 9 deletions elk/training/sweep.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,11 @@
from dataclasses import InitVar, dataclass, replace, field
from dataclasses import InitVar, dataclass, field, replace

import numpy as np
import torch
from datasets import get_dataset_config_info
from transformers import AutoConfig

from ..evaluation import Eval
from ..extraction import Extract
from ..files import memorably_named_dir, sweeps_dir
from ..plotting.visualize import visualize_sweep
from ..training.eigen_reporter import EigenFitterConfig
Expand Down Expand Up @@ -53,13 +52,7 @@ class Sweep:
name: str | None = None

# A bit of a hack to add all the command line arguments from Elicit
run_template: Elicit = field(default_factory=lambda: Elicit(
data=Extract(
model="<placeholder>",
datasets=("<placeholder>",),
)
)
)
run_template: Elicit = field(default_factory=Elicit.default)

def __post_init__(self, add_pooled: bool):
if not self.datasets:
Expand Down
10 changes: 10 additions & 0 deletions elk/training/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
from simple_parsing import subgroups
from simple_parsing.helpers.serialization import save

from ..extraction import Extract
from ..metrics import evaluate_preds, to_one_hot
from ..run import Run
from ..training.supervised import train_supervised
Expand All @@ -34,6 +35,15 @@ class Elicit(Run):
cross-validation. Defaults to "single", which means to train a single classifier
on the training data. "cv" means to use cross-validation."""

@staticmethod
def default():
return Elicit(
data=Extract(
model="<placeholder>",
datasets=("<placeholder>",),
)
)

def create_models_dir(self, out_dir: Path):
lr_dir = None
lr_dir = out_dir / "lr_models"
Expand Down
2 changes: 2 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@ license = { text = "MIT License" }
dependencies = [
# Allows us to use device_map in from_pretrained. Also needed for 8bit
"accelerate",
# Already a dependency of datasets, but newer versions introduce breaking changes
"pyarrow==12.0.0",
# For pseudolabel and prompt normalization. We're picky about the version because
# the package isn't guaranteed to be stable yet.
"concept-erasure==0.1.0",
Expand Down

0 comments on commit 611c93c

Please sign in to comment.