diff --git a/tests/conftest.py b/tests/conftest.py index c2d8155..120fb76 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,7 +1,5 @@ from funlib.persistence.graphs import ( - FileGraphProvider, - MongoDbGraphProvider, - SQLiteGraphProvider, + SQLiteGraphDataBase ) import pytest @@ -10,27 +8,8 @@ from pathlib import Path -def mongo_db_available(): - client = pymongo.MongoClient(serverSelectionTimeoutMS=1000) - try: - client.admin.command("ping") - return True - except pymongo.errors.ConnectionFailure: - return False - - @pytest.fixture( params=( - pytest.param( - "files", - marks=pytest.mark.xfail(reason="FileProvider not fully implemented!"), - ), - pytest.param( - "mongo", - marks=pytest.mark.skipif( - not mongo_db_available(), reason="MongoDB not available!" - ), - ), pytest.param("sqlite"), ) ) @@ -42,30 +21,10 @@ def provider_factory(request, tmpdir): tmpdir = Path(tmpdir) - def mongo_provider_factory( - mode, directed=None, total_roi=None, node_attrs=None, edge_attrs=None - ): - return MongoDbGraphProvider( - "test_mongo_graph", mode=mode, directed=directed, total_roi=total_roi - ) - - def file_provider_factory( - mode, directed=None, total_roi=None, node_attrs=None, edge_attrs=None - ): - return FileGraphProvider( - tmpdir / "test_file_graph.db", - chunk_size=(10, 10, 10), - mode=mode, - directed=directed, - total_roi=total_roi, - # node_attrs=node_attrs, - # edge_attrs=edge_attrs, - ) - def sqlite_provider_factory( mode, directed=None, total_roi=None, node_attrs=None, edge_attrs=None ): - return SQLiteGraphProvider( + return SQLiteGraphDataBase( tmpdir / "test_sqlite_graph.db", mode=mode, directed=directed, @@ -74,11 +33,7 @@ def sqlite_provider_factory( edge_attrs=edge_attrs, ) - if request.param == "mongo": - yield mongo_provider_factory - elif request.param == "sqlite": + if request.param == "sqlite": yield sqlite_provider_factory - elif request.param == "files": - yield file_provider_factory else: raise ValueError() diff --git a/tests/test_graph.py b/tests/test_graph.py index cc0ecf2..26163e6 100644 --- a/tests/test_graph.py +++ b/tests/test_graph.py @@ -19,8 +19,8 @@ def test_graph_filtering(provider_factory): graph.add_edge(57, 23, selected=True) graph.add_edge(2, 42, selected=True) - graph.write_nodes() - graph.write_edges() + graph_writer.write_nodes(graph.nodes()) + graph_writer.write_edges(graph.nodes(), graph.edges()) graph_reader = provider_factory("r") @@ -35,7 +35,7 @@ def test_graph_filtering(provider_factory): for u, v in expected_edge_endpoints: assert (u, v) in filtered_edge_endpoints or (v, u) in filtered_edge_endpoints - filtered_subgraph = graph_reader.get_graph( + filtered_subgraph = graph_reader.read_graph( roi, nodes_filter={"selected": True}, edges_filter={"selected": True} ) nodes_with_position = [ @@ -66,8 +66,8 @@ def test_graph_filtering_complex(provider_factory): graph.add_edge(57, 23, selected=True, a=100, b=2) graph.add_edge(2, 42, selected=True, a=101, b=3) - graph.write_nodes() - graph.write_edges() + graph_provider.write_nodes(graph.nodes()) + graph_provider.write_edges(graph.nodes(), graph.edges()) graph_provider = provider_factory("r") @@ -86,7 +86,7 @@ def test_graph_filtering_complex(provider_factory): for u, v in expected_edge_endpoints: assert (u, v) in filtered_edge_endpoints or (v, u) in filtered_edge_endpoints - filtered_subgraph = graph_provider.get_graph( + filtered_subgraph = graph_provider.read_graph( roi, nodes_filter={"selected": True, "test": "test"}, edges_filter={"selected": True, "a": 100}, @@ -100,7 +100,7 @@ def test_graph_filtering_complex(provider_factory): def test_graph_read_and_update_specific_attrs(provider_factory): graph_provider = provider_factory( - "w", node_attrs=["selected", "test"], edge_attrs=["selected", "a", "b"] + "w", node_attrs=["selected", "test"], edge_attrs=["selected", "a", "b", "c"] ) roi = Roi((0, 0, 0), (10, 10, 10)) graph = graph_provider[roi] @@ -114,11 +114,10 @@ def test_graph_read_and_update_specific_attrs(provider_factory): graph.add_edge(57, 23, selected=True, a=100, b=2) graph.add_edge(2, 42, selected=True, a=101, b=3) - graph.write_nodes() - graph.write_edges() + graph_provider.write_graph(graph) graph_provider = provider_factory("r+") - limited_graph = graph_provider.get_graph( + limited_graph = graph_provider.read_graph( roi, node_attrs=["selected"], edge_attrs=["c"] ) @@ -133,12 +132,11 @@ def test_graph_read_and_update_specific_attrs(provider_factory): nx.set_edge_attributes(limited_graph, 5, "c") try: - limited_graph.update_edge_attrs(attributes=["c"]) - limited_graph.update_node_attrs(attributes=["selected"]) + graph_provider.write_attrs(limited_graph, edge_attrs=["c"], node_attrs=["selected"]) except NotImplementedError: pytest.xfail() - updated_graph = graph_provider.get_graph(roi) + updated_graph = graph_provider.read_graph(roi) for node, data in updated_graph.nodes(data=True): assert data["selected"] @@ -165,11 +163,11 @@ def test_graph_read_unbounded_roi(provider_factory): graph.add_edge(57, 23, selected=True, a=100, b=2) graph.add_edge(2, 42, selected=True, a=101, b=3) - graph.write_nodes() - graph.write_edges() + graph_provider.write_nodes(graph.nodes(), ) + graph_provider.write_edges(graph.nodes(), graph.edges(), ) graph_provider = provider_factory("r+") - limited_graph = graph_provider.get_graph( + limited_graph = graph_provider.read_graph( unbounded_roi, node_attrs=["selected"], edge_attrs=["c"] ) @@ -228,8 +226,8 @@ def test_graph_io(provider_factory): graph.add_edge(57, 23) graph.add_edge(2, 42) - graph.write_nodes() - graph.write_edges() + graph_provider.write_nodes(graph.nodes(), ) + graph_provider.write_edges(graph.nodes(), graph.edges(), ) graph_provider = provider_factory("r") compare_graph = graph_provider[Roi((0, 0, 0), (10, 10, 10))] @@ -240,7 +238,7 @@ def test_graph_io(provider_factory): edges = sorted(tuple(sorted(e)) for e in list(graph.edges())) edges.remove((2, 42)) # node 2 has no position and will not be queried - compare_edges = sorted(list(compare_graph.edges())) + compare_edges = sorted(tuple(sorted(e)) for e in list(compare_graph.edges())) assert nodes == compare_nodes assert edges == compare_edges @@ -258,12 +256,11 @@ def test_graph_fail_if_exists(provider_factory): graph.add_edge(57, 23) graph.add_edge(2, 42) - graph.write_nodes() - graph.write_edges() + graph_provider.write_graph(graph) with pytest.raises(Exception): - graph.write_nodes(fail_if_exists=True) + graph_provider.write_nodes(graph.nodes(), fail_if_exists=True) with pytest.raises(Exception): - graph.write_edges(fail_if_exists=True) + graph_provider.write_edges(graph.nodes(), graph.edges(), fail_if_exists=True) def test_graph_fail_if_not_exists(provider_factory): @@ -279,9 +276,9 @@ def test_graph_fail_if_not_exists(provider_factory): graph.add_edge(2, 42) with pytest.raises(Exception): - graph.write_nodes(fail_if_not_exists=True) + graph_provider.write_nodes(graph.nodes(), fail_if_not_exists=True) with pytest.raises(Exception): - graph.write_edges(fail_if_not_exists=True) + graph_provider.write_edges(graph.nodes(), graph.edges(), fail_if_not_exists=True) def test_graph_write_attributes(provider_factory): @@ -297,10 +294,10 @@ def test_graph_write_attributes(provider_factory): graph.add_edge(2, 42) try: - graph.write_nodes(attributes=["position", "swip"]) + graph_provider.write_nodes(graph.nodes(), attributes=["position", "swip"]) except NotImplementedError: pytest.xfail() - graph.write_edges() + graph_provider.write_edges(graph.nodes(), graph.edges(), ) graph_provider = provider_factory("r") compare_graph = graph_provider[Roi((0, 0, 0), (10, 10, 10))] @@ -334,8 +331,7 @@ def test_graph_write_roi(provider_factory): graph.add_edge(2, 42) write_roi = Roi((0, 0, 0), (6, 6, 6)) - graph.write_nodes(roi=write_roi) - graph.write_edges(roi=write_roi) + graph_provider.write_graph(graph, write_roi) graph_provider = provider_factory("r") compare_graph = graph_provider[Roi((0, 0, 0), (10, 10, 10))] @@ -348,7 +344,7 @@ def test_graph_write_roi(provider_factory): compare_nodes = sorted(list(compare_nodes)) edges = sorted(tuple(sorted(e)) for e in list(graph.edges())) edges.remove((2, 42)) # node 2 has no position and will not be queried - compare_edges = sorted(list(compare_graph.edges())) + compare_edges = sorted(tuple(sorted(e)) for e in list(compare_graph.edges())) assert nodes == compare_nodes assert edges == compare_edges @@ -365,13 +361,13 @@ def test_graph_connected_components(provider_factory): graph.add_edge(57, 23) graph.add_edge(2, 42) try: - components = graph.get_connected_components() + components = list(nx.connected_components(graph)) except NotImplementedError: pytest.xfail() assert len(components) == 2 c1, c2 = components - n1 = sorted(list(c1.nodes())) - n2 = sorted(list(c2.nodes())) + n1 = sorted(list(c1)) + n2 = sorted(list(c2)) compare_n1 = [2, 42] compare_n2 = [23, 57] @@ -399,7 +395,7 @@ def test_graph_has_edge(provider_factory): graph.add_edge(57, 23) write_roi = Roi((0, 0, 0), (6, 6, 6)) - graph.write_nodes(roi=write_roi) - graph.write_edges(roi=write_roi) + graph_provider.write_nodes(graph.nodes(), roi=write_roi) + graph_provider.write_edges(graph.nodes(), graph.edges(), roi=write_roi) assert graph_provider.has_edges(roi)