Skip to content

Commit

Permalink
Fixes #4231: The apoc.vectordb.milvus.query* and apoc.vectordb.weavia…
Browse files Browse the repository at this point in the history
…te.query* procedures should get the fields config from metadataKey if present (#4241)

* Fixes #4231: The apoc.vectordb.milvus.query* and apoc.vectordb.weaviate.query* procedures should get the fields config from metadataKey if present

* test fixes and changes review

* fix tests
  • Loading branch information
vga91 authored Dec 10, 2024
1 parent 00d75d3 commit 5748e1c
Show file tree
Hide file tree
Showing 7 changed files with 170 additions and 10 deletions.
68 changes: 68 additions & 0 deletions extended-it/src/test/java/apoc/vectordb/MilvusTest.java
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package apoc.vectordb;

import apoc.ml.Prompt;
import apoc.util.ExtendedTestUtil;
import apoc.util.TestUtil;
import apoc.util.Util;
import org.junit.AfterClass;
Expand Down Expand Up @@ -42,11 +43,14 @@
import static apoc.vectordb.VectorMappingConfig.METADATA_KEY;
import static apoc.vectordb.VectorMappingConfig.MODE_KEY;
import static apoc.vectordb.VectorMappingConfig.MappingMode;
import static apoc.vectordb.VectorMappingConfig.NO_FIELDS_ERROR_MSG;
import static apoc.vectordb.VectorMappingConfig.NODE_LABEL;
import static apoc.vectordb.VectorMappingConfig.REL_TYPE;
import static org.assertj.core.api.Assertions.assertThat;
import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertNotNull;
import static org.junit.Assert.assertNull;
import static org.junit.Assert.fail;
import static org.neo4j.configuration.GraphDatabaseSettings.DEFAULT_DATABASE_NAME;
import static org.neo4j.configuration.GraphDatabaseSettings.SYSTEM_DATABASE_NAME;

Expand Down Expand Up @@ -488,4 +492,68 @@ WITH collect(node) as paths
VectorDbTestUtil::assertRagWithVectors);
}

@Test
public void queryVectorsWithMetadataKeyNoFields() {
Map<String, Object> conf = map(
ALL_RESULTS_KEY, true,
MAPPING_KEY, map(EMBEDDING_KEY, "vect",
REL_TYPE, "TEST",
ENTITY_KEY, "readID",
METADATA_KEY, "foo"
)
);
testResult(db, "CALL apoc.vectordb.milvus.query($host, 'test_collection', [0.2, 0.1, 0.9, 0.7], null, 5, $conf)",
map("host", HOST, "conf", conf),
VectorDbTestUtil::assertMetadataFooResult);
}

@Test
public void queryVectorsWithNoMetadataKeyNoFields() {
Map<String, Object> params = map(
"host", HOST, "conf", Map.of(
ALL_RESULTS_KEY, true,
MAPPING_KEY, map(EMBEDDING_KEY, "vect",
REL_TYPE, "TEST",
ENTITY_KEY, "readID"
))
);
String query = "CALL apoc.vectordb.milvus.query($host, 'test_collection', [0.2, 0.1, 0.9, 0.7], null, 5, $conf)";
ExtendedTestUtil.assertFails(db, query, params, NO_FIELDS_ERROR_MSG);
}

@Test
public void queryAndUpdateMetadataKeyWithoutFieldsTest() {
Map<String, Object> conf = map(
ALL_RESULTS_KEY, true,
MAPPING_KEY, map(EMBEDDING_KEY, "vect",
REL_TYPE, "TEST",
ENTITY_KEY, "readID",
METADATA_KEY, "foo"
)
);

String query = "CALL apoc.vectordb.milvus.queryAndUpdate($host, 'test_collection', [0.2, 0.1, 0.9, 0.7], null, 5, $conf)";

testResult(db, query,
map("host", HOST, "conf", conf),
VectorDbTestUtil::assertMetadataFooResult);
}

@Test
public void queryAndUpdateWithNoMetadataKeyNoFields() {
Map<String, Object> conf = map(
ALL_RESULTS_KEY, true,
MAPPING_KEY, map(EMBEDDING_KEY, "vect",
REL_TYPE, "TEST",
ENTITY_KEY, "readID"
)
);
db.executeTransactionally("CREATE (:Test {myId: 'one'}), (:Test {myId: 'two'})");
Map<String, Object> params = Util.map("host", HOST,
"conf", conf);

String query = "CALL apoc.vectordb.milvus.queryAndUpdate($host, 'test_collection', [0.2, 0.1, 0.9, 0.7], null, 5, $conf)";

ExtendedTestUtil.assertFails(db, query, params, NO_FIELDS_ERROR_MSG);
}
}
58 changes: 58 additions & 0 deletions extended-it/src/test/java/apoc/vectordb/WeaviateTest.java
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package apoc.vectordb;

