Skip to content

Commit

Permalink
update tests
Browse files Browse the repository at this point in the history
  • Loading branch information
pattonw committed Nov 16, 2023
1 parent 0872ca2 commit cd4ae55
Show file tree
Hide file tree
Showing 2 changed files with 34 additions and 83 deletions.
51 changes: 3 additions & 48 deletions tests/conftest.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,5 @@
from funlib.persistence.graphs import (
FileGraphProvider,
MongoDbGraphProvider,
SQLiteGraphProvider,
SQLiteGraphDataBase
)

import pytest
Expand All @@ -10,27 +8,8 @@
from pathlib import Path


def mongo_db_available():
client = pymongo.MongoClient(serverSelectionTimeoutMS=1000)
try:
client.admin.command("ping")
return True
except pymongo.errors.ConnectionFailure:
return False


@pytest.fixture(
params=(
pytest.param(
"files",
marks=pytest.mark.xfail(reason="FileProvider not fully implemented!"),
),
pytest.param(
"mongo",
marks=pytest.mark.skipif(
not mongo_db_available(), reason="MongoDB not available!"
),
),
pytest.param("sqlite"),
)
)
Expand All @@ -42,30 +21,10 @@ def provider_factory(request, tmpdir):

tmpdir = Path(tmpdir)

def mongo_provider_factory(
mode, directed=None, total_roi=None, node_attrs=None, edge_attrs=None
):
return MongoDbGraphProvider(
"test_mongo_graph", mode=mode, directed=directed, total_roi=total_roi
)

def file_provider_factory(
mode, directed=None, total_roi=None, node_attrs=None, edge_attrs=None
):
return FileGraphProvider(
tmpdir / "test_file_graph.db",
chunk_size=(10, 10, 10),
mode=mode,
directed=directed,
total_roi=total_roi,
# node_attrs=node_attrs,
# edge_attrs=edge_attrs,
)

