Skip to content

Commit

Permalink
Fixes #4231: The apoc.vectordb.milvus.query* and apoc.vectordb.weavia…
Browse files Browse the repository at this point in the history
…te.query* procedures should get the fields config from metadataKey if present
  • Loading branch information
vga91 committed Nov 29, 2024
1 parent b7d8a60 commit aefcee1
Show file tree
Hide file tree
Showing 7 changed files with 213 additions and 9 deletions.
66 changes: 66 additions & 0 deletions extended-it/src/test/java/apoc/vectordb/MilvusTest.java
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package apoc.vectordb;

import apoc.ml.Prompt;
import apoc.util.ExtendedTestUtil;
import apoc.util.TestUtil;
import apoc.util.Util;
import org.junit.AfterClass;
Expand Down Expand Up @@ -42,11 +43,14 @@
import static apoc.vectordb.VectorMappingConfig.METADATA_KEY;
import static apoc.vectordb.VectorMappingConfig.MODE_KEY;
import static apoc.vectordb.VectorMappingConfig.MappingMode;
import static apoc.vectordb.VectorMappingConfig.NO_FIELDS_ERROR_MSG;
import static apoc.vectordb.VectorMappingConfig.NODE_LABEL;
import static apoc.vectordb.VectorMappingConfig.REL_TYPE;
import static org.assertj.core.api.Assertions.assertThat;
import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertNotNull;
import static org.junit.Assert.assertNull;
import static org.junit.Assert.fail;
import static org.neo4j.configuration.GraphDatabaseSettings.DEFAULT_DATABASE_NAME;
import static org.neo4j.configuration.GraphDatabaseSettings.SYSTEM_DATABASE_NAME;

Expand Down Expand Up @@ -488,4 +492,66 @@ WITH collect(node) as paths
VectorDbTestUtil::assertRagWithVectors);
}

@Test
public void queryVectorsWithMetadataKeyNoFields() {
Map<String, Object> conf = map(
ALL_RESULTS_KEY, true,
MAPPING_KEY, map(EMBEDDING_KEY, "vect",
REL_TYPE, "TEST",
ENTITY_KEY, "readID",
METADATA_KEY, "foo"
)
);
testResult(db, "CALL apoc.vectordb.milvus.query($host, 'test_collection', [0.2, 0.1, 0.9, 0.7], null, 5, $conf)",
map("host", HOST, "conf", conf),
VectorDbTestUtil::assertMetadataOneResult);
}

@Test
public void queryVectorsWithMetadataKeyAndOneField() {
Map<String, Object> conf = map(
FIELDS_KEY, List.of("city"),
ALL_RESULTS_KEY, true,
MAPPING_KEY, map(EMBEDDING_KEY, "vect",
REL_TYPE, "TEST",
ENTITY_KEY, "readID",
METADATA_KEY, "foo"
)
);
testResult(db, "CALL apoc.vectordb.milvus.query($host, 'test_collection', [0.2, 0.1, 0.9, 0.7], null, 5, $conf)",
map("host", HOST, "conf", conf),
VectorDbTestUtil::assertMetadataOneResult);
}

@Test
public void queryVectorsWithNoMetadataKeyNoFields() {
Map<String, Object> params = map(
"host", HOST, "conf", Map.of(
ALL_RESULTS_KEY, true,
MAPPING_KEY, map(EMBEDDING_KEY, "vect",
REL_TYPE, "TEST",
ENTITY_KEY, "readID"
))
);
String query = "CALL apoc.vectordb.milvus.query($host, 'test_collection', [0.2, 0.1, 0.9, 0.7], null, 5, $conf)";
ExtendedTestUtil.assertFails(db, query, params, NO_FIELDS_ERROR_MSG);
}

@Test
public void queryAndUpdateMetadataKeyWithoutFieldsTest() {
Map<String, Object> conf = map(
ALL_RESULTS_KEY, true,
MAPPING_KEY, map(EMBEDDING_KEY, "vect",
REL_TYPE, "TEST",
ENTITY_KEY, "readID",
METADATA_KEY, "foo"
)
);

String query = "CALL apoc.vectordb.milvus.queryAndUpdate($host, 'test_collection', [0.2, 0.1, 0.9, 0.7], null, 5, $conf)";

testResult(db, query,
map("host", HOST, "conf", conf),
VectorDbTestUtil::assertMetadataOneResult);
}
}
103 changes: 102 additions & 1 deletion extended-it/src/test/java/apoc/vectordb/WeaviateTest.java
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
package apoc.vectordb;

