Skip to content

Commit

Permalink
[NOID] Fixes #4231: The apoc.vectordb.milvus.query* and apoc.vectordb…
Browse files Browse the repository at this point in the history
….weaviate.query* procedures should get the fields config from metadataKey if present (#4241)

* Fixes #4231: The apoc.vectordb.milvus.query* and apoc.vectordb.weaviate.query* procedures should get the fields config from metadataKey if present

* test fixes and changes review

* fix tests
  • Loading branch information
vga91 committed Dec 18, 2024
1 parent d394eab commit cde1151
Show file tree
Hide file tree
Showing 6 changed files with 233 additions and 12 deletions.
99 changes: 99 additions & 0 deletions full-it/src/test/java/apoc/full/it/vectordb/MilvusTest.java
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
package apoc.full.it.vectordb;

import apoc.ml.Prompt;
import apoc.util.ExtendedTestUtil;
import apoc.util.TestUtil;
import apoc.util.Util;
import apoc.vectordb.Milvus;
Expand All @@ -21,6 +23,8 @@
import java.util.List;
import java.util.Map;

import static apoc.ml.Prompt.API_KEY_CONF;
import static apoc.ml.RestAPIConfig.HEADERS_KEY;
import static apoc.util.MapUtil.map;
import static apoc.util.TestUtil.testCall;
import static apoc.util.TestUtil.testResult;
Expand All @@ -33,6 +37,7 @@
import static apoc.vectordb.VectorDbTestUtil.assertNodesCreated;
import static apoc.vectordb.VectorDbTestUtil.assertRelsCreated;
import static apoc.vectordb.VectorDbTestUtil.dropAndDeleteAll;
import static apoc.vectordb.VectorDbTestUtil.ragSetup;
import static apoc.vectordb.VectorDbUtil.ERROR_READONLY_MAPPING;
import static apoc.vectordb.VectorEmbeddingConfig.ALL_RESULTS_KEY;
import static apoc.vectordb.VectorEmbeddingConfig.FIELDS_KEY;
Expand All @@ -41,8 +46,11 @@
import static apoc.vectordb.VectorMappingConfig.ENTITY_KEY;
import static apoc.vectordb.VectorMappingConfig.METADATA_KEY;
import static apoc.vectordb.VectorMappingConfig.MODE_KEY;
import static apoc.vectordb.VectorMappingConfig.MappingMode;
import static apoc.vectordb.VectorMappingConfig.NO_FIELDS_ERROR_MSG;
import static apoc.vectordb.VectorMappingConfig.NODE_LABEL;
import static apoc.vectordb.VectorMappingConfig.REL_TYPE;
import static org.assertj.core.api.Assertions.assertThat;
import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertNotNull;
import static org.junit.Assert.assertNull;
Expand Down Expand Up @@ -406,4 +414,95 @@ public void queryVectorsWithSystemDbStorage() {
assertNodesCreated(db);
}

@Test
public void queryVectorsWithRag() {
String openAIKey = ragSetup(db);

Map<String, Object> conf = map(
FIELDS_KEY, FIELDS,
ALL_RESULTS_KEY, true,
MAPPING_KEY, map(NODE_LABEL, "Rag",
ENTITY_KEY, "readID",
METADATA_KEY, "foo")
);

testResult(db,
"CALL apoc.vectordb.milvus.getAndUpdate($host, 'test_collection', [1, 2], $conf) YIELD node, metadata, id, vector\n" +
"WITH collect(node) as paths\n" +
"CALL apoc.ml.rag(paths, $attributes, \"Which city has foo equals to one?\", $confPrompt) YIELD value\n" +
"RETURN value"
,
map(
"host", HOST,
"conf", conf,
"confPrompt", map(API_KEY_CONF, openAIKey),
"attributes", List.of("city", "foo")
),
VectorDbTestUtil::assertRagWithVectors);
}

@Test
public void queryVectorsWithMetadataKeyNoFields() {
Map<String, Object> conf = map(
ALL_RESULTS_KEY, true,
MAPPING_KEY, map(EMBEDDING_KEY, "vect",
REL_TYPE, "TEST",
ENTITY_KEY, "readID",
METADATA_KEY, "foo"
)
);
testResult(db, "CALL apoc.vectordb.milvus.query($host, 'test_collection', [0.2, 0.1, 0.9, 0.7], null, 5, $conf)",
map("host", HOST, "conf", conf),
VectorDbTestUtil::assertMetadataFooResult);
}

@Test
public void queryVectorsWithNoMetadataKeyNoFields() {
Map<String, Object> params = map(
"host", HOST, "conf", Map.of(
ALL_RESULTS_KEY, true,
MAPPING_KEY, map(EMBEDDING_KEY, "vect",
REL_TYPE, "TEST",
ENTITY_KEY, "readID"
))
);
String query = "CALL apoc.vectordb.milvus.query($host, 'test_collection', [0.2, 0.1, 0.9, 0.7], null, 5, $conf)";
ExtendedTestUtil.assertFails(db, query, params, NO_FIELDS_ERROR_MSG);
}

@Test
public void queryAndUpdateMetadataKeyWithoutFieldsTest() {
Map<String, Object> conf = map(
ALL_RESULTS_KEY, true,
MAPPING_KEY, map(EMBEDDING_KEY, "vect",
REL_TYPE, "TEST",
ENTITY_KEY, "readID",
METADATA_KEY, "foo"
)
);

String query = "CALL apoc.vectordb.milvus.queryAndUpdate($host, 'test_collection', [0.2, 0.1, 0.9, 0.7], null, 5, $conf)";

testResult(db, query,
map("host", HOST, "conf", conf),
VectorDbTestUtil::assertMetadataFooResult);
}

@Test
public void queryAndUpdateWithNoMetadataKeyNoFields() {
Map<String, Object> conf = map(
ALL_RESULTS_KEY, true,
MAPPING_KEY, map(EMBEDDING_KEY, "vect",
REL_TYPE, "TEST",
ENTITY_KEY, "readID"
)
);
db.executeTransactionally("CREATE (:Test {myId: 'one'}), (:Test {myId: 'two'})");
Map<String, Object> params = Util.map("host", HOST,
"conf", conf);

String query = "CALL apoc.vectordb.milvus.queryAndUpdate($host, 'test_collection', [0.2, 0.1, 0.9, 0.7], null, 5, $conf)";

ExtendedTestUtil.assertFails(db, query, params, NO_FIELDS_ERROR_MSG);
}
}
77 changes: 77 additions & 0 deletions full-it/src/test/java/apoc/full/it/vectordb/WeaviateTest.java
Original file line number Diff line number Diff line change
@@ -1,5 +1,25 @@
package apoc.full.it.vectordb;

