diff --git a/tests/test_graph.py b/tests/test_graph.py index 7c536b6..1cff02e 100644 --- a/tests/test_graph.py +++ b/tests/test_graph.py @@ -5,9 +5,9 @@ def test_graph_filtering(provider_factory): - graph_provider = provider_factory("w") + graph_writer = provider_factory("w") roi = Roi((0, 0, 0), (10, 10, 10)) - graph = graph_provider[roi] + graph = graph_writer[roi] graph.add_node(2, position=(2, 2, 2), selected=True) graph.add_node(42, position=(1, 1, 1), selected=False) @@ -20,26 +20,29 @@ def test_graph_filtering(provider_factory): graph.write_nodes() graph.write_edges() - graph_provider = provider_factory("r") + graph_reader = provider_factory("r") - filtered_nodes = graph_provider.read_nodes(roi, attr_filter={"selected": True}) + filtered_nodes = graph_reader.read_nodes(roi, attr_filter={"selected": True}) filtered_node_ids = [node["id"] for node in filtered_nodes] expected_node_ids = [2, 23, 57] assert expected_node_ids == filtered_node_ids - filtered_edges = graph_provider.read_edges(roi, attr_filter={"selected": True}) + filtered_edges = graph_reader.read_edges(roi, attr_filter={"selected": True}) filtered_edge_endpoints = [(edge["u"], edge["v"]) for edge in filtered_edges] expected_edge_endpoints = [(57, 23), (2, 42)] - assert expected_edge_endpoints == filtered_edge_endpoints + for u, v in expected_edge_endpoints: + assert (u,v) in filtered_edge_endpoints or (v,u) in filtered_edge_endpoints - filtered_subgraph = graph_provider.get_graph( + filtered_subgraph = graph_reader.get_graph( roi, nodes_filter={"selected": True}, edges_filter={"selected": True} ) nodes_with_position = [ node for node, data in filtered_subgraph.nodes(data=True) if "position" in data ] assert expected_node_ids == nodes_with_position - assert expected_edge_endpoints == filtered_subgraph.edges() + assert len(filtered_subgraph.edges()) == len(expected_edge_endpoints) + for u, v in expected_edge_endpoints: + assert (u,v) in filtered_subgraph.edges() or (v,u) in filtered_subgraph.edges() def test_graph_filtering_complex(provider_factory): @@ -73,7 +76,8 @@ def test_graph_filtering_complex(provider_factory): ) filtered_edge_endpoints = [(edge["u"], edge["v"]) for edge in filtered_edges] expected_edge_endpoints = [(57, 23)] - assert expected_edge_endpoints == filtered_edge_endpoints + for u, v in expected_edge_endpoints: + assert (u,v) in filtered_edge_endpoints or (v,u) in filtered_edge_endpoints filtered_subgraph = graph_provider.get_graph( roi, @@ -84,7 +88,9 @@ def test_graph_filtering_complex(provider_factory): node for node, data in filtered_subgraph.nodes(data=True) if "position" in data ] assert expected_node_ids == nodes_with_position - assert expected_edge_endpoints == filtered_subgraph.edges() + assert len(filtered_subgraph.edges()) == len(expected_edge_endpoints) + for u, v in expected_edge_endpoints: + assert (u,v) in filtered_subgraph.edges() or (v,u) in filtered_subgraph.edges() def test_graph_read_and_update_specific_attrs(provider_factory):