From f45b365dcde1057a308d6c884f8f8715f7b5204e Mon Sep 17 00:00:00 2001 From: Ben Millwood Date: Thu, 23 May 2024 12:43:38 +0100 Subject: [PATCH] fixturise some tests --- tests/unit/test_svd_interpreter.py | 31 +++++++++++++++++++----------- 1 file changed, 20 insertions(+), 11 deletions(-) diff --git a/tests/unit/test_svd_interpreter.py b/tests/unit/test_svd_interpreter.py index face0643b..c2c57a6dc 100644 --- a/tests/unit/test_svd_interpreter.py +++ b/tests/unit/test_svd_interpreter.py @@ -7,9 +7,18 @@ MODEL = "solu-2l" VECTOR_TYPES = ["OV", "w_in", "w_out"] ATOL = 2e-4 # Absolute tolerance - how far does a float have to be before we consider it no longer equal? -model = HookedTransformer.from_pretrained(MODEL) -unfolded_model = HookedTransformer.from_pretrained(MODEL, fold_ln=False) -second_model = HookedTransformer.from_pretrained("solu-3l") + +@pytest.fixture(scope="module") +def model(): + return HookedTransformer.from_pretrained(MODEL) + +@pytest.fixture(scope="module") +def unfolded_model(): + return HookedTransformer.from_pretrained(MODEL, fold_ln=False) + +@pytest.fixture(scope="module") +def second_model(): + return HookedTransformer.from_pretrained("solu-3l") expected_OV_match = torch.Tensor( @@ -31,7 +40,7 @@ # Successes -def test_svd_interpreter(): +def test_svd_interpreter(model): svd_interpreter = SVDInterpreter(model) ov = svd_interpreter.get_singular_vectors( "OV", num_vectors=4, layer_index=0, head_index=0 @@ -54,7 +63,7 @@ def test_svd_interpreter(): assert torch.allclose(w_out.cpu(), expected_w_out_match, atol=ATOL) -def test_w_in_when_fold_ln_is_false(): +def test_w_in_when_fold_ln_is_false(unfolded_model): svd_interpreter = SVDInterpreter(unfolded_model) w_in = svd_interpreter.get_singular_vectors( "w_in", num_vectors=4, layer_index=0, head_index=0 @@ -63,7 +72,7 @@ def test_w_in_when_fold_ln_is_false(): assert torch.allclose(w_in.cpu(), expected_w_in_unfolded_match, atol=ATOL) -def test_svd_interpreter_returns_different_answers_for_different_layers(): +def test_svd_interpreter_returns_different_answers_for_different_layers(model): svd_interpreter = SVDInterpreter(model) ov = svd_interpreter.get_singular_vectors( "OV", layer_index=1, num_vectors=4, head_index=0 @@ -86,7 +95,7 @@ def test_svd_interpreter_returns_different_answers_for_different_layers(): assert not torch.allclose(w_out.cpu(), expected_w_out_match, atol=ATOL) -def test_svd_interpreter_returns_different_answers_for_different_models(): +def test_svd_interpreter_returns_different_answers_for_different_models(second_model): svd_interpreter = SVDInterpreter(second_model) ov = svd_interpreter.get_singular_vectors( "OV", layer_index=1, num_vectors=4, head_index=0 @@ -111,20 +120,20 @@ def test_svd_interpreter_returns_different_answers_for_different_models(): # Failures -def test_svd_interpreter_fails_on_invalid_vector_type(): +def test_svd_interpreter_fails_on_invalid_vector_type(model): svd_interpreter = SVDInterpreter(model) with pytest.raises(BeartypeCallHintParamViolation) as e: svd_interpreter.get_singular_vectors("test", layer_index=0, num_vectors=4, head_index=0) -def test_svd_interpreter_fails_on_not_passing_required_head_index(): +def test_svd_interpreter_fails_on_not_passing_required_head_index(model): svd_interpreter = SVDInterpreter(model) with pytest.raises(AssertionError) as e: svd_interpreter.get_singular_vectors("OV", layer_index=0, num_vectors=4) assert str(e.value) == "Head index optional only for w_in and w_out, got OV" -def test_svd_interpreter_fails_on_invalid_layer_index(): +def test_svd_interpreter_fails_on_invalid_layer_index(model): svd_interpreter = SVDInterpreter(model) for vector in VECTOR_TYPES: with pytest.raises(AssertionError) as e: @@ -132,7 +141,7 @@ def test_svd_interpreter_fails_on_invalid_layer_index(): assert str(e.value) == "Layer index must be between 0 and 1 but got 2" -def test_svd_interpreter_fails_on_invalid_head_index(): +def test_svd_interpreter_fails_on_invalid_head_index(model): # Only OV uses head index. svd_interpreter = SVDInterpreter(model) with pytest.raises(AssertionError) as e: