Skip to content

Commit

Permalink
update tests
Browse files Browse the repository at this point in the history
  • Loading branch information
pattonw committed May 11, 2023
1 parent 96bd8a7 commit 82f8b62
Showing 1 changed file with 16 additions and 10 deletions.
26 changes: 16 additions & 10 deletions tests/test_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,9 @@


def test_graph_filtering(provider_factory):
graph_provider = provider_factory("w")
graph_writer = provider_factory("w")
roi = Roi((0, 0, 0), (10, 10, 10))
graph = graph_provider[roi]
graph = graph_writer[roi]

graph.add_node(2, position=(2, 2, 2), selected=True)
graph.add_node(42, position=(1, 1, 1), selected=False)
Expand All @@ -20,26 +20,29 @@ def test_graph_filtering(provider_factory):
graph.write_nodes()
graph.write_edges()

graph_provider = provider_factory("r")
graph_reader = provider_factory("r")

filtered_nodes = graph_provider.read_nodes(roi, attr_filter={"selected": True})
filtered_nodes = graph_reader.read_nodes(roi, attr_filter={"selected": True})
filtered_node_ids = [node["id"] for node in filtered_nodes]
expected_node_ids = [2, 23, 57]
assert expected_node_ids == filtered_node_ids

filtered_edges = graph_provider.read_edges(roi, attr_filter={"selected": True})
filtered_edges = graph_reader.read_edges(roi, attr_filter={"selected": True})
filtered_edge_endpoints = [(edge["u"], edge["v"]) for edge in filtered_edges]
expected_edge_endpoints = [(57, 23), (2, 42)]
assert expected_edge_endpoints == filtered_edge_endpoints
for u, v in expected_edge_endpoints:
assert (u,v) in filtered_edge_endpoints or (v,u) in filtered_edge_endpoints

filtered_subgraph = graph_provider.get_graph(
filtered_subgraph = graph_reader.get_graph(
roi, nodes_filter={"selected": True}, edges_filter={"selected": True}
)
nodes_with_position = [
node for node, data in filtered_subgraph.nodes(data=True) if "position" in data
]
assert expected_node_ids == nodes_with_position
assert expected_edge_endpoints == filtered_subgraph.edges()
assert len(filtered_subgraph.edges()) == len(expected_edge_endpoints)
for u, v in expected_edge_endpoints:
assert (u,v) in filtered_subgraph.edges() or (v,u) in filtered_subgraph.edges()


def test_graph_filtering_complex(provider_factory):
Expand Down Expand Up @@ -73,7 +76,8 @@ def test_graph_filtering_complex(provider_factory):
)
filtered_edge_endpoints = [(edge["u"], edge["v"]) for edge in filtered_edges]
expected_edge_endpoints = [(57, 23)]
assert expected_edge_endpoints == filtered_edge_endpoints
for u, v in expected_edge_endpoints:
assert (u,v) in filtered_edge_endpoints or (v,u) in filtered_edge_endpoints

filtered_subgraph = graph_provider.get_graph(
roi,
Expand All @@ -84,7 +88,9 @@ def test_graph_filtering_complex(provider_factory):
node for node, data in filtered_subgraph.nodes(data=True) if "position" in data
]
assert expected_node_ids == nodes_with_position
assert expected_edge_endpoints == filtered_subgraph.edges()
assert len(filtered_subgraph.edges()) == len(expected_edge_endpoints)
for u, v in expected_edge_endpoints:
assert (u,v) in filtered_subgraph.edges() or (v,u) in filtered_subgraph.edges()


def test_graph_read_and_update_specific_attrs(provider_factory):
Expand Down

0 comments on commit 82f8b62

Please sign in to comment.