import apoc.ml.Prompt;
import apoc.util.ExtendedTestUtil;
import apoc.util.MapUtil;
import apoc.util.TestUtil;
import org.junit.AfterClass;
import org.junit.Assume;
import org.junit.Before;
import org.junit.BeforeClass;
import org.junit.ClassRule;
Expand Down Expand Up @@ -562,4 +562,105 @@ WITH collect(node) as paths
),
VectorDbTestUtil::assertRagWithVectors);
}

@Test
public void queryVectorsWithMetadataKeyNoFields() {
testResult(db, "CALL apoc.vectordb.weaviate.query($host, 'TestCollection', [0.2, 0.1, 0.9, 0.7], null, 5, $conf) " +
" YIELD score, vector, id, metadata RETURN * ORDER BY id",
map("host", HOST, "conf", map(
ALL_RESULTS_KEY, true,
MAPPING_KEY, map(EMBEDDING_KEY, "vect",
NODE_LABEL, "Test",
ENTITY_KEY, "myId",
METADATA_KEY, "foo"
),
HEADERS_KEY, ADMIN_AUTHORIZATION)),
VectorDbTestUtil::assertMetadataOneResult);
}

@Test
public void queryVectorsWithMetadataKeyAndOneField() {
testResult(db, "CALL apoc.vectordb.weaviate.query($host, 'TestCollection', [0.2, 0.1, 0.9, 0.7], null, 5, $conf) " +
" YIELD score, vector, id, metadata RETURN * ORDER BY id",
map("host", HOST, "conf", map(
FIELDS_KEY, List.of("city"),
ALL_RESULTS_KEY, true,
MAPPING_KEY, map(EMBEDDING_KEY, "vect",
NODE_LABEL, "Test",
ENTITY_KEY, "myId",
METADATA_KEY, "foo"
),
HEADERS_KEY, ADMIN_AUTHORIZATION)),
VectorDbTestUtil::assertMetadataOneResult);
}

@Test
public void queryVectorsWithNoMetadataKeyNoFields() {
Map<String, Object> params = map("host", HOST, "conf", map(
ALL_RESULTS_KEY, true,
MAPPING_KEY, map(EMBEDDING_KEY, "vect",
NODE_LABEL, "Test",
ENTITY_KEY, "myId"
),
HEADERS_KEY, ADMIN_AUTHORIZATION));
String query = "CALL apoc.vectordb.weaviate.query($host, 'TestCollection', [0.2, 0.1, 0.9, 0.7], null, 5, $conf) " +
" YIELD score, vector, id, metadata RETURN * ORDER BY id";
ExtendedTestUtil.assertFails(db, query, params, NO_FIELDS_ERROR_MSG);
}

@Test
public void queryVectorsWithCreateNodeUsingExistingNodeWithMetadataKeyAndOneField () {
db.executeTransactionally("CREATE (:Test {myId: 'one'}), (:Test {myId: 'two'})");
Map<String, Object> conf = map(ALL_RESULTS_KEY, true,
FIELDS_KEY, List.of("city"),
HEADERS_KEY, ADMIN_AUTHORIZATION,
MAPPING_KEY, map(EMBEDDING_KEY, "vect",
NODE_LABEL, "Test",
ENTITY_KEY, "myId",
METADATA_KEY, "foo"));
testResult(db, "CALL apoc.vectordb.weaviate.queryAndUpdate($host, 'TestCollection', [0.2, 0.1, 0.9, 0.7], null, 5, $conf) " +
" YIELD score, vector, id, metadata, node RETURN * ORDER BY id",
map("host", HOST, "conf", conf),
r -> {
Map<String, Object> row = r.next();
assertBerlinResult(row, ID_1, NODE);
assertNotNull(row.get("score"));
assertNotNull(row.get("vector"));

row = r.next();
assertLondonResult(row, ID_2, NODE);
assertNotNull(row.get("score"));
assertNotNull(row.get("vector"));
});

assertNodesCreated(db);
}

