diff --git a/libs/community/langchain_community/graphs/neo4j_graph.py b/libs/community/langchain_community/graphs/neo4j_graph.py index 8d7da7fe91b41..b7037621ce135 100644 --- a/libs/community/langchain_community/graphs/neo4j_graph.py +++ b/libs/community/langchain_community/graphs/neo4j_graph.py @@ -5,10 +5,13 @@ from langchain_community.graphs.graph_document import GraphDocument from langchain_community.graphs.graph_store import GraphStore +BASE_ENTITY_LABEL = "__Entity__" + node_properties_query = """ CALL apoc.meta.data() YIELD label, other, elementType, type, property -WHERE NOT type = "RELATIONSHIP" AND elementType = "node" +WHERE NOT type = "RELATIONSHIP" AND elementType = "node" + AND NOT label IN [$BASE_ENTITY_LABEL] WITH label AS nodeLabels, collect({property:property, type:type}) AS properties RETURN {labels: nodeLabels, properties: properties} AS output @@ -27,9 +30,18 @@ YIELD label, other, elementType, type, property WHERE type = "RELATIONSHIP" AND elementType = "node" UNWIND other AS other_node +WITH * WHERE NOT label IN [$BASE_ENTITY_LABEL] + AND NOT other_node IN [$BASE_ENTITY_LABEL] RETURN {start: label, type: property, end: toString(other_node)} AS output """ +include_docs_query = ( + "CREATE (d:Document) " + "SET d.text = $document.page_content " + "SET d += $document.metadata " + "WITH d " +) + def value_sanitize(d: Dict[str, Any]) -> Dict[str, Any]: """Sanitize the input dictionary. @@ -63,6 +75,53 @@ def value_sanitize(d: Dict[str, Any]) -> Dict[str, Any]: return new_dict +def _get_node_import_query(baseEntityLabel: bool, include_source: bool) -> str: + if baseEntityLabel: + return ( + f"{include_docs_query if include_source else ''}" + "UNWIND $data AS row " + f"MERGE (source:`{BASE_ENTITY_LABEL}` {{id: row.id}}) " + "SET source += row.properties " + f"{'MERGE (d)-[:MENTIONS]->(source) ' if include_source else ''}" + "WITH source, row " + "CALL apoc.create.addLabels( source, [row.type] ) YIELD node " + "RETURN distinct 'done' AS result" + ) + else: + return ( + f"{include_docs_query if include_source else ''}" + "UNWIND $data AS row " + "CALL apoc.merge.node([row.type], {id: row.id}, " + "row.properties, {}) YIELD node " + f"{'MERGE (d)-[:MENTIONS]->(node) ' if include_source else ''}" + "RETURN distinct 'done' AS result" + ) + + +def _get_rel_import_query(baseEntityLabel: bool) -> str: + if baseEntityLabel: + return ( + "UNWIND $data AS row " + f"MERGE (source:`{BASE_ENTITY_LABEL}` {{id: row.source}}) " + f"MERGE (target:`{BASE_ENTITY_LABEL}` {{id: row.target}}) " + "WITH source, target, row " + "CALL apoc.merge.relationship(source, row.type, " + "{}, row.properties, target) YIELD rel " + "RETURN distinct 'done'" + ) + else: + return ( + "UNWIND $data AS row " + "CALL apoc.merge.node([row.source_label], {id: row.source}," + "{}, {}) YIELD node as source " + "CALL apoc.merge.node([row.target_label], {id: row.target}," + "{}, {}) YIELD node as target " + "CALL apoc.merge.relationship(source, row.type, " + "{}, row.properties, target) YIELD rel " + "RETURN distinct 'done'" + ) + + class Neo4jGraph(GraphStore): """Neo4j database wrapper for various graph operations. @@ -173,14 +232,42 @@ def refresh_schema(self) -> None: """ Refreshes the Neo4j graph schema information. """ - node_properties = [el["output"] for el in self.query(node_properties_query)] - rel_properties = [el["output"] for el in self.query(rel_properties_query)] - relationships = [el["output"] for el in self.query(rel_query)] + from neo4j.exceptions import ClientError + + node_properties = [ + el["output"] + for el in self.query( + node_properties_query, params={"BASE_ENTITY_LABEL": BASE_ENTITY_LABEL} + ) + ] + rel_properties = [ + el["output"] + for el in self.query( + rel_properties_query, params={"BASE_ENTITY_LABEL": BASE_ENTITY_LABEL} + ) + ] + relationships = [ + el["output"] + for el in self.query( + rel_query, params={"BASE_ENTITY_LABEL": BASE_ENTITY_LABEL} + ) + ] + + # Get constraints & indexes + try: + constraint = self.query("SHOW CONSTRAINTS") + index = self.query("SHOW INDEXES YIELD *") + except ( + ClientError + ): # Read-only user might not have access to schema information + constraint = [] + index = [] self.structured_schema = { "node_props": {el["labels"]: el["properties"] for el in node_properties}, "rel_props": {el["type"]: el["properties"] for el in rel_properties}, "relationships": relationships, + "metadata": {"constraint": constraint, "index": index}, } # Format node properties @@ -216,28 +303,51 @@ def refresh_schema(self) -> None: ) def add_graph_documents( - self, graph_documents: List[GraphDocument], include_source: bool = False + self, + graph_documents: List[GraphDocument], + include_source: bool = False, + baseEntityLabel: bool = False, ) -> None: """ - Take GraphDocument as input as uses it to construct a graph. + This method constructs nodes and relationships in the graph based on the + provided GraphDocument objects. + + Parameters: + - graph_documents (List[GraphDocument]): A list of GraphDocument objects + that contain the nodes and relationships to be added to the graph. Each + GraphDocument should encapsulate the structure of part of the graph, + including nodes, relationships, and the source document information. + - include_source (bool, optional): If True, stores the source document + and links it to nodes in the graph using the MENTIONS relationship. + This is useful for tracing back the origin of data. Defaults to False. + - baseEntityLabel (bool, optional): If True, each newly created node + gets a secondary __Entity__ label, which is indexed and improves import + speed and performance. Defaults to False. """ - for document in graph_documents: - include_docs_query = ( - "CREATE (d:Document) " - "SET d.text = $document.page_content " - "SET d += $document.metadata " - "WITH d " + if baseEntityLabel: # Check if constraint already exists + constraint_exists = any( + [ + el["labelsOrTypes"] == [BASE_ENTITY_LABEL] + and el["properties"] == ["id"] + for el in self.structured_schema.get("metadata", {}).get( + "constraint" + ) + ] ) + if not constraint_exists: + # Create constraint + self.query( + f"CREATE CONSTRAINT IF NOT EXISTS FOR (b:{BASE_ENTITY_LABEL}) " + "REQUIRE b.id IS UNIQUE;" + ) + self.refresh_schema() # Refresh constraint information + + node_import_query = _get_node_import_query(baseEntityLabel, include_source) + rel_import_query = _get_rel_import_query(baseEntityLabel) + for document in graph_documents: # Import nodes self.query( - ( - f"{include_docs_query if include_source else ''}" - "UNWIND $data AS row " - "CALL apoc.merge.node([row.type], {id: row.id}, " - "row.properties, {}) YIELD node " - f"{'MERGE (d)-[:MENTIONS]->(node) ' if include_source else ''}" - "RETURN distinct 'done' AS result" - ), + node_import_query, { "data": [el.__dict__ for el in document.nodes], "document": document.source.__dict__, @@ -245,14 +355,7 @@ def add_graph_documents( ) # Import relationships self.query( - "UNWIND $data AS row " - "CALL apoc.merge.node([row.source_label], {id: row.source}," - "{}, {}) YIELD node as source " - "CALL apoc.merge.node([row.target_label], {id: row.target}," - "{}, {}) YIELD node as target " - "CALL apoc.merge.relationship(source, row.type, " - "{}, row.properties, target) YIELD rel " - "RETURN distinct 'done'", + rel_import_query, { "data": [ { diff --git a/libs/community/tests/integration_tests/graphs/test_neo4j.py b/libs/community/tests/integration_tests/graphs/test_neo4j.py index a209c56d099b8..1fcb4b7fbb636 100644 --- a/libs/community/tests/integration_tests/graphs/test_neo4j.py +++ b/libs/community/tests/integration_tests/graphs/test_neo4j.py @@ -1,12 +1,30 @@ import os +from langchain_core.documents import Document + from langchain_community.graphs import Neo4jGraph +from langchain_community.graphs.graph_document import GraphDocument, Node, Relationship from langchain_community.graphs.neo4j_graph import ( + BASE_ENTITY_LABEL, node_properties_query, rel_properties_query, rel_query, ) +test_data = [ + GraphDocument( + nodes=[Node(id="foo", type="foo"), Node(id="bar", type="bar")], + relationships=[ + Relationship( + source=Node(id="foo", type="foo"), + target=Node(id="bar", type="bar"), + type="REL", + ) + ], + source=Document(page_content="source document"), + ) +] + def test_cypher_return_correct_schema() -> None: """Test that chain returns direct results.""" @@ -37,9 +55,15 @@ def test_cypher_return_correct_schema() -> None: # Refresh schema information graph.refresh_schema() - node_properties = graph.query(node_properties_query) - relationships_properties = graph.query(rel_properties_query) - relationships = graph.query(rel_query) + node_properties = graph.query( + node_properties_query, params={"BASE_ENTITY_LABEL": BASE_ENTITY_LABEL} + ) + relationships_properties = graph.query( + rel_properties_query, params={"BASE_ENTITY_LABEL": BASE_ENTITY_LABEL} + ) + relationships = graph.query( + rel_query, params={"BASE_ENTITY_LABEL": BASE_ENTITY_LABEL} + ) expected_node_properties = [ { @@ -116,3 +140,112 @@ def test_neo4j_sanitize_values() -> None: output = graph.query("RETURN range(0,130,1) AS result") assert output == [{}] + + +def test_neo4j_add_data() -> None: + """Test that neo4j correctly import graph document.""" + url = os.environ.get("NEO4J_URI") + username = os.environ.get("NEO4J_USERNAME") + password = os.environ.get("NEO4J_PASSWORD") + assert url is not None + assert username is not None + assert password is not None + + graph = Neo4jGraph(url=url, username=username, password=password, sanitize=True) + # Delete all nodes in the graph + graph.query("MATCH (n) DETACH DELETE n") + # Remove all constraints + graph.query("CALL apoc.schema.assert({}, {})") + graph.refresh_schema() + # Create two nodes and a relationship + graph.add_graph_documents(test_data) + output = graph.query( + "MATCH (n) RETURN labels(n) AS label, count(*) AS count ORDER BY label" + ) + assert output == [{"label": ["bar"], "count": 1}, {"label": ["foo"], "count": 1}] + assert graph.structured_schema["metadata"]["constraint"] == [] + + +def test_neo4j_add_data_source() -> None: + """Test that neo4j correctly import graph document with source.""" + url = os.environ.get("NEO4J_URI") + username = os.environ.get("NEO4J_USERNAME") + password = os.environ.get("NEO4J_PASSWORD") + assert url is not None + assert username is not None + assert password is not None + + graph = Neo4jGraph(url=url, username=username, password=password, sanitize=True) + # Delete all nodes in the graph + graph.query("MATCH (n) DETACH DELETE n") + # Remove all constraints + graph.query("CALL apoc.schema.assert({}, {})") + graph.refresh_schema() + # Create two nodes and a relationship + graph.add_graph_documents(test_data, include_source=True) + output = graph.query( + "MATCH (n) RETURN labels(n) AS label, count(*) AS count ORDER BY label" + ) + assert output == [ + {"label": ["Document"], "count": 1}, + {"label": ["bar"], "count": 1}, + {"label": ["foo"], "count": 1}, + ] + assert graph.structured_schema["metadata"]["constraint"] == [] + + +def test_neo4j_add_data_base() -> None: + """Test that neo4j correctly import graph document with base_entity.""" + url = os.environ.get("NEO4J_URI") + username = os.environ.get("NEO4J_USERNAME") + password = os.environ.get("NEO4J_PASSWORD") + assert url is not None + assert username is not None + assert password is not None + + graph = Neo4jGraph(url=url, username=username, password=password, sanitize=True) + # Delete all nodes in the graph + graph.query("MATCH (n) DETACH DELETE n") + # Remove all constraints + graph.query("CALL apoc.schema.assert({}, {})") + graph.refresh_schema() + # Create two nodes and a relationship + graph.add_graph_documents(test_data, baseEntityLabel=True) + output = graph.query( + "MATCH (n) RETURN apoc.coll.sort(labels(n)) AS label, " + "count(*) AS count ORDER BY label" + ) + assert output == [ + {"label": [BASE_ENTITY_LABEL, "bar"], "count": 1}, + {"label": [BASE_ENTITY_LABEL, "foo"], "count": 1}, + ] + assert graph.structured_schema["metadata"]["constraint"] != [] + + +def test_neo4j_add_data_base_source() -> None: + """Test that neo4j correctly import graph document with base_entity and source.""" + url = os.environ.get("NEO4J_URI") + username = os.environ.get("NEO4J_USERNAME") + password = os.environ.get("NEO4J_PASSWORD") + assert url is not None + assert username is not None + assert password is not None + + graph = Neo4jGraph(url=url, username=username, password=password, sanitize=True) + # Delete all nodes in the graph + graph.query("MATCH (n) DETACH DELETE n") + # Remove all constraints + graph.query("CALL apoc.schema.assert({}, {})") + graph.refresh_schema() + # Create two nodes and a relationship + graph.add_graph_documents(test_data, baseEntityLabel=True, include_source=True) + output = graph.query( + "MATCH (n) RETURN apoc.coll.sort(labels(n)) AS label, " + "count(*) AS count ORDER BY label" + ) + assert output == [ + {"label": ["Document"], "count": 1}, + {"label": [BASE_ENTITY_LABEL, "bar"], "count": 1}, + {"label": [BASE_ENTITY_LABEL, "foo"], "count": 1}, + ] + assert graph.structured_schema["metadata"]["constraint"] != []