From 8073374926457516047fe62f65490eb27d7cc565 Mon Sep 17 00:00:00 2001 From: JINO ROHIT Date: Mon, 20 Jan 2025 18:10:59 +0530 Subject: [PATCH] testcases for community detection (#3163) * testcases for community detection * ruff lint * Run 'pre-commit run --all' * Rename test cases to clarify that they're comm. detect. related --------- Co-authored-by: Tom Aarsen --- tests/test_util.py | 130 +++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 130 insertions(+) diff --git a/tests/test_util.py b/tests/test_util.py index 71d194e54..0cecce2bb 100644 --- a/tests/test_util.py +++ b/tests/test_util.py @@ -1,10 +1,12 @@ from __future__ import annotations import numpy as np +import pytest import sklearn import torch from sentence_transformers import SentenceTransformer, util +from sentence_transformers.util import community_detection def test_normalize_embeddings() -> None: @@ -145,3 +147,131 @@ def test_dot_score_cos_sim() -> None: assert np.allclose(cosine_calculated, dot_and_cosine_expected) assert np.allclose(dot_calculated, dot_and_cosine_expected) + + +def test_community_detection_two_clear_communities(): + """Test case with two clear communities.""" + embeddings = torch.tensor( + [ + [1.0, 0.0, 0.0], # Point 0 + [0.9, 0.1, 0.0], # Point 1 + [0.8, 0.2, 0.0], # Point 2 + [0.1, 0.9, 0.0], # Point 3 + [0.0, 1.0, 0.0], # Point 4 + [0.2, 0.8, 0.0], # Point 5 + ] + ) + expected = [ + [0, 1, 2], # Community 1 + [3, 4, 5], # Community 2 + ] + result = community_detection(embeddings, threshold=0.8, min_community_size=2) + assert sorted([sorted(community) for community in result]) == sorted([sorted(community) for community in expected]) + + +def test_community_detection_no_communities_high_threshold(): + """Test case where no communities are found due to a high threshold.""" + embeddings = torch.tensor( + [ + [1.0, 0.0, 0.0], + [0.0, 1.0, 0.0], + [0.0, 0.0, 1.0], + ] + ) + expected = [] + result = community_detection(embeddings, threshold=0.99, min_community_size=2) + assert result == expected + + +def test_community_detection_all_points_in_one_community(): + """Test case where all points form a single community due to a low threshold.""" + embeddings = torch.tensor( + [ + [1.0, 0.0, 0.0], + [0.9, 0.1, 0.0], + [0.8, 0.2, 0.0], + ] + ) + expected = [ + [0, 1, 2], # Single community + ] + result = community_detection(embeddings, threshold=0.5, min_community_size=2) + assert sorted([sorted(community) for community in result]) == sorted([sorted(community) for community in expected]) + + +def test_community_detection_min_community_size_filtering(): + """Test case where communities are filtered based on minimum size.""" + embeddings = torch.tensor( + [ + [1.0, 0.0, 0.0], + [0.9, 0.1, 0.0], + [0.8, 0.2, 0.0], + [0.1, 0.9, 0.0], + ] + ) + expected = [ + [0, 1, 2], # Only one community meets the min size requirement + ] + result = community_detection(embeddings, threshold=0.8, min_community_size=3) + assert sorted([sorted(community) for community in result]) == sorted([sorted(community) for community in expected]) + + +def test_community_detection_overlapping_communities(): + """Test case with overlapping communities (resolved by the function).""" + embeddings = torch.tensor( + [ + [1.0, 0.0, 0.0], # Point 0 + [0.9, 0.1, 0.0], # Point 1 + [0.8, 0.2, 0.0], # Point 2 + [0.7, 0.3, 0.0], # Point 3 (overlaps with both communities) + [0.1, 0.9, 0.0], # Point 4 + [0.0, 1.0, 0.0], # Point 5 + ] + ) + expected = [ + [0, 1, 2, 3], # Community 1 (includes overlapping point 3) + [4, 5], # Community 2 + ] + result = community_detection(embeddings, threshold=0.8, min_community_size=2) + assert sorted([sorted(community) for community in result]) == sorted([sorted(community) for community in expected]) + + +def test_community_detection_numpy_input(): + """Test case where input is a numpy array instead of a torch tensor.""" + embeddings = np.array( + [ + [1.0, 0.0, 0.0], + [0.9, 0.1, 0.0], + [0.8, 0.2, 0.0], + ] + ) + expected = [ + [0, 1, 2], # Single community + ] + result = community_detection(embeddings, threshold=0.8, min_community_size=2) + assert sorted([sorted(community) for community in result]) == sorted([sorted(community) for community in expected]) + + +def test_community_detection_large_batch_size(): + """Test case with a large dataset and batching.""" + embeddings = torch.rand(1000, 128) # Random embeddings + result = community_detection(embeddings, threshold=0.8, min_community_size=10, batch_size=256) + # Check that all communities meet the minimum size requirement + assert all(len(community) >= 10 for community in result) + + +@pytest.mark.skipif(not torch.cuda.is_available(), reason="GPU not available") +def test_community_detection_gpu_support(): + """Test case for GPU support (if available).""" + embeddings = torch.tensor( + [ + [1.0, 0.0, 0.0], + [0.9, 0.1, 0.0], + [0.8, 0.2, 0.0], + ] + ).cuda() + expected = [ + [0, 1, 2], # Single community + ] + result = community_detection(embeddings, threshold=0.8, min_community_size=2) + assert sorted([sorted(community) for community in result]) == sorted([sorted(community) for community in expected])