diff --git a/full-it/src/test/java/apoc/full/it/vectordb/MilvusTest.java b/full-it/src/test/java/apoc/full/it/vectordb/MilvusTest.java index 893f5064d8..e236dccd16 100644 --- a/full-it/src/test/java/apoc/full/it/vectordb/MilvusTest.java +++ b/full-it/src/test/java/apoc/full/it/vectordb/MilvusTest.java @@ -1,5 +1,7 @@ package apoc.full.it.vectordb; +import apoc.ml.Prompt; +import apoc.util.ExtendedTestUtil; import apoc.util.TestUtil; import apoc.util.Util; import apoc.vectordb.Milvus; @@ -21,6 +23,8 @@ import java.util.List; import java.util.Map; +import static apoc.ml.Prompt.API_KEY_CONF; +import static apoc.ml.RestAPIConfig.HEADERS_KEY; import static apoc.util.MapUtil.map; import static apoc.util.TestUtil.testCall; import static apoc.util.TestUtil.testResult; @@ -33,6 +37,7 @@ import static apoc.vectordb.VectorDbTestUtil.assertNodesCreated; import static apoc.vectordb.VectorDbTestUtil.assertRelsCreated; import static apoc.vectordb.VectorDbTestUtil.dropAndDeleteAll; +import static apoc.vectordb.VectorDbTestUtil.ragSetup; import static apoc.vectordb.VectorDbUtil.ERROR_READONLY_MAPPING; import static apoc.vectordb.VectorEmbeddingConfig.ALL_RESULTS_KEY; import static apoc.vectordb.VectorEmbeddingConfig.FIELDS_KEY; @@ -41,8 +46,11 @@ import static apoc.vectordb.VectorMappingConfig.ENTITY_KEY; import static apoc.vectordb.VectorMappingConfig.METADATA_KEY; import static apoc.vectordb.VectorMappingConfig.MODE_KEY; +import static apoc.vectordb.VectorMappingConfig.MappingMode; +import static apoc.vectordb.VectorMappingConfig.NO_FIELDS_ERROR_MSG; import static apoc.vectordb.VectorMappingConfig.NODE_LABEL; import static apoc.vectordb.VectorMappingConfig.REL_TYPE; +import static org.assertj.core.api.Assertions.assertThat; import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertNotNull; import static org.junit.Assert.assertNull; @@ -406,4 +414,95 @@ public void queryVectorsWithSystemDbStorage() { assertNodesCreated(db); } + @Test + public void queryVectorsWithRag() { + String openAIKey = ragSetup(db); + + Map conf = map( + FIELDS_KEY, FIELDS, + ALL_RESULTS_KEY, true, + MAPPING_KEY, map(NODE_LABEL, "Rag", + ENTITY_KEY, "readID", + METADATA_KEY, "foo") + ); + + testResult(db, + "CALL apoc.vectordb.milvus.getAndUpdate($host, 'test_collection', [1, 2], $conf) YIELD node, metadata, id, vector\n" + + "WITH collect(node) as paths\n" + + "CALL apoc.ml.rag(paths, $attributes, \"Which city has foo equals to one?\", $confPrompt) YIELD value\n" + + "RETURN value" + , + map( + "host", HOST, + "conf", conf, + "confPrompt", map(API_KEY_CONF, openAIKey), + "attributes", List.of("city", "foo") + ), + VectorDbTestUtil::assertRagWithVectors); + } + + @Test + public void queryVectorsWithMetadataKeyNoFields() { + Map conf = map( + ALL_RESULTS_KEY, true, + MAPPING_KEY, map(EMBEDDING_KEY, "vect", + REL_TYPE, "TEST", + ENTITY_KEY, "readID", + METADATA_KEY, "foo" + ) + ); + testResult(db, "CALL apoc.vectordb.milvus.query($host, 'test_collection', [0.2, 0.1, 0.9, 0.7], null, 5, $conf)", + map("host", HOST, "conf", conf), + VectorDbTestUtil::assertMetadataFooResult); + } + + @Test + public void queryVectorsWithNoMetadataKeyNoFields() { + Map params = map( + "host", HOST, "conf", Map.of( + ALL_RESULTS_KEY, true, + MAPPING_KEY, map(EMBEDDING_KEY, "vect", + REL_TYPE, "TEST", + ENTITY_KEY, "readID" + )) + ); + String query = "CALL apoc.vectordb.milvus.query($host, 'test_collection', [0.2, 0.1, 0.9, 0.7], null, 5, $conf)"; + ExtendedTestUtil.assertFails(db, query, params, NO_FIELDS_ERROR_MSG); + } + + @Test + public void queryAndUpdateMetadataKeyWithoutFieldsTest() { + Map conf = map( + ALL_RESULTS_KEY, true, + MAPPING_KEY, map(EMBEDDING_KEY, "vect", + REL_TYPE, "TEST", + ENTITY_KEY, "readID", + METADATA_KEY, "foo" + ) + ); + + String query = "CALL apoc.vectordb.milvus.queryAndUpdate($host, 'test_collection', [0.2, 0.1, 0.9, 0.7], null, 5, $conf)"; + + testResult(db, query, + map("host", HOST, "conf", conf), + VectorDbTestUtil::assertMetadataFooResult); + } + + @Test + public void queryAndUpdateWithNoMetadataKeyNoFields() { + Map conf = map( + ALL_RESULTS_KEY, true, + MAPPING_KEY, map(EMBEDDING_KEY, "vect", + REL_TYPE, "TEST", + ENTITY_KEY, "readID" + ) + ); + db.executeTransactionally("CREATE (:Test {myId: 'one'}), (:Test {myId: 'two'})"); + Map params = Util.map("host", HOST, + "conf", conf); + + String query = "CALL apoc.vectordb.milvus.queryAndUpdate($host, 'test_collection', [0.2, 0.1, 0.9, 0.7], null, 5, $conf)"; + + ExtendedTestUtil.assertFails(db, query, params, NO_FIELDS_ERROR_MSG); + } } diff --git a/full-it/src/test/java/apoc/full/it/vectordb/WeaviateTest.java b/full-it/src/test/java/apoc/full/it/vectordb/WeaviateTest.java index 5f7cff3858..710889521e 100644 --- a/full-it/src/test/java/apoc/full/it/vectordb/WeaviateTest.java +++ b/full-it/src/test/java/apoc/full/it/vectordb/WeaviateTest.java @@ -1,5 +1,25 @@ package apoc.full.it.vectordb; +import apoc.ml.Prompt; +import apoc.util.ExtendedTestUtil; +import apoc.util.MapUtil; +import apoc.util.TestUtil; +import org.junit.AfterClass; +import org.junit.Before; +import org.junit.BeforeClass; +import org.junit.ClassRule; +import org.junit.Test; +import org.junit.rules.TemporaryFolder; +import org.neo4j.dbms.api.DatabaseManagementService; +import org.neo4j.graphdb.Entity; +import org.neo4j.graphdb.GraphDatabaseService; +import org.neo4j.graphdb.ResourceIterator; +import org.neo4j.test.TestDatabaseManagementServiceBuilder; +import org.testcontainers.weaviate.WeaviateContainer; + +import java.util.List; +import java.util.Map; + import static apoc.ml.Prompt.API_KEY_CONF; import static apoc.ml.RestAPIConfig.HEADERS_KEY; import static apoc.util.ExtendedTestUtil.assertFails; @@ -658,4 +678,61 @@ public void queryVectorsWithRag() { "attributes", List.of("city", "foo")), VectorDbTestUtil::assertRagWithVectors); } + + @Test + public void queryVectorsWithMetadataKeyNoFields() { + testResult(db, "CALL apoc.vectordb.weaviate.query($host, 'TestCollection', [0.2, 0.1, 0.9, 0.7], null, 5, $conf) " + + " YIELD score, vector, id, metadata RETURN * ORDER BY id", + map("host", HOST, "conf", map( + ALL_RESULTS_KEY, true, + MAPPING_KEY, map(EMBEDDING_KEY, "vect", + NODE_LABEL, "Test", + ENTITY_KEY, "myId", + METADATA_KEY, "foo" + ), + HEADERS_KEY, ADMIN_AUTHORIZATION)), + VectorDbTestUtil::assertMetadataFooResult); + } + + @Test + public void queryVectorsWithNoMetadataKeyNoFields() { + Map params = map("host", HOST, "conf", map( + ALL_RESULTS_KEY, true, + MAPPING_KEY, map(EMBEDDING_KEY, "vect", + NODE_LABEL, "Test", + ENTITY_KEY, "myId" + ), + HEADERS_KEY, ADMIN_AUTHORIZATION)); + String query = "CALL apoc.vectordb.weaviate.query($host, 'TestCollection', [0.2, 0.1, 0.9, 0.7], null, 5, $conf) " + + " YIELD score, vector, id, metadata RETURN * ORDER BY id"; + ExtendedTestUtil.assertFails(db, query, params, NO_FIELDS_ERROR_MSG); + } + + @Test + public void queryAndUpdateMetadataKeyWithoutFieldsTest() { + db.executeTransactionally("CREATE (:Test {myId: 'one'}), (:Test {myId: 'two'})"); + Map conf = map(ALL_RESULTS_KEY, true, + HEADERS_KEY, ADMIN_AUTHORIZATION, + MAPPING_KEY, map(EMBEDDING_KEY, "vect", + NODE_LABEL, "Test", + ENTITY_KEY, "myId", + METADATA_KEY, "foo")); + testResult(db, "CALL apoc.vectordb.weaviate.queryAndUpdate($host, 'TestCollection', [0.2, 0.1, 0.9, 0.7], null, 5, $conf) " + + " YIELD score, vector, id, metadata, node RETURN * ORDER BY id", + map("host", HOST, "conf", conf), + VectorDbTestUtil::assertMetadataFooResult); + } + + @Test + public void queryAndUpdateWithCreateNodeUsingExistingNodeFailWithNoMetadataKeyAndNoFields() { + db.executeTransactionally("CREATE (:Test {myId: 'one'}), (:Test {myId: 'two'})"); + Map params = map("host", HOST, + "conf", Map.of(ALL_RESULTS_KEY, true, + HEADERS_KEY, ADMIN_AUTHORIZATION, + MAPPING_KEY, map(EMBEDDING_KEY, "vect", + NODE_LABEL, "Test", + ENTITY_KEY, "myId"))); + String query = "CALL apoc.vectordb.weaviate.queryAndUpdate($host, 'TestCollection', [0.2, 0.1, 0.9, 0.7], null, 5, $conf) YIELD score, vector, id, metadata, node RETURN * ORDER BY id"; + ExtendedTestUtil.assertFails(db, query, params, NO_FIELDS_ERROR_MSG); + } } diff --git a/full/src/main/java/apoc/vectordb/MilvusHandler.java b/full/src/main/java/apoc/vectordb/MilvusHandler.java index 2ea22544b1..9643fd1631 100644 --- a/full/src/main/java/apoc/vectordb/MilvusHandler.java +++ b/full/src/main/java/apoc/vectordb/MilvusHandler.java @@ -1,13 +1,15 @@ package apoc.vectordb; import static apoc.util.MapUtil.map; -import static apoc.vectordb.VectorEmbeddingConfig.FIELDS_KEY; +import static apoc.vectordb.VectorDbUtil.addMetadataKeyToFields; import static apoc.vectordb.VectorEmbeddingConfig.META_AS_SUBKEY_KEY; import static apoc.vectordb.VectorEmbeddingConfig.SCORE_KEY; import apoc.util.UrlResolver; import java.util.List; import java.util.Map; +import java.util.stream.Collectors; + import org.neo4j.internal.kernel.api.procs.ProcedureCallContext; public class MilvusHandler implements VectorDbHandler { @@ -34,7 +36,7 @@ static class MilvusEmbeddingHandler implements VectorEmbeddingHandler { @Override public VectorEmbeddingConfig fromGet( Map config, ProcedureCallContext procedureCallContext, List ids, String collection) { - List fields = procedureCallContext.outputFields().toList(); + List fields = procedureCallContext.outputFields().collect(Collectors.toList()); Map additionalBodies = map("id", ids); return getVectorEmbeddingConfig(config, fields, collection, additionalBodies); @@ -50,7 +52,7 @@ public VectorEmbeddingConfig fromQuery( String collection) { config.putIfAbsent(SCORE_KEY, "distance"); - List fields = procedureCallContext.outputFields().toList(); + List fields = procedureCallContext.outputFields().collect(Collectors.toList()); Map additionalBodies = map("data", List.of(vector), "limit", limit); if (filter != null) { additionalBodies.put("filter", filter); @@ -66,10 +68,8 @@ private VectorEmbeddingConfig getVectorEmbeddingConfig( Map additionalBodies) { config.putIfAbsent(META_AS_SUBKEY_KEY, false); - List listFields = (List) config.get(FIELDS_KEY); - if (listFields == null) { - throw new RuntimeException("You have to define `field` list of parameter to be returned"); - } + List listFields = addMetadataKeyToFields(config); + if (procFields.contains("vector") && !listFields.contains("vector")) { listFields.add("vector"); } diff --git a/full/src/main/java/apoc/vectordb/VectorDbUtil.java b/full/src/main/java/apoc/vectordb/VectorDbUtil.java index 6a4540970d..54c8d9efa7 100644 --- a/full/src/main/java/apoc/vectordb/VectorDbUtil.java +++ b/full/src/main/java/apoc/vectordb/VectorDbUtil.java @@ -3,14 +3,21 @@ import static apoc.ml.RestAPIConfig.BASE_URL_KEY; import static apoc.ml.RestAPIConfig.ENDPOINT_KEY; import static apoc.util.SystemDbUtil.withSystemDb; +import static apoc.vectordb.VectorEmbeddingConfig.FIELDS_KEY; import static apoc.vectordb.VectorEmbeddingConfig.MAPPING_KEY; +import static apoc.vectordb.VectorMappingConfig.METADATA_KEY; +import static apoc.vectordb.VectorMappingConfig.NO_FIELDS_ERROR_MSG; import apoc.SystemPropertyKeys; import apoc.util.Util; + +import java.util.ArrayList; import java.util.HashMap; import java.util.List; import java.util.Map; import org.apache.commons.collections.MapUtils; +import org.apache.commons.collections4.CollectionUtils; +import org.apache.commons.lang3.StringUtils; import org.neo4j.graphdb.Label; import org.neo4j.graphdb.Node; import org.neo4j.graphdb.Relationship; @@ -111,6 +118,26 @@ public static void checkMappingConf(Map configuration, String pr } } + public static List addMetadataKeyToFields(Map config) { + List listFields = (List) config.getOrDefault(FIELDS_KEY, new ArrayList<>()); + + Map mapping = (Map) config.get(MAPPING_KEY); + + String metadataKey = mapping == null + ? null + : (String) mapping.get(METADATA_KEY); + + if (CollectionUtils.isEmpty(listFields)) { + + if (StringUtils.isEmpty(metadataKey)) { + throw new RuntimeException(NO_FIELDS_ERROR_MSG); + } + listFields.add(metadataKey); + } + + return listFields; + } + /** * If the vectorDb is WEAVIATE and endpoint doesn't end with `/vN`, where N is a number, * then add `/v1` to the endpoint diff --git a/full/src/main/java/apoc/vectordb/WeaviateHandler.java b/full/src/main/java/apoc/vectordb/WeaviateHandler.java index 0cf7d856c0..e4186b5d76 100644 --- a/full/src/main/java/apoc/vectordb/WeaviateHandler.java +++ b/full/src/main/java/apoc/vectordb/WeaviateHandler.java @@ -3,7 +3,7 @@ import static apoc.ml.RestAPIConfig.BODY_KEY; import static apoc.ml.RestAPIConfig.METHOD_KEY; import static apoc.util.MapUtil.map; -import static apoc.vectordb.VectorEmbeddingConfig.FIELDS_KEY; +import static apoc.vectordb.VectorDbUtil.addMetadataKeyToFields; import static apoc.vectordb.VectorEmbeddingConfig.METADATA_KEY; import static apoc.vectordb.VectorEmbeddingConfig.VECTOR_KEY; @@ -53,10 +53,8 @@ public VectorEmbeddingConfig fromQuery( config.putIfAbsent(METHOD_KEY, "POST"); VectorEmbeddingConfig vectorEmbeddingConfig = getVectorEmbeddingConfig(config); - List list = (List) config.get(FIELDS_KEY); - if (list == null) { - throw new RuntimeException("You have to define `field` list of parameter to be returned"); - } + List list = addMetadataKeyToFields(config); + Object fieldList = String.join("\n", list); filter = filter == null ? "" : ", where: " + filter; diff --git a/full/src/test/java/apoc/vectordb/VectorDbTestUtil.java b/full/src/test/java/apoc/vectordb/VectorDbTestUtil.java index c430ecb19c..bcd40e8ec9 100644 --- a/full/src/test/java/apoc/vectordb/VectorDbTestUtil.java +++ b/full/src/test/java/apoc/vectordb/VectorDbTestUtil.java @@ -1,5 +1,15 @@ package apoc.vectordb; +import apoc.util.MapUtil; +import org.junit.Assume; +import org.neo4j.graphdb.Entity; +import org.neo4j.graphdb.GraphDatabaseService; +import org.neo4j.graphdb.ResourceIterator; +import org.neo4j.graphdb.Result; + +import java.util.Map; + +import static apoc.vectordb.VectorEmbeddingConfig.DEFAULT_METADATA; import static apoc.util.TestUtil.testResult; import static apoc.util.Util.map; import static org.junit.Assert.assertEquals; @@ -122,4 +132,14 @@ public static String ragSetup(GraphDatabaseService db) { db.executeTransactionally("CREATE (:Rag {readID: 'one'}), (:Rag {readID: 'two'})"); return openAIKey; } + + public static void assertMetadataFooResult(Result r) { + Map row = r.next(); + Map metadata = (Map) row.get(DEFAULT_METADATA); + assertEquals("one", metadata.get("foo")); + row = r.next(); + metadata = (Map) row.get(DEFAULT_METADATA); + assertEquals("two", metadata.get("foo")); + assertFalse(r.hasNext()); + } }