import apoc.ml.Prompt;
import apoc.util.ExtendedTestUtil;
import apoc.util.MapUtil;
import apoc.util.TestUtil;
import org.junit.AfterClass;
import org.junit.Before;
import org.junit.BeforeClass;
import org.junit.ClassRule;
import org.junit.Test;
import org.junit.rules.TemporaryFolder;
import org.neo4j.dbms.api.DatabaseManagementService;
import org.neo4j.graphdb.Entity;
import org.neo4j.graphdb.GraphDatabaseService;
import org.neo4j.graphdb.ResourceIterator;
import org.neo4j.test.TestDatabaseManagementServiceBuilder;
import org.testcontainers.weaviate.WeaviateContainer;

import java.util.List;
import java.util.Map;

import static apoc.ml.Prompt.API_KEY_CONF;
import static apoc.ml.RestAPIConfig.HEADERS_KEY;
import static apoc.util.ExtendedTestUtil.assertFails;
Expand Down Expand Up @@ -658,4 +678,61 @@ public void queryVectorsWithRag() {
"attributes", List.of("city", "foo")),
VectorDbTestUtil::assertRagWithVectors);
}

@Test
public void queryVectorsWithMetadataKeyNoFields() {
testResult(db, "CALL apoc.vectordb.weaviate.query($host, 'TestCollection', [0.2, 0.1, 0.9, 0.7], null, 5, $conf) " +
" YIELD score, vector, id, metadata RETURN * ORDER BY id",
map("host", HOST, "conf", map(
ALL_RESULTS_KEY, true,
MAPPING_KEY, map(EMBEDDING_KEY, "vect",
NODE_LABEL, "Test",
ENTITY_KEY, "myId",
METADATA_KEY, "foo"
),
HEADERS_KEY, ADMIN_AUTHORIZATION)),
VectorDbTestUtil::assertMetadataFooResult);
}

@Test
public void queryVectorsWithNoMetadataKeyNoFields() {
Map<String, Object> params = map("host", HOST, "conf", map(
ALL_RESULTS_KEY, true,
MAPPING_KEY, map(EMBEDDING_KEY, "vect",
NODE_LABEL, "Test",
ENTITY_KEY, "myId"
),
HEADERS_KEY, ADMIN_AUTHORIZATION));
String query = "CALL apoc.vectordb.weaviate.query($host, 'TestCollection', [0.2, 0.1, 0.9, 0.7], null, 5, $conf) " +
" YIELD score, vector, id, metadata RETURN * ORDER BY id";
ExtendedTestUtil.assertFails(db, query, params, NO_FIELDS_ERROR_MSG);
}

