Skip to content

Commit

Permalink
refactoring and improving efficiency of transform_attributes
Browse files Browse the repository at this point in the history
- moved construct sources tree into transform attributes and refactored it
- reworked how attributes get passed around some so that we don't have to check reserved properties, sources, or an ignore list, instead remove them before the checks
- removed support for biolink prefixes on knowledge source attributes
- improved tests for attributes and knowledge sources
  • Loading branch information
EvanDietzMorris committed Oct 3, 2024
1 parent 0d3531c commit 2ef1067
Show file tree
Hide file tree
Showing 8 changed files with 184 additions and 183 deletions.
129 changes: 80 additions & 49 deletions reasoner_transpiler/attributes.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import json
import os
from pathlib import Path

from .biolink import bmt
Expand All @@ -20,41 +21,63 @@ def get_attribute_types_from_config():

ATTRIBUTE_SKIP_LIST = []

RESERVED_NODE_PROPS = [
"id",
"name",
"labels",
"element_id"
]
RESERVED_EDGE_PROPS = [
"id",
"predicate",
"object",
"subject",
"sources"
]

# this should really be one representation or the other, or be configurable,
# but we have graphs with each now so temporarily (I hope, hope, hope) looking for both
EDGE_SOURCE_PROPS = [
"aggregator_knowledge_source",
"primary_knowledge_source",
"biolink:aggregator_knowledge_source",
"biolink:primary_knowledge_source"
]


def transform_attributes(result_item, node=False):

# make a list of attributes to ignore while processing
ignore_list = RESERVED_NODE_PROPS if node else EDGE_SOURCE_PROPS + RESERVED_EDGE_PROPS
ignore_list += ATTRIBUTE_SKIP_LIST
PRIMARY_KNOWLEDGE_SOURCE = "primary_knowledge_source"
AGGREGATOR_KNOWLEDGE_SOURCE = "aggregator_knowledge_source"

PROVENANCE_TAG = os.environ.get('PROVENANCE_TAG', 'reasoner-transpiler')


# This function takes EDGE_SOURCE_PROPS properties from results, converts them into proper
# TRAPI dictionaries, and assigns the proper upstream ids to each resource. It does not currently attempt to avoid
# duplicate aggregator results, which shouldn't exist in the graphs.
def construct_sources_tree(primary_knowledge_source, aggregator_knowledge_sources):

if not primary_knowledge_source:
return [{"resource_id": PROVENANCE_TAG,
"resource_role": "primary_knowledge_source"}]

# set the primary knowledge source
formatted_sources = [{"resource_id": primary_knowledge_source,
"resource_role": "primary_knowledge_source"}]

# walk through the aggregator lists and construct the chains of provenance
terminal_aggregators = set()
for aggregator_list in aggregator_knowledge_sources:
# each aggregator list should be in order, so we can deduce the upstream chains
last_aggregator = None
for aggregator_knowledge_source in aggregator_list:
formatted_sources.append({
"resource_id": aggregator_knowledge_source,
"resource_role": "aggregator_knowledge_source",
"upstream_resource_ids": [last_aggregator] if last_aggregator else [primary_knowledge_source]
})
last_aggregator = aggregator_knowledge_source
# store the last aggregator in the list, because this will be an upstream source for the plater one
terminal_aggregators.add(last_aggregator)
# add PROVENANCE_TAG as the most downstream aggregator,
# it will have as upstream either the primary ks or all of the furthest downstream aggregators if they exist
# this will be used by applications like Plater which need to append themselves as an aggregator
formatted_sources.append({
"resource_id": PROVENANCE_TAG,
"resource_role": "aggregator_knowledge_source",
"upstream_resource_ids": list(terminal_aggregators) if terminal_aggregators else [primary_knowledge_source]
})
return formatted_sources


def transform_attributes(result_entity, node=False):

# construct a valid TRAPI entity to return in trapi_entity
trapi_entity = {}

for attribute in ATTRIBUTE_SKIP_LIST:
result_entity.pop(attribute, None)

