Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add an option for indexed generic label when import neo4j graph documents #18122

Merged
merged 9 commits into from
Mar 1, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
159 changes: 131 additions & 28 deletions libs/community/langchain_community/graphs/neo4j_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,13 @@
from langchain_community.graphs.graph_document import GraphDocument
from langchain_community.graphs.graph_store import GraphStore

BASE_ENTITY_LABEL = "__Entity__"

node_properties_query = """
CALL apoc.meta.data()
YIELD label, other, elementType, type, property
WHERE NOT type = "RELATIONSHIP" AND elementType = "node"
WHERE NOT type = "RELATIONSHIP" AND elementType = "node"
AND NOT label IN [$BASE_ENTITY_LABEL]
WITH label AS nodeLabels, collect({property:property, type:type}) AS properties
RETURN {labels: nodeLabels, properties: properties} AS output

Expand All @@ -27,9 +30,18 @@
YIELD label, other, elementType, type, property
WHERE type = "RELATIONSHIP" AND elementType = "node"
UNWIND other AS other_node
WITH * WHERE NOT label IN [$BASE_ENTITY_LABEL]
AND NOT other_node IN [$BASE_ENTITY_LABEL]
RETURN {start: label, type: property, end: toString(other_node)} AS output
"""

include_docs_query = (
"CREATE (d:Document) "
"SET d.text = $document.page_content "
"SET d += $document.metadata "
"WITH d "
)


def value_sanitize(d: Dict[str, Any]) -> Dict[str, Any]:
"""Sanitize the input dictionary.
Expand Down Expand Up @@ -63,6 +75,53 @@ def value_sanitize(d: Dict[str, Any]) -> Dict[str, Any]:
return new_dict


def _get_node_import_query(baseEntityLabel: bool, include_source: bool) -> str:
if baseEntityLabel:
return (
f"{include_docs_query if include_source else ''}"
"UNWIND $data AS row "
f"MERGE (source:`{BASE_ENTITY_LABEL}` {{id: row.id}}) "
"SET source += row.properties "
f"{'MERGE (d)-[:MENTIONS]->(source) ' if include_source else ''}"
"WITH source, row "
"CALL apoc.create.addLabels( source, [row.type] ) YIELD node "
"RETURN distinct 'done' AS result"
)
else:
return (
f"{include_docs_query if include_source else ''}"
"UNWIND $data AS row "
"CALL apoc.merge.node([row.type], {id: row.id}, "
"row.properties, {}) YIELD node "
f"{'MERGE (d)-[:MENTIONS]->(node) ' if include_source else ''}"
"RETURN distinct 'done' AS result"
)


def _get_rel_import_query(baseEntityLabel: bool) -> str:
if baseEntityLabel:
return (
"UNWIND $data AS row "
f"MERGE (source:`{BASE_ENTITY_LABEL}` {{id: row.source}}) "
f"MERGE (target:`{BASE_ENTITY_LABEL}` {{id: row.target}}) "
"WITH source, target, row "
"CALL apoc.merge.relationship(source, row.type, "
"{}, row.properties, target) YIELD rel "
"RETURN distinct 'done'"
)
else:
return (
"UNWIND $data AS row "
"CALL apoc.merge.node([row.source_label], {id: row.source},"
"{}, {}) YIELD node as source "
"CALL apoc.merge.node([row.target_label], {id: row.target},"
"{}, {}) YIELD node as target "
"CALL apoc.merge.relationship(source, row.type, "
"{}, row.properties, target) YIELD rel "
"RETURN distinct 'done'"
)


class Neo4jGraph(GraphStore):
"""Neo4j database wrapper for various graph operations.

Expand Down Expand Up @@ -173,14 +232,42 @@ def refresh_schema(self) -> None:
"""
Refreshes the Neo4j graph schema information.
"""
node_properties = [el["output"] for el in self.query(node_properties_query)]
rel_properties = [el["output"] for el in self.query(rel_properties_query)]
relationships = [el["output"] for el in self.query(rel_query)]
from neo4j.exceptions import ClientError

node_properties = [
el["output"]
for el in self.query(
node_properties_query, params={"BASE_ENTITY_LABEL": BASE_ENTITY_LABEL}
)
]
rel_properties = [
el["output"]
for el in self.query(
rel_properties_query, params={"BASE_ENTITY_LABEL": BASE_ENTITY_LABEL}
)
]
relationships = [
el["output"]
for el in self.query(
rel_query, params={"BASE_ENTITY_LABEL": BASE_ENTITY_LABEL}
)
]

# Get constraints & indexes
try:
constraint = self.query("SHOW CONSTRAINTS")
index = self.query("SHOW INDEXES YIELD *")
except (
ClientError
): # Read-only user might not have access to schema information
constraint = []
index = []

