Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feature/spectral clustering #665

Open
wants to merge 12 commits into
base: main
Choose a base branch
from
117 changes: 117 additions & 0 deletions tests/communities/test_spectral.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,117 @@
import numpy as np

import pytest

import xgi
from xgi.exception import XGIError


class TestKMeans:
def test_k_is_1(self):
X = np.random.random((3, 3))
clusters = xgi.communities.spectral._kmeans(X, 1)

assert len(clusters) == 3
assert np.all(map(lambda v: v == 1, clusters.values()))
assert np.all(map(lambda v: isinstance(v, int), clusters.values()))

def test_perfectly_separable_low_dimensions(self):
X = np.zeros((10, 10))
X[:5, :] = np.random.random((5, 10))
X[5:10, :] = 37 + np.random.random((5, 10))

clusters = xgi.communities.spectral._kmeans(X, 2, seed=2)
assert len(clusters) == 10

c1 = list(filter(lambda node: clusters[node] == 0, clusters.keys()))
c2 = list(filter(lambda node: clusters[node] == 1, clusters.keys()))
assert len(c1) == 5
assert len(c2) == 5
assert (set(c1) == {0, 1, 2, 3, 4} and set(c2) == {5, 6, 7, 8, 9}) or (
set(c2) == {0, 1, 2, 3, 4} and set(c1) == {5, 6, 7, 8, 9}
)

def test_perfectly_separable_high_dimensions(self):
X = np.zeros((10, 100))
X[:5, :] = np.random.random((5, 100))
X[5:10, :] = 37 + np.random.random((5, 100))

clusters = xgi.communities.spectral._kmeans(X, 2, seed=2)
assert len(clusters) == 10

c1 = list(filter(lambda node: clusters[node] == 0, clusters.keys()))
c2 = list(filter(lambda node: clusters[node] == 1, clusters.keys()))
assert len(c1) == 5
assert len(c2) == 5
assert (set(c1) == {0, 1, 2, 3, 4} and set(c2) == {5, 6, 7, 8, 9}) or (
set(c2) == {0, 1, 2, 3, 4} and set(c1) == {5, 6, 7, 8, 9}
)


class TestSpectralClustering:
def test_errors_num_clusters(self):
H = xgi.complete_hypergraph(5, order=2)

with pytest.raises(XGIError):
xgi.spectral_clustering(H, 6)

def test_perfectly_separable_low_dimensions(self):
H = xgi.Hypergraph(
[
[1, 2],
[2, 3],
[3, 4],
[4, 5],
[1, 3],
[2, 4],
[1, 5],
[6, 7],
[7, 8],
[8, 9],
[9, 10],
[6, 8],
[7, 9],
[6, 10],
]
)

clusters = xgi.communities.spectral.spectral_clustering(H, 2)
assert len(clusters) == 10

c1 = list(filter(lambda node: clusters[node] == 0, clusters.keys()))
c2 = list(filter(lambda node: clusters[node] == 1, clusters.keys()))
assert len(c1) == 5
assert len(c2) == 5
assert (set(c1) == {1, 2, 3, 4, 5} and set(c2) == {6, 7, 8, 9, 10}) or (
set(c2) == {1, 2, 3, 4, 5} and set(c1) == {6, 7, 8, 9, 10}
)

def test_strongly_separable_low_dimensions(self):
H = xgi.Hypergraph(
[
[1, 2, 3],
[4, 5],
[1, 3],
[2, 4],
[1, 5],
[4, 9],
[6, 7, 8],
[7, 8],
[8, 9],
[9, 10],
[6, 8],
[7, 9],
[6, 10],
]
)

clusters = xgi.communities.spectral.spectral_clustering(H, 2)
assert len(clusters) == 10

# Some nodes obviously in same cluster
assert clusters[1] == clusters[2]
assert clusters[2] == clusters[3]