import apoc.ml.Prompt;
import apoc.util.ExtendedTestUtil;
import apoc.util.MapUtil;
import apoc.util.TestUtil;
import org.junit.AfterClass;
Expand Down Expand Up @@ -606,4 +607,61 @@ private static void assertQueryVectorsWithSystemDbStorage(String keyConfig, Stri

assertNodesCreated(db);
}

@Test
public void queryVectorsWithMetadataKeyNoFields() {
testResult(db, "CALL apoc.vectordb.weaviate.query($host, 'TestCollection', [0.2, 0.1, 0.9, 0.7], null, 5, $conf) " +
" YIELD score, vector, id, metadata RETURN * ORDER BY id",
map("host", HOST, "conf", map(
ALL_RESULTS_KEY, true,
MAPPING_KEY, map(EMBEDDING_KEY, "vect",
NODE_LABEL, "Test",
ENTITY_KEY, "myId",
METADATA_KEY, "foo"
),
HEADERS_KEY, ADMIN_AUTHORIZATION)),
VectorDbTestUtil::assertMetadataFooResult);
}

@Test
public void queryVectorsWithNoMetadataKeyNoFields() {
Map<String, Object> params = map("host", HOST, "conf", map(
ALL_RESULTS_KEY, true,
MAPPING_KEY, map(EMBEDDING_KEY, "vect",
NODE_LABEL, "Test",
ENTITY_KEY, "myId"
),
HEADERS_KEY, ADMIN_AUTHORIZATION));
String query = "CALL apoc.vectordb.weaviate.query($host, 'TestCollection', [0.2, 0.1, 0.9, 0.7], null, 5, $conf) " +
" YIELD score, vector, id, metadata RETURN * ORDER BY id";
ExtendedTestUtil.assertFails(db, query, params, NO_FIELDS_ERROR_MSG);
}

@Test
public void queryAndUpdateMetadataKeyWithoutFieldsTest() {
db.executeTransactionally("CREATE (:Test {myId: 'one'}), (:Test {myId: 'two'})");
Map<String, Object> conf = map(ALL_RESULTS_KEY, true,
HEADERS_KEY, ADMIN_AUTHORIZATION,
MAPPING_KEY, map(EMBEDDING_KEY, "vect",
NODE_LABEL, "Test",
ENTITY_KEY, "myId",
METADATA_KEY, "foo"));
testResult(db, "CALL apoc.vectordb.weaviate.queryAndUpdate($host, 'TestCollection', [0.2, 0.1, 0.9, 0.7], null, 5, $conf) " +
" YIELD score, vector, id, metadata, node RETURN * ORDER BY id",
map("host", HOST, "conf", conf),
VectorDbTestUtil::assertMetadataFooResult);
}

@Test
public void queryAndUpdateWithCreateNodeUsingExistingNodeFailWithNoMetadataKeyAndNoFields() {
db.executeTransactionally("CREATE (:Test {myId: 'one'}), (:Test {myId: 'two'})");
Map<String, Object> params = map("host", HOST,
"conf", Map.of(ALL_RESULTS_KEY, true,
HEADERS_KEY, ADMIN_AUTHORIZATION,
MAPPING_KEY, map(EMBEDDING_KEY, "vect",
NODE_LABEL, "Test",
ENTITY_KEY, "myId")));
String query = "CALL apoc.vectordb.weaviate.queryAndUpdate($host, 'TestCollection', [0.2, 0.1, 0.9, 0.7], null, 5, $conf) YIELD score, vector, id, metadata, node RETURN * ORDER BY id";
ExtendedTestUtil.assertFails(db, query, params, NO_FIELDS_ERROR_MSG);
}
}
8 changes: 3 additions & 5 deletions extended/src/main/java/apoc/vectordb/MilvusHandler.java
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
import java.util.Map;

import static apoc.util.MapUtil.map;
import static apoc.vectordb.VectorEmbeddingConfig.FIELDS_KEY;
import static apoc.vectordb.VectorDbUtil.addMetadataKeyToFields;
import static apoc.vectordb.VectorEmbeddingConfig.META_AS_SUBKEY_KEY;
import static apoc.vectordb.VectorEmbeddingConfig.SCORE_KEY;

