From 83a9de65a98ccf62ac01dec365fee2e26c5ac104 Mon Sep 17 00:00:00 2001 From: Benedikt Horn <120414378+BeneHTWG@users.noreply.github.com> Date: Fri, 10 Jan 2025 15:38:46 +0100 Subject: [PATCH] precommit check fails without none type check in embeddings assignment (#220) --- src/htwgnlp/embeddings.py | 8 ++++++++ tests/htwgnlp/test_embeddings.py | 21 +++++++++++++++++---- 2 files changed, 25 insertions(+), 4 deletions(-) diff --git a/src/htwgnlp/embeddings.py b/src/htwgnlp/embeddings.py index a00ae68..22e1350 100644 --- a/src/htwgnlp/embeddings.py +++ b/src/htwgnlp/embeddings.py @@ -35,6 +35,9 @@ def __init__(self) -> None: def embedding_values(self) -> np.ndarray: """Returns the embedding values. + Raises: + ValueError: if the embeddings have not been loaded yet + Returns: np.ndarray: the embedding values as a numpy array of shape (n, d), where n is the vocabulary size and d is the number of dimensions """ @@ -77,6 +80,9 @@ def get_embeddings(self, word: str) -> np.ndarray | None: Args: word (str): the word to get the embedding vector for + Raises: + ValueError: if the embeddings have not been loaded yet + Returns: np.ndarray | None: the embedding vector for the given word in the form of a numpy array of shape (d,), where d is the number of dimensions, or None if the word is not in the vocabulary """ @@ -125,6 +131,7 @@ def get_most_similar_words( metric (Literal["euclidean", "cosine"], optional): the metric to use for computing the similarity. Defaults to "euclidean". Raises: + ValueError: if the embeddings have not been loaded yet ValueError: if the metric is not "euclidean" or "cosine" AssertionError: if the word is not in the vocabulary @@ -146,6 +153,7 @@ def find_closest_word( metric (Literal["euclidean", "cosine"], optional): the metric to use for computing the similarity. Defaults to "euclidean". Raises: + ValueError: if the embeddings have not been loaded yet ValueError: if the metric is not "euclidean" or "cosine" Returns: diff --git a/tests/htwgnlp/test_embeddings.py b/tests/htwgnlp/test_embeddings.py index 0a310c6..c61fce8 100644 --- a/tests/htwgnlp/test_embeddings.py +++ b/tests/htwgnlp/test_embeddings.py @@ -15,6 +15,11 @@ def embeddings(): return WordEmbeddings() +@pytest.fixture +def non_loaded_embeddings(): + return WordEmbeddings() + + @pytest.fixture def loaded_embeddings(embeddings): embeddings._load_raw_embeddings("notebooks/data/embeddings.pkl") @@ -47,12 +52,16 @@ def test_load_embeddings_to_dataframe(loaded_embeddings): assert loaded_embeddings._embeddings_df.shape == (243, 300) -def test_embedding_values(loaded_embeddings): +def test_embedding_values(loaded_embeddings, non_loaded_embeddings): + with pytest.raises(ValueError): + non_loaded_embeddings.embedding_values assert isinstance(loaded_embeddings.embedding_values, np.ndarray) assert loaded_embeddings.embedding_values.shape == (243, 300) -def test_get_embeddings(loaded_embeddings): +def test_get_embeddings(loaded_embeddings, non_loaded_embeddings): + with pytest.raises(ValueError): + non_loaded_embeddings.embedding_values assert isinstance(loaded_embeddings.get_embeddings("happy"), np.ndarray) assert loaded_embeddings.get_embeddings("happy").shape == (300,) assert loaded_embeddings.get_embeddings("non_existent_word") is None @@ -104,13 +113,17 @@ def test_cosine_similarity(loaded_embeddings, test_vector): ) -def test_find_closest_word(loaded_embeddings, test_vector): +def test_find_closest_word(loaded_embeddings, test_vector, non_loaded_embeddings): + with pytest.raises(ValueError): + non_loaded_embeddings.embedding_values for metric in ["euclidean", "cosine"]: assert isinstance(loaded_embeddings.find_closest_word(test_vector, metric), str) assert loaded_embeddings.find_closest_word(test_vector, metric) == "Bahamas" -def test_get_most_similar_words(loaded_embeddings): +def test_get_most_similar_words(loaded_embeddings, non_loaded_embeddings): + with pytest.raises(ValueError): + non_loaded_embeddings.embedding_values assert isinstance(loaded_embeddings.get_most_similar_words("Germany"), list) assert loaded_embeddings.get_most_similar_words("Germany") == [ "Austria",