# Some nodes obviously not
assert clusters[1] != clusters[8]
assert clusters[2] != clusters[7]
2 changes: 2 additions & 0 deletions xgi/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
utils,
core,
algorithms,
communities,
convert,
drawing,
dynamics,
Expand All @@ -13,6 +14,7 @@
from .utils import *
from .core import *
from .algorithms import *
from .communities import *
from .convert import *
from .drawing import *
from .dynamics import *
Expand Down
2 changes: 2 additions & 0 deletions xgi/communities/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
from . import spectral
from .spectral import *
115 changes: 115 additions & 0 deletions xgi/communities/spectral.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,115 @@
import numpy as np
from scipy.sparse.linalg import eigsh

from ..core import Hypergraph
from ..linalg.laplacian_matrix import normalized_hypergraph_laplacian

from ..exception import XGIError

__all__ = [
"spectral_clustering",
]

MAX_ITERATIONS = 10_000
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Unless this parameter is planned to be used in other functions, I'd define it in _kmeans(), the only function using it? Maybe even as a parameter of the function?



def spectral_clustering(H, k=None):
"""Cluster into k-many groups using spectral techniques.
kaiser-dan marked this conversation as resolved.
Show resolved Hide resolved

Compute a spectral clustering according to the heuristic suggested in [1].

Parameters
----------
H : Hypergraph
Hypergraph
k : int, optional
Number of clusters to find. If unspecified, computes spectral gap.
Copy link
Collaborator

@maximelucas maximelucas Feb 17, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How would this work, the spectral gap is not an integer in general?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah, sorry I meant the location of the spectral gap. I will expand on this later this week with updates to the PR.


Returns
-------
dict
A dictionary mapping node ids to their clusters. Clusters begin at 0.

Raises
------
XGIError
If more groups are specified than nodes in the hypergraph.


References
----------
.. [1] Zhou, D., Huang, J., & Schölkopf, B. (2006).
Learning with Hypergraphs: Clustering, Classification, and Embedding
Advances in Neural Information Processing Systems.

"""
if k is None:
raise NotImplementedError(
"Choosing a number of clusters organically is currently unsupported. Please specify an integer value for paramater 'k'!"
)
Comment on lines +46 to +49
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'd say either we implement this k=None default option, or we remove the default value and these lines (until it's maybe implemented).

else:
if k > H.num_nodes:
raise XGIError(
"The number of desired clusters cannot exceed the number of nodes!"
)

# Compute normalize Laplacian and its spectra
L, rowdict = normalized_hypergraph_laplacian(H, index=True)
evals, eigs = eigsh(L, k=k, which="SA")

# Form metric space representation
X = np.array(eigs)
print(X.shape, X)
kaiser-dan marked this conversation as resolved.
Show resolved Hide resolved

# Apply k-means clustering
_clusters = _kmeans(X, k)

# Remap to node ids
clusters = {rowdict[id]: cluster for id, cluster in _clusters.items()}

return clusters


def _kmeans(X, k, seed=37):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Add docstring. Also, is there a reason to specify the seed by default instead of seed=None like in other functions?

rng = np.random.default_rng(seed=seed)

# Handle edge cases
if k == 1:
return {node_idx: 1 for node_idx in range(X.shape[0])}

# Initialize stopping criterion
num_cluster_changes = np.inf
num_iterations = 0

# Instantiate random centers
bounds_inf = X.min(axis=0)
bounds_sup = X.max(axis=0)
width = bounds_sup - bounds_inf

centroids = width * rng.random((k, X.shape[1]))

# Instantiate random clusters
previous_clusters = {node: rng.integers(0, k) for node in range(X.shape[0])}

# Iterate main kmeans computation
while (num_cluster_changes > 0) and (num_iterations < MAX_ITERATIONS):
# Find nearest centroid to each point
next_clusters = dict()
for node, vector in enumerate(X):
distances = list(
map(lambda centroid: np.linalg.norm(vector - centroid), centroids)
)
closest_centroid = np.argmin(distances)
next_clusters[node] = closest_centroid

# Update convergence condition
cluster_changes = {
node: next_clusters[node] != previous_clusters[node]
for node in range(X.shape[0])
}
num_cluster_changes = len(
list(filter(lambda diff: diff, cluster_changes.values()))
)
num_iterations += 1

return next_clusters
Loading