def sqlite_provider_factory(
mode, directed=None, total_roi=None, node_attrs=None, edge_attrs=None
):
return SQLiteGraphProvider(
return SQLiteGraphDataBase(
tmpdir / "test_sqlite_graph.db",
mode=mode,
directed=directed,
Expand All @@ -74,11 +33,7 @@ def sqlite_provider_factory(
edge_attrs=edge_attrs,
)

if request.param == "mongo":
yield mongo_provider_factory
elif request.param == "sqlite":
if request.param == "sqlite":
yield sqlite_provider_factory
elif request.param == "files":
yield file_provider_factory
else:
raise ValueError()
66 changes: 31 additions & 35 deletions tests/test_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,8 @@ def test_graph_filtering(provider_factory):
graph.add_edge(57, 23, selected=True)
graph.add_edge(2, 42, selected=True)

graph.write_nodes()
graph.write_edges()
graph_writer.write_nodes(graph.nodes())
graph_writer.write_edges(graph.nodes(), graph.edges())

graph_reader = provider_factory("r")

Expand All @@ -35,7 +35,7 @@ def test_graph_filtering(provider_factory):
for u, v in expected_edge_endpoints:
assert (u, v) in filtered_edge_endpoints or (v, u) in filtered_edge_endpoints

filtered_subgraph = graph_reader.get_graph(
filtered_subgraph = graph_reader.read_graph(
roi, nodes_filter={"selected": True}, edges_filter={"selected": True}
)
nodes_with_position = [
Expand Down Expand Up @@ -66,8 +66,8 @@ def test_graph_filtering_complex(provider_factory):
graph.add_edge(57, 23, selected=True, a=100, b=2)
graph.add_edge(2, 42, selected=True, a=101, b=3)

graph.write_nodes()
graph.write_edges()
graph_provider.write_nodes(graph.nodes())
graph_provider.write_edges(graph.nodes(), graph.edges())

graph_provider = provider_factory("r")

Expand All @@ -86,7 +86,7 @@ def test_graph_filtering_complex(provider_factory):
for u, v in expected_edge_endpoints:
assert (u, v) in filtered_edge_endpoints or (v, u) in filtered_edge_endpoints

filtered_subgraph = graph_provider.get_graph(
filtered_subgraph = graph_provider.read_graph(
roi,
nodes_filter={"selected": True, "test": "test"},
edges_filter={"selected": True, "a": 100},
Expand All @@ -100,7 +100,7 @@ def test_graph_filtering_complex(provider_factory):

def test_graph_read_and_update_specific_attrs(provider_factory):
graph_provider = provider_factory(
"w", node_attrs=["selected", "test"], edge_attrs=["selected", "a", "b"]
"w", node_attrs=["selected", "test"], edge_attrs=["selected", "a", "b", "c"]
)
roi = Roi((0, 0, 0), (10, 10, 10))
graph = graph_provider[roi]
Expand All @@ -114,11 +114,10 @@ def test_graph_read_and_update_specific_attrs(provider_factory):
graph.add_edge(57, 23, selected=True, a=100, b=2)
graph.add_edge(2, 42, selected=True, a=101, b=3)

graph.write_nodes()
graph.write_edges()
graph_provider.write_graph(graph)

graph_provider = provider_factory("r+")
limited_graph = graph_provider.get_graph(
limited_graph = graph_provider.read_graph(
roi, node_attrs=["selected"], edge_attrs=["c"]
)

Expand All @@ -133,12 +132,11 @@ def test_graph_read_and_update_specific_attrs(provider_factory):
nx.set_edge_attributes(limited_graph, 5, "c")

try:
limited_graph.update_edge_attrs(attributes=["c"])
limited_graph.update_node_attrs(attributes=["selected"])
graph_provider.write_attrs(limited_graph, edge_attrs=["c"], node_attrs=["selected"])
except NotImplementedError:
pytest.xfail()

updated_graph = graph_provider.get_graph(roi)
updated_graph = graph_provider.read_graph(roi)

for node, data in updated_graph.nodes(data=True):
assert data["selected"]
Expand All @@ -165,11 +163,11 @@ def test_graph_read_unbounded_roi(provider_factory):
graph.add_edge(57, 23, selected=True, a=100, b=2)
graph.add_edge(2, 42, selected=True, a=101, b=3)

graph.write_nodes()
graph.write_edges()
graph_provider.write_nodes(graph.nodes(), )
graph_provider.write_edges(graph.nodes(), graph.edges(), )

graph_provider = provider_factory("r+")
limited_graph = graph_provider.get_graph(
limited_graph = graph_provider.read_graph(
unbounded_roi, node_attrs=["selected"], edge_attrs=["c"]
)

Expand Down Expand Up @@ -228,8 +226,8 @@ def test_graph_io(provider_factory):
graph.add_edge(57, 23)
graph.add_edge(2, 42)

graph.write_nodes()
graph.write_edges()
graph_provider.write_nodes(graph.nodes(), )
graph_provider.write_edges(graph.nodes(), graph.edges(), )

graph_provider = provider_factory("r")
compare_graph = graph_provider[Roi((0, 0, 0), (10, 10, 10))]
Expand All @@ -240,7 +238,7 @@ def test_graph_io(provider_factory):

edges = sorted(tuple(sorted(e)) for e in list(graph.edges()))
edges.remove((2, 42)) # node 2 has no position and will not be queried
compare_edges = sorted(list(compare_graph.edges()))
compare_edges = sorted(tuple(sorted(e)) for e in list(compare_graph.edges()))

assert nodes == compare_nodes
assert edges == compare_edges
Expand All @@ -258,12 +256,11 @@ def test_graph_fail_if_exists(provider_factory):
graph.add_edge(57, 23)
graph.add_edge(2, 42)

graph.write_nodes()
graph.write_edges()
graph_provider.write_graph(graph)
with pytest.raises(Exception):
graph.write_nodes(fail_if_exists=True)
graph_provider.write_nodes(graph.nodes(), fail_if_exists=True)
with pytest.raises(Exception):
graph.write_edges(fail_if_exists=True)
graph_provider.write_edges(graph.nodes(), graph.edges(), fail_if_exists=True)


def test_graph_fail_if_not_exists(provider_factory):
Expand All @@ -279,9 +276,9 @@ def test_graph_fail_if_not_exists(provider_factory):
graph.add_edge(2, 42)

with pytest.raises(Exception):
graph.write_nodes(fail_if_not_exists=True)
graph_provider.write_nodes(graph.nodes(), fail_if_not_exists=True)
with pytest.raises(Exception):
graph.write_edges(fail_if_not_exists=True)
graph_provider.write_edges(graph.nodes(), graph.edges(), fail_if_not_exists=True)


def test_graph_write_attributes(provider_factory):
Expand All @@ -297,10 +294,10 @@ def test_graph_write_attributes(provider_factory):
graph.add_edge(2, 42)

try:
graph.write_nodes(attributes=["position", "swip"])
graph_provider.write_nodes(graph.nodes(), attributes=["position", "swip"])
except NotImplementedError:
pytest.xfail()
graph.write_edges()
graph_provider.write_edges(graph.nodes(), graph.edges(), )

graph_provider = provider_factory("r")
compare_graph = graph_provider[Roi((0, 0, 0), (10, 10, 10))]
Expand Down Expand Up @@ -334,8 +331,7 @@ def test_graph_write_roi(provider_factory):
graph.add_edge(2, 42)

write_roi = Roi((0, 0, 0), (6, 6, 6))
graph.write_nodes(roi=write_roi)
graph.write_edges(roi=write_roi)
graph_provider.write_graph(graph, write_roi)

graph_provider = provider_factory("r")
compare_graph = graph_provider[Roi((0, 0, 0), (10, 10, 10))]
Expand All @@ -348,7 +344,7 @@ def test_graph_write_roi(provider_factory):
compare_nodes = sorted(list(compare_nodes))
edges = sorted(tuple(sorted(e)) for e in list(graph.edges()))
edges.remove((2, 42)) # node 2 has no position and will not be queried
compare_edges = sorted(list(compare_graph.edges()))
compare_edges = sorted(tuple(sorted(e)) for e in list(compare_graph.edges()))

assert nodes == compare_nodes
assert edges == compare_edges
Expand All @@ -365,13 +361,13 @@ def test_graph_connected_components(provider_factory):
graph.add_edge(57, 23)
graph.add_edge(2, 42)
try:
components = graph.get_connected_components()
components = list(nx.connected_components(graph))
except NotImplementedError:
pytest.xfail()
assert len(components) == 2
c1, c2 = components
n1 = sorted(list(c1.nodes()))
n2 = sorted(list(c2.nodes()))
n1 = sorted(list(c1))
n2 = sorted(list(c2))

compare_n1 = [2, 42]
compare_n2 = [23, 57]
Expand Down Expand Up @@ -399,7 +395,7 @@ def test_graph_has_edge(provider_factory):
graph.add_edge(57, 23)

write_roi = Roi((0, 0, 0), (6, 6, 6))
graph.write_nodes(roi=write_roi)
graph.write_edges(roi=write_roi)
graph_provider.write_nodes(graph.nodes(), roi=write_roi)
graph_provider.write_edges(graph.nodes(), graph.edges(), roi=write_roi)

assert graph_provider.has_edges(roi)

0 comments on commit cd4ae55

Please sign in to comment.