# an "attributes" attribute in neo4j should be a list of json strings,
# attempt to start the attributes section of transformed attributes with its contents,
# here we are assuming the attributes in "attributes" are already valid trapi
json_attributes = []
json_attributes_attribute = result_item.pop('attributes', None)
json_attributes_attribute = result_entity.pop('attributes', None)
if json_attributes_attribute:
if isinstance(json_attributes_attribute, list):
try:
Expand All @@ -64,26 +87,33 @@ def transform_attributes(result_item, node=False):
print(f'!!! JSONDecodeError while parsing attributes property, ignoring: {json_attributes_attribute}')
else:
print(f'!!! the attributes edge property should be a list, ignoring: {json_attributes_attribute}')
transformed_attributes = {
'attributes': json_attributes
}
trapi_attributes = json_attributes
else:
trapi_attributes = []

# if it's an edge handle provenance (sources) and qualifiers
if not node:
# for edges, find and format attributes that are qualifiers
qualifiers = [key for key in result_item if key not in ignore_list
and bmt.is_qualifier(key)]
transformed_attributes['qualifiers'] = [
{"qualifier_type_id": f"biolink:{key}",
"qualifier_value": value}
for key, value in result_item.items() if key in qualifiers
]
else:
qualifiers = []
# extract properties for provenance, construct the sources section
primary_knowledge_source = result_entity.pop(PRIMARY_KNOWLEDGE_SOURCE, None)
# get any properties that start with AGGREGATOR_KNOWLEDGE_SOURCE, this handles the possibility of edges
# with multiple aggregator knowledge source lists like aggregator_knowledge_source_2
aggregator_knowledge_source_keys = [ks_attribute for ks_attribute in result_entity.keys()
if ks_attribute.startswith(AGGREGATOR_KNOWLEDGE_SOURCE)]
aggregator_knowledge_sources = [result_entity[key] for key in aggregator_knowledge_source_keys]
for ks_property in aggregator_knowledge_source_keys:
result_entity.pop(ks_property)
trapi_entity["sources"] = construct_sources_tree(primary_knowledge_source, aggregator_knowledge_sources)

# find and format attributes that are qualifiers
qualifiers = [key for key in result_entity if bmt.is_qualifier(key)]
if qualifiers:
trapi_entity["qualifiers"] = [{"qualifier_type_id": f"biolink:{qualifier}",
"qualifier_value": result_entity.pop(qualifier)}
for qualifier in qualifiers]

# for attributes that aren't in ATTRIBUTE_TYPES, see if they are valid biolink attributes
# add them to ATTRIBUTE_TYPES, so we don't need to look again
for attribute in \
[key for key in result_item.keys() if key not in ignore_list + qualifiers + list(ATTRIBUTE_TYPES.keys())]:
for attribute in [key for key in result_entity.keys() if key not in list(ATTRIBUTE_TYPES.keys())]:
attribute_mapping = DEFAULT_ATTRIBUTE_TYPE
bmt_element = bmt.get_element(attribute)
if bmt_element:
Expand All @@ -94,17 +124,18 @@ def transform_attributes(result_item, node=False):
ATTRIBUTE_TYPES[attribute] = attribute_mapping

# format the rest of the attributes, look up their attribute type and value type
transformed_attributes['attributes'].extend([
trapi_attributes.extend([
{'original_attribute_name': key,
'value': value,
# the following function will return
# 'attribute_type_id': 'biolink-ified attribute type id'
# 'value_type_id': 'biolink-ified value type id'
**ATTRIBUTE_TYPES.get(key)}
for key, value in result_item.items()
if key not in ignore_list + qualifiers
for key, value in result_entity.items()
])
return transformed_attributes
if trapi_attributes:
trapi_entity["attributes"] = trapi_attributes
return trapi_entity


def set_custom_attribute_types(attribute_types: dict):
Expand Down
103 changes: 19 additions & 84 deletions reasoner_transpiler/cypher.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,12 @@
"""Tools for compiling QGraph into Cypher query."""
import os
import json

from collections import defaultdict

from .attributes import transform_attributes, EDGE_SOURCE_PROPS
from .attributes import transform_attributes, PROVENANCE_TAG
from .matching import match_query


PROVENANCE_TAG = os.environ.get('PROVENANCE_TAG', 'reasoner-transpiler')


def nest_op(operator, *args):
"""Generate a nested set of operations from a flat expression."""
if len(args) > 2:
Expand Down Expand Up @@ -51,7 +47,8 @@ def assemble_results(qnodes, qedges, **kwargs):
])
if not edges_assemble:
edges_assemble = '[]'
assemble_clause = f"WITH apoc.coll.toSet({nodes_assemble}) AS nodes, apoc.coll.toSet({edges_assemble}) AS edges, collect(DISTINCT ["
assemble_clause = f"WITH apoc.coll.toSet({nodes_assemble}) AS nodes, " \
f"apoc.coll.toSet({edges_assemble}) AS edges, collect(DISTINCT ["