Expand Down Expand Up @@ -57,10 +57,8 @@ public VectorEmbeddingConfig fromQuery(Map<String, Object> config, ProcedureCall
private VectorEmbeddingConfig getVectorEmbeddingConfig(Map<String, Object> config, List<String> procFields, String collection, Map<String, Object> additionalBodies) {
config.putIfAbsent(META_AS_SUBKEY_KEY, false);

List listFields = (List) config.get(FIELDS_KEY);
if (listFields == null) {
throw new RuntimeException("You have to define `field` list of parameter to be returned");
}
List listFields = addMetadataKeyToFields(config);

if (procFields.contains("vector") && !listFields.contains("vector")) {
listFields.add("vector");
}
Expand Down
26 changes: 26 additions & 0 deletions extended/src/main/java/apoc/vectordb/VectorDbUtil.java
Original file line number Diff line number Diff line change
Expand Up @@ -3,14 +3,17 @@

import apoc.ExtendedSystemPropertyKeys;
import apoc.SystemPropertyKeys;
import apoc.util.CollectionUtils;
import apoc.util.ExtendedMapUtils;
import apoc.util.Util;
import org.apache.commons.lang3.StringUtils;
import org.neo4j.graphdb.Label;
import org.neo4j.graphdb.Node;
import org.neo4j.graphdb.Relationship;

import java.net.HttpURLConnection;
import java.net.URL;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
Expand All @@ -20,9 +23,12 @@
import static apoc.ml.RestAPIConfig.ENDPOINT_KEY;
import static apoc.ml.RestAPIConfig.METHOD_KEY;
import static apoc.util.SystemDbUtil.withSystemDb;
import static apoc.vectordb.VectorEmbeddingConfig.FIELDS_KEY;
import static apoc.vectordb.VectorEmbeddingConfig.MAPPING_KEY;
import static apoc.vectordb.VectorEmbeddingConfig.METADATA_KEY;
import static apoc.vectordb.VectorMappingConfig.MODE_KEY;
import static apoc.vectordb.VectorMappingConfig.MappingMode.READ_ONLY;
import static apoc.vectordb.VectorMappingConfig.NO_FIELDS_ERROR_MSG;

public class VectorDbUtil {

Expand Down Expand Up @@ -136,6 +142,26 @@ public static void methodAndPayloadNull(Map<String, Object> config) {
config.put(BODY_KEY, null);
}

public static List addMetadataKeyToFields(Map<String, Object> config) {
List listFields = (List) config.getOrDefault(FIELDS_KEY, new ArrayList<>());

Map<String, Object> mapping = (Map<String, Object>) config.get(MAPPING_KEY);

String metadataKey = mapping == null
? null
: (String) mapping.get(METADATA_KEY);

if (CollectionUtils.isEmpty(listFields)) {

if (StringUtils.isEmpty(metadataKey)) {
throw new RuntimeException(NO_FIELDS_ERROR_MSG);
}
listFields.add(metadataKey);
}

return listFields;
}

/**
* If the vectorDb is WEAVIATE and endpoint doesn't end with `/vN`, where N is a number,
* then add `/v1` to the endpoint
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ enum MappingMode {
public static final String EMBEDDING_KEY = "embeddingKey";
public static final String SIMILARITY_KEY = "similarity";
public static final String MODE_KEY = "mode";
public static final String NO_FIELDS_ERROR_MSG = "You need to define either the 'field' list parameter, or the 'metadataKey' string parameter within the `embeddingConfig` parameter";

private final String metadataKey;
private final String entityKey;
Expand Down
8 changes: 3 additions & 5 deletions extended/src/main/java/apoc/vectordb/WeaviateHandler.java
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
import static apoc.ml.RestAPIConfig.BODY_KEY;
import static apoc.ml.RestAPIConfig.METHOD_KEY;
import static apoc.util.MapUtil.map;
import static apoc.vectordb.VectorEmbeddingConfig.FIELDS_KEY;
import static apoc.vectordb.VectorDbUtil.addMetadataKeyToFields;
import static apoc.vectordb.VectorEmbeddingConfig.METADATA_KEY;
import static apoc.vectordb.VectorEmbeddingConfig.VECTOR_KEY;

Expand Down Expand Up @@ -47,10 +47,8 @@ public VectorEmbeddingConfig fromQuery(Map<String, Object> config, ProcedureCall
config.putIfAbsent(METHOD_KEY, "POST");
VectorEmbeddingConfig vectorEmbeddingConfig = getVectorEmbeddingConfig(config);

List list = (List) config.get(FIELDS_KEY);
if (list == null) {
throw new RuntimeException("You have to define `field` list of parameter to be returned");
}
List list = addMetadataKeyToFields(config);

Object fieldList = String.join("\n", list);

filter = filter == null
Expand Down
11 changes: 11 additions & 0 deletions extended/src/test/java/apoc/vectordb/VectorDbTestUtil.java
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@

import java.util.Map;

import static apoc.vectordb.VectorEmbeddingConfig.DEFAULT_METADATA;
import static apoc.util.TestUtil.testResult;
import static apoc.util.Util.map;
import static org.junit.Assert.assertEquals;
Expand Down Expand Up @@ -114,4 +115,14 @@ public static String ragSetup(GraphDatabaseService db) {
db.executeTransactionally("CREATE (:Rag {readID: 'one'}), (:Rag {readID: 'two'})");
return openAIKey;
}

public static void assertMetadataFooResult(Result r) {
Map<String, Object> row = r.next();
Map<String, Object> metadata = (Map<String, Object>) row.get(DEFAULT_METADATA);
assertEquals("one", metadata.get("foo"));
row = r.next();
metadata = (Map<String, Object>) row.get(DEFAULT_METADATA);
assertEquals("two", metadata.get("foo"));
assertFalse(r.hasNext());
}
}

0 comments on commit 5748e1c

Please sign in to comment.