Skip to content

Commit

Permalink
fixturise some tests
Browse files Browse the repository at this point in the history
  • Loading branch information
bmillwood committed May 23, 2024
1 parent c699ced commit f45b365
Showing 1 changed file with 20 additions and 11 deletions.
31 changes: 20 additions & 11 deletions tests/unit/test_svd_interpreter.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,18 @@
MODEL = "solu-2l"
VECTOR_TYPES = ["OV", "w_in", "w_out"]
ATOL = 2e-4 # Absolute tolerance - how far does a float have to be before we consider it no longer equal?
model = HookedTransformer.from_pretrained(MODEL)
unfolded_model = HookedTransformer.from_pretrained(MODEL, fold_ln=False)
second_model = HookedTransformer.from_pretrained("solu-3l")

@pytest.fixture(scope="module")
def model():
return HookedTransformer.from_pretrained(MODEL)

@pytest.fixture(scope="module")
def unfolded_model():
return HookedTransformer.from_pretrained(MODEL, fold_ln=False)

@pytest.fixture(scope="module")
def second_model():
return HookedTransformer.from_pretrained("solu-3l")


expected_OV_match = torch.Tensor(
Expand All @@ -31,7 +40,7 @@
# Successes


def test_svd_interpreter():
def test_svd_interpreter(model):
svd_interpreter = SVDInterpreter(model)
ov = svd_interpreter.get_singular_vectors(
"OV", num_vectors=4, layer_index=0, head_index=0
Expand All @@ -54,7 +63,7 @@ def test_svd_interpreter():
assert torch.allclose(w_out.cpu(), expected_w_out_match, atol=ATOL)


def test_w_in_when_fold_ln_is_false():
def test_w_in_when_fold_ln_is_false(unfolded_model):
svd_interpreter = SVDInterpreter(unfolded_model)
w_in = svd_interpreter.get_singular_vectors(
"w_in", num_vectors=4, layer_index=0, head_index=0
Expand All @@ -63,7 +72,7 @@ def test_w_in_when_fold_ln_is_false():
assert torch.allclose(w_in.cpu(), expected_w_in_unfolded_match, atol=ATOL)


def test_svd_interpreter_returns_different_answers_for_different_layers():
def test_svd_interpreter_returns_different_answers_for_different_layers(model):
svd_interpreter = SVDInterpreter(model)
ov = svd_interpreter.get_singular_vectors(
"OV", layer_index=1, num_vectors=4, head_index=0
Expand All @@ -86,7 +95,7 @@ def test_svd_interpreter_returns_different_answers_for_different_layers():
assert not torch.allclose(w_out.cpu(), expected_w_out_match, atol=ATOL)


def test_svd_interpreter_returns_different_answers_for_different_models():
def test_svd_interpreter_returns_different_answers_for_different_models(second_model):
svd_interpreter = SVDInterpreter(second_model)
ov = svd_interpreter.get_singular_vectors(
"OV", layer_index=1, num_vectors=4, head_index=0
Expand All @@ -111,28 +120,28 @@ def test_svd_interpreter_returns_different_answers_for_different_models():
# Failures


def test_svd_interpreter_fails_on_invalid_vector_type():
def test_svd_interpreter_fails_on_invalid_vector_type(model):
svd_interpreter = SVDInterpreter(model)
with pytest.raises(BeartypeCallHintParamViolation) as e:
svd_interpreter.get_singular_vectors("test", layer_index=0, num_vectors=4, head_index=0)


def test_svd_interpreter_fails_on_not_passing_required_head_index():
def test_svd_interpreter_fails_on_not_passing_required_head_index(model):
svd_interpreter = SVDInterpreter(model)
with pytest.raises(AssertionError) as e:
svd_interpreter.get_singular_vectors("OV", layer_index=0, num_vectors=4)
assert str(e.value) == "Head index optional only for w_in and w_out, got OV"


def test_svd_interpreter_fails_on_invalid_layer_index():
def test_svd_interpreter_fails_on_invalid_layer_index(model):
svd_interpreter = SVDInterpreter(model)
for vector in VECTOR_TYPES:
with pytest.raises(AssertionError) as e:
svd_interpreter.get_singular_vectors(vector, layer_index=2, num_vectors=4, head_index=0)
assert str(e.value) == "Layer index must be between 0 and 1 but got 2"


def test_svd_interpreter_fails_on_invalid_head_index():
def test_svd_interpreter_fails_on_invalid_head_index(model):
# Only OV uses head index.
svd_interpreter = SVDInterpreter(model)
with pytest.raises(AssertionError) as e:
Expand Down

0 comments on commit f45b365

Please sign in to comment.