@Test
public void queryAndUpdateMetadataKeyWithoutFieldsTest() {
db.executeTransactionally("CREATE (:Test {myId: 'one'}), (:Test {myId: 'two'})");
Map<String, Object> conf = map(ALL_RESULTS_KEY, true,
HEADERS_KEY, ADMIN_AUTHORIZATION,
MAPPING_KEY, map(EMBEDDING_KEY, "vect",
NODE_LABEL, "Test",
ENTITY_KEY, "myId",
METADATA_KEY, "foo"));
testResult(db, "CALL apoc.vectordb.weaviate.queryAndUpdate($host, 'TestCollection', [0.2, 0.1, 0.9, 0.7], null, 5, $conf) " +
" YIELD score, vector, id, metadata, node RETURN * ORDER BY id",
map("host", HOST, "conf", conf),
VectorDbTestUtil::assertMetadataFooResult);
}

@Test
public void queryAndUpdateWithCreateNodeUsingExistingNodeFailWithNoMetadataKeyAndNoFields() {
db.executeTransactionally("CREATE (:Test {myId: 'one'}), (:Test {myId: 'two'})");
Map<String, Object> params = map("host", HOST,
"conf", Map.of(ALL_RESULTS_KEY, true,
HEADERS_KEY, ADMIN_AUTHORIZATION,
MAPPING_KEY, map(EMBEDDING_KEY, "vect",
NODE_LABEL, "Test",
ENTITY_KEY, "myId")));
String query = "CALL apoc.vectordb.weaviate.queryAndUpdate($host, 'TestCollection', [0.2, 0.1, 0.9, 0.7], null, 5, $conf) YIELD score, vector, id, metadata, node RETURN * ORDER BY id";
ExtendedTestUtil.assertFails(db, query, params, NO_FIELDS_ERROR_MSG);
}
}
14 changes: 7 additions & 7 deletions full/src/main/java/apoc/vectordb/MilvusHandler.java
Original file line number Diff line number Diff line change
@@ -1,13 +1,15 @@
package apoc.vectordb;

import static apoc.util.MapUtil.map;
import static apoc.vectordb.VectorEmbeddingConfig.FIELDS_KEY;
import static apoc.vectordb.VectorDbUtil.addMetadataKeyToFields;
import static apoc.vectordb.VectorEmbeddingConfig.META_AS_SUBKEY_KEY;
import static apoc.vectordb.VectorEmbeddingConfig.SCORE_KEY;

import apoc.util.UrlResolver;
import java.util.List;
import java.util.Map;
import java.util.stream.Collectors;

import org.neo4j.internal.kernel.api.procs.ProcedureCallContext;