@Test
public void queryVectorsWithCreateNodeUsingExistingNodeWithMetadataKeyAndNoFields () {
db.executeTransactionally("CREATE (:Test {myId: 'one'}), (:Test {myId: 'two'})");
Map<String, Object> conf = map(ALL_RESULTS_KEY, true,
HEADERS_KEY, ADMIN_AUTHORIZATION,
MAPPING_KEY, map(EMBEDDING_KEY, "vect",
NODE_LABEL, "Test",
ENTITY_KEY, "myId",
METADATA_KEY, "foo"));
testResult(db, "CALL apoc.vectordb.weaviate.queryAndUpdate($host, 'TestCollection', [0.2, 0.1, 0.9, 0.7], null, 5, $conf) " +
" YIELD score, vector, id, metadata, node RETURN * ORDER BY id",
map("host", HOST, "conf", conf),
VectorDbTestUtil::assertMetadataOneResult);
}

@Test
public void queryVectorsWithCreateNodeUsingExistingNodeFailWithNoMetadataKeyAndNoFields () {
db.executeTransactionally("CREATE (:Test {myId: 'one'}), (:Test {myId: 'two'})");
Map<String, Object> params = map("host", HOST,
"conf", Map.of(ALL_RESULTS_KEY, true,
HEADERS_KEY, ADMIN_AUTHORIZATION,
MAPPING_KEY, map(EMBEDDING_KEY, "vect",
NODE_LABEL, "Test",
ENTITY_KEY, "myId")));
String query = "CALL apoc.vectordb.weaviate.queryAndUpdate($host, 'TestCollection', [0.2, 0.1, 0.9, 0.7], null, 5, $conf) YIELD score, vector, id, metadata, node RETURN * ORDER BY id";
ExtendedTestUtil.assertFails(db, query, params, NO_FIELDS_ERROR_MSG);
}
}
12 changes: 8 additions & 4 deletions extended/src/main/java/apoc/vectordb/MilvusHandler.java
Original file line number Diff line number Diff line change
@@ -1,15 +1,17 @@
package apoc.vectordb;

import apoc.util.CollectionUtils;
import apoc.util.UrlResolver;
import org.neo4j.internal.kernel.api.procs.ProcedureCallContext;

import java.util.List;
import java.util.Map;

import static apoc.util.MapUtil.map;
import static apoc.vectordb.VectorEmbeddingConfig.FIELDS_KEY;
import static apoc.vectordb.VectorDbUtil.addMetadataKeyToFields;
import static apoc.vectordb.VectorEmbeddingConfig.META_AS_SUBKEY_KEY;
import static apoc.vectordb.VectorEmbeddingConfig.SCORE_KEY;
import static apoc.vectordb.VectorMappingConfig.NO_FIELDS_ERROR_MSG;

