Skip to content

Commit

Permalink
Fixes #4232: The apoc.vectordb.configure(WEAVIATE', ..) procedure sho…
Browse files Browse the repository at this point in the history
…uld append /v1 to url (#4248)
  • Loading branch information
vga91 authored Dec 4, 2024
1 parent 8223e27 commit b7697f4
Show file tree
Hide file tree
Showing 3 changed files with 83 additions and 36 deletions.
99 changes: 65 additions & 34 deletions extended-it/src/test/java/apoc/vectordb/WeaviateTest.java
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
import apoc.util.MapUtil;
import apoc.util.TestUtil;
import org.junit.AfterClass;
import org.junit.Assume;
import org.junit.Before;
import org.junit.BeforeClass;
import org.junit.ClassRule;
Expand Down Expand Up @@ -507,41 +506,21 @@ MAPPING_KEY, map(REL_TYPE, "TEST",
public void queryVectorsWithSystemDbStorage() {
String keyConfig = "weaviate-config-foo";
String baseUrl = "http://" + HOST + "/v1";
Map<String, String> mapping = map(EMBEDDING_KEY, "vect",
NODE_LABEL, "Test",
ENTITY_KEY, "myId",
METADATA_KEY, "foo");
sysDb.executeTransactionally("CALL apoc.vectordb.configure($vectorName, $keyConfig, $databaseName, $conf)",
map("vectorName", WEAVIATE.toString(),
"keyConfig", keyConfig,
"databaseName", DEFAULT_DATABASE_NAME,
"conf", map(
"host", baseUrl,
"credentials", ADMIN_KEY,
"mapping", mapping
)
)
);

db.executeTransactionally("CREATE (:Test {myId: 'one'}), (:Test {myId: 'two'})");

testResult(db, "CALL apoc.vectordb.weaviate.queryAndUpdate($host, 'TestCollection', [0.2, 0.1, 0.9, 0.7], null, 5, $conf)",
map("host", keyConfig,
"conf", map(FIELDS_KEY, FIELDS, ALL_RESULTS_KEY, true)
),
r -> {
Map<String, Object> row = r.next();
assertBerlinResult(row, ID_1, NODE);
assertNotNull(row.get("score"));
assertNotNull(row.get("vector"));
assertQueryVectorsWithSystemDbStorage(keyConfig, baseUrl, false);
}

row = r.next();
assertLondonResult(row, ID_2, NODE);
assertNotNull(row.get("score"));
assertNotNull(row.get("vector"));
});
@Test
public void queryVectorsWithSystemDbStorageWithUrlWithoutVersion() {
String keyConfig = "weaviate-config-foo";
String baseUrl = "http://" + HOST;
assertQueryVectorsWithSystemDbStorage(keyConfig, baseUrl, false);
}

assertNodesCreated(db);
@Test
public void queryVectorsWithSystemDbStorageWithUrlV3Version() {
String keyConfig = "weaviate-config-foo";
String baseUrl = "http://" + HOST + "/v3";
assertQueryVectorsWithSystemDbStorage(keyConfig, baseUrl, true);
}

@Test
Expand Down Expand Up @@ -575,4 +554,56 @@ WITH collect(node) as paths
),
VectorDbTestUtil::assertRagWithVectors);
}

private static void assertQueryVectorsWithSystemDbStorage(String keyConfig, String baseUrl, boolean fails) {
Map<String, String> mapping = map(EMBEDDING_KEY, "vect",
NODE_LABEL, "Test",
ENTITY_KEY, "myId",
METADATA_KEY, "foo");
sysDb.executeTransactionally("CALL apoc.vectordb.configure($vectorName, $keyConfig, $databaseName, $conf)",
map("vectorName", WEAVIATE.toString(),
"keyConfig", keyConfig,
"databaseName", DEFAULT_DATABASE_NAME,
"conf", map(
"host", baseUrl,
"credentials", ADMIN_KEY,
"mapping", mapping
)
)
);

db.executeTransactionally("CREATE (:Test {myId: 'one'}), (:Test {myId: 'two'})");

String query = "CALL apoc.vectordb.weaviate.queryAndUpdate($host, 'TestCollection', [0.2, 0.1, 0.9, 0.7], null, 5, $conf)";
Map<String, Object> params = map("host", keyConfig,
"conf", map(FIELDS_KEY, FIELDS, ALL_RESULTS_KEY, true)
);

if (fails) {
assertFails(
db,
query,
params,
"Caused by: java.io.FileNotFoundException: http://127.0.0.1:" + HOST.split(":")[1] + "/v3/graphql"
);
return;
}


testResult(db, query,
params,
r -> {
Map<String, Object> row = r.next();
assertBerlinResult(row, ID_1, NODE);
assertNotNull(row.get("score"));
assertNotNull(row.get("vector"));

row = r.next();
assertLondonResult(row, ID_2, NODE);
assertNotNull(row.get("score"));
assertNotNull(row.get("vector"));
});

assertNodesCreated(db);
}
}
6 changes: 4 additions & 2 deletions extended/src/main/java/apoc/vectordb/VectorDb.java
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,9 @@
import static apoc.util.ExtendedUtil.setProperties;
import static apoc.util.JsonUtil.OBJECT_MAPPER;
import static apoc.util.SystemDbUtil.withSystemDb;
import static apoc.vectordb.VectorDbUtil.*;
import static apoc.vectordb.VectorDbUtil.EmbeddingResult;
import static apoc.vectordb.VectorDbUtil.appendVersionUrlIfNeeded;
import static apoc.vectordb.VectorDbUtil.getEndpoint;
import static apoc.vectordb.VectorEmbeddingConfig.ALL_RESULTS_KEY;
import static apoc.vectordb.VectorEmbeddingConfig.MAPPING_KEY;

Expand Down Expand Up @@ -259,7 +261,7 @@ public void vectordb(
Node node = Util.mergeNode(transaction, label, null, Pair.of(SystemPropertyKeys.name.name(), configKey));

Map mapping = (Map) config.get("mapping");
String host = (String) config.get("host");
String host = appendVersionUrlIfNeeded(type, (String) config.get("host"));
Object credentials = config.get("credentials");

if (host != null) {
Expand Down
14 changes: 14 additions & 0 deletions extended/src/main/java/apoc/vectordb/VectorDbUtil.java
Original file line number Diff line number Diff line change
Expand Up @@ -135,4 +135,18 @@ public static void methodAndPayloadNull(Map<String, Object> config) {
config.put(METHOD_KEY, null);
config.put(BODY_KEY, null);
}

/**
* If the vectorDb is WEAVIATE and endpoint doesn't end with `/vN`, where N is a number,
* then add `/v1` to the endpoint
*/
public static String appendVersionUrlIfNeeded(VectorDbHandler.Type type, String host) {
if (VectorDbHandler.Type.WEAVIATE == type) {
String regex = ".*(/v\\d+)$";
if (!host.matches(regex)) {
host = host + "/v1";
}
}
return host;
}
}

0 comments on commit b7697f4

Please sign in to comment.