From 07e259f2e6850740096380200ab43d2b5e733734 Mon Sep 17 00:00:00 2001 From: Michael Chin Date: Wed, 15 Dec 2021 18:11:43 -0800 Subject: [PATCH] Fix OC group-by-label, add OC group-by-depth, add/revise unit tests --- src/graph_notebook/magics/graph_magic.py | 6 +- .../network/opencypher/OCNetwork.py | 27 ++- .../opencypher/test_opencypher_network.py | 214 +++++++++++++++++- 3 files changed, 233 insertions(+), 14 deletions(-) diff --git a/src/graph_notebook/magics/graph_magic.py b/src/graph_notebook/magics/graph_magic.py index 5795fee3..2316caf5 100644 --- a/src/graph_notebook/magics/graph_magic.py +++ b/src/graph_notebook/magics/graph_magic.py @@ -258,7 +258,7 @@ def sparql(self, line='', cell='', local_ns: dict = None): parser.add_argument('-t', '--tooltip-property', type=str, default='', help='Property to display the value of on each node tooltip.') parser.add_argument('-te', '--edge-tooltip-property', type=str, default='', - help='Property to display the value of on each node tooltip.') + help='Property to display the value of on each edge tooltip.') parser.add_argument('-l', '--label-max-length', type=int, default=10, help='Specifies max length of vertex labels, in characters. Default is 10') parser.add_argument('-le', '--edge-label-max-length', type=int, default=10, @@ -447,7 +447,7 @@ def gremlin(self, line, cell, local_ns: dict = None): help='Property to display the value of on each node tooltip. If not specified, tooltip ' 'will default to the node label value.') parser.add_argument('-te', '--edge-tooltip-property', type=str, default='', - help='Property to display the value of on each node tooltip. If not specified, tooltip ' + help='Property to display the value of on each edge tooltip. If not specified, tooltip ' 'will default to the edge label value.') parser.add_argument('-l', '--label-max-length', type=int, default=10, help='Specifies max length of vertex label, in characters. Default is 10') @@ -1591,7 +1591,7 @@ def handle_opencypher_query(self, line, cell, local_ns): help='Property to display the value of on each node tooltip. If not specified, tooltip ' 'will default to the node label value.') parser.add_argument('-te', '--edge-tooltip-property', type=str, default='', - help='Property to display the value of on each node tooltip. If not specified, tooltip ' + help='Property to display the value of on each edge tooltip. If not specified, tooltip ' 'will default to the edge label value.') parser.add_argument('-l', '--label-max-length', type=int, default=10, help='Specifies max length of vertex label, in characters. Default is 10') diff --git a/src/graph_notebook/network/opencypher/OCNetwork.py b/src/graph_notebook/network/opencypher/OCNetwork.py index 7e536023..b13a89c1 100644 --- a/src/graph_notebook/network/opencypher/OCNetwork.py +++ b/src/graph_notebook/network/opencypher/OCNetwork.py @@ -107,12 +107,16 @@ def get_edge_property_value(self, data: dict, rel: dict, custom_property): return display_label - def parse_node(self, node: dict): + def parse_node(self, node: dict, path_index: int = -2): """This parses the node parameter and adds the node to the network diagram Args: node (dict): The node dictionary to parse + path_index: Position of the element in the path traversal order """ + + depth_group = "__DEPTH-" + str(path_index//2) + "__" + # generate placeholder tooltip from label; if not present, amalgamate node property values instead if LABEL_KEY in node.keys(): title_plc = node[LABEL_KEY][0] @@ -127,6 +131,8 @@ def parse_node(self, node: dict): group = node[LABEL_KEY][0] elif self.group_by_property in [ID_KEY, 'id']: group = node[ID_KEY] + elif self.group_by_property == "TRAVERSAL_DEPTH": + group = depth_group elif self.group_by_property in node[PROPERTIES_KEY]: group = node[PROPERTIES_KEY][self.group_by_property] else: @@ -137,12 +143,14 @@ def parse_node(self, node: dict): try: if str(node[LABEL_KEY][0]) in self.group_by_property and len(node[LABEL_KEY]) > 0: key = node[LABEL_KEY][0] - if self.group_by_property[key]['groupby'] in [LABEL_KEY, 'labels']: + if self.group_by_property[key] in [LABEL_KEY, 'labels']: group = node[LABEL_KEY][0] + elif self.group_by_property[key] in [ID_KEY, 'id']: + group = node[ID_KEY] + elif self.group_by_property[key] == "TRAVERSAL_DEPTH": + group = depth_group else: - group = node[PROPERTIES_KEY][self.group_by_property[key]['groupby']] - elif ID_KEY in self.group_by_property: - group = node[ID_KEY] + group = node[PROPERTIES_KEY][self.group_by_property[key]] else: group = DEFAULT_GRP except KeyError: @@ -170,15 +178,16 @@ def parse_rel(self, rel): self.add_edge(from_id=rel[START_KEY], to_id=rel[END_KEY], edge_id=rel[ID_KEY], label=edge_label, title=edge_title, data=data) - def process_result(self, res: dict): + def process_result(self, res: dict, path_index: int = -2): """Determines the type of element passed in and processes it appropriately Args: res (dict): The dictionary to parse + path_index: Position of the element in the path traversal order """ if ENTITY_KEY in res: if res[ENTITY_KEY] == NODE_ENTITY_TYPE: - self.parse_node(res) + self.parse_node(res, path_index) else: self.parse_rel(res) @@ -194,9 +203,9 @@ def add_results(self, results): if type(res[k]) is dict: self.process_result(res[k]) elif type(res[k]) is list: - for res_sublist in res[k]: + for path_index, res_sublist in enumerate(res[k]): try: - self.process_result(res_sublist) + self.process_result(res_sublist, path_index) except TypeError as e: logger.debug(f'Property {res_sublist} in list results set is invalid, skipping') logger.debug(f'Error: {e}') diff --git a/test/unit/network/opencypher/test_opencypher_network.py b/test/unit/network/opencypher/test_opencypher_network.py index 34babe84..87a5b4f4 100644 --- a/test/unit/network/opencypher/test_opencypher_network.py +++ b/test/unit/network/opencypher/test_opencypher_network.py @@ -480,6 +480,75 @@ def test_group_with_groupby_label(self): self.assertEqual(node1['group'], 'US-AK') self.assertEqual(node2['group'], 'US-TX') + def test_group_with_groupby_depth(self): + res = { + "results": [ + { + "p": [ + { + "~id": "3", + "~entityType": "node", + "~labels": [ + "airport" + ], + "~properties": { + "desc": "Austin Bergstrom International Airport", + } + }, + { + "~id": "3820", + "~entityType": "relationship", + "~start": "3", + "~end": "23", + "~type": "route", + "~properties": { + "dist": 1500 + } + }, + { + "~id": "23", + "~entityType": "node", + "~labels": [ + "airport" + ], + "~properties": { + "desc": "San Francisco International Airport", + } + }, + { + "~id": "7541", + "~entityType": "relationship", + "~start": "23", + "~end": "55", + "~type": "route", + "~properties": { + "dist": 7420 + } + }, + { + "~id": "55", + "~entityType": "node", + "~labels": [ + "airport" + ], + "~properties": { + "desc": "Sydney Kingsford Smith", + } + } + ] + } + ] + } + + gn = OCNetwork(group_by_property='TRAVERSAL_DEPTH') + gn.add_results(res) + node1 = gn.graph.nodes.get('3') + node2 = gn.graph.nodes.get('23') + node3 = gn.graph.nodes.get('55') + self.assertEqual(node1['group'], '__DEPTH-0__') + self.assertEqual(node2['group'], '__DEPTH-1__') + self.assertEqual(node3['group'], '__DEPTH-2__') + def test_path_with_default_groupby(self): res = { "results": [ @@ -600,11 +669,152 @@ def test_group_with_groupby_properties_json_single_label(self): ] } - gn = OCNetwork(group_by_property='{"airport":{"groupby":"code"}}') + gn = OCNetwork(group_by_property='{"airport":"code"}') gn.add_results(res) node1 = gn.graph.nodes.get('22') self.assertEqual(node1['group'], 'SEA') + def test_group_with_groupby_properties_json_label_value(self): + res = { + "results": [ + { + "a": { + "~id": "22", + "~entityType": "node", + "~labels": [ + "airport" + ], + "~properties": { + "runways": 3, + "code": "SEA" + } + } + } + ] + } + + gn = OCNetwork(group_by_property='{"airport":"~labels"}') + gn.add_results(res) + node1 = gn.graph.nodes.get('22') + self.assertEqual(node1['group'], 'airport') + + def test_group_with_groupby_properties_json_ID_value(self): + res = { + "results": [ + { + "a": { + "~id": "22", + "~entityType": "node", + "~labels": [ + "airport" + ], + "~properties": { + "runways": 3, + "code": "SEA" + } + } + } + ] + } + + gn = OCNetwork(group_by_property='{"airport":"~id"}') + gn.add_results(res) + node1 = gn.graph.nodes.get('22') + self.assertEqual(node1['group'], '22') + + def test_group_with_groupby_properties_json_depth(self): + res = { + "results": [ + { + "p": [ + { + "~id": "3", + "~entityType": "node", + "~labels": [ + "airport" + ], + "~properties": { + "desc": "Austin Bergstrom International Airport", + } + }, + { + "~id": "3820", + "~entityType": "relationship", + "~start": "3", + "~end": "23", + "~type": "route", + "~properties": { + "dist": 1500 + } + }, + { + "~id": "23", + "~entityType": "node", + "~labels": [ + "airport" + ], + "~properties": { + "desc": "San Francisco International Airport", + } + }, + { + "~id": "7541", + "~entityType": "relationship", + "~start": "23", + "~end": "55", + "~type": "route", + "~properties": { + "dist": 7420 + } + }, + { + "~id": "55", + "~entityType": "node", + "~labels": [ + "airport" + ], + "~properties": { + "desc": "Sydney Kingsford Smith", + } + } + ] + } + ] + } + + gn = OCNetwork(group_by_property='{"airport":"TRAVERSAL_DEPTH"}') + gn.add_results(res) + node1 = gn.graph.nodes.get('3') + node2 = gn.graph.nodes.get('23') + node3 = gn.graph.nodes.get('55') + self.assertEqual(node1['group'], '__DEPTH-0__') + self.assertEqual(node2['group'], '__DEPTH-1__') + self.assertEqual(node3['group'], '__DEPTH-2__') + + def test_group_with_groupby_properties_json_invalid(self): + res = { + "results": [ + { + "a": { + "~id": "22", + "~entityType": "node", + "~labels": [ + "airport" + ], + "~properties": { + "runways": 3, + "code": "SEA" + } + } + } + ] + } + + gn = OCNetwork(group_by_property='{"airport":"elevation"}') + gn.add_results(res) + node1 = gn.graph.nodes.get('22') + self.assertEqual(node1['group'], 'DEFAULT_GROUP') + def test_group_with_groupby_properties_json_multiple_labels(self): path = { "results": [ @@ -654,7 +864,7 @@ def test_group_with_groupby_properties_json_multiple_labels(self): ] } - gn = OCNetwork(group_by_property='{"airport":{"groupby":"code"},"country":{"groupby":"desc"}}') + gn = OCNetwork(group_by_property='{"airport":"code","country":"desc"}') gn.add_results(path) node1 = gn.graph.nodes.get('2') node2 = gn.graph.nodes.get('3670')