public class MilvusHandler implements VectorDbHandler {

Expand Down Expand Up @@ -57,10 +59,12 @@ public VectorEmbeddingConfig fromQuery(Map<String, Object> config, ProcedureCall
private VectorEmbeddingConfig getVectorEmbeddingConfig(Map<String, Object> config, List<String> procFields, String collection, Map<String, Object> additionalBodies) {
config.putIfAbsent(META_AS_SUBKEY_KEY, false);

List listFields = (List) config.get(FIELDS_KEY);
if (listFields == null) {
throw new RuntimeException("You have to define `field` list of parameter to be returned");
List listFields = addMetadataKeyToFields(config);

if (CollectionUtils.isEmpty(listFields)) {
throw new RuntimeException(NO_FIELDS_ERROR_MSG);
}

if (procFields.contains("vector") && !listFields.contains("vector")) {
listFields.add("vector");
}
Expand Down
18 changes: 18 additions & 0 deletions extended/src/main/java/apoc/vectordb/VectorDbUtil.java
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,14 @@
import apoc.SystemPropertyKeys;
import apoc.util.Util;
import org.apache.commons.collections.MapUtils;
import org.apache.commons.lang3.StringUtils;
import org.neo4j.graphdb.Label;
import org.neo4j.graphdb.Node;
import org.neo4j.graphdb.Relationship;

import java.net.HttpURLConnection;
import java.net.URL;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
Expand All @@ -20,7 +22,9 @@
import static apoc.ml.RestAPIConfig.ENDPOINT_KEY;
import static apoc.ml.RestAPIConfig.METHOD_KEY;
import static apoc.util.SystemDbUtil.withSystemDb;
import static apoc.vectordb.VectorEmbeddingConfig.FIELDS_KEY;
import static apoc.vectordb.VectorEmbeddingConfig.MAPPING_KEY;
import static apoc.vectordb.VectorEmbeddingConfig.METADATA_KEY;
import static apoc.vectordb.VectorMappingConfig.MODE_KEY;
import static apoc.vectordb.VectorMappingConfig.MappingMode.READ_ONLY;

Expand Down Expand Up @@ -130,4 +134,18 @@ public static void methodAndPayloadNull(Map<String, Object> config) {
config.put(METHOD_KEY, null);
config.put(BODY_KEY, null);
}

public static List addMetadataKeyToFields(Map<String, Object> config) {
List listFields = (List) config.getOrDefault(FIELDS_KEY, new ArrayList<>());

Map<String, Object> mapping = (Map<String, Object>) config.get(MAPPING_KEY);

String metadataKey = MapUtils.getString(mapping, METADATA_KEY);

if (StringUtils.isNotEmpty(metadataKey)) {
listFields.add(metadataKey);
}

return listFields;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ enum MappingMode {
public static final String EMBEDDING_KEY = "embeddingKey";
public static final String SIMILARITY_KEY = "similarity";
public static final String MODE_KEY = "mode";
public static final String NO_FIELDS_ERROR_MSG = "You have to define `field` list of parameter to be returned";

private final String metadataKey;
private final String entityKey;
Expand Down
12 changes: 8 additions & 4 deletions extended/src/main/java/apoc/vectordb/WeaviateHandler.java
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
package apoc.vectordb;

import apoc.util.CollectionUtils;
import apoc.util.UrlResolver;
import org.neo4j.internal.kernel.api.procs.ProcedureCallContext;

Expand All @@ -9,9 +10,10 @@
import static apoc.ml.RestAPIConfig.BODY_KEY;
import static apoc.ml.RestAPIConfig.METHOD_KEY;
import static apoc.util.MapUtil.map;
import static apoc.vectordb.VectorEmbeddingConfig.FIELDS_KEY;
import static apoc.vectordb.VectorDbUtil.addMetadataKeyToFields;
import static apoc.vectordb.VectorEmbeddingConfig.METADATA_KEY;
import static apoc.vectordb.VectorEmbeddingConfig.VECTOR_KEY;
import static apoc.vectordb.VectorMappingConfig.NO_FIELDS_ERROR_MSG;

public class WeaviateHandler implements VectorDbHandler {

Expand Down Expand Up @@ -47,10 +49,12 @@ public VectorEmbeddingConfig fromQuery(Map<String, Object> config, ProcedureCall
config.putIfAbsent(METHOD_KEY, "POST");
VectorEmbeddingConfig vectorEmbeddingConfig = getVectorEmbeddingConfig(config);

List list = (List) config.get(FIELDS_KEY);
if (list == null) {
throw new RuntimeException("You have to define `field` list of parameter to be returned");
List list = addMetadataKeyToFields(config);

if (CollectionUtils.isEmpty(list)) {
throw new RuntimeException(NO_FIELDS_ERROR_MSG);
}

Object fieldList = String.join("\n", list);

filter = filter == null
Expand Down
10 changes: 10 additions & 0 deletions extended/src/test/java/apoc/vectordb/VectorDbTestUtil.java
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@

import java.util.Map;

import static apoc.vectordb.VectorEmbeddingConfig.DEFAULT_METADATA;
import static apoc.util.TestUtil.testResult;
import static apoc.util.Util.map;
import static org.junit.Assert.assertEquals;
Expand Down Expand Up @@ -114,4 +115,13 @@ public static String ragSetup(GraphDatabaseService db) {
db.executeTransactionally("CREATE (:Rag {readID: 'one'}), (:Rag {readID: 'two'})");
return openAIKey;
}

public static void assertMetadataOneResult(Result r) {
Map<String, Object> row = r.next();
Map<String, Object> metadata = (Map<String, Object>) row.get(DEFAULT_METADATA);
assertNotNull(metadata);
assertEquals("one", metadata.get("foo"));
row = r.next();
assertFalse(r.hasNext());
}
}

0 comments on commit aefcee1

Please sign in to comment.