From 8d37efdb509a174bf2b834fc98a925ce31b87f64 Mon Sep 17 00:00:00 2001 From: Anush Date: Thu, 21 Dec 2023 13:24:20 +0530 Subject: [PATCH] v1.7.0 (#13) * feat: ShardKeyFactory * feat: ShardKeySelectorFactory * feat: shard key ops * feat: shard key ops overloads * feat: discover ops * refactor: upade discoverBatchAsync * feat: targetVectorFactory, VectorFactory * test: discover searches * refactor: VectorsFactory to use vector() * Apply suggestions from code review Co-authored-by: Russ Cam * chore: review changes * docs: Update README.md version --------- Co-authored-by: Anush Co-authored-by: Russ Cam --- README.md | 8 +- build.gradle | 2 +- gradle.properties | 2 +- .../java/io/qdrant/client/QdrantClient.java | 164 ++++++++++++++++++ .../io/qdrant/client/ShardKeyFactory.java | 31 ++++ .../client/ShardKeySelectorFactory.java | 54 ++++++ .../io/qdrant/client/TargetVectorFactory.java | 32 ++++ .../java/io/qdrant/client/VectorFactory.java | 52 ++++++ .../java/io/qdrant/client/VectorsFactory.java | 19 +- .../java/io/qdrant/client/PointsTest.java | 54 +++++- 10 files changed, 393 insertions(+), 25 deletions(-) create mode 100644 src/main/java/io/qdrant/client/ShardKeyFactory.java create mode 100644 src/main/java/io/qdrant/client/ShardKeySelectorFactory.java create mode 100644 src/main/java/io/qdrant/client/TargetVectorFactory.java create mode 100644 src/main/java/io/qdrant/client/VectorFactory.java diff --git a/README.md b/README.md index 2c8757c..51c562b 100644 --- a/README.md +++ b/README.md @@ -23,8 +23,6 @@ Java client library with handy utility methods and overloads for interfacing wit ## 📥 Installation -> Not yet published. - > [!IMPORTANT] > Requires Java 8 or above. @@ -36,20 +34,20 @@ To install the library, add the following lines to your build config file. io.qdrant client - 1.7-SNAPSHOT + 1.7.0 ``` #### Scala SBT ```sbt -libraryDependencies += "io.qdrant" % "client" % "1.7-SNAPSHOT" +libraryDependencies += "io.qdrant" % "client" % "1.7.0" ``` #### Gradle ```gradle -implementation 'io.qdrant:client:1.7-SNAPSHOT' +implementation 'io.qdrant:client:1.7.0' ``` ## 📖 Documentation diff --git a/build.gradle b/build.gradle index 722d6c0..189e243 100644 --- a/build.gradle +++ b/build.gradle @@ -228,4 +228,4 @@ publishing { repositories { mavenLocal() } -} \ No newline at end of file +} diff --git a/gradle.properties b/gradle.properties index 44e0858..a609339 100644 --- a/gradle.properties +++ b/gradle.properties @@ -5,4 +5,4 @@ qdrantProtosVersion=v1.7.0 qdrantVersion=v1.7.0 # The version of the client to generate -packageVersion=1.7-SNAPSHOT \ No newline at end of file +packageVersion=1.7.0 \ No newline at end of file diff --git a/src/main/java/io/qdrant/client/QdrantClient.java b/src/main/java/io/qdrant/client/QdrantClient.java index 680f29e..a8dcae4 100644 --- a/src/main/java/io/qdrant/client/QdrantClient.java +++ b/src/main/java/io/qdrant/client/QdrantClient.java @@ -10,6 +10,7 @@ import io.qdrant.client.grpc.CollectionsGrpc; import io.qdrant.client.grpc.PointsGrpc; import io.qdrant.client.grpc.SnapshotsGrpc; + import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -28,8 +29,16 @@ import static io.qdrant.client.grpc.Collections.CollectionOperationResponse; import static io.qdrant.client.grpc.Collections.CreateAlias; import static io.qdrant.client.grpc.Collections.CreateCollection; +import static io.qdrant.client.grpc.Collections.CreateShardKeyRequest; +import static io.qdrant.client.grpc.Collections.CreateShardKeyResponse; import static io.qdrant.client.grpc.Collections.DeleteAlias; import static io.qdrant.client.grpc.Collections.DeleteCollection; +import static io.qdrant.client.grpc.Collections.DeleteShardKeyRequest; +import static io.qdrant.client.grpc.Collections.DeleteShardKeyResponse; +import static io.qdrant.client.grpc.Points.DiscoverBatchPoints; +import static io.qdrant.client.grpc.Points.DiscoverBatchResponse; +import static io.qdrant.client.grpc.Points.DiscoverPoints; +import static io.qdrant.client.grpc.Points.DiscoverResponse; import static io.qdrant.client.grpc.Collections.GetCollectionInfoRequest; import static io.qdrant.client.grpc.Collections.GetCollectionInfoResponse; import static io.qdrant.client.grpc.Collections.ListAliasesRequest; @@ -40,6 +49,7 @@ import static io.qdrant.client.grpc.Collections.PayloadIndexParams; import static io.qdrant.client.grpc.Collections.PayloadSchemaType; import static io.qdrant.client.grpc.Collections.RenameAlias; +import static io.qdrant.client.grpc.Collections.ShardKey; import static io.qdrant.client.grpc.Collections.UpdateCollection; import static io.qdrant.client.grpc.Collections.VectorParams; import static io.qdrant.client.grpc.Collections.VectorParamsMap; @@ -665,6 +675,78 @@ public ListenableFuture> listAliasesAsync(@Nullable Durat //endregion + //region ShardKey Management + + /** + * Creates a shard key for a collection. + * + * @param createShardKey The request object for the operation. + * @return a new instance of {@link ListenableFuture} + */ + public ListenableFuture createShardKeyAsync(CreateShardKeyRequest createShardKey) { + return createShardKeyAsync(createShardKey, null); + } + + /** + * Creates a shard key for a collection. + * + * @param createShardKey The request object for the operation. + * @param timeout The timeout for the call. + * @return a new instance of {@link ListenableFuture} + */ + public ListenableFuture createShardKeyAsync(CreateShardKeyRequest createShardKey, @Nullable Duration timeout) { + String collectionName = createShardKey.getCollectionName(); + Preconditions.checkArgument(!collectionName.isEmpty(), "Collection name must not be empty"); + ShardKey shardKey = createShardKey.getRequest().getShardKey(); + logger.debug("Create shard key '{}' for '{}'", shardKey, collectionName); + + ListenableFuture future = getCollections(timeout).createShardKey(createShardKey); + addLogFailureCallback(future, "Create shard key"); + return Futures.transform(future, response -> { + if (!response.getResult()) { + logger.error("Shard key could not be created for '{}'", collectionName); + throw new QdrantException("Shard key " + shardKey + " could not be created for " + collectionName); + } + return response; + }, MoreExecutors.directExecutor()); + } + + /** + * Deletes a shard key for a collection. + * + * @param deleteShardKey The request object for the operation. + * @return a new instance of {@link ListenableFuture} + */ + public ListenableFuture deleteShardKeyAsync(DeleteShardKeyRequest deleteShardKey) { + return deleteShardKeyAsync(deleteShardKey, null); + } + + /** + * Deletes a shard key for a collection. + * + * @param deleteShardKey The request object for the operation. + * @param timeout The timeout for the call. + * @return a new instance of {@link ListenableFuture} + */ + public ListenableFuture deleteShardKeyAsync(DeleteShardKeyRequest deleteShardKey, @Nullable Duration timeout) { + String collectionName = deleteShardKey.getCollectionName(); + Preconditions.checkArgument(!collectionName.isEmpty(), "Collection name must not be empty"); + ShardKey shardKey = deleteShardKey.getRequest().getShardKey(); + logger.debug("Delete shard key '{}' for '{}'", shardKey, collectionName); + + ListenableFuture future = getCollections(timeout).deleteShardKey(deleteShardKey); + addLogFailureCallback(future, "Delete shard key"); + return Futures.transform(future, response -> { + if (!response.getResult()) { + logger.error("Shard key '{}' could not be deleted for '{}'", shardKey, collectionName); + throw new QdrantException("Shard key " + shardKey + " could not be created for " + collectionName); + } + return response; + }, MoreExecutors.directExecutor()); + } + + //endregion + //region Point Management /** @@ -2153,6 +2235,88 @@ public ListenableFuture> recommendGroupsAsync(RecommendPointGro MoreExecutors.directExecutor()); } + /** + * Use the context and a target to find the most similar points to the target. + * Constraints by the context. + * + * @param request The discover points request + * @return a new instance of {@link ListenableFuture} + */ + public ListenableFuture> discoverAsync(DiscoverPoints request) { + return discoverAsync(request, null); + } + + /** + * Use the context and a target to find the most similar points to the target. + * Constraints by the context. + * + * @param request The discover points request + * @param timeout The timeout for the call. + * @return a new instance of {@link ListenableFuture} + */ + public ListenableFuture> discoverAsync(DiscoverPoints request, @Nullable Duration timeout) { + String collectionName = request.getCollectionName(); + Preconditions.checkArgument(!collectionName.isEmpty(), "Collection name must not be empty"); + logger.debug("Discover on '{}'", collectionName); + ListenableFuture future = getPoints(timeout).discover(request); + addLogFailureCallback(future, "Discover"); + return Futures.transform( + future, + response -> response.getResultList(), + MoreExecutors.directExecutor()); + } + + /** + * Use the context and a target to find the most similar points to the target in + * a batch. + * Constrained by the context. + * + * @param collectionName The name of the collection + * @param discoverSearches The list for discover point searches + * @param readConsistency Options for specifying read consistency guarantees + * @return a new instance of {@link ListenableFuture} + */ + public ListenableFuture> discoverBatchAsync( + String collectionName, + List discoverSearches, + @Nullable ReadConsistency readConsistency) { + return discoverBatchAsync(collectionName, discoverSearches, readConsistency, null); + } + + /** + * Use the context and a target to find the most similar points to the target in + * a batch. + * Constrained by the context. + * + * @param collectionName The name of the collection + * @param discoverSearches The list for discover point searches + * @param readConsistency Options for specifying read consistency guarantees + * @param timeout The timeout for the call. + * @return a new instance of {@link ListenableFuture} + */ + public ListenableFuture> discoverBatchAsync( + String collectionName, + List discoverSearches, + @Nullable ReadConsistency readConsistency, + @Nullable Duration timeout) { + Preconditions.checkArgument(!collectionName.isEmpty(), "Collection name must not be empty"); + + DiscoverBatchPoints.Builder requestBuilder = DiscoverBatchPoints.newBuilder() + .setCollectionName(collectionName) + .addAllDiscoverPoints(discoverSearches); + + if (readConsistency != null) { + requestBuilder.setReadConsistency(readConsistency); + } + logger.debug("Discover batch on '{}'", collectionName); + ListenableFuture future = getPoints(timeout).discoverBatch(requestBuilder.build()); + addLogFailureCallback(future, "Discover batch"); + return Futures.transform( + future, + response -> response.getResultList(), + MoreExecutors.directExecutor()); + } + /** * Count the points in a collection. The count is exact * diff --git a/src/main/java/io/qdrant/client/ShardKeyFactory.java b/src/main/java/io/qdrant/client/ShardKeyFactory.java new file mode 100644 index 0000000..262a4ed --- /dev/null +++ b/src/main/java/io/qdrant/client/ShardKeyFactory.java @@ -0,0 +1,31 @@ +package io.qdrant.client; + +import io.qdrant.client.grpc.Collections.ShardKey; + +/** + * Convenience methods for constructing {@link ShardKey} + */ +public final class ShardKeyFactory { + private ShardKeyFactory() { + } + + /** + * Creates a {@link ShardKey} based on a keyword. + * + * @param keyword The keyword to create the shard key from + * @return The {@link ShardKey} object + */ + public static ShardKey shardKey(String keyword) { + return ShardKey.newBuilder().setKeyword(keyword).build(); + } + + /** + * Creates a {@link ShardKey} based on a number. + * + * @param number The number to create the shard key from + * @return The {@link ShardKey} object + */ + public static ShardKey shardKey(long number) { + return ShardKey.newBuilder().setNumber(number).build(); + } +} diff --git a/src/main/java/io/qdrant/client/ShardKeySelectorFactory.java b/src/main/java/io/qdrant/client/ShardKeySelectorFactory.java new file mode 100644 index 0000000..a171f11 --- /dev/null +++ b/src/main/java/io/qdrant/client/ShardKeySelectorFactory.java @@ -0,0 +1,54 @@ +package io.qdrant.client; + +import io.qdrant.client.grpc.Collections.ShardKey; +import io.qdrant.client.grpc.Points.ShardKeySelector; + +import static io.qdrant.client.ShardKeyFactory.shardKey; + +import java.util.Arrays; + +/** + * Convenience methods for constructing {@link ShardKeySelector} + */ +public class ShardKeySelectorFactory { + private ShardKeySelectorFactory() { + } + + /** + * Creates a {@link ShardKeySelector} with the given shard keys. + * + * @param shardKeys The shard keys to include in the selector. + * @return The created {@link ShardKeySelector} object. + */ + public static ShardKeySelector shardKeySelector(ShardKey... shardKeys) { + return ShardKeySelector.newBuilder().addAllShardKeys(Arrays.asList(shardKeys)).build(); + } + + /** + * Creates a {@link ShardKeySelector} with the given shard key keywords. + * + * @param keywords The shard key keywords to include in the selector. + * @return The created {@link ShardKeySelector} object. + */ + public static ShardKeySelector shardKeySelector(String... keywords) { + ShardKeySelector.Builder builder = ShardKeySelector.newBuilder(); + for (String keyword : keywords) { + builder.addShardKeys(shardKey(keyword)); + } + return builder.build(); + } + + /** + * Creates a {@link ShardKeySelector} with the given shard key numbers. + * + * @param numbers The shard key numbers to include in the selector. + * @return The created {@link ShardKeySelector} object. + */ + public static ShardKeySelector shardKeySelector(long... numbers) { + ShardKeySelector.Builder builder = ShardKeySelector.newBuilder(); + for (long number : numbers) { + builder.addShardKeys(shardKey(number)); + } + return builder.build(); + } +} diff --git a/src/main/java/io/qdrant/client/TargetVectorFactory.java b/src/main/java/io/qdrant/client/TargetVectorFactory.java new file mode 100644 index 0000000..f62c32b --- /dev/null +++ b/src/main/java/io/qdrant/client/TargetVectorFactory.java @@ -0,0 +1,32 @@ +package io.qdrant.client; + +import io.qdrant.client.grpc.Points.PointId; +import io.qdrant.client.grpc.Points.TargetVector; +import io.qdrant.client.grpc.Points.Vector; +import io.qdrant.client.grpc.Points.VectorExample; + +/** + * Convenience methods for constructing {@link TargetVector} + */ +public class TargetVectorFactory { + private TargetVectorFactory() { + } + + /** + * Creates a TargetVector from a point ID + * @param id The point ID to use + * @return A new instance of {@link TargetVector} + */ + public static TargetVector targetVector(PointId id) { + return TargetVector.newBuilder().setSingle(VectorExample.newBuilder().setId(id)).build(); + } + + /** + * Creates a TargetVector from a Vector + * @param vector The Vector value to use + * @return A new instance of {@link TargetVector} + */ + public static TargetVector targetVector(Vector vector) { + return TargetVector.newBuilder().setSingle(VectorExample.newBuilder().setVector(vector)).build(); + } +} diff --git a/src/main/java/io/qdrant/client/VectorFactory.java b/src/main/java/io/qdrant/client/VectorFactory.java new file mode 100644 index 0000000..c82ce12 --- /dev/null +++ b/src/main/java/io/qdrant/client/VectorFactory.java @@ -0,0 +1,52 @@ +package io.qdrant.client; + +import com.google.common.primitives.Floats; + +import java.util.List; + +import static io.qdrant.client.grpc.Points.SparseIndices; +import static io.qdrant.client.grpc.Points.Vector; + +/** + * Convenience methods for constructing {@link Vector} + */ +public final class VectorFactory { + private VectorFactory() { + } + + /** + * Creates a vector from a list of floats + * + * @param values A map of vector names to values + * @return A new instance of {@link Vector} + */ + public static Vector vector(List values) { + return Vector.newBuilder().addAllData(values).build(); + } + + /** + * Creates a vector from a list of floats + * + * @param values A list of values + * @return A new instance of {@link Vector} + */ + public static Vector vector(float... values) { + return Vector.newBuilder() + .addAllData(Floats.asList(values)) + .build(); + } + + /** + * Creates a sparse vector from a list of floats and integers as indices + * + * @param vector The list of floats representing the vector. + * @param indices The list of integers representing the indices. + * @return A new instance of {@link Vector} + */ + public static Vector vector(List vector, List indices) { + return Vector.newBuilder() + .addAllData(vector) + .setIndices(SparseIndices.newBuilder().addAllData(indices).build()) + .build(); + } +} diff --git a/src/main/java/io/qdrant/client/VectorsFactory.java b/src/main/java/io/qdrant/client/VectorsFactory.java index 0e2c348..6fe9864 100644 --- a/src/main/java/io/qdrant/client/VectorsFactory.java +++ b/src/main/java/io/qdrant/client/VectorsFactory.java @@ -1,13 +1,12 @@ package io.qdrant.client; import com.google.common.collect.Maps; -import com.google.common.primitives.Floats; import java.util.List; import java.util.Map; +import static io.qdrant.client.VectorFactory.vector; import static io.qdrant.client.grpc.Points.NamedVectors; -import static io.qdrant.client.grpc.Points.Vector; import static io.qdrant.client.grpc.Points.Vectors; /** @@ -25,9 +24,7 @@ private VectorsFactory() { public static Vectors namedVectors(Map> values) { return Vectors.newBuilder() .setVectors(NamedVectors.newBuilder() - .putAllVectors(Maps.transformValues(values, v -> Vector.newBuilder() - .addAllData(v) - .build())) + .putAllVectors(Maps.transformValues(values, v -> vector(v))) ) .build(); } @@ -37,11 +34,9 @@ public static Vectors namedVectors(Map> values) { * @param values A list of values * @return a new instance of {@link Vectors} */ - public static Vectors vector(List values) { + public static Vectors vectors(List values) { return Vectors.newBuilder() - .setVector(Vector.newBuilder() - .addAllData(values) - .build()) + .setVector(vector(values)) .build(); } @@ -50,11 +45,9 @@ public static Vectors vector(List values) { * @param values A list of values * @return a new instance of {@link Vectors} */ - public static Vectors vector(float... values) { + public static Vectors vectors(float... values) { return Vectors.newBuilder() - .setVector(Vector.newBuilder() - .addAllData(Floats.asList(values)) - .build()) + .setVector(vector(values)) .build(); } } diff --git a/src/test/java/io/qdrant/client/PointsTest.java b/src/test/java/io/qdrant/client/PointsTest.java index 82c30c5..75cb540 100644 --- a/src/test/java/io/qdrant/client/PointsTest.java +++ b/src/test/java/io/qdrant/client/PointsTest.java @@ -13,6 +13,7 @@ import org.testcontainers.shaded.com.google.common.collect.ImmutableMap; import org.testcontainers.shaded.com.google.common.collect.ImmutableSet; import io.qdrant.client.container.QdrantContainer; +import io.qdrant.client.grpc.Points.DiscoverPoints; import java.util.List; import java.util.concurrent.ExecutionException; @@ -45,8 +46,9 @@ import static io.qdrant.client.ConditionFactory.hasId; import static io.qdrant.client.ConditionFactory.matchKeyword; import static io.qdrant.client.PointIdFactory.id; +import static io.qdrant.client.TargetVectorFactory.targetVector; import static io.qdrant.client.ValueFactory.value; -import static io.qdrant.client.VectorsFactory.vector; +import static io.qdrant.client.VectorFactory.vector; @Testcontainers class PointsTest { @@ -307,7 +309,7 @@ public void searchGroups() throws ExecutionException, InterruptedException { ImmutableList.of( PointStruct.newBuilder() .setId(id(10)) - .setVectors(VectorsFactory.vector(30f, 31f)) + .setVectors(VectorsFactory.vectors(30f, 31f)) .putAllPayload(ImmutableMap.of("foo", value("hello"))) .build() ) @@ -404,7 +406,7 @@ public void recommendGroups() throws ExecutionException, InterruptedException { ImmutableList.of( PointStruct.newBuilder() .setId(id(10)) - .setVectors(VectorsFactory.vector(30f, 31f)) + .setVectors(VectorsFactory.vectors(30f, 31f)) .putAllPayload(ImmutableMap.of("foo", value("hello"))) .build() ) @@ -423,6 +425,48 @@ public void recommendGroups() throws ExecutionException, InterruptedException { assertEquals(2, groups.get(0).getHitsCount()); } + @Test + public void discover() throws ExecutionException, InterruptedException { + createAndSeedCollection(testName); + + List points = client.discoverAsync(DiscoverPoints.newBuilder() + .setCollectionName(testName) + .setTarget(targetVector(vector(ImmutableList.of(10.4f, 11.4f)))) + .setLimit(1) + .build()).get(); + + assertEquals(1, points.size()); + assertEquals(id(9), points.get(0).getId()); + } + + @Test + public void discoverBatch() throws ExecutionException, InterruptedException { + createAndSeedCollection(testName); + + List batchResults = client.discoverBatchAsync( + testName, + ImmutableList.of( + DiscoverPoints.newBuilder() + .setCollectionName(testName) + .setTarget(targetVector(vector(ImmutableList.of(10.4f, 11.4f)))) + .setLimit(1) + .build(), + DiscoverPoints.newBuilder() + .setCollectionName(testName) + .setTarget(targetVector(vector(ImmutableList.of(3.5f, 4.5f)))) + .setLimit(1) + .build()), + null).get(); + + assertEquals(2, batchResults.size()); + BatchResult result = batchResults.get(0); + assertEquals(1, result.getResultCount()); + assertEquals(id(9), result.getResult(0).getId()); + result = batchResults.get(1); + assertEquals(1, result.getResultCount()); + assertEquals(id(8), result.getResult(0).getId()); + } + @Test public void count() throws ExecutionException, InterruptedException { createAndSeedCollection(testName); @@ -512,7 +556,7 @@ private void createAndSeedCollection(String collectionName) throws ExecutionExce UpdateResult result = client.upsertAsync(collectionName, ImmutableList.of( PointStruct.newBuilder() .setId(id(8)) - .setVectors(VectorsFactory.vector(ImmutableList.of(3.5f, 4.5f))) + .setVectors(VectorsFactory.vectors(ImmutableList.of(3.5f, 4.5f))) .putAllPayload(ImmutableMap.of( "foo", value("hello"), "bar", value(1) @@ -520,7 +564,7 @@ private void createAndSeedCollection(String collectionName) throws ExecutionExce .build(), PointStruct.newBuilder() .setId(id(9)) - .setVectors(VectorsFactory.vector(ImmutableList.of(10.5f, 11.5f))) + .setVectors(VectorsFactory.vectors(ImmutableList.of(10.5f, 11.5f))) .putAllPayload(ImmutableMap.of( "foo", value("goodbye"), "bar", value(2)