public class MilvusHandler implements VectorDbHandler {
Expand All @@ -34,7 +36,7 @@ static class MilvusEmbeddingHandler implements VectorEmbeddingHandler {
@Override
public <T> VectorEmbeddingConfig fromGet(
Map<String, Object> config, ProcedureCallContext procedureCallContext, List<T> ids, String collection) {
List<String> fields = procedureCallContext.outputFields().toList();
List<String> fields = procedureCallContext.outputFields().collect(Collectors.toList());
Map<String, Object> additionalBodies = map("id", ids);

return getVectorEmbeddingConfig(config, fields, collection, additionalBodies);
Expand All @@ -50,7 +52,7 @@ public VectorEmbeddingConfig fromQuery(
String collection) {
config.putIfAbsent(SCORE_KEY, "distance");

List<String> fields = procedureCallContext.outputFields().toList();
List<String> fields = procedureCallContext.outputFields().collect(Collectors.toList());
Map<String, Object> additionalBodies = map("data", List.of(vector), "limit", limit);
if (filter != null) {
additionalBodies.put("filter", filter);
Expand All @@ -66,10 +68,8 @@ private VectorEmbeddingConfig getVectorEmbeddingConfig(
Map<String, Object> additionalBodies) {
config.putIfAbsent(META_AS_SUBKEY_KEY, false);

List listFields = (List) config.get(FIELDS_KEY);
if (listFields == null) {
throw new RuntimeException("You have to define `field` list of parameter to be returned");
}
List listFields = addMetadataKeyToFields(config);

if (procFields.contains("vector") && !listFields.contains("vector")) {
listFields.add("vector");
}
Expand Down
27 changes: 27 additions & 0 deletions full/src/main/java/apoc/vectordb/VectorDbUtil.java
Original file line number Diff line number Diff line change
Expand Up @@ -3,14 +3,21 @@
import static apoc.ml.RestAPIConfig.BASE_URL_KEY;
import static apoc.ml.RestAPIConfig.ENDPOINT_KEY;
import static apoc.util.SystemDbUtil.withSystemDb;
import static apoc.vectordb.VectorEmbeddingConfig.FIELDS_KEY;
import static apoc.vectordb.VectorEmbeddingConfig.MAPPING_KEY;
import static apoc.vectordb.VectorMappingConfig.METADATA_KEY;
import static apoc.vectordb.VectorMappingConfig.NO_FIELDS_ERROR_MSG;

import apoc.SystemPropertyKeys;
import apoc.util.Util;

import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import org.apache.commons.collections.MapUtils;
import org.apache.commons.collections4.CollectionUtils;
import org.apache.commons.lang3.StringUtils;
import org.neo4j.graphdb.Label;
import org.neo4j.graphdb.Node;
import org.neo4j.graphdb.Relationship;
Expand Down Expand Up @@ -111,6 +118,26 @@ public static void checkMappingConf(Map<String, Object> configuration, String pr
}
}

public static List addMetadataKeyToFields(Map<String, Object> config) {
List listFields = (List) config.getOrDefault(FIELDS_KEY, new ArrayList<>());

Map<String, Object> mapping = (Map<String, Object>) config.get(MAPPING_KEY);

String metadataKey = mapping == null
? null
: (String) mapping.get(METADATA_KEY);

if (CollectionUtils.isEmpty(listFields)) {

if (StringUtils.isEmpty(metadataKey)) {
throw new RuntimeException(NO_FIELDS_ERROR_MSG);
}
listFields.add(metadataKey);
}

return listFields;
}

/**
* If the vectorDb is WEAVIATE and endpoint doesn't end with `/vN`, where N is a number,
* then add `/v1` to the endpoint
Expand Down
8 changes: 3 additions & 5 deletions full/src/main/java/apoc/vectordb/WeaviateHandler.java
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import static apoc.ml.RestAPIConfig.BODY_KEY;
import static apoc.ml.RestAPIConfig.METHOD_KEY;
import static apoc.util.MapUtil.map;
import static apoc.vectordb.VectorEmbeddingConfig.FIELDS_KEY;
import static apoc.vectordb.VectorDbUtil.addMetadataKeyToFields;
import static apoc.vectordb.VectorEmbeddingConfig.METADATA_KEY;
import static apoc.vectordb.VectorEmbeddingConfig.VECTOR_KEY;

Expand Down Expand Up @@ -53,10 +53,8 @@ public VectorEmbeddingConfig fromQuery(
config.putIfAbsent(METHOD_KEY, "POST");
VectorEmbeddingConfig vectorEmbeddingConfig = getVectorEmbeddingConfig(config);

List list = (List) config.get(FIELDS_KEY);
if (list == null) {
throw new RuntimeException("You have to define `field` list of parameter to be returned");
}
List list = addMetadataKeyToFields(config);

Object fieldList = String.join("\n", list);

filter = filter == null ? "" : ", where: " + filter;
Expand Down
20 changes: 20 additions & 0 deletions full/src/test/java/apoc/vectordb/VectorDbTestUtil.java
Original file line number Diff line number Diff line change
@@ -1,5 +1,15 @@
package apoc.vectordb;

import apoc.util.MapUtil;
import org.junit.Assume;
import org.neo4j.graphdb.Entity;
import org.neo4j.graphdb.GraphDatabaseService;
import org.neo4j.graphdb.ResourceIterator;
import org.neo4j.graphdb.Result;

import java.util.Map;

import static apoc.vectordb.VectorEmbeddingConfig.DEFAULT_METADATA;
import static apoc.util.TestUtil.testResult;
import static apoc.util.Util.map;
import static org.junit.Assert.assertEquals;
Expand Down Expand Up @@ -122,4 +132,14 @@ public static String ragSetup(GraphDatabaseService db) {
db.executeTransactionally("CREATE (:Rag {readID: 'one'}), (:Rag {readID: 'two'})");
return openAIKey;
}

public static void assertMetadataFooResult(Result r) {
Map<String, Object> row = r.next();
Map<String, Object> metadata = (Map<String, Object>) row.get(DEFAULT_METADATA);
assertEquals("one", metadata.get("foo"));
row = r.next();
metadata = (Map<String, Object>) row.get(DEFAULT_METADATA);
assertEquals("two", metadata.get("foo"));
assertFalse(r.hasNext());
}
}

0 comments on commit cde1151

Please sign in to comment.