self.structured_schema = {
"node_props": {el["labels"]: el["properties"] for el in node_properties},
"rel_props": {el["type"]: el["properties"] for el in rel_properties},
"relationships": relationships,
"metadata": {"constraint": constraint, "index": index},
}

# Format node properties
Expand Down Expand Up @@ -216,43 +303,59 @@ def refresh_schema(self) -> None:
)

def add_graph_documents(
self, graph_documents: List[GraphDocument], include_source: bool = False
self,
graph_documents: List[GraphDocument],
include_source: bool = False,
baseEntityLabel: bool = False,
) -> None:
"""
Take GraphDocument as input as uses it to construct a graph.
This method constructs nodes and relationships in the graph based on the
provided GraphDocument objects.

Parameters:
- graph_documents (List[GraphDocument]): A list of GraphDocument objects
that contain the nodes and relationships to be added to the graph. Each
GraphDocument should encapsulate the structure of part of the graph,
including nodes, relationships, and the source document information.
- include_source (bool, optional): If True, stores the source document
and links it to nodes in the graph using the MENTIONS relationship.
This is useful for tracing back the origin of data. Defaults to False.
- baseEntityLabel (bool, optional): If True, each newly created node
gets a secondary __Entity__ label, which is indexed and improves import
speed and performance. Defaults to False.
"""
for document in graph_documents:
include_docs_query = (
"CREATE (d:Document) "
"SET d.text = $document.page_content "
"SET d += $document.metadata "
"WITH d "
if baseEntityLabel: # Check if constraint already exists
constraint_exists = any(
[
el["labelsOrTypes"] == [BASE_ENTITY_LABEL]
and el["properties"] == ["id"]
for el in self.structured_schema.get("metadata", {}).get(
"constraint"
)
]
)
if not constraint_exists:
# Create constraint
self.query(
f"CREATE CONSTRAINT IF NOT EXISTS FOR (b:{BASE_ENTITY_LABEL}) "
"REQUIRE b.id IS UNIQUE;"
)
self.refresh_schema() # Refresh constraint information

node_import_query = _get_node_import_query(baseEntityLabel, include_source)
rel_import_query = _get_rel_import_query(baseEntityLabel)
for document in graph_documents:
# Import nodes
self.query(
(
f"{include_docs_query if include_source else ''}"
"UNWIND $data AS row "
"CALL apoc.merge.node([row.type], {id: row.id}, "
"row.properties, {}) YIELD node "
f"{'MERGE (d)-[:MENTIONS]->(node) ' if include_source else ''}"
"RETURN distinct 'done' AS result"
),
node_import_query,
{
"data": [el.__dict__ for el in document.nodes],
"document": document.source.__dict__,
},
)
# Import relationships
self.query(
"UNWIND $data AS row "
"CALL apoc.merge.node([row.source_label], {id: row.source},"
"{}, {}) YIELD node as source "
"CALL apoc.merge.node([row.target_label], {id: row.target},"
"{}, {}) YIELD node as target "
"CALL apoc.merge.relationship(source, row.type, "
"{}, row.properties, target) YIELD rel "
"RETURN distinct 'done'",
rel_import_query,
{
"data": [
{
Expand Down
139 changes: 136 additions & 3 deletions libs/community/tests/integration_tests/graphs/test_neo4j.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,30 @@
import os

from langchain_core.documents import Document

from langchain_community.graphs import Neo4jGraph
from langchain_community.graphs.graph_document import GraphDocument, Node, Relationship
from langchain_community.graphs.neo4j_graph import (
BASE_ENTITY_LABEL,
node_properties_query,
rel_properties_query,
rel_query,
)

test_data = [
GraphDocument(
nodes=[Node(id="foo", type="foo"), Node(id="bar", type="bar")],
relationships=[
Relationship(
source=Node(id="foo", type="foo"),
target=Node(id="bar", type="bar"),
type="REL",
)
],
source=Document(page_content="source document"),
)
]


def test_cypher_return_correct_schema() -> None:
"""Test that chain returns direct results."""
Expand Down Expand Up @@ -37,9 +55,15 @@ def test_cypher_return_correct_schema() -> None:
# Refresh schema information
graph.refresh_schema()

node_properties = graph.query(node_properties_query)
relationships_properties = graph.query(rel_properties_query)
relationships = graph.query(rel_query)
node_properties = graph.query(
node_properties_query, params={"BASE_ENTITY_LABEL": BASE_ENTITY_LABEL}
)
relationships_properties = graph.query(
rel_properties_query, params={"BASE_ENTITY_LABEL": BASE_ENTITY_LABEL}
)
relationships = graph.query(
rel_query, params={"BASE_ENTITY_LABEL": BASE_ENTITY_LABEL}
)

expected_node_properties = [
{
Expand Down Expand Up @@ -116,3 +140,112 @@ def test_neo4j_sanitize_values() -> None:

output = graph.query("RETURN range(0,130,1) AS result")
assert output == [{}]


def test_neo4j_add_data() -> None:
"""Test that neo4j correctly import graph document."""
url = os.environ.get("NEO4J_URI")
username = os.environ.get("NEO4J_USERNAME")
password = os.environ.get("NEO4J_PASSWORD")
assert url is not None
assert username is not None
assert password is not None

graph = Neo4jGraph(url=url, username=username, password=password, sanitize=True)
# Delete all nodes in the graph
graph.query("MATCH (n) DETACH DELETE n")
# Remove all constraints
graph.query("CALL apoc.schema.assert({}, {})")
graph.refresh_schema()
# Create two nodes and a relationship
graph.add_graph_documents(test_data)
output = graph.query(
"MATCH (n) RETURN labels(n) AS label, count(*) AS count ORDER BY label"
)
assert output == [{"label": ["bar"], "count": 1}, {"label": ["foo"], "count": 1}]
assert graph.structured_schema["metadata"]["constraint"] == []


def test_neo4j_add_data_source() -> None:
"""Test that neo4j correctly import graph document with source."""
url = os.environ.get("NEO4J_URI")
username = os.environ.get("NEO4J_USERNAME")
password = os.environ.get("NEO4J_PASSWORD")
assert url is not None
assert username is not None
assert password is not None

graph = Neo4jGraph(url=url, username=username, password=password, sanitize=True)
# Delete all nodes in the graph
graph.query("MATCH (n) DETACH DELETE n")
# Remove all constraints
graph.query("CALL apoc.schema.assert({}, {})")
graph.refresh_schema()
# Create two nodes and a relationship
graph.add_graph_documents(test_data, include_source=True)
output = graph.query(
"MATCH (n) RETURN labels(n) AS label, count(*) AS count ORDER BY label"
)
assert output == [
{"label": ["Document"], "count": 1},
{"label": ["bar"], "count": 1},
{"label": ["foo"], "count": 1},
]
assert graph.structured_schema["metadata"]["constraint"] == []


def test_neo4j_add_data_base() -> None:
"""Test that neo4j correctly import graph document with base_entity."""
url = os.environ.get("NEO4J_URI")
username = os.environ.get("NEO4J_USERNAME")
password = os.environ.get("NEO4J_PASSWORD")
assert url is not None
assert username is not None
assert password is not None

graph = Neo4jGraph(url=url, username=username, password=password, sanitize=True)
# Delete all nodes in the graph
graph.query("MATCH (n) DETACH DELETE n")
# Remove all constraints
graph.query("CALL apoc.schema.assert({}, {})")
graph.refresh_schema()
# Create two nodes and a relationship
graph.add_graph_documents(test_data, baseEntityLabel=True)
output = graph.query(
"MATCH (n) RETURN apoc.coll.sort(labels(n)) AS label, "
"count(*) AS count ORDER BY label"
)
assert output == [
{"label": [BASE_ENTITY_LABEL, "bar"], "count": 1},
{"label": [BASE_ENTITY_LABEL, "foo"], "count": 1},
]
assert graph.structured_schema["metadata"]["constraint"] != []


def test_neo4j_add_data_base_source() -> None:
"""Test that neo4j correctly import graph document with base_entity and source."""
url = os.environ.get("NEO4J_URI")
username = os.environ.get("NEO4J_USERNAME")
password = os.environ.get("NEO4J_PASSWORD")
assert url is not None
assert username is not None
assert password is not None

graph = Neo4jGraph(url=url, username=username, password=password, sanitize=True)
# Delete all nodes in the graph
graph.query("MATCH (n) DETACH DELETE n")
# Remove all constraints
graph.query("CALL apoc.schema.assert({}, {})")
graph.refresh_schema()
# Create two nodes and a relationship
graph.add_graph_documents(test_data, baseEntityLabel=True, include_source=True)
output = graph.query(
"MATCH (n) RETURN apoc.coll.sort(labels(n)) AS label, "
"count(*) AS count ORDER BY label"
)
assert output == [
{"label": ["Document"], "count": 1},
{"label": [BASE_ENTITY_LABEL, "bar"], "count": 1},
{"label": [BASE_ENTITY_LABEL, "foo"], "count": 1},
]
assert graph.structured_schema["metadata"]["constraint"] != []
Loading