Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Faster flag complex #355

Merged
merged 12 commits into from
May 10, 2023
48 changes: 0 additions & 48 deletions tests/generators/test_nonuniform.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,51 +67,3 @@ def test_random_hypergraph():
H4 = xgi.random_hypergraph(10, [0.1], order=2, seed=1)
assert H4.num_nodes == 10
assert xgi.unique_edge_sizes(H4) == [3]


def test_random_simplicial_complex():
# seed
S1 = xgi.random_simplicial_complex(10, [0.1, 0.001], seed=1)
S2 = xgi.random_simplicial_complex(10, [0.1, 0.001], seed=2)
S3 = xgi.random_simplicial_complex(10, [0.1, 0.001], seed=2)

assert S1._edge != S2._edge
assert S2._edge == S3._edge

# wrong input
with pytest.raises(ValueError):
S1 = xgi.random_simplicial_complex(10, [1, 1.1])
with pytest.raises(ValueError):
S1 = xgi.random_simplicial_complex(10, [1, -2])


def test_random_flag_complex():
# seed
S1 = xgi.random_flag_complex(10, 0.1, seed=1)
S2 = xgi.random_flag_complex(10, 0.1, seed=2)
S3 = xgi.random_flag_complex(10, 0.1, seed=2)

assert S1._edge != S2._edge
assert S2._edge == S3._edge

# wrong input
with pytest.raises(ValueError):
S1 = xgi.random_flag_complex(10, 1.1)
with pytest.raises(ValueError):
S1 = xgi.random_flag_complex(10, -2)


def test_random_flag_complex_d2():
# seed
S1 = xgi.random_flag_complex_d2(10, 0.1, seed=1)
S2 = xgi.random_flag_complex_d2(10, 0.1, seed=2)
S3 = xgi.random_flag_complex_d2(10, 0.1, seed=2)

assert S1._edge != S2._edge
assert S2._edge == S3._edge

# wrong input
with pytest.raises(ValueError):
S1 = xgi.random_flag_complex_d2(10, 1.1)
with pytest.raises(ValueError):
S1 = xgi.random_flag_complex_d2(10, -2)
116 changes: 116 additions & 0 deletions tests/generators/test_simplicial_complexes.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import networkx as nx
import pytest

import xgi
from xgi.exception import XGIError
Expand All @@ -21,6 +22,7 @@ def test_flag_complex():

assert S.edges.members() == simplices_3

# ps
S1 = xgi.flag_complex(G, ps=[1], seed=42)
S2 = xgi.flag_complex(G, ps=[0.5], seed=42)
S3 = xgi.flag_complex(G, ps=[0], seed=42)
Expand All @@ -29,6 +31,7 @@ def test_flag_complex():
assert S2.edges.members() == simplices_2
assert S3.edges.members() == simplices_2

# complete graph
G1 = nx.complete_graph(4)
S4 = xgi.flag_complex(G1)
S5 = xgi.flag_complex(G1, ps=[1])
Expand All @@ -44,3 +47,116 @@ def test_flag_complex_d2():
S2 = xgi.flag_complex_d2(G)

assert set(S.edges.members()) == set(S2.edges.members())


def test_random_simplicial_complex():
# seed
S1 = xgi.random_simplicial_complex(10, [0.1, 0.001], seed=1)
S2 = xgi.random_simplicial_complex(10, [0.1, 0.001], seed=2)
S3 = xgi.random_simplicial_complex(10, [0.1, 0.001], seed=2)

assert S1._edge != S2._edge
assert S2._edge == S3._edge

# wrong input
with pytest.raises(ValueError):
S1 = xgi.random_simplicial_complex(10, [1, 1.1])
with pytest.raises(ValueError):
S1 = xgi.random_simplicial_complex(10, [1, -2])


def test_random_flag_complex():

S = xgi.random_flag_complex(10, 0.4, seed=2)
simplices = {
frozenset({0, 4}),
frozenset({0, 7}),
frozenset({1, 8}),
frozenset({2, 5}),
frozenset({2, 9}),
frozenset({3, 5}),
frozenset({3, 6}),
frozenset({3, 7}),
frozenset({3, 8}),
frozenset({4, 5}),
frozenset({4, 7}),
frozenset({4, 8}),
frozenset({6, 7}),
frozenset({6, 8}),
frozenset({7, 8}),
frozenset({0, 4, 7}),
frozenset({3, 6, 7}),
frozenset({3, 6, 8}),
frozenset({3, 7, 8}),
frozenset({4, 7, 8}),
frozenset({6, 7, 8}),
}

assert set(S.edges.members()) == simplices

# max_order
S = xgi.random_flag_complex(10, 0.4, seed=2, max_order=3)
assert set(S.edges.members()) == simplices.union({frozenset({3, 6, 7, 8})})

# seed
S1 = xgi.random_flag_complex(10, 0.1, seed=1)
S2 = xgi.random_flag_complex(10, 0.1, seed=2)
S3 = xgi.random_flag_complex(10, 0.1, seed=2)

