Skip to content

Commit

Permalink
testcases for community detection (#3163)
Browse files Browse the repository at this point in the history
* testcases for community detection

* ruff lint

* Run 'pre-commit run --all'

* Rename test cases to clarify that they're comm. detect. related

---------

Co-authored-by: Tom Aarsen <[email protected]>
  • Loading branch information
JINO-ROHIT and tomaarsen authored Jan 20, 2025
1 parent c68bf68 commit 8073374
Showing 1 changed file with 130 additions and 0 deletions.
130 changes: 130 additions & 0 deletions tests/test_util.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,12 @@
from __future__ import annotations

import numpy as np
import pytest
import sklearn
import torch

from sentence_transformers import SentenceTransformer, util
from sentence_transformers.util import community_detection


def test_normalize_embeddings() -> None:
Expand Down Expand Up @@ -145,3 +147,131 @@ def test_dot_score_cos_sim() -> None:

assert np.allclose(cosine_calculated, dot_and_cosine_expected)
assert np.allclose(dot_calculated, dot_and_cosine_expected)


def test_community_detection_two_clear_communities():
"""Test case with two clear communities."""
embeddings = torch.tensor(
[
[1.0, 0.0, 0.0], # Point 0
[0.9, 0.1, 0.0], # Point 1
[0.8, 0.2, 0.0], # Point 2
[0.1, 0.9, 0.0], # Point 3
[0.0, 1.0, 0.0], # Point 4
[0.2, 0.8, 0.0], # Point 5
]
)
expected = [
[0, 1, 2], # Community 1
[3, 4, 5], # Community 2
]
result = community_detection(embeddings, threshold=0.8, min_community_size=2)
assert sorted([sorted(community) for community in result]) == sorted([sorted(community) for community in expected])


def test_community_detection_no_communities_high_threshold():
"""Test case where no communities are found due to a high threshold."""
embeddings = torch.tensor(
[
[1.0, 0.0, 0.0],
[0.0, 1.0, 0.0],
[0.0, 0.0, 1.0],
]
)
expected = []
result = community_detection(embeddings, threshold=0.99, min_community_size=2)
assert result == expected


def test_community_detection_all_points_in_one_community():
"""Test case where all points form a single community due to a low threshold."""
embeddings = torch.tensor(
[
[1.0, 0.0, 0.0],
[0.9, 0.1, 0.0],
[0.8, 0.2, 0.0],
]
)
expected = [
[0, 1, 2], # Single community
]
result = community_detection(embeddings, threshold=0.5, min_community_size=2)
assert sorted([sorted(community) for community in result]) == sorted([sorted(community) for community in expected])


def test_community_detection_min_community_size_filtering():
"""Test case where communities are filtered based on minimum size."""
embeddings = torch.tensor(
[
[1.0, 0.0, 0.0],
[0.9, 0.1, 0.0],
[0.8, 0.2, 0.0],
[0.1, 0.9, 0.0],
]
)
expected = [
[0, 1, 2], # Only one community meets the min size requirement
]
result = community_detection(embeddings, threshold=0.8, min_community_size=3)
assert sorted([sorted(community) for community in result]) == sorted([sorted(community) for community in expected])


def test_community_detection_overlapping_communities():
"""Test case with overlapping communities (resolved by the function)."""
embeddings = torch.tensor(
[
[1.0, 0.0, 0.0], # Point 0
[0.9, 0.1, 0.0], # Point 1
[0.8, 0.2, 0.0], # Point 2
[0.7, 0.3, 0.0], # Point 3 (overlaps with both communities)
[0.1, 0.9, 0.0], # Point 4
[0.0, 1.0, 0.0], # Point 5
]
)
expected = [
[0, 1, 2, 3], # Community 1 (includes overlapping point 3)
[4, 5], # Community 2
]
result = community_detection(embeddings, threshold=0.8, min_community_size=2)
assert sorted([sorted(community) for community in result]) == sorted([sorted(community) for community in expected])


def test_community_detection_numpy_input():
"""Test case where input is a numpy array instead of a torch tensor."""
embeddings = np.array(
[
[1.0, 0.0, 0.0],
[0.9, 0.1, 0.0],
[0.8, 0.2, 0.0],
]
)
expected = [
[0, 1, 2], # Single community
]
result = community_detection(embeddings, threshold=0.8, min_community_size=2)
assert sorted([sorted(community) for community in result]) == sorted([sorted(community) for community in expected])


def test_community_detection_large_batch_size():
"""Test case with a large dataset and batching."""
embeddings = torch.rand(1000, 128) # Random embeddings
result = community_detection(embeddings, threshold=0.8, min_community_size=10, batch_size=256)
# Check that all communities meet the minimum size requirement
assert all(len(community) >= 10 for community in result)


@pytest.mark.skipif(not torch.cuda.is_available(), reason="GPU not available")
def test_community_detection_gpu_support():
"""Test case for GPU support (if available)."""
embeddings = torch.tensor(
[
[1.0, 0.0, 0.0],
[0.9, 0.1, 0.0],
[0.8, 0.2, 0.0],
]
).cuda()
expected = [
[0, 1, 2], # Single community
]
result = community_detection(embeddings, threshold=0.8, min_community_size=2)
assert sorted([sorted(community) for community in result]) == sorted([sorted(community) for community in expected])

0 comments on commit 8073374

Please sign in to comment.