if nodes:
assemble_clause += ', '.join(nodes)
Expand Down Expand Up @@ -117,9 +114,18 @@ def transform_result(cypher_record,

nodes, edges, paths = unpack_bolt_record(cypher_record)

# Convert the list of unique result nodes from cypher results to dictionaries
# then convert them to TRAPI format, constructing the knowledge_graph["nodes"] section of the TRAPI response
kg_nodes = transform_nodes_list(nodes)
# Construct the knowledge_graph["nodes"] section of the TRAPI response
kg_nodes = {}
for cypher_node in nodes:
# Convert the list of unique result nodes from cypher results to dictionaries
node = convert_bolt_node_to_dict(cypher_node)
# Convert nodes to TRAPI format
# id, name, and labels are removed before transform_attributes
node_id = node.pop('id')
kg_nodes[node_id] = {
'name': node.pop('name'),
'categories': sorted(node.pop('labels'))}
kg_nodes[node_id].update(**transform_attributes(node, node=True))

# Convert the list of unique edges from cypher results to dictionaries
# then convert them to TRAPI format, constructing the knowledge_graph["edges"] section of the TRAPI response.
Expand Down Expand Up @@ -213,7 +219,8 @@ def transform_result(cypher_record,
# Check to see if the edge has subclass edges that are connected to it
subclass_edge_ids = []
superclass_node_ids = {}
for (subclass_subject_or_object, subclass_qedge_id, superclass_qnode_id) in qedges_with_attached_subclass_edges.get(qedge_id, []):
for (subclass_subject_or_object, subclass_qedge_id, superclass_qnode_id) in \
qedges_with_attached_subclass_edges.get(qedge_id, []):
# If so, check to see if there are results for it
qedge, subclass_edge_element_ids = qedge_id_to_results[subclass_qedge_id]
if subclass_edge_element_ids:
Expand Down Expand Up @@ -297,17 +304,6 @@ def transform_result(cypher_record,
return transformed_results


def transform_nodes_list(nodes):
kg_nodes = {}
for cypher_node in nodes:
node = convert_bolt_node_to_dict(cypher_node)
kg_nodes[node['id']] = {
'name': node['name'],
'categories': sorted(node.pop('labels')),
**transform_attributes(node, node=True)}
return kg_nodes


def transform_edges_list(edges):
# See convert_bolt_edge_to_dict() for details on the contents of edges,
# it is a list of lists (which can also be lists), representing unique edges from the graph
Expand Down Expand Up @@ -352,65 +348,10 @@ def transform_edges_list(edges):
return kg_edges, element_id_to_edge_id


# This function takes EDGE_SOURCE_PROPS properties from results, converts them into proper
# TRAPI dictionaries, and assigns the proper upstream ids to each resource. It does not currently attempt to avoid
# duplicate aggregator results, which shouldn't exist in the graphs.
def construct_sources_tree(sources):

# first find the primary knowledge source, there should always be one
primary_knowledge_source = None
formatted_sources = None
for resource_role, resource_id in sources:
if resource_role == "primary_knowledge_source":
primary_knowledge_source = resource_id
# add it to the formatted TRAPI output
formatted_sources = [{
"resource_id": primary_knowledge_source,
"resource_role": "primary_knowledge_source"
}]
if not primary_knowledge_source:
# we could hard fail here, every edge should have a primary ks, but I haven't fixed all the tests yet
# raise KeyError(f'primary_knowledge_source missing from sources section of cypher results! '
# f'sources: {sources}')
return []

# then find any aggregator lists
aggregator_list_sources = []
for resource_role, resource_id in sources:
# this looks weird but the idea is that you could have a few parallel lists like:
# aggregator_knowledge_source, aggregator_knowledge_source_2, aggregator_knowledge_source_3
if resource_role.startswith("aggregator_knowledge_source"):
aggregator_list_sources.append(resource_id)
# walk through the aggregator lists and construct the chains of provenance
terminal_aggregators = set()
for aggregator_list in aggregator_list_sources:
# each aggregator list should be in order, so we can deduce the upstream chains
last_aggregator = None
for aggregator_knowledge_source in aggregator_list:
formatted_sources.append({
"resource_id": aggregator_knowledge_source,
"resource_role": "aggregator_knowledge_source",
"upstream_resource_ids": [last_aggregator] if last_aggregator else [primary_knowledge_source]
})
last_aggregator = aggregator_knowledge_source
# store the last aggregator in the list, because this will be an upstream source for the plater one
terminal_aggregators.add(last_aggregator)
# add PROVENANCE_TAG as the most downstream aggregator,
# it will have as upstream either the primary ks or all of the furthest downstream aggregators if they exist
# this will be used by applications like Plater which need to append themselves as an aggregator
formatted_sources.append({
"resource_id": PROVENANCE_TAG,
"resource_role": "aggregator_knowledge_source",
"upstream_resource_ids": list(terminal_aggregators) if terminal_aggregators else [primary_knowledge_source]
})
return list(formatted_sources)


def convert_bolt_node_to_dict(bolt_node):
if not bolt_node:
return None
node = {key: value for key, value in bolt_node.items()}
# node['element_id'] = bolt_node.element_id
node['labels'] = bolt_node.labels
return node

Expand Down Expand Up @@ -442,16 +383,10 @@ def convert_bolt_edge_to_trapi(bolt_edge):
# edge_props - any other properties from the edge
edge_props = {**bolt_edge[4]}

# get the id if there is one on the edge
# retrieve and remove the id if there is one on the edge
edge_id = edge_props.pop('id', None)

# get properties matching EDGE_SOURCE_PROPS keys, remove biolink: if needed,
# then pass (key, value) tuples to construct_sources_tree for formatting, constructing the sources section
converted_edge['sources'] = construct_sources_tree([
(edge_source_prop.removeprefix('biolink:'), edge_props.pop(edge_source_prop))
for edge_source_prop in EDGE_SOURCE_PROPS if edge_source_prop in edge_props])

# convert all remaining attributes to TRAPI format, constructing the attributes section
# convert all remaining attributes to TRAPI format, constructing the attributes and sources sections
converted_edge.update(transform_attributes(edge_props, node=False))

# return the edge id if there was one, and a TRAPI edge
Expand Down
14 changes: 8 additions & 6 deletions tests/initialize_db.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,9 @@
import logging
import time

import neo4j.exceptions
from neo4j import GraphDatabase
from neo4j.exceptions import ServiceUnavailable, DatabaseUnavailable
from neo4j.exceptions import ServiceUnavailable, DatabaseUnavailable, ClientError

LOGGER = logging.getLogger(__name__)

Expand All @@ -19,11 +20,9 @@ def get_driver(url):
while True:
try:
driver = GraphDatabase.driver(url, auth=("neo4j", "plater_testing_pw"))
# make sure we can start and finish a session
with driver.session() as session:
session.run("SHOW PROCEDURES")
driver.verify_connectivity()
return driver
except (OSError, ServiceUnavailable, DatabaseUnavailable) as err:
except (OSError, ServiceUnavailable, DatabaseUnavailable, ClientError) as err:
if seconds >= 256:
raise err
LOGGER.error(
Expand Down Expand Up @@ -53,6 +52,7 @@ def main(hash: str = None):
"name: row.name, id: row.id"
"}, apoc.convert.fromJsonMap(row.props))) YIELD node "
"RETURN count(*)")
print(f'Nodes added: {result.single()["count(*)"]}')
result.consume() # this looks like it doesn't do anything, but it's needed to throw errors if they occur
result = session.run(f"LOAD CSV WITH HEADERS FROM \"{edge_file}\" "
"AS edge "
Expand All @@ -62,8 +62,10 @@ def main(hash: str = None):
"apoc.map.merge({predicate: edge.predicate, id: edge.id}, "
"apoc.convert.fromJsonMap(edge.props)), object) YIELD rel "
"RETURN count(*)")
result.consume() # this looks like it doesn't do anything, but it's needed to throw errors if they occur
print(f'Edges added: {result.single()["count(*)"]}')
result.consume() # this looks like it doesn't do anything, but it's needed to throw errors if they occur

driver.close()
LOGGER.info("Done. Neo4j is ready for testing.")


Expand Down
Loading

0 comments on commit 2ef1067

Please sign in to comment.