assert S1._edge != S2._edge
assert S2._edge == S3._edge

# wrong input
with pytest.raises(ValueError):
S1 = xgi.random_flag_complex(10, 1.1)
with pytest.raises(ValueError):
S1 = xgi.random_flag_complex(10, -2)


def test_random_flag_complex_d2():

S = xgi.random_flag_complex_d2(10, 0.4, seed=2)
simplices = {
frozenset({0, 4}),
frozenset({0, 7}),
frozenset({1, 8}),
frozenset({2, 5}),
frozenset({2, 9}),
frozenset({3, 5}),
frozenset({3, 6}),
frozenset({3, 7}),
frozenset({3, 8}),
frozenset({4, 5}),
frozenset({4, 7}),
frozenset({4, 8}),
frozenset({6, 7}),
frozenset({6, 8}),
frozenset({7, 8}),
frozenset({0, 4, 7}),
frozenset({3, 6, 7}),
frozenset({3, 6, 8}),
frozenset({3, 7, 8}),
frozenset({4, 7, 8}),
frozenset({6, 7, 8}),
}

assert set(S.edges.members()) == simplices

# consistency with other function
S = xgi.random_flag_complex(10, 0.4, seed=3, max_order=2)
S0 = xgi.random_flag_complex_d2(10, 0.4, seed=3)
assert set(S.edges.members()) == set(S0.edges.members())

# seed
S1 = xgi.random_flag_complex_d2(10, 0.1, seed=1)
S2 = xgi.random_flag_complex_d2(10, 0.1, seed=2)
S3 = xgi.random_flag_complex_d2(10, 0.1, seed=2)

assert S1._edge != S2._edge
assert S2._edge == S3._edge

# wrong input
with pytest.raises(ValueError):
S1 = xgi.random_flag_complex_d2(10, 1.1)
with pytest.raises(ValueError):
S1 = xgi.random_flag_complex_d2(10, -2)
1 change: 1 addition & 0 deletions xgi/generators/classic.py
Original file line number Diff line number Diff line change
Expand Up @@ -217,6 +217,7 @@ def complete_hypergraph(N, order=None, max_order=None, include_singletons=False)
elif max_order is not None:
start = 1 if include_singletons else 2
end = max_order + 1
assert end >= start # can be equal because adding +1 to end below

s = list(nodes)
edges = chain.from_iterable(combinations(s, r) for r in range(start, end + 1))
Expand Down
55 changes: 42 additions & 13 deletions xgi/generators/simplicial_complexes.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,26 +130,21 @@ def flag_complex(G, max_order=2, ps=None, seed=None):
random.seed(seed)

nodes = G.nodes()
N = len(nodes)
edges = G.edges()

# compute all maximal cliques to fill
max_cliques = list(nx.find_cliques(G))
cliques_to_add = _cliques_to_fill(G, max_order)

S = SimplicialComplex()
S.add_nodes_from(nodes)
S.add_simplices_from(edges)
if not ps: # promote all cliques
S.add_simplices_from(max_cliques, max_order=max_order)
S.add_simplices_from(cliques_to_add, max_order=max_order)
return S

if max_order: # compute subfaces of order max_order (allowed max cliques)
max_cliques_to_add = subfaces(max_cliques, order=max_order)
else:
max_cliques_to_add = max_cliques

# store max cliques per order
cliques_d = defaultdict(list)
for x in max_cliques_to_add:
for x in cliques_to_add:
cliques_d[len(x)].append(x)

# promote cliques with a given probability
Expand Down Expand Up @@ -277,13 +272,47 @@ def random_flag_complex(N, p, max_order=2, seed=None):
G = nx.fast_gnp_random_graph(N, p, seed=seed)

nodes = G.nodes()
edges = list(G.edges())

# compute all triangles to fill
max_cliques = list(nx.find_cliques(G))
cliques = _cliques_to_fill(G, max_order)

S = SimplicialComplex()
S.add_nodes_from(nodes)
S.add_simplices_from(max_cliques, max_order=max_order)
S.add_simplices_from(cliques, max_order=max_order)

return S


def _cliques_to_fill(G, max_order):
"""Return cliques to fill for flag complexes,
to be passed to `add_simplices_from`.

This function was written to speedup flag_complex functions
by avoiding adding redundant faces.

Parameters
----------
G : networkx Graph
Graph to consider
max_order: int or None
If None, return maximal cliques. If int, return all cliques
up to max_order.

Returns
-------
cliques : list
List of cliques

"""
if max_order is None:
cliques = list(nx.find_cliques(G)) # max cliques
else: # avoid adding many unnecessary redundant cliques
cliques = []
for clique in nx.enumerate_all_cliques(G): # sorted by size
if len(clique) == 1:
continue # don't add singletons
if len(clique) <= max_order + 1:
cliques.append(clique)
else:
break # dont go over whole